| import matplotlib.pyplot as plt |
|
|
| |
| labels = ['Llama-3 (BF16)', 'INT8 Quant', 'JiRack Ternary (2-bit)'] |
| vram_usage = [810, 405, 243] |
| colors = ['#ff9999', '#66b3ff', '#99ff99'] |
|
|
| def generate_vram_chart(): |
| plt.figure(figsize=(10, 6)) |
| bars = plt.bar(labels, vram_usage, color=colors) |
| |
| plt.title('VRAM Weight Footprint: 405B Model Comparison', fontsize=14) |
| plt.ylabel('VRAM Usage (GB)', fontsize=12) |
| plt.grid(axis='y', linestyle='--', alpha=0.7) |
| |
| |
| for bar in bars: |
| yval = bar.get_height() |
| plt.text(bar.get_x() + bar.get_width()/2, yval + 10, f'{yval} GB', ha='center', va='bottom', fontweight='bold') |
|
|
| plt.tight_layout() |
| plt.savefig('vram_benchmark_405b.png') |
| print("Benchmark chart saved as 'vram_benchmark_405b.png'") |
| plt.show() |
|
|
| if __name__ == "__main__": |
| generate_vram_chart() |