import matplotlib.pyplot as plt # Data for 405B model weights in GB labels = ['Llama-3 (BF16)', 'INT8 Quant', 'JiRack Ternary (2-bit)'] vram_usage = [810, 405, 243] colors = ['#ff9999', '#66b3ff', '#99ff99'] def generate_vram_chart(): plt.figure(figsize=(10, 6)) bars = plt.bar(labels, vram_usage, color=colors) plt.title('VRAM Weight Footprint: 405B Model Comparison', fontsize=14) plt.ylabel('VRAM Usage (GB)', fontsize=12) plt.grid(axis='y', linestyle='--', alpha=0.7) # Add values on top of bars for bar in bars: yval = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2, yval + 10, f'{yval} GB', ha='center', va='bottom', fontweight='bold') plt.tight_layout() plt.savefig('vram_benchmark_405b.png') print("Benchmark chart saved as 'vram_benchmark_405b.png'") plt.show() if __name__ == "__main__": generate_vram_chart()