codemichaeld commited on
Commit
7776d1d
Β·
verified Β·
1 Parent(s): 2e940ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -579
app.py CHANGED
@@ -21,140 +21,6 @@ try:
21
  except ImportError:
22
  MODELScope_AVAILABLE = False
23
 
24
- def low_rank_decomposition(weight, rank=64):
25
- """
26
- Correct LoRA decomposition supporting 2D and 4D tensors.
27
- Returns (lora_A, lora_B) such that weight β‰ˆ lora_B @ lora_A for 2D,
28
- or appropriate conv form for 4D.
29
- """
30
- original_shape = weight.shape
31
- original_dtype = weight.dtype
32
- try:
33
- if weight.ndim == 2:
34
- actual_rank = min(rank, min(weight.shape) // 2)
35
- if actual_rank < 4:
36
- return None, None
37
- U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
38
- S_sqrt = torch.sqrt(S[:actual_rank])
39
- # Standard LoRA factorization: W β‰ˆ W_B @ W_A
40
- W_A = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous() # [rank, in_features]
41
- W_B = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous() # [out_features, rank]
42
- return W_A.to(original_dtype), W_B.to(original_dtype)
43
- elif weight.ndim == 4:
44
- out_ch, in_ch, k_h, k_w = weight.shape
45
- if k_h * k_w <= 9: # small conv kernels (e.g., 3x3)
46
- # Reshape to 2D: [out_ch, in_ch * k_h * k_w]
47
- weight_2d = weight.view(out_ch, -1)
48
- actual_rank = min(rank, min(weight_2d.shape) // 2)
49
- if actual_rank < 4:
50
- return None, None
51
- U, S, Vh = torch.linalg.svd(weight_2d.float(), full_matrices=False)
52
- S_sqrt = torch.sqrt(S[:actual_rank])
53
- W_A_2d = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous()
54
- W_B_2d = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous()
55
- # Reshape back to conv format
56
- W_A = W_A_2d.view(actual_rank, in_ch, k_h, k_w).contiguous()
57
- W_B = W_B_2d.view(out_ch, actual_rank, 1, 1).contiguous()
58
- return W_A.to(original_dtype), W_B.to(original_dtype)
59
- return None, None
60
- except Exception as e:
61
- print(f"Decomposition error for {original_shape}: {e}")
62
- traceback.print_exc()
63
- return None, None
64
-
65
- def extract_correction_factors(original_weight, fp8_weight):
66
- """Extract per-channel/tensor correction factors (difference method)."""
67
- with torch.no_grad():
68
- # Convert to float32 for precision
69
- orig = original_weight.float()
70
- quant = fp8_weight.float()
71
-
72
- # Compute error (what needs to be added to FP8 to recover original)
73
- error = orig - quant
74
-
75
- # Skip if error is negligible
76
- error_norm = torch.norm(error)
77
- orig_norm = torch.norm(orig)
78
- if orig_norm > 1e-6 and error_norm / orig_norm < 0.01:
79
- return None
80
-
81
- # For 4D tensors (common in VAE, CNNs)
82
- if orig.ndim == 4:
83
- # Channel dimension is typically dimension 0 (output channels)
84
- channel_dim = 0
85
- # Compute mean error per output channel
86
- channel_mean = error.mean(dim=tuple(i for i in range(1, orig.ndim)), keepdim=True)
87
- return channel_mean.to(original_weight.dtype)
88
-
89
- # For 2D tensors (linear layers)
90
- elif orig.ndim == 2:
91
- # Compute mean error per output row
92
- row_mean = error.mean(dim=1, keepdim=True)
93
- return row_mean.to(original_weight.dtype)
94
-
95
- # For 1D tensors (bias, batchnorm)
96
- else:
97
- return error.mean().to(original_weight.dtype)
98
-
99
- def get_tensor_info(tensor):
100
- """Get detailed tensor information for pattern matching."""
101
- shape = list(tensor.shape)
102
- dim = tensor.dim()
103
- numel = tensor.numel()
104
- dtype = str(tensor.dtype)
105
-
106
- # Determine tensor type based on shape
107
- tensor_type = "other"
108
- if dim == 4 and shape[2] == shape[3]: # Convolutional layer with square kernel
109
- tensor_type = "conv"
110
- elif dim == 2:
111
- if shape[0] > shape[1] * 4: # More likely to be output projection
112
- tensor_type = "output_proj"
113
- elif shape[1] > shape[0] * 4: # More likely to be input projection
114
- tensor_type = "input_proj"
115
- else:
116
- tensor_type = "linear"
117
- elif dim == 1:
118
- tensor_type = "bias"
119
-
120
- return {
121
- "shape": shape,
122
- "dim": dim,
123
- "numel": numel,
124
- "type": tensor_type,
125
- "dtype": dtype
126
- }
127
-
128
- def matches_pattern(key, tensor_info, pattern):
129
- """Check if a tensor matches a pattern definition."""
130
- key_lower = key.lower()
131
-
132
- # Match by key name pattern
133
- if "key_pattern" in pattern:
134
- key_pattern = pattern["key_pattern"].lower()
135
- if key_pattern != "all" and key_pattern not in key_lower:
136
- return False
137
-
138
- # Match by tensor dimension
139
- if "dim" in pattern and tensor_info["dim"] != pattern["dim"]:
140
- return False
141
-
142
- # Match by tensor type
143
- if "type" in pattern and tensor_info["type"] != pattern["type"]:
144
- return False
145
-
146
- # Match by minimum tensor size
147
- if "min_size" in pattern and tensor_info["numel"] < pattern["min_size"]:
148
- return False
149
-
150
- # Match by shape constraints
151
- if "shape_contains" in pattern:
152
- shape_contains = pattern["shape_contains"]
153
- if not any(shape_contains == dim for dim in tensor_info["shape"]):
154
- return False
155
-
156
- return True
157
-
158
  def load_model_files(model_paths, model_format="safetensors", progress_callback=None):
159
  """
160
  Load model weights from one or more files, supporting sharded safetensors and other formats.
@@ -279,10 +145,10 @@ def extract_base_name_from_sharded_files(model_paths):
279
 
280
  return base_name
281
 
282
- def convert_model_to_fp8_with_recovery(model_paths, output_dir, fp8_format, recovery_rules,
283
- model_format="safetensors", progress=gr.Progress()):
284
- """Convert model to FP8 with customizable per-tensor recovery strategies."""
285
- progress(0.05, desc=f"Starting FP8 conversion with precision recovery for {model_format}...")
286
  try:
287
  metadata = read_model_metadata(model_paths, model_format)
288
  progress(0.1, desc="Loaded metadata.")
@@ -300,121 +166,63 @@ def convert_model_to_fp8_with_recovery(model_paths, output_dir, fp8_format, reco
300
 
301
  # Initialize outputs
302
  sd_fp8 = {}
303
- recovery_weights = {}
304
- stats = {
305
- "total_layers": len(state_dict),
306
- "processed_layers": 0,
307
- "skipped_layers": [],
308
- "recovery_counts": {"lora": 0, "diff": 0},
309
- "rule_matches": {i: 0 for i in range(len(recovery_rules))}
310
  }
311
 
312
  # Process each tensor
313
  total = len(state_dict)
314
  for i, key in enumerate(state_dict):
315
- progress(0.3 + 0.5 * (i / total), desc=f"Processing {i+1}/{total}: {key.split('.')[-1]}")
 
 
316
  weight = state_dict[key]
317
- tensor_info = get_tensor_info(weight)
318
 
 
319
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
320
  fp8_weight = weight.to(fp8_dtype)
321
  sd_fp8[key] = fp8_weight
 
322
  else:
 
323
  sd_fp8[key] = weight
324
- stats["skipped_layers"].append(f"{key}: non-float dtype")
325
- continue
326
-
327
- # Find matching rule for this tensor
328
- recovery_applied = False
329
- matched_rule_index = -1
330
-
331
- for rule_idx, rule in enumerate(recovery_rules):
332
- if matches_pattern(key, tensor_info, rule):
333
- matched_rule_index = rule_idx
334
- recovery_method = rule["method"]
335
-
336
- try:
337
- if recovery_method == "lora" and weight.ndim == 2:
338
- # LoRA recovery for 2D tensors only
339
- rank = rule.get("rank", 64)
340
- # Adjust rank for smaller matrices
341
- adjusted_rank = min(rank, min(weight.shape) // 2)
342
- if adjusted_rank >= 4:
343
- A, B = low_rank_decomposition(weight, rank=adjusted_rank)
344
- if A is not None and B is not None:
345
- recovery_weights[f"lora_A.{key}"] = A
346
- recovery_weights[f"lora_B.{key}"] = B
347
- stats["processed_layers"] += 1
348
- stats["recovery_counts"]["lora"] += 1
349
- stats["rule_matches"][rule_idx] += 1
350
- recovery_applied = True
351
- break
352
-
353
- elif recovery_method == "diff":
354
- # Difference/correction recovery for any tensor type
355
- corr = extract_correction_factors(weight, fp8_weight)
356
- if corr is not None:
357
- recovery_weights[f"diff.{key}"] = corr
358
- stats["processed_layers"] += 1
359
- stats["recovery_counts"]["diff"] += 1
360
- stats["rule_matches"][rule_idx] += 1
361
- recovery_applied = True
362
- break
363
-
364
- # If method is "none" or recovery failed, continue to next rule
365
- if recovery_method == "none":
366
- break
367
-
368
- except Exception as e:
369
- stats["skipped_layers"].append(f"{key}: error with rule {rule_idx} - {str(e)}")
370
-
371
- if not recovery_applied:
372
- reason = "no matching rule" if matched_rule_index == -1 else f"recovery failed with rule {matched_rule_index}"
373
- stats["skipped_layers"].append(f"{key}: {reason}")
374
 
375
  # Extract base name for output files
376
  base_name = extract_base_name_from_sharded_files(model_paths)
377
 
378
  # Save FP8 model
379
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
380
- save_file(sd_fp8, fp8_path, metadata={"format": model_format, "fp8_format": fp8_format, **metadata})
381
-
382
- # Save recovery weights if any were generated
383
- recovery_path = None
384
- if recovery_weights:
385
- recovery_path = os.path.join(output_dir, f"{base_name}-recovery.safetensors")
386
- recovery_metadata = {
387
- "format": model_format,
388
- "fp8_format": fp8_format,
389
- "recovery_rules": json.dumps(recovery_rules),
390
- "stats": json.dumps(stats)
391
- }
392
- save_file(recovery_weights, recovery_path, metadata=recovery_metadata)
393
 
394
- progress(0.9, desc="Saved FP8 and recovery files.")
395
 
396
  # Generate stats message
397
- stats_msg = f"FP8 ({fp8_format}) conversion complete with precision recovery:\n"
398
- stats_msg += f"- Total layers: {stats['total_layers']}\n"
399
- stats_msg += f"- Layers with recovery: {stats['processed_layers']}\n"
400
- stats_msg += f" - LoRA recovery: {stats['recovery_counts']['lora']}\n"
401
- stats_msg += f" - Difference recovery: {stats['recovery_counts']['diff']}\n"
402
-
403
- # Show rule effectiveness
404
- stats_msg += "\nRule effectiveness:\n"
405
- for rule_idx, rule in enumerate(recovery_rules):
406
- matches = stats["rule_matches"][rule_idx]
407
- if matches > 0:
408
- method = rule["method"]
409
- pattern = rule.get("key_pattern", "no pattern")
410
- rank_info = f" (rank {rule.get('rank', 'N/A')})" if method == "lora" else ""
411
- stats_msg += f"- Rule {rule_idx}: {matches} layers matched pattern '{pattern}' with {method}{rank_info}\n"
412
-
413
- if not recovery_weights:
414
- stats_msg += "\n⚠️ No recovery weights were generated. All layers use pure FP8."
415
-
416
- progress(1.0, desc="βœ… FP8 conversion with precision recovery complete!")
417
- return True, stats_msg, stats, fp8_path, recovery_path
418
 
419
  except Exception as e:
420
  traceback.print_exc()
@@ -625,167 +433,12 @@ def upload_to_target(target_type, new_repo_id, output_dir, fp8_format, hf_token=
625
  else:
626
  raise ValueError("Unknown target")
627
 
628
- def generate_default_rules(architecture="auto"):
629
- """Generate default recovery rules based on architecture."""
630
- if architecture == "vae":
631
- return """[
632
- {
633
- "key_pattern": "vae",
634
- "dim": 4,
635
- "method": "diff"
636
- },
637
- {
638
- "key_pattern": "encoder",
639
- "dim": 4,
640
- "method": "diff"
641
- },
642
- {
643
- "key_pattern": "decoder",
644
- "dim": 4,
645
- "method": "diff"
646
- },
647
- {
648
- "key_pattern": "all",
649
- "method": "none"
650
- }
651
- ]"""
652
- elif architecture == "text_encoder":
653
- return """[
654
- {
655
- "key_pattern": "text",
656
- "dim": 2,
657
- "min_size": 10000,
658
- "method": "lora",
659
- "rank": 64
660
- },
661
- {
662
- "key_pattern": "emb",
663
- "dim": 2,
664
- "min_size": 10000,
665
- "method": "lora",
666
- "rank": 64
667
- },
668
- {
669
- "key_pattern": "attn",
670
- "dim": 2,
671
- "min_size": 10000,
672
- "method": "lora",
673
- "rank": 128
674
- },
675
- {
676
- "key_pattern": "all",
677
- "method": "none"
678
- }
679
- ]"""
680
- elif architecture == "unet_transformer":
681
- return """[
682
- {
683
- "key_pattern": "attn",
684
- "dim": 2,
685
- "min_size": 10000,
686
- "method": "lora",
687
- "rank": 128
688
- },
689
- {
690
- "key_pattern": "transformer",
691
- "dim": 2,
692
- "min_size": 10000,
693
- "method": "lora",
694
- "rank": 96
695
- },
696
- {
697
- "key_pattern": "all",
698
- "method": "none"
699
- }
700
- ]"""
701
- elif architecture == "unet_conv":
702
- return """[
703
- {
704
- "key_pattern": "conv",
705
- "dim": 4,
706
- "method": "diff"
707
- },
708
- {
709
- "key_pattern": "resnet",
710
- "dim": 4,
711
- "method": "diff"
712
- },
713
- {
714
- "key_pattern": "down",
715
- "dim": 4,
716
- "method": "diff"
717
- },
718
- {
719
- "key_pattern": "up",
720
- "dim": 4,
721
- "method": "diff"
722
- },
723
- {
724
- "key_pattern": "all",
725
- "method": "none"
726
- }
727
- ]"""
728
- else: # "all" or "auto"
729
- return """[
730
- {
731
- "key_pattern": "vae",
732
- "dim": 4,
733
- "method": "diff"
734
- },
735
- {
736
- "key_pattern": "encoder",
737
- "dim": 4,
738
- "method": "diff"
739
- },
740
- {
741
- "key_pattern": "decoder",
742
- "dim": 4,
743
- "method": "diff"
744
- },
745
- {
746
- "key_pattern": "text",
747
- "dim": 2,
748
- "min_size": 10000,
749
- "method": "lora",
750
- "rank": 64
751
- },
752
- {
753
- "key_pattern": "emb",
754
- "dim": 2,
755
- "min_size": 10000,
756
- "method": "lora",
757
- "rank": 64
758
- },
759
- {
760
- "key_pattern": "attn",
761
- "dim": 2,
762
- "min_size": 10000,
763
- "method": "lora",
764
- "rank": 128
765
- },
766
- {
767
- "key_pattern": "conv",
768
- "dim": 4,
769
- "method": "diff"
770
- },
771
- {
772
- "key_pattern": "resnet",
773
- "dim": 4,
774
- "method": "diff"
775
- },
776
- {
777
- "key_pattern": "all",
778
- "method": "none"
779
- }
780
- ]"""
781
-
782
  def process_and_upload_fp8(
783
  source_type,
784
  repo_url,
785
  filename_pattern,
786
  model_format,
787
  fp8_format,
788
- recovery_rules_json,
789
  target_type,
790
  new_repo_id,
791
  hf_token,
@@ -800,20 +453,6 @@ def process_and_upload_fp8(
800
  if target_type == "huggingface" and not hf_token:
801
  return None, "❌ Hugging Face token required for target.", "", ""
802
 
803
- # Parse recovery rules
804
- try:
805
- recovery_rules = json.loads(recovery_rules_json)
806
- except json.JSONDecodeError:
807
- return None, "❌ Invalid recovery rules JSON.", "", ""
808
-
809
- # Validate rules
810
- valid_methods = ["none", "lora", "diff"]
811
- for rule in recovery_rules:
812
- if "method" not in rule or rule["method"] not in valid_methods:
813
- return None, f"❌ Invalid method in rule. Use 'none', 'lora', or 'diff'", "", ""
814
- if rule["method"] == "lora" and "rank" not in rule:
815
- return None, "❌ LoRA method requires 'rank' parameter", "", ""
816
-
817
  temp_dir = None
818
  output_dir = tempfile.mkdtemp()
819
  try:
@@ -822,9 +461,9 @@ def process_and_upload_fp8(
822
  source_type, repo_url, filename_pattern, model_format, hf_token, progress
823
  )
824
 
825
- progress(0.8, desc="Converting to FP8 with precision recovery...")
826
- success, msg, stats, fp8_path, recovery_path = convert_model_to_fp8_with_recovery(
827
- model_paths, output_dir, fp8_format, recovery_rules, model_format, progress
828
  )
829
 
830
  if not success:
@@ -845,68 +484,40 @@ def process_and_upload_fp8(
845
  original_filename += f" matching '{filename_pattern}'"
846
 
847
  fp8_filename = os.path.basename(fp8_path)
848
- recovery_filename = os.path.basename(recovery_path) if recovery_path else ""
849
 
850
  readme = f"""---
851
  library_name: diffusers
852
  tags:
853
  - fp8
854
  - safetensors
855
- - precision-recovery
856
- - mixed-method
857
  - converted-by-gradio
858
  ---
859
- # FP8 Model with Per-Tensor Precision Recovery
860
  - **Source**: `{repo_url}`
861
  - **Original File(s)**: `{original_filename}`
862
  - **Original Format**: `{model_format}`
863
  - **FP8 Format**: `{fp8_format.upper()}`
864
  - **FP8 File**: `{fp8_filename}`
865
- - **Recovery File**: `{recovery_filename if recovery_filename else "None"}`
866
- ## Recovery Rules Used
867
- ```json
868
- {json.dumps(recovery_rules, indent=2)}
869
- ```
870
- ## Usage (Inference)
871
  ```python
872
  from safetensors.torch import load_file
873
  import torch
 
874
  # Load FP8 model
875
  fp8_state = load_file("{fp8_filename}")
876
- # Load recovery weights if available
877
- recovery_state = load_file("{recovery_filename}") if "{recovery_filename}" and os.path.exists("{recovery_filename}") else {{}}
878
- # Reconstruct high-precision weights
879
- reconstructed = {{}}
880
- for key in fp8_state:
881
- fp8_weight = fp8_state[key].to(torch.float32) # Convert to float32 for computation
882
-
883
- # Apply LoRA recovery if available
884
- lora_a_key = f"lora_A.{{key}}"
885
- lora_b_key = f"lora_B.{{key}}"
886
- if lora_a_key in recovery_state and lora_b_key in recovery_state:
887
- A = recovery_state[lora_a_key].to(torch.float32)
888
- B = recovery_state[lora_b_key].to(torch.float32)
889
- # Reconstruct the low-rank approximation
890
- lora_weight = B @ A
891
- fp8_weight = fp8_weight + lora_weight
892
-
893
- # Apply difference recovery if available
894
- diff_key = f"diff.{{key}}"
895
- if diff_key in recovery_state:
896
- diff = recovery_state[diff_key].to(torch.float32)
897
- fp8_weight = fp8_weight + diff
898
-
899
- reconstructed[key] = fp8_weight
900
- # Use reconstructed weights in your model
901
- model.load_state_dict(reconstructed)
902
  ```
903
- > **Note**: For best results, use the same recovery configuration during inference as was used during extraction.
 
904
  > Requires PyTorch β‰₯ 2.1 for FP8 support.
 
905
  ## Statistics
906
- - **Total layers**: {stats['total_layers']}
907
- - **Layers with recovery**: {stats['processed_layers']}
908
- - LoRA recovery: {stats['recovery_counts']['lora']}
909
- - Difference recovery: {stats['recovery_counts']['diff']}
910
  """
911
 
912
  with open(os.path.join(output_dir, "README.md"), "w") as f:
@@ -924,23 +535,17 @@ model.load_state_dict(reconstructed)
924
  progress(1.0, desc="βœ… Done!")
925
 
926
  # Generate result HTML
927
- recovery_links = []
928
- if recovery_path:
929
- recovery_links.append(f"- **Recovery weights**: `{recovery_filename}`")
930
-
931
  result_html = f"""
932
  βœ… Success!
933
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
934
- Includes:
935
  - FP8 model: `{fp8_filename}`
936
- - {chr(10).join(recovery_links)}
937
  """
938
 
939
- recovery_details = f"Recovery file: {recovery_filename}" if recovery_filename else "No recovery weights generated"
940
  return (gr.HTML(result_html),
941
- "βœ… FP8 conversion with precision recovery successful!",
942
  msg,
943
- recovery_details)
944
 
945
  except Exception as e:
946
  traceback.print_exc()
@@ -951,9 +556,9 @@ Includes:
951
  shutil.rmtree(temp_dir, ignore_errors=True)
952
  shutil.rmtree(output_dir, ignore_errors=True)
953
 
954
- with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery") as demo:
955
- gr.Markdown("# πŸ”„ Advanced FP8 Quantizer with Per-Tensor Precision Recovery")
956
- gr.Markdown("Convert model files (safetensors, pth, ckpt) β†’ **FP8** + **customizable precision recovery**. Supports any number of sharded files.")
957
 
958
  with gr.Row():
959
  with gr.Column():
@@ -975,70 +580,6 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
975
  with gr.Accordion("FP8 Settings", open=True):
976
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
977
 
978
- with gr.Accordion("Per-Tensor Recovery Rules", open=True):
979
- gr.Markdown("""
980
- ### Configure recovery strategy for each tensor pattern
981
-
982
- Format: JSON array of rule objects:
983
- ```json
984
- [
985
- {
986
- "key_pattern": "vae",
987
- "dim": 4,
988
- "method": "diff"
989
- },
990
- {
991
- "key_pattern": "attn",
992
- "dim": 2,
993
- "min_size": 10000,
994
- "method": "lora",
995
- "rank": 64
996
- },
997
- {
998
- "key_pattern": "all",
999
- "method": "none"
1000
- }
1001
- ]
1002
- ```
1003
-
1004
- ### Rule Fields (all optional except "method"):
1005
- - `key_pattern`: Substring to match in weight keys (case-insensitive). Use "all" to match everything.
1006
- - `dim`: Tensor dimension (e.g., 2 for linear layers, 4 for convolutions)
1007
- - `type`: Tensor type ("conv", "linear", "bias", "input_proj", "output_proj")
1008
- - `min_size`: Minimum number of elements in tensor
1009
- - `shape_contains`: Specific dimension size that must be present in shape
1010
- - `method`: "none" (pure FP8), "lora" (low-rank adaptation), or "diff" (difference/correction)
1011
- - `rank`: Required for "lora" method (higher = better quality but larger file)
1012
-
1013
- **Rules are applied in order** - first match wins. Always end with a catch-all rule.
1014
- """)
1015
-
1016
- recovery_rules_json = gr.Textbox(
1017
- value=generate_default_rules("all"),
1018
- lines=15,
1019
- label="Recovery Rules (JSON)",
1020
- interactive=True
1021
- )
1022
-
1023
- architecture_preset = gr.Dropdown(
1024
- choices=[
1025
- ("Auto-detect architecture", "auto"),
1026
- ("VAE (Difference method)", "vae"),
1027
- ("Text Encoder (LoRA)", "text_encoder"),
1028
- ("UNet Transformers (LoRA)", "unet_transformer"),
1029
- ("UNet Convolutions (Difference)", "unet_conv"),
1030
- ("All Components (Mixed)", "all")
1031
- ],
1032
- value="auto",
1033
- label="Architecture Preset"
1034
- )
1035
-
1036
- architecture_preset.change(
1037
- fn=generate_default_rules,
1038
- inputs=architecture_preset,
1039
- outputs=recovery_rules_json
1040
- )
1041
-
1042
  with gr.Accordion("Authentication", open=False):
1043
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
1044
  modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
@@ -1050,7 +591,7 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
1050
 
1051
  status_output = gr.Markdown()
1052
  detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10)
1053
- recovery_summary = gr.Textbox(label="Recovery Files Generated", interactive=False, lines=3)
1054
 
1055
  convert_btn = gr.Button("πŸš€ Convert & Upload", variant="primary")
1056
  repo_link_output = gr.HTML()
@@ -1063,7 +604,6 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
1063
  filename_pattern,
1064
  model_format,
1065
  fp8_format,
1066
- recovery_rules_json,
1067
  target_type,
1068
  new_repo_id,
1069
  hf_token,
@@ -1082,7 +622,6 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
1082
  "auto",
1083
  "safetensors",
1084
  "e4m3fn",
1085
- generate_default_rules("vae"),
1086
  "huggingface"
1087
  ],
1088
  [
@@ -1091,7 +630,6 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
1091
  "auto",
1092
  "safetensors",
1093
  "e5m2",
1094
- generate_default_rules("text_encoder"),
1095
  "huggingface"
1096
  ],
1097
  [
@@ -1100,7 +638,6 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
1100
  "auto",
1101
  "safetensors",
1102
  "e5m2",
1103
- generate_default_rules("unet_transformer"),
1104
  "huggingface"
1105
  ],
1106
  [
@@ -1109,7 +646,6 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
1109
  "model-*.safetensors",
1110
  "safetensors",
1111
  "e5m2",
1112
- generate_default_rules("all"),
1113
  "huggingface"
1114
  ],
1115
  [
@@ -1118,70 +654,49 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
1118
  "sd-v1-4.ckpt",
1119
  "ckpt",
1120
  "e5m2",
1121
- generate_default_rules("all"),
1122
  "huggingface"
1123
  ]
1124
  ],
1125
- inputs=[source_type, repo_url, filename_pattern, model_format, fp8_format, recovery_rules_json, target_type],
1126
  label="Example Conversions",
1127
  cache_examples=False
1128
  )
1129
 
1130
  gr.Markdown("""
1131
- ## πŸ’‘ Tensor Pattern Matching Guide
1132
-
1133
- This tool uses **advanced tensor pattern matching** to determine which recovery method to apply to each layer:
1134
 
1135
- ### **Key Patterns**
1136
- - Match by substring in weight key name
1137
- - Case-insensitive matching
1138
- - Special keyword "all" matches everything
1139
 
1140
- ### **Tensor Properties**
1141
- - **Dimension (dim)**: Use `dim: 2` for linear layers, `dim: 4` for convolutions
1142
- - **Type**: Automatic classification based on shape:
1143
- - `conv`: 4D tensors with equal spatial dimensions
1144
- - `linear`: 2D tensors without extreme aspect ratio
1145
- - `input_proj`: 2D tensors with much larger second dimension
1146
- - `output_proj`: 2D tensors with much larger first dimension
1147
- - `bias`: 1D tensors
1148
 
1149
- ### **Size Constraints**
1150
- - **min_size**: Only apply to tensors with at least N elements
1151
- - **shape_contains**: Match tensors containing a specific dimension size
1152
-
1153
- ### **Rule Processing**
1154
- - Rules are evaluated **in order**
1155
- - First matching rule wins
1156
- - Always include a catch-all rule at the end
1157
-
1158
- > **Pro Tip for VAE**: Use `"dim": 4` combined with `"key_pattern": "vae"` to reliably target VAE convolutional layers with difference recovery.
1159
-
1160
- ## πŸ“ File Format Support
1161
-
1162
- This tool supports multiple model formats:
1163
-
1164
- - **Safetensors**: Modern, secure format for storing tensors. Supports sharded files (e.g., `model-00001-of-00005.safetensors`).
1165
- - **PTH/PT**: PyTorch checkpoint files. Can contain state dicts or full model objects.
1166
- - **CKPT**: Checkpoint files, commonly used for stable diffusion models.
1167
-
1168
- ### Shard Support:
1169
  - **Unlimited Shards**: Supports any number of sharded files (2, 5, 10, 20+)
1170
  - **Auto-Detection**: Automatically finds all shards when using "auto" pattern
1171
- - **Parallel Downloads**: Downloads multiple shards simultaneously for faster processing
1172
- - **Memory Efficient**: Processes shards one at a time to manage memory usage
1173
- - **Progress Tracking**: Shows detailed progress for each shard download and processing
1174
-
1175
- ### Filename Patterns:
1176
- - **Auto-detection**: Use "auto" to automatically find all sharded safetensors files
1177
- - **Wildcard patterns**: Use `model-*.safetensors` to match sharded files
1178
- - **Specific file**: Use exact filename for single files
1179
-
1180
- For models with many shards (e.g., 5+ files), the tool will:
1181
- 1. Automatically detect all shards
1182
- 2. Download them in parallel (up to 4 simultaneous downloads)
1183
- 3. Load them sequentially to manage memory
1184
- 4. Merge them into a single FP8 model
 
 
 
 
 
 
 
 
1185
  """)
1186
 
1187
- demo.launch()
 
21
  except ImportError:
22
  MODELScope_AVAILABLE = False
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def load_model_files(model_paths, model_format="safetensors", progress_callback=None):
25
  """
26
  Load model weights from one or more files, supporting sharded safetensors and other formats.
 
145
 
146
  return base_name
147
 
148
+ def convert_model_to_fp8(model_paths, output_dir, fp8_format,
149
+ model_format="safetensors", progress=gr.Progress()):
150
+ """Simple and fast FP8 conversion without recovery strategies."""
151
+ progress(0.05, desc=f"Starting FP8 conversion for {model_format}...")
152
  try:
153
  metadata = read_model_metadata(model_paths, model_format)
154
  progress(0.1, desc="Loaded metadata.")
 
166
 
167
  # Initialize outputs
168
  sd_fp8 = {}
169
+ conversion_stats = {
170
+ "total_tensors": len(state_dict),
171
+ "converted_tensors": 0,
172
+ "skipped_tensors": 0,
173
+ "skipped_reasons": []
 
 
174
  }
175
 
176
  # Process each tensor
177
  total = len(state_dict)
178
  for i, key in enumerate(state_dict):
179
+ if i % 100 == 0: # Update progress every 100 tensors for speed
180
+ progress(0.3 + 0.6 * (i / total), desc=f"Converting {i}/{total} tensors...")
181
+
182
  weight = state_dict[key]
 
183
 
184
+ # Convert only float tensors to FP8
185
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
186
  fp8_weight = weight.to(fp8_dtype)
187
  sd_fp8[key] = fp8_weight
188
+ conversion_stats["converted_tensors"] += 1
189
  else:
190
+ # Keep non-float tensors as-is (e.g., ints, bools)
191
  sd_fp8[key] = weight
192
+ conversion_stats["skipped_tensors"] += 1
193
+ conversion_stats["skipped_reasons"].append(f"{key}: {weight.dtype}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  # Extract base name for output files
196
  base_name = extract_base_name_from_sharded_files(model_paths)
197
 
198
  # Save FP8 model
199
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
200
+ save_file(sd_fp8, fp8_path, metadata={
201
+ "format": model_format,
202
+ "fp8_format": fp8_format,
203
+ "original_files": str(len(model_paths)),
204
+ "conversion_stats": json.dumps(conversion_stats),
205
+ **metadata
206
+ })
 
 
 
 
 
 
207
 
208
+ progress(0.95, desc="Saved FP8 file.")
209
 
210
  # Generate stats message
211
+ stats_msg = f"βœ… FP8 ({fp8_format}) conversion complete!\n"
212
+ stats_msg += f"- Total tensors: {conversion_stats['total_tensors']}\n"
213
+ stats_msg += f"- Converted to FP8: {conversion_stats['converted_tensors']}\n"
214
+ stats_msg += f"- Skipped (non-float): {conversion_stats['skipped_tensors']}\n"
215
+ stats_msg += f"- Output file: {os.path.basename(fp8_path)}\n"
216
+
217
+ if conversion_stats["skipped_tensors"] > 0:
218
+ stats_msg += "\n⚠️ Some tensors were skipped (non-float types):\n"
219
+ for i, reason in enumerate(conversion_stats["skipped_reasons"][:5]): # Show first 5
220
+ stats_msg += f" - {reason}\n"
221
+ if len(conversion_stats["skipped_reasons"]) > 5:
222
+ stats_msg += f" - ... and {len(conversion_stats['skipped_reasons']) - 5} more\n"
223
+
224
+ progress(1.0, desc="βœ… FP8 conversion complete!")
225
+ return True, stats_msg, conversion_stats, fp8_path, None
 
 
 
 
 
 
226
 
227
  except Exception as e:
228
  traceback.print_exc()
 
433
  else:
434
  raise ValueError("Unknown target")
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  def process_and_upload_fp8(
437
  source_type,
438
  repo_url,
439
  filename_pattern,
440
  model_format,
441
  fp8_format,
 
442
  target_type,
443
  new_repo_id,
444
  hf_token,
 
453
  if target_type == "huggingface" and not hf_token:
454
  return None, "❌ Hugging Face token required for target.", "", ""
455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  temp_dir = None
457
  output_dir = tempfile.mkdtemp()
458
  try:
 
461
  source_type, repo_url, filename_pattern, model_format, hf_token, progress
462
  )
463
 
464
+ progress(0.8, desc="Converting to FP8...")
465
+ success, msg, stats, fp8_path, _ = convert_model_to_fp8(
466
+ model_paths, output_dir, fp8_format, model_format, progress
467
  )
468
 
469
  if not success:
 
484
  original_filename += f" matching '{filename_pattern}'"
485
 
486
  fp8_filename = os.path.basename(fp8_path)
 
487
 
488
  readme = f"""---
489
  library_name: diffusers
490
  tags:
491
  - fp8
492
  - safetensors
 
 
493
  - converted-by-gradio
494
  ---
495
+ # FP8 Model Conversion
496
  - **Source**: `{repo_url}`
497
  - **Original File(s)**: `{original_filename}`
498
  - **Original Format**: `{model_format}`
499
  - **FP8 Format**: `{fp8_format.upper()}`
500
  - **FP8 File**: `{fp8_filename}`
501
+
502
+ ## Usage
 
 
 
 
503
  ```python
504
  from safetensors.torch import load_file
505
  import torch
506
+
507
  # Load FP8 model
508
  fp8_state = load_file("{fp8_filename}")
509
+
510
+ # Convert tensors back to float32 for computation (auto-converted by PyTorch)
511
+ model.load_state_dict(fp8_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  ```
513
+
514
+ > **Note**: FP8 tensors are automatically converted to float32 when loaded in PyTorch.
515
  > Requires PyTorch β‰₯ 2.1 for FP8 support.
516
+
517
  ## Statistics
518
+ - **Total tensors**: {stats['total_tensors']}
519
+ - **Converted to FP8**: {stats['converted_tensors']}
520
+ - **Skipped (non-float)**: {stats['skipped_tensors']}
 
521
  """
522
 
523
  with open(os.path.join(output_dir, "README.md"), "w") as f:
 
535
  progress(1.0, desc="βœ… Done!")
536
 
537
  # Generate result HTML
 
 
 
 
538
  result_html = f"""
539
  βœ… Success!
540
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
 
541
  - FP8 model: `{fp8_filename}`
542
+ - Converted {stats['converted_tensors']} tensors to {fp8_format.upper()}
543
  """
544
 
 
545
  return (gr.HTML(result_html),
546
+ "βœ… FP8 conversion successful!",
547
  msg,
548
+ "")
549
 
550
  except Exception as e:
551
  traceback.print_exc()
 
556
  shutil.rmtree(temp_dir, ignore_errors=True)
557
  shutil.rmtree(output_dir, ignore_errors=True)
558
 
559
+ with gr.Blocks(title="Fast FP8 Model Converter") as demo:
560
+ gr.Markdown("# ⚑ Fast FP8 Model Converter")
561
+ gr.Markdown("Convert model files (safetensors, pth, ckpt) β†’ **FP8**. Supports sharded files with auto-discovery. Simple and fast!")
562
 
563
  with gr.Row():
564
  with gr.Column():
 
580
  with gr.Accordion("FP8 Settings", open=True):
581
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  with gr.Accordion("Authentication", open=False):
584
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
585
  modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
 
591
 
592
  status_output = gr.Markdown()
593
  detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10)
594
+ recovery_summary = gr.Textbox(label="Additional Info", interactive=False, lines=3)
595
 
596
  convert_btn = gr.Button("πŸš€ Convert & Upload", variant="primary")
597
  repo_link_output = gr.HTML()
 
604
  filename_pattern,
605
  model_format,
606
  fp8_format,
 
607
  target_type,
608
  new_repo_id,
609
  hf_token,
 
622
  "auto",
623
  "safetensors",
624
  "e4m3fn",
 
625
  "huggingface"
626
  ],
627
  [
 
630
  "auto",
631
  "safetensors",
632
  "e5m2",
 
633
  "huggingface"
634
  ],
635
  [
 
638
  "auto",
639
  "safetensors",
640
  "e5m2",
 
641
  "huggingface"
642
  ],
643
  [
 
646
  "model-*.safetensors",
647
  "safetensors",
648
  "e5m2",
 
649
  "huggingface"
650
  ],
651
  [
 
654
  "sd-v1-4.ckpt",
655
  "ckpt",
656
  "e5m2",
 
657
  "huggingface"
658
  ]
659
  ],
660
+ inputs=[source_type, repo_url, filename_pattern, model_format, fp8_format, target_type],
661
  label="Example Conversions",
662
  cache_examples=False
663
  )
664
 
665
  gr.Markdown("""
666
+ ## πŸ“ Fast FP8 Conversion Tool
 
 
667
 
668
+ This tool provides **fast and simple FP8 conversion** for various model formats:
 
 
 
669
 
670
+ ### **Supported Formats:**
671
+ - **Safetensors**: Modern, secure format. Supports sharded files (e.g., `model-00001-of-00005.safetensors`)
672
+ - **PTH/PT**: PyTorch checkpoint files
673
+ - **CKPT**: Checkpoint files (commonly used for stable diffusion models)
 
 
 
 
674
 
675
+ ### **Shard Support:**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  - **Unlimited Shards**: Supports any number of sharded files (2, 5, 10, 20+)
677
  - **Auto-Detection**: Automatically finds all shards when using "auto" pattern
678
+ - **Parallel Downloads**: Downloads multiple shards simultaneously (up to 4 at once)
679
+ - **Memory Efficient**: Processes files efficiently to manage memory
680
+
681
+ ### **Performance Features:**
682
+ - **Fast Conversion**: Simple dtype conversion without complex recovery strategies
683
+ - **Batch Processing**: Processes tensors in batches for better performance
684
+ - **Progress Tracking**: Shows detailed progress for each step
685
+
686
+ ### **How It Works:**
687
+ 1. **Discovery**: Automatically detects sharded files or uses your specified pattern
688
+ 2. **Download**: Downloads files in parallel for maximum speed
689
+ 3. **Conversion**: Converts float tensors to FP8, leaves other types unchanged
690
+ 4. **Upload**: Uploads the converted model to your target repository
691
+
692
+ ### **Usage Tips:**
693
+ - Use "auto" pattern to automatically detect all sharded safetensors files
694
+ - Use `model-*.safetensors` to match specific shard patterns
695
+ - For single files, just enter the filename (e.g., `model.safetensors`)
696
+ - FP8 conversion reduces model size by ~4x compared to FP32
697
+ - FP8 tensors are automatically converted to float32 when loaded in PyTorch
698
+
699
+ > **Note**: This is a simple conversion tool. For precision recovery options, use the advanced version.
700
  """)
701
 
702
+ demo.launch().