Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import boto3
|
|
| 7 |
import uvicorn
|
| 8 |
import soundfile as sf
|
| 9 |
import imageio
|
| 10 |
-
from typing import Dict
|
| 11 |
|
| 12 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
| 13 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
@@ -42,7 +42,7 @@ class GenerateRequest(BaseModel):
|
|
| 42 |
do_sample: bool = True
|
| 43 |
stop_sequences: list[str] = []
|
| 44 |
no_repeat_ngram_size: int = 2
|
| 45 |
-
continuation_id: str = None
|
| 46 |
|
| 47 |
@field_validator("model_name")
|
| 48 |
def model_name_cannot_be_empty(cls, v):
|
|
@@ -115,8 +115,10 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
|
|
| 115 |
if continuation_id:
|
| 116 |
if continuation_id not in active_generations:
|
| 117 |
raise HTTPException(status_code=404, detail="Continuation ID not found.")
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
| 120 |
|
| 121 |
generation_config = GenerationConfig(
|
| 122 |
temperature=temperature,
|
|
@@ -132,13 +134,10 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
|
|
| 132 |
|
| 133 |
generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
active_generations[continuation_id] = {"model_name": model_name, "output": generated_text}
|
| 138 |
-
else:
|
| 139 |
-
active_generations[continuation_id]["output"] = generated_text
|
| 140 |
|
| 141 |
-
return JSONResponse({"text": generated_text, "continuation_id":
|
| 142 |
|
| 143 |
except HTTPException as http_err:
|
| 144 |
raise http_err
|
|
@@ -186,9 +185,9 @@ async def generate_image(request: GenerateRequest):
|
|
| 186 |
|
| 187 |
image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
|
| 188 |
image = image_generator(request.input_text)[0]
|
| 189 |
-
|
| 190 |
-
active_generations[
|
| 191 |
-
return JSONResponse({"url": "Image generated successfully", "continuation_id":
|
| 192 |
|
| 193 |
except HTTPException as http_err:
|
| 194 |
raise http_err
|
|
@@ -203,9 +202,9 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
| 203 |
|
| 204 |
tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
|
| 205 |
output = tts_pipeline(request.input_text)
|
| 206 |
-
|
| 207 |
-
active_generations[
|
| 208 |
-
return JSONResponse({"url": "Audio generated successfully", "continuation_id":
|
| 209 |
|
| 210 |
except HTTPException as http_err:
|
| 211 |
raise http_err
|
|
@@ -220,9 +219,9 @@ async def generate_video(request: GenerateRequest):
|
|
| 220 |
|
| 221 |
video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
|
| 222 |
output = video_pipeline(request.input_text)
|
| 223 |
-
|
| 224 |
-
active_generations[
|
| 225 |
-
return JSONResponse({"url": "Video generated successfully", "continuation_id":
|
| 226 |
|
| 227 |
except HTTPException as http_err:
|
| 228 |
raise http_err
|
|
|
|
| 7 |
import uvicorn
|
| 8 |
import soundfile as sf
|
| 9 |
import imageio
|
| 10 |
+
from typing import Dict, Optional
|
| 11 |
|
| 12 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
| 13 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
|
|
| 42 |
do_sample: bool = True
|
| 43 |
stop_sequences: list[str] = []
|
| 44 |
no_repeat_ngram_size: int = 2
|
| 45 |
+
continuation_id: Optional[str] = None
|
| 46 |
|
| 47 |
@field_validator("model_name")
|
| 48 |
def model_name_cannot_be_empty(cls, v):
|
|
|
|
| 115 |
if continuation_id:
|
| 116 |
if continuation_id not in active_generations:
|
| 117 |
raise HTTPException(status_code=404, detail="Continuation ID not found.")
|
| 118 |
+
previous_data = active_generations[continuation_id]
|
| 119 |
+
if previous_data["model_name"] != model_name:
|
| 120 |
+
raise HTTPException(status_code=400, detail="Model mismatch for continuation.")
|
| 121 |
+
input_text = previous_data["output"]
|
| 122 |
|
| 123 |
generation_config = GenerationConfig(
|
| 124 |
temperature=temperature,
|
|
|
|
| 134 |
|
| 135 |
generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
|
| 136 |
|
| 137 |
+
new_continuation_id = continuation_id if continuation_id else os.urandom(16).hex()
|
| 138 |
+
active_generations[new_continuation_id] = {"model_name": model_name, "output": generated_text}
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
return JSONResponse({"text": generated_text, "continuation_id": new_continuation_id})
|
| 141 |
|
| 142 |
except HTTPException as http_err:
|
| 143 |
raise http_err
|
|
|
|
| 185 |
|
| 186 |
image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
|
| 187 |
image = image_generator(request.input_text)[0]
|
| 188 |
+
new_continuation_id = os.urandom(16).hex()
|
| 189 |
+
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Image generated successfully"}
|
| 190 |
+
return JSONResponse({"url": "Image generated successfully", "continuation_id": new_continuation_id})
|
| 191 |
|
| 192 |
except HTTPException as http_err:
|
| 193 |
raise http_err
|
|
|
|
| 202 |
|
| 203 |
tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
|
| 204 |
output = tts_pipeline(request.input_text)
|
| 205 |
+
new_continuation_id = os.urandom(16).hex()
|
| 206 |
+
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Audio generated successfully"}
|
| 207 |
+
return JSONResponse({"url": "Audio generated successfully", "continuation_id": new_continuation_id})
|
| 208 |
|
| 209 |
except HTTPException as http_err:
|
| 210 |
raise http_err
|
|
|
|
| 219 |
|
| 220 |
video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
|
| 221 |
output = video_pipeline(request.input_text)
|
| 222 |
+
new_continuation_id = os.urandom(16).hex()
|
| 223 |
+
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Video generated successfully"}
|
| 224 |
+
return JSONResponse({"url": "Video generated successfully", "continuation_id": new_continuation_id})
|
| 225 |
|
| 226 |
except HTTPException as http_err:
|
| 227 |
raise http_err
|