File size: 2,218 Bytes
df15b5f
 
 
 
 
 
 
fbc78b9
 
 
df15b5f
fbc78b9
 
df15b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67b4bc4
df15b5f
 
 
 
 
fbc78b9
 
df15b5f
bfcd5d3
df15b5f
fbc78b9
df15b5f
 
 
bfcd5d3
df15b5f
fbc78b9
df15b5f
 
 
bfcd5d3
d2910e9
df15b5f
 
 
 
bfcd5d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import torchaudio
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# Whisper model for language detection
lang_id_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to(device)
lang_id_processor = WhisperProcessor.from_pretrained("openai/whisper-medium")

# Indic Conformer model for transcription
model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)

def detect_language(audio_path):
    waveform, sr = torchaudio.load(audio_path)
    waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
    inputs = lang_id_processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt").to(device)

    start_token_id = lang_id_processor.tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
    decoder_input_ids = torch.tensor([[start_token_id]], device=device)

    with torch.no_grad():
        outputs = lang_id_model.generate(
            inputs["input_features"],
            decoder_input_ids=decoder_input_ids,
            max_new_tokens=1,
        )

    lang_token = lang_id_processor.tokenizer.decode(outputs[0], skip_special_tokens=False)
    lang_code = lang_token.replace("<|", "").replace("|>", "").strip()
    return lang_code, waveform.to(device)

def transcribe(audio_path):
    try:
        lang_code, wav = detect_language(audio_path)

        transcription_ctc = model(wav, lang_code, "ctc")
        transcription_rnnt = model(wav, lang_code, "rnnt")

        return lang_code, transcription_ctc, transcription_rnnt
    except Exception as e:
        return "Error", str(e), ""

demo = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="filepath", label="Upload a WAV or MP3 file"),
    outputs=[
        gr.Textbox(label="Detected Language"),
        gr.Textbox(label="CTC Transcription"),
        gr.Textbox(label="RNNT Transcription")
    ],
    title="Language-Aware Transcription",
    description="Step 1: Detect language using Whisper-Medium. Step 2: Transcribe using AI4Bharat Indic Conformer."
)

if __name__ == "__main__":
    demo.launch()