422_MTDDP / app.py
ASureevaA
fix mms
c6a3c71
import tempfile
from typing import List, Tuple, Any
import gradio as gr
import soundfile as sf
import torch
import torch.nn.functional as torch_functional
from gtts import gTTS
from PIL import Image, ImageDraw
from transformers import (
AutoTokenizer,
CLIPModel,
CLIPProcessor,
SamModel,
SamProcessor,
VitsModel,
pipeline,
BlipForQuestionAnswering,
BlipProcessor,
)
MODEL_STORE = {}
def _normalize_gallery_images(gallery_value: Any) -> List[Image.Image]:
if not gallery_value:
return []
normalized_images: List[Image.Image] = []
for item in gallery_value:
if isinstance(item, Image.Image):
normalized_images.append(item)
continue
if isinstance(item, str):
try:
image_object = Image.open(item).convert("RGB")
normalized_images.append(image_object)
except Exception:
continue
continue
if isinstance(item, (list, tuple)) and item:
candidate = item[0]
if isinstance(candidate, Image.Image):
normalized_images.append(candidate)
continue
if isinstance(item, dict):
candidate = item.get("image") or item.get("value")
if isinstance(candidate, Image.Image):
normalized_images.append(candidate)
continue
return normalized_images
def get_audio_pipeline(model_key: str):
if model_key in MODEL_STORE:
return MODEL_STORE[model_key]
if model_key == "whisper":
audio_pipeline = pipeline(
task="automatic-speech-recognition",
model="distil-whisper/distil-small.en",
)
elif model_key == "wav2vec2":
audio_pipeline = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-small",
)
elif model_key == "audio_classifier":
audio_pipeline = pipeline(
task="audio-classification",
model="MIT/ast-finetuned-audioset-10-10-0.4593",
)
elif model_key == "emotion_classifier":
audio_pipeline = pipeline(
task="audio-classification",
model="superb/hubert-large-superb-er",
)
else:
raise ValueError(f"Неизвестный тип аудио модели: {model_key}")
MODEL_STORE[model_key] = audio_pipeline
return audio_pipeline
def get_zero_shot_audio_pipeline():
if "audio_zero_shot_clap" not in MODEL_STORE:
zero_shot_pipeline = pipeline(
task="zero-shot-audio-classification",
model="laion/clap-htsat-unfused",
)
MODEL_STORE["audio_zero_shot_clap"] = zero_shot_pipeline
return MODEL_STORE["audio_zero_shot_clap"]
def get_blip_vqa_components() -> Tuple[BlipForQuestionAnswering, BlipProcessor]:
if "blip_vqa_model" not in MODEL_STORE or "blip_vqa_processor" not in MODEL_STORE:
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
MODEL_STORE["blip_vqa_model"] = blip_model
MODEL_STORE["blip_vqa_processor"] = blip_processor
blip_model = MODEL_STORE["blip_vqa_model"]
blip_processor = MODEL_STORE["blip_vqa_processor"]
return blip_model, blip_processor
def get_vision_pipeline(model_key: str):
if model_key in MODEL_STORE:
return MODEL_STORE[model_key]
if model_key == "object_detection_conditional_detr":
vision_pipeline = pipeline(
task="object-detection",
model="microsoft/conditional-detr-resnet-50",
)
elif model_key == "object_detection_yolos_small":
vision_pipeline = pipeline(
task="object-detection",
model="hustvl/yolos-small",
)
elif model_key == "segmentation":
vision_pipeline = pipeline(
task="image-segmentation",
model="nvidia/segformer-b0-finetuned-ade-512-512",
)
elif model_key == "depth_estimation":
vision_pipeline = pipeline(
task="depth-estimation",
model="Intel/dpt-hybrid-midas",
)
elif model_key == "captioning_blip_base":
vision_pipeline = pipeline(
task="image-to-text",
model="Salesforce/blip-image-captioning-base",
)
elif model_key == "captioning_blip_large":
vision_pipeline = pipeline(
task="image-to-text",
model="Salesforce/blip-image-captioning-large",
)
elif model_key == "vqa_blip_base":
vision_pipeline = pipeline(
task="visual-question-answering",
model="Salesforce/blip-vqa-base",
)
elif model_key == "vqa_vilt_b32":
vision_pipeline = pipeline(
task="visual-question-answering",
model="dandelin/vilt-b32-finetuned-vqa",
)
else:
raise ValueError(f"Неизвестный тип визуальной модели: {model_key}")
MODEL_STORE[model_key] = vision_pipeline
return vision_pipeline
def get_clip_components(clip_key: str) -> Tuple[CLIPModel, CLIPProcessor]:
model_store_key_model = f"clip_model_{clip_key}"
model_store_key_processor = f"clip_processor_{clip_key}"
if model_store_key_model not in MODEL_STORE or model_store_key_processor not in MODEL_STORE:
if clip_key == "clip_large_patch14":
clip_name = "openai/clip-vit-large-patch14"
elif clip_key == "clip_base_patch32":
clip_name = "openai/clip-vit-base-patch32"
else:
raise ValueError(f"Неизвестный вариант CLIP модели: {clip_key}")
clip_model = CLIPModel.from_pretrained(clip_name)
clip_processor = CLIPProcessor.from_pretrained(clip_name)
MODEL_STORE[model_store_key_model] = clip_model
MODEL_STORE[model_store_key_processor] = clip_processor
clip_model = MODEL_STORE[model_store_key_model]
clip_processor = MODEL_STORE[model_store_key_processor]
return clip_model, clip_processor
def get_silero_tts_model():
if "silero_tts_model" not in MODEL_STORE:
silero_model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-models",
model="silero_tts",
language="ru",
speaker="ru_v3",
)
MODEL_STORE["silero_tts_model"] = silero_model
return MODEL_STORE["silero_tts_model"]
def get_mms_tts_components():
if "mms_tts_pipeline" not in MODEL_STORE:
tts_pipeline = pipeline(
task="text-to-speech",
model="facebook/mms-tts-rus",
)
MODEL_STORE["mms_tts_pipeline"] = tts_pipeline
return MODEL_STORE["mms_tts_pipeline"]
def get_sam_components() -> Tuple[SamModel, SamProcessor]:
if "sam_model" not in MODEL_STORE or "sam_processor" not in MODEL_STORE:
sam_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
sam_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
MODEL_STORE["sam_model"] = sam_model
MODEL_STORE["sam_processor"] = sam_processor
sam_model = MODEL_STORE["sam_model"]
sam_processor = MODEL_STORE["sam_processor"]
return sam_model, sam_processor
def classify_audio_file(audio_path: str, model_key: str) -> str:
audio_classifier = get_audio_pipeline(model_key)
prediction_list = audio_classifier(audio_path)
result_lines = ["Топ-5 предсказаний:"]
for prediction_index, prediction_item in enumerate(prediction_list[:5], start=1):
label_value = prediction_item["label"]
score_value = prediction_item["score"]
result_lines.append(
f"{prediction_index}. {label_value}: {score_value:.4f}"
)
return "\n".join(result_lines)
def classify_audio_zero_shot_clap(audio_path: str, label_texts: str) -> str:
clap_pipeline = get_zero_shot_audio_pipeline()
label_list = [
label_item.strip()
for label_item in label_texts.split(",")
if label_item.strip()
]
if not label_list:
return "Не задано ни одной текстовой метки для zero-shot классификации."
prediction_list = clap_pipeline(
audio_path,
candidate_labels=label_list,
)
result_lines = ["Zero-Shot Audio Classification (CLAP):"]
for prediction_index, prediction_item in enumerate(prediction_list, start=1):
label_value = prediction_item["label"]
score_value = prediction_item["score"]
result_lines.append(
f"{prediction_index}. {label_value}: {score_value:.4f}"
)
return "\n".join(result_lines)
def recognize_speech(audio_path: str, model_key: str) -> str:
speech_pipeline = get_audio_pipeline(model_key)
prediction_result = speech_pipeline(audio_path)
return prediction_result["text"]
def synthesize_speech(text_value: str, model_key: str):
if model_key == "Google TTS":
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
text_to_speech_engine = gTTS(text=text_value, lang="ru")
text_to_speech_engine.save(file_object.name)
return file_object.name
elif model_key == "mms":
model = VitsModel.from_pretrained("facebook/mms-tts-rus")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
inputs = tokenizer(text_value, return_tensors="pt")
with torch.no_grad():
output = model(**inputs).waveform
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate)
return f.name
raise ValueError(f"Неизвестная модель: {model_key}")
def detect_objects_on_image(image_object, model_key: str):
detector_pipeline = get_vision_pipeline(model_key)
detection_results = detector_pipeline(image_object)
drawer_object = ImageDraw.Draw(image_object)
for detection_item in detection_results:
box_data = detection_item["box"]
label_value = detection_item["label"]
score_value = detection_item["score"]
drawer_object.rectangle(
[
box_data["xmin"],
box_data["ymin"],
box_data["xmax"],
box_data["ymax"],
],
outline="red",
width=3,
)
drawer_object.text(
(box_data["xmin"], box_data["ymin"]),
f"{label_value}: {score_value:.2f}",
fill="red",
)
return image_object
def segment_image(image_object):
segmentation_pipeline = get_vision_pipeline("segmentation")
segmentation_results = segmentation_pipeline(image_object)
return segmentation_results[0]["mask"]
def estimate_image_depth(image_object):
depth_pipeline = get_vision_pipeline("depth_estimation")
depth_output = depth_pipeline(image_object)
predicted_depth_tensor = depth_output["predicted_depth"]
if predicted_depth_tensor.ndim == 3:
predicted_depth_tensor = predicted_depth_tensor.unsqueeze(1)
elif predicted_depth_tensor.ndim == 2:
predicted_depth_tensor = predicted_depth_tensor.unsqueeze(0).unsqueeze(0)
else:
raise ValueError(
f"Неожиданная размерность predicted_depth: {predicted_depth_tensor.shape}"
)
resized_depth_tensor = torch_functional.interpolate(
predicted_depth_tensor,
size=image_object.size[::-1],
mode="bicubic",
align_corners=False,
)
depth_array = resized_depth_tensor.squeeze().cpu().numpy()
max_value = float(depth_array.max())
if max_value <= 0.0:
return Image.new("L", image_object.size, color=0)
normalized_depth_array = (depth_array * 255.0 / max_value).astype("uint8")
depth_image = Image.fromarray(normalized_depth_array, mode="L")
return depth_image
def generate_image_caption(image_object, model_key: str) -> str:
caption_pipeline = get_vision_pipeline(model_key)
caption_result = caption_pipeline(image_object)
return caption_result[0]["generated_text"]
def answer_visual_question(image_object, question_text: str, model_key: str) -> str:
if image_object is None:
return "Пожалуйста, сначала загрузите изображение."
if not question_text.strip():
return "Пожалуйста, введите вопрос об изображении."
if model_key == "vqa_blip_base":
blip_model, blip_processor = get_blip_vqa_components()
inputs = blip_processor(
images=image_object,
text=question_text,
return_tensors="pt",
)
with torch.no_grad():
output_ids = blip_model.generate(**inputs)
decoded_answers = blip_processor.batch_decode(
output_ids,
skip_special_tokens=True,
)
answer_text = decoded_answers[0] if decoded_answers else ""
return answer_text or "Модель не смогла сгенерировать ответ."
vqa_pipeline = get_vision_pipeline(model_key)
vqa_result = vqa_pipeline(
image=image_object,
question=question_text,
)
top_item = vqa_result[0]
answer_text = top_item["answer"]
confidence_value = top_item["score"]
return f"{answer_text} (confidence: {confidence_value:.3f})"
def perform_zero_shot_classification(
image_object,
class_texts: str,
clip_key: str,
) -> str:
clip_model, clip_processor = get_clip_components(clip_key)
class_list = [
class_name.strip()
for class_name in class_texts.split(",")
if class_name.strip()
]
if not class_list:
return "Не задано ни одного класса для классификации."
input_batch = clip_processor(
text=class_list,
images=image_object,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
clip_outputs = clip_model(**input_batch)
logits_per_image = clip_outputs.logits_per_image
probability_tensor = logits_per_image.softmax(dim=1)
result_lines = ["Zero-Shot Classification Results:"]
for class_index, class_name in enumerate(class_list):
probability_value = probability_tensor[0][class_index].item()
result_lines.append(f"{class_name}: {probability_value:.4f}")
return "\n".join(result_lines)
def retrieve_best_image(
gallery_value: Any,
query_text: str,
clip_key: str,
) -> Tuple[str, Image.Image | None]:
image_list = _normalize_gallery_images(gallery_value)
if not image_list or not query_text.strip():
return "Пожалуйста, загрузите изображения и введите запрос", None
clip_model, clip_processor = get_clip_components(clip_key)
image_inputs = clip_processor(
images=image_list,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
image_features = clip_model.get_image_features(**image_inputs)
image_features = image_features / image_features.norm(
dim=-1,
keepdim=True,
)
text_inputs = clip_processor(
text=[query_text],
return_tensors="pt",
padding=True,
)
with torch.no_grad():
text_features = clip_model.get_text_features(**text_inputs)
text_features = text_features / text_features.norm(
dim=-1,
keepdim=True,
)
similarity_tensor = image_features @ text_features.T
best_index_tensor = similarity_tensor.argmax()
best_index_value = best_index_tensor.item()
best_score_value = similarity_tensor[best_index_value].item()
description_text = (
f"Лучшее изображение: #{best_index_value + 1} "
f"(схожесть: {best_score_value:.4f})"
)
return description_text, image_list[best_index_value]
def segment_image_with_sam_points(
image_object,
point_coordinates_list: List[List[int]],
) -> Image.Image:
if image_object is None:
raise ValueError("Изображение не передано в segment_image_with_sam_points")
if not point_coordinates_list:
return Image.new("L", image_object.size, color=0)
sam_model, sam_processor = get_sam_components()
batched_points: List[List[List[int]]] = [point_coordinates_list]
batched_labels: List[List[int]] = [[1 for _ in point_coordinates_list]]
sam_inputs = sam_processor(
image=image_object,
input_points=batched_points,
input_labels=batched_labels,
return_tensors="pt",
)
with torch.no_grad():
sam_outputs = sam_model(**sam_inputs, multimask_output=True)
processed_masks_list = sam_processor.image_processor.post_process_masks(
sam_outputs.pred_masks.squeeze(1).cpu(),
sam_inputs["original_sizes"].cpu(),
sam_inputs["reshaped_input_sizes"].cpu(),
)
batch_masks_tensor = processed_masks_list[0]
if batch_masks_tensor.ndim != 3 or batch_masks_tensor.shape[0] == 0:
return Image.new("L", image_object.size, color=0)
first_mask_tensor = batch_masks_tensor[0]
mask_array = first_mask_tensor.numpy()
binary_mask_array = (mask_array > 0.5).astype("uint8") * 255
mask_image = Image.fromarray(binary_mask_array, mode="L")
return mask_image
def segment_image_with_sam_points_ui(image_object, coordinates_text: str) -> Image.Image:
if image_object is None:
return None
coordinates_text_clean = coordinates_text.strip()
if not coordinates_text_clean:
return Image.new("L", image_object.size, color=0)
point_coordinates_list: List[List[int]] = []
for raw_pair in coordinates_text_clean.replace("\n", ";").split(";"):
raw_pair_clean = raw_pair.strip()
if not raw_pair_clean:
continue
parts = raw_pair_clean.split(",")
if len(parts) != 2:
continue
try:
x_value = int(parts[0].strip())
y_value = int(parts[1].strip())
except ValueError:
continue
point_coordinates_list.append([x_value, y_value])
if not point_coordinates_list:
return Image.new("L", image_object.size, color=0)
return segment_image_with_sam_points(image_object, point_coordinates_list)
def parse_point_coordinates_text(coordinates_text: str) -> List[List[int]]:
if not coordinates_text.strip():
return []
point_list: List[List[int]] = []
for raw_pair in coordinates_text.split(";"):
cleaned_pair = raw_pair.strip()
if not cleaned_pair:
continue
coordinate_parts = cleaned_pair.split(",")
if len(coordinate_parts) != 2:
continue
try:
x_value = int(coordinate_parts[0].strip())
y_value = int(coordinate_parts[1].strip())
except ValueError:
continue
point_list.append([x_value, y_value])
return point_list
def build_interface():
with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo_block:
gr.Markdown("# AI модели")
with gr.Tab("Классификация аудио"):
gr.Markdown("## Классификация аудио")
with gr.Row():
audio_input_component = gr.Audio(
label="Загрузите аудиофайл",
type="filepath",
)
audio_model_selector = gr.Dropdown(
choices=["audio_classifier", "emotion_classifier"],
label="Выберите модель",
value="audio_classifier",
info=(
"audio_classifier - общая классификация (курс)"
"emotion_classifier - эмоции в речи "
),
)
audio_classify_button = gr.Button("Применить")
audio_output_component = gr.Textbox(
label="Результаты классификации",
lines=10,
)
audio_classify_button.click(
fn=classify_audio_file,
inputs=[audio_input_component, audio_model_selector],
outputs=audio_output_component,
)
with gr.Tab("Zero-Shot аудио"):
gr.Markdown("## Zero-Shot аудио классификатор")
with gr.Row():
clap_audio_input_component = gr.Audio(
label="Загрузите аудиофайл",
type="filepath",
)
clap_label_texts_component = gr.Textbox(
label="Кандидатные метки (через запятую)",
placeholder="лай собаки, шум дождя, музыка, разговор",
lines=2,
)
clap_button = gr.Button("Применить")
clap_output_component = gr.Textbox(
label="Результаты zero-shot классификации",
lines=10,
)
clap_button.click(
fn=classify_audio_zero_shot_clap,
inputs=[clap_audio_input_component, clap_label_texts_component],
outputs=clap_output_component,
)
with gr.Tab("Распознавание речи"):
gr.Markdown("## Распознавание реч")
with gr.Row():
asr_audio_input_component = gr.Audio(
label="Загрузите аудио с речью",
type="filepath",
)
asr_model_selector = gr.Dropdown(
choices=["whisper", "wav2vec2"],
label="Выберите модель",
value="whisper",
info=(
"whisper - distil-whisper/distil-small.en (курс),\n"
"wav2vec2 - openai/whisper-small"
),
)
asr_button = gr.Button("Применить")
asr_output_component = gr.Textbox(
label="Транскрипция",
lines=5,
)
asr_button.click(
fn=recognize_speech,
inputs=[asr_audio_input_component, asr_model_selector],
outputs=asr_output_component,
)
with gr.Tab("Синтез речи"):
gr.Markdown("## Text-to-Speech")
with gr.Row():
tts_text_component = gr.Textbox(
label="Введите текст для синтеза",
placeholder="Введите текст на русском или английском языке...",
lines=3,
)
tts_model_selector = gr.Dropdown(
choices=["mms", "Google TTS"],
label="Выберите модель",
value="mms",
info=(
"facebook/mms-tts-rus\n"
"Google TTS"
),
)
tts_button = gr.Button("Применить")
tts_audio_output_component = gr.Audio(
label="Синтезированная речь",
type="filepath",
)
tts_button.click(
fn=synthesize_speech,
inputs=[tts_text_component, tts_model_selector],
outputs=tts_audio_output_component,
)
with gr.Tab("Детекция объектов"):
gr.Markdown("## Детекция объектов")
with gr.Row():
object_input_image = gr.Image(
label="Загрузите изображение",
type="pil",
)
object_model_selector = gr.Dropdown(
choices=[
"object_detection_conditional_detr",
"object_detection_yolos_small",
],
label="Модель",
value="object_detection_conditional_detr",
info=(
"object_detection_conditional_detr - microsoft/conditional-detr-resnet-50\n"
"object_detection_yolos_small - hustvl/yolos-small"
),
)
object_detect_button = gr.Button("Применить")
object_output_image = gr.Image(
label="Результат",
)
object_detect_button.click(
fn=detect_objects_on_image,
inputs=[object_input_image, object_model_selector],
outputs=object_output_image,
)
with gr.Tab("Сегментация"):
gr.Markdown("## Сегментация")
with gr.Row():
segmentation_input_image = gr.Image(
label="Загрузите изображение",
type="pil",
)
segmentation_button = gr.Button("Применить")
segmentation_output_image = gr.Image(
label="Маска",
)
segmentation_button.click(
fn=segment_image,
inputs=segmentation_input_image,
outputs=segmentation_output_image,
)
with gr.Tab("Глубина"):
gr.Markdown("## Глубина (Depth Estimation)")
with gr.Row():
depth_input_image = gr.Image(
label="Загрузите изображение",
type="pil",
)
depth_button = gr.Button("Применить")
depth_output_image = gr.Image(
label="Глубины",
)
depth_button.click(
fn=estimate_image_depth,
inputs=depth_input_image,
outputs=depth_output_image,
)
with gr.Tab("Описание изображений"):
gr.Markdown("## Описание изображений")
with gr.Row():
caption_input_image = gr.Image(
label="Загрузите изображение",
type="pil",
)
caption_model_selector = gr.Dropdown(
choices=[
"captioning_blip_base",
"captioning_blip_large",
],
label="Модель",
value="captioning_blip_base",
info=(
"captioning_blip_base - Salesforce/blip-image-captioning-base (курс)\n"
"captioning_blip_large - Salesforce/blip-image-captioning-large"
),
)
caption_button = gr.Button("Применить")
caption_output_text = gr.Textbox(
label="Описание изображения",
lines=3,
)
caption_button.click(
fn=generate_image_caption,
inputs=[caption_input_image, caption_model_selector],
outputs=caption_output_text,
)
with gr.Tab("Визуальные вопросы"):
gr.Markdown("## Visual Question Answering")
with gr.Row():
vqa_input_image = gr.Image(
label="Загрузите изображение",
type="pil",
)
vqa_question_text = gr.Textbox(
label="Вопрос",
placeholder="Вопрос",
lines=2,
)
vqa_model_selector = gr.Dropdown(
choices=[
"vqa_blip_base",
"vqa_vilt_b32",
],
label="Модель",
value="vqa_blip_base",
info=(
"vqa_blip_base - Salesforce/blip-vqa-base (курс)\n"
"vqa_vilt_b32 - dandelin/vilt-b32-finetuned-vqa"
),
)
vqa_button = gr.Button("Ответить на вопрос")
vqa_output_text = gr.Textbox(
label="Ответ",
lines=3,
)
vqa_button.click(
fn=answer_visual_question,
inputs=[vqa_input_image, vqa_question_text, vqa_model_selector],
outputs=vqa_output_text,
)
with gr.Tab("Zero-Shot классификация"):
gr.Markdown("## Zero-Shot классификация")
with gr.Row():
zero_shot_input_image = gr.Image(
label="Загрузите изображение",
type="pil",
)
zero_shot_classes_text = gr.Textbox(
label="Классы для классификации (через запятую)",
placeholder="человек, машина, дерево, здание, животное",
lines=2,
)
clip_model_selector = gr.Dropdown(
choices=[
"clip_large_patch14",
"clip_base_patch32",
],
label="модель",
value="clip_large_patch14",
info=(
"clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
"clip_base_patch32 - openai/clip-vit-base-patch32"
),
)
zero_shot_button = gr.Button("Применить")
zero_shot_output_text = gr.Textbox(
label="Результаты",
lines=10,
)
zero_shot_button.click(
fn=perform_zero_shot_classification,
inputs=[zero_shot_input_image, zero_shot_classes_text, clip_model_selector],
outputs=zero_shot_output_text,
)
with gr.Tab("Поиск изображений"):
gr.Markdown("## Поиск изображений")
with gr.Row():
retrieval_dir = gr.File(
label="Загрузите папку с изображениями",
file_count="directory",
file_types=["image"],
type="filepath",
)
retrieval_query_text = gr.Textbox(
label="Текстовый запрос",
placeholder="описание того, что вы ищете...",
lines=2,
)
retrieval_clip_selector = gr.Dropdown(
choices=[
"clip_large_patch14",
"clip_base_patch32",
],
label="модель",
value="clip_large_patch14",
info=(
"clip_large_patch14 - openai/clip-vit-large-patch14 (курс)\n"
"clip_base_patch32 - openai/clip-vit-base-patch32 (альтернатива)"
),
)
retrieval_button = gr.Button("Поиск")
retrieval_output_text = gr.Textbox(
label="Результат",
)
retrieval_output_image = gr.Image(
label="Наиболее подходящее изображение",
)
retrieval_button.click(
fn=retrieve_best_image,
inputs=[retrieval_dir, retrieval_query_text, retrieval_clip_selector],
outputs=[retrieval_output_text, retrieval_output_image],
)
gr.Markdown("---")
gr.Markdown("### Задачи:")
gr.Markdown(
"""
- Аудио: классификация, распознавание речи, синтез речи
- Компьютерное зрение: детекция объектов, сегментация, оценка глубины, генерация описаний изображений
- Мультимодальные задачи: вопросы к изображению, zero-shot классификация изображений, поиск по изображениям по текстовому запросу
"""
)
return demo_block
if __name__ == "__main__":
interface_block = build_interface()
interface_block.launch(share=True)