manoskary commited on
Commit
c250d8c
·
1 Parent(s): 139449e

Add .gitignore and implement model weight loading with quantization support

Browse files
Files changed (2) hide show
  1. .gitignore +8 -0
  2. app.py +34 -3
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ env/
2
+ .cache/
3
+ squashfs-root/
4
+ __pycache__/
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+ *.AppImage
app.py CHANGED
@@ -21,7 +21,7 @@ import traceback
21
  import os
22
  from smolagents import Tool
23
  from typing import Optional
24
- from weavemuse.models.notagen.inference import inference_patch, load_model_weights
25
 
26
 
27
  logging.basicConfig(level=logging.INFO)
@@ -37,6 +37,34 @@ device = "cuda"
37
  logger.info(f"Preparing NotaGen tool on device: {device}")
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  class SimpleNotaGenTool(Tool):
41
  """
42
  Simple tool for symbolic music generation using NotaGen model.
@@ -120,12 +148,15 @@ class SimpleNotaGenTool(Tool):
120
 
121
  Returns:
122
  Path to generated ABC file or error message
123
- """
 
 
 
124
  logger.info(f"Generating music: {period}-{composer}-{instrumentation}")
125
 
126
  # Create prompt for NotaGen
127
  prompt = f"{period}-{composer}-{instrumentation}"
128
-
129
  # Use the inference function
130
  inference_fn = inference_patch
131
  if inference_fn is None:
 
21
  import os
22
  from smolagents import Tool
23
  from typing import Optional
24
+ from weavemuse.models.notagen.inference import inference_patch, download_model_weights, model
25
 
26
 
27
  logging.basicConfig(level=logging.INFO)
 
37
  logger.info(f"Preparing NotaGen tool on device: {device}")
38
 
39
 
40
+ def load_model_weights(model_id=None):
41
+ """Load model weights with intelligent quantization support."""
42
+ global model
43
+
44
+ # Fall back to original weights
45
+ try:
46
+ logger.info("Loading original full-precision model...")
47
+ original_path = download_model_weights(repo_id="ElectricAlexis/NotaGen")
48
+
49
+ checkpoint = torch.load(original_path, map_location=device, weights_only=False)
50
+ if 'model' in checkpoint:
51
+ state_dict = checkpoint['model']
52
+ else:
53
+ state_dict = checkpoint
54
+
55
+ model.load_state_dict(state_dict)
56
+ model.eval()
57
+ logger.info("✅ Original model loaded successfully!")
58
+ return True
59
+
60
+ except Exception as e:
61
+ logger.error(f"Failed to load model weights: {e}")
62
+ import traceback
63
+ logger.error(traceback.format_exc())
64
+ raise
65
+
66
+
67
+
68
  class SimpleNotaGenTool(Tool):
69
  """
70
  Simple tool for symbolic music generation using NotaGen model.
 
148
 
149
  Returns:
150
  Path to generated ABC file or error message
151
+ """
152
+ global model
153
+ global device
154
+
155
  logger.info(f"Generating music: {period}-{composer}-{instrumentation}")
156
 
157
  # Create prompt for NotaGen
158
  prompt = f"{period}-{composer}-{instrumentation}"
159
+ model = model.to(device)
160
  # Use the inference function
161
  inference_fn = inference_patch
162
  if inference_fn is None: