Noumida commited on
Commit
de7eff6
·
verified ·
1 Parent(s): 6a11e69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -35
app.py CHANGED
@@ -1,65 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
  import torch
3
  import torchaudio
4
  import gradio as gr
5
  import spaces
6
  from transformers import AutoModel
 
7
 
8
- DESCRIPTION = "IndicConformer-600M Multilingual ASR (CTC + RNNT)"
9
 
10
- LANGUAGE_NAME_TO_CODE = {
11
- "Assamese": "as", "Bengali": "bn", "Bodo": "br", "Dogri": "doi",
12
- "Gujarati": "gu", "Hindi": "hi", "Kannada": "kn", "Kashmiri": "ks",
13
- "Konkani": "kok", "Maithili": "mai", "Malayalam": "ml", "Manipuri": "mni",
14
- "Marathi": "mr", "Nepali": "ne", "Odia": "or", "Punjabi": "pa",
15
- "Sanskrit": "sa", "Santali": "sat", "Sindhi": "sd", "Tamil": "ta",
16
- "Telugu": "te", "Urdu": "ur"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  }
18
 
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- # Load Indic Conformer model (assumes custom forward handles decoding strategy)
 
22
  model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True).to(device)
23
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  @spaces.GPU
26
- def transcribe_ctc_and_rnnt(audio_path, language_name):
27
- lang_code = LANGUAGE_NAME_TO_CODE[language_name]
 
28
 
29
  # Load and preprocess audio
30
- waveform, sr = torchaudio.load(audio_path)
31
- waveform = waveform.mean(dim=0, keepdim=True) if waveform.shape[0] > 1 else waveform
32
- waveform = torchaudio.functional.resample(waveform, sr, 16000).to(device)
 
 
 
33
 
34
  try:
35
- # Assume model's forward method takes waveform, language code, and decoding type
36
  with torch.no_grad():
37
- transcription_ctc = model(waveform, lang_code, "ctc")
38
- transcription_rnnt = model(waveform, lang_code, "rnnt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
- return f"Error: {str(e)}", ""
 
 
41
 
42
- return transcription_ctc.strip(), transcription_rnnt.strip()
43
 
44
  # Gradio UI
45
- with gr.Blocks() as demo:
46
  gr.Markdown(f"## {DESCRIPTION}")
 
 
47
  with gr.Row():
48
- with gr.Column():
49
  audio = gr.Audio(label="Upload or Record Audio", type="filepath")
50
- lang = gr.Dropdown(
51
- label="Select Language",
52
- choices=list(LANGUAGE_NAME_TO_CODE.keys()),
53
- value="Hindi"
54
- )
55
- transcribe_btn = gr.Button("Transcribe (CTC + RNNT)")
56
- with gr.Column():
57
- gr.Markdown("### CTC Transcription")
58
- ctc_output = gr.Textbox(lines=3)
59
- gr.Markdown("### RNNT Transcription")
60
- rnnt_output = gr.Textbox(lines=3)
61
-
62
- transcribe_btn.click(fn=transcribe_ctc_and_rnnt, inputs=[audio, lang], outputs=[ctc_output, rnnt_output], api_name="transcribe")
 
 
63
 
64
  if __name__ == "__main__":
65
- demo.queue().launch()
 
 
 
1
+ Of course. I'll update the code to perform automatic language identification based on the transcription's characters and common words before providing the final, high-quality transcription.
2
+
3
+ This new version will:
4
+
5
+ 1. **Remove the language dropdown**, as the language will be detected automatically.
6
+ 2. Perform a quick, initial transcription using Hindi as a "pivot" language.
7
+ 3. Analyze the resulting text against a **custom dictionary** of unique characters and common words for all 22 supported languages.
8
+ 4. Once the language is identified, it will perform the final, more accurate transcription using the detected language code.
9
+
10
+ -----
11
+
12
+ ### **Updated Code with Automatic Language Identification**
13
+
14
+ Here is the complete, updated code. You can replace your existing script with this one.
15
+
16
+ ```python
17
  from __future__ import annotations
18
  import torch
19
  import torchaudio
20
  import gradio as gr
21
  import spaces
22
  from transformers import AutoModel
23
+ import re
24
 
25
+ DESCRIPTION = "IndicConformer-600M Multilingual ASR (CTC + RNNT) with Auto Language ID"
26
 
27
+ # --- Language Identification Data ---
28
+ # A dictionary containing unique character sets and common words for each language.
29
+ # This data is used by our custom language identification logic.
30
+ LANGUAGE_DATA = {
31
+ "as": {"chars": set("অআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযৰলৱশষসহৎংঃঽািীুূৃেৈোৌ্"), "words": set(["আৰু", "হয়", "এটা", "কৰি", "ওপৰত"])},
32
+ "bn": {"chars": set("অআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহৎংঃঽািীুূৃেৈোৌ্ড়ঢ়য়"), "words": set(["এবং", "একটি", "করুন", "জন্য", "সঙ্গে"])},
33
+ "br": {"chars": set("अआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह़ािीुূृेैोौ्"), "words": set(["आरो", "एसे", "मोनसे", "माव", "आव"])},
34
+ "doi": {"chars": set("अआइईउऊएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहािीुूेैोौ्"), "words": set(["ते", "दे", "ऐ", "इक", "ओह्"])},
35
+ "gu": {"chars": set("અઆઇઈઉઊઋએઐઓઔકખગઘઙચછજઝઞટઠડઢણતથદધનપફબભમયરલવશષસહ઼ાિીુૂૃેૈોૌ્"), "words": set(["અને", "એક", "માટે", "છે", "સાથે"])},
36
+ "hi": {"chars": set("अआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहािीुूृेैोौ्"), "words": set(["और", "है", "एक", "में", "के"])},
37
+ "kn": {"chars": set("ಅಆಇಈಉಊಋಎಏಐಒಓಔಕಖಗಘಙಚಛಜಝಞಟಠಡಢಣತಥದಧನಪಫಬಭಮಯರಲವಶಷಸಹಳಱಾಿೀುೂೃೆೇೈೊೋೌ್"), "words": set(["ಮತ್ತು", "ಒಂದು", "ಹೇಗೆ", "ನಾನು", "ಇದೆ"])},
38
+ "ks": {"chars": set("اآبپتٹثجچحخدڈذرڑزژسشصضطظعغفقکگلمنوھءییے"), "words": set([" تہٕ", "چھُ", "اکھ", "منز", "کیتھ"])},
39
+ "kok": {"chars": set("अआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहािीुूृेैोौ्"), "words": set(["आनी", "एक", "कर", "खातीर", "कडेन"])},
40
+ "mai": {"chars": set("अआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहािीुूृेैोौ्"), "words": set(["आ", "एक", "हम", "अछि", "क"])},
41
+ "ml": {"chars": set("അആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരലവശഷസഹളഴറാിീുൂൃെേൈൊോൌ്"), "words": set(["ഒരു", "மற்றும்", "എങ്ങനെ", "ഞാൻ", "ഇതു"])},
42
+ "mni": {"chars": set("ꯑ꯲꯳꯴꯵꯶꯷꯸꯹꯺꯻꯼꯽꯾꯿ꯀꯂꯃꯄꯅꯆꯇꯈꯉꯊꯋꯌꯍꯎꯏꯐꯑ"), "words": set(["ꯗꯥ", "ꯑꯃꯥ", "ꯀꯔꯤ", "ꯑꯩꯅꯥ", "ꯑꯁꯤ"])},
43
+ "mr": {"chars": set("अआइईउऊऋएऐओऔकखगघङचछजझ��टठडढणतथदधनपफबभमयरलवशषसहािीुूृेैोौ्ळ"), "words": set(["आणि", "एक", "आहे", "मी", "तू"])},
44
+ "ne": {"chars": set("अआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहािीुूृेैोौ्"), "words": set(["र", "एक", "हो", "म", "तिमी"])},
45
+ "or": {"chars": set("ଅଆଇଈଉଊଋଏଐଓଔକଖଗଘଙଚଛଜଝଞଟଠଡଢଣତଥଦଧନପଫବଭମଯରଲଳବଶଷସହକ୍ଷାିୀୁୂୃେୈୋୌ୍"), "words": set(["ଏବଂ", "ଗୋଟିଏ", "କରନ୍ତୁ", "ପାଇଁ", "ସହିତ"])},
46
+ "pa": {"chars": set("ਅਆਇਈਉਊਏਐਓਔਕਖਗਘਙਚਛਜਝਞਟਠਡਢਣਤਥਦਧਨਪਫਬਭਮਯਰਲਵਸ਼ਸਹਖ਼ਗ਼ਜ਼ੜਫ਼ਲ਼ਿੀੁੂੇੈੋੌ੍"), "words": set(["ਅਤੇ", "ਇੱਕ", "ਹੈ", "ਵਿੱਚ", "ਨੂੰ"])},
47
+ "sa": {"chars": set("अआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहािीुूृेैोौ्"), "words": set(["च", "एकः", "अस्ति", "अहम्", "त्वम्"])},
48
+ "sat": {"chars": set("ᱚᱟᱤᱥᱩᱨᱮႅᱳ鄴ᱠᱜᱝᱪᱡᱧଟଡᱬᱛᱫ Narayan pur pᱷᱵᱶᱷ"), "words": set(["ᱟᱨ", "ᱫᱚ", "হয়", "ఒకటి", "మరియు"])},
49
+ "sd": {"chars": set("اآبڀتٽثپجڄ جھچحخڌدڏڊذرزڙژسشصضطظعغفڦقڪکگڳڱلمنوھ ءي"), "words": set(["۽", "هڪ", "آهي", "۾", "کي"])},
50
+ "ta": {"chars": set("அஆஇஈஉஊஎஏஐஒஓஔகஙசஞடணதநனபமயரலவழளஷஸஹாిീுூெேைொோௌ்"), "words": set(["மற்றும்", "ஒரு", "வேண்டும்", "நான்", "இது"])},
51
+ "te": {"chars": set("అఆఇఈఉఊఋఎఏఐఒఓఔకఖగఘఙచఛజఝఞటఠడఢణతథదధనపఫబభమయరలవశషసహళక్షఱాిీుూృెేైొోౌ్"), "words": set(["మరియు", "ఒక", "வேண்டும்", "నేను", "ఇది"])},
52
+ "ur": {"chars": set("اآبپتٹثجچحخدڈذرڑزژسشصضطظعغفقکگلمنوھءییے"), "words": set(["اور", "ہے", "ایک", "میں", "کے"])},
53
  }
54
 
55
+ LANGUAGE_CODE_TO_NAME = {v: k for k, v in LANGUAGE_DATA.items()}
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
57
 
58
+ # Load Indic Conformer model
59
+ print("Loading IndicConformer model...")
60
  model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True).to(device)
61
  model.eval()
62
+ print("✅ Model loaded successfully.")
63
+
64
+
65
+ def identify_language(text: str) -> str | None:
66
+ """Identifies the language of a given text based on character sets and common words."""
67
+ if not text.strip():
68
+ return None
69
+
70
+ scores = {lang: 0 for lang in LANGUAGE_DATA}
71
+ text_chars = set(text)
72
+ # Use regex to split words, handling various scripts
73
+ text_words = set(re.split(r'[\s,.:;!?]+', text))
74
+
75
+ for lang_code, data in LANGUAGE_DATA.items():
76
+ char_score = len(text_chars.intersection(data["chars"]))
77
+ word_score = len(text_words.intersection(data["words"]))
78
+
79
+ # Give more weight to character matches as they are a stronger signal of the script
80
+ scores[lang_code] = (char_score * 2) + word_score
81
+
82
+ # Identify the language with the highest score
83
+ # Return None if the highest score is very low, indicating a poor match
84
+ max_score = max(scores.values())
85
+ if max_score < 3: # Heuristic threshold to prevent misidentification on noise
86
+ return None
87
+
88
+ identified_code = max(scores, key=scores.get)
89
+ return identified_code
90
+
91
 
92
  @spaces.GPU
93
+ def transcribe_and_identify(audio_path):
94
+ if not audio_path:
95
+ return "Please provide an audio file.", "", ""
96
 
97
  # Load and preprocess audio
98
+ try:
99
+ waveform, sr = torchaudio.load(audio_path)
100
+ waveform = waveform.mean(dim=0, keepdim=True) if waveform.shape[0] > 1 else waveform
101
+ waveform = torchaudio.functional.resample(waveform, sr, 16000).to(device)
102
+ except Exception as e:
103
+ return f"Error loading audio: {e}", "", ""
104
 
105
  try:
106
+ # 1. Perform a fast, initial transcription using a pivot language (Hindi)
107
  with torch.no_grad():
108
+ initial_transcription = model(waveform, "hi", "ctc")
109
+
110
+ # 2. Identify the language from the initial transcription
111
+ identified_lang_code = identify_language(initial_transcription)
112
+
113
+ if not identified_lang_code:
114
+ detected_lang_str = "Language not detected or unsupported."
115
+ return detected_lang_str, initial_transcription + " (pivot)", "Could not perform final transcription."
116
+
117
+ detected_lang_str = f"Detected Language: {LANGUAGE_CODE_TO_NAME.get(identified_lang_code, 'Unknown')}"
118
+
119
+ # 3. Perform the final, high-quality transcription using the identified language
120
+ with torch.no_grad():
121
+ transcription_ctc = model(waveform, identified_lang_code, "ctc")
122
+ transcription_rnnt = model(waveform, identified_lang_code, "rnnt")
123
+
124
  except Exception as e:
125
+ return f"Error during transcription: {str(e)}", "", ""
126
+
127
+ return detected_lang_str, transcription_ctc.strip(), transcription_rnnt.strip()
128
 
 
129
 
130
  # Gradio UI
131
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
  gr.Markdown(f"## {DESCRIPTION}")
133
+ gr.Markdown("Upload or record audio in any of the 22 supported Indian languages. The app will automatically detect the language and provide the transcription using both CTC and RNNT decoding.")
134
+
135
  with gr.Row():
136
+ with gr.Column(scale=1):
137
  audio = gr.Audio(label="Upload or Record Audio", type="filepath")
138
+ transcribe_btn = gr.Button("Transcribe", variant="primary")
139
+
140
+ with gr.Column(scale=2):
141
+ detected_lang_output = gr.Label(label="Language Detection Result")
142
+ gr.Markdown("### RNNT Transcription (More Accurate)")
143
+ rnnt_output = gr.Textbox(lines=3, label="RNNT Output")
144
+ gr.Markdown("### CTC Transcription (Faster)")
145
+ ctc_output = gr.Textbox(lines=3, label="CTC Output")
146
+
147
+ transcribe_btn.click(
148
+ fn=transcribe_and_identify,
149
+ inputs=[audio],
150
+ outputs=[detected_lang_output, ctc_output, rnnt_output],
151
+ api_name="transcribe"
152
+ )
153
 
154
  if __name__ == "__main__":
155
+ demo.queue().launch(share=True)
156
+
157
+ ```