# embed_build_muril.py """ Produce answer embeddings for the dataset using a fine-tuned MuRIL model. Saves: - muril_multilingual_dataset.csv (columns: question, answer, language) - answer_embeddings.pt (torch tensor shape [N, D], float32, on CPU) Usage: python embed_build_muril.py \ --model_dir ./muril_multilang_out \ --input_jsonl /path/to/legal_multilingual_QA_10k.jsonl \ --out_dir ./export_artifacts \ --batch_size 64 """ import argparse, os, math from pathlib import Path import torch import pandas as pd from tqdm.auto import tqdm from transformers import AutoTokenizer, AutoModel def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model_dir", type=str, default="./muril_multilang_out", help="Path or HF repo id of fine-tuned MuRIL") p.add_argument("--input_jsonl", type=str, required=True, help="Path to legal_multilingual_QA_10k.jsonl") p.add_argument("--out_dir", type=str, default="./export_artifacts") p.add_argument("--langs", type=str, default="en,hi,mr,ta,bn,gu,kn,ml,pa,or,as,ur,sa,ne", help="comma-separated languages to merge (will stack)") p.add_argument("--text_prefix", type=str, default="question_", help="prefix for question columns in JSONL") p.add_argument("--answer_col_prefix", type=str, default="answer_", help="prefix for answer columns if present (not used here)") p.add_argument("--batch_size", type=int, default=64) p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") return p.parse_args() def mean_pooling(last_hidden_state, attention_mask): # last_hidden_state: (B, L, H) # attention_mask: (B, L) input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return sum_embeddings / sum_mask def build_question_answer_rows(df, langs, text_prefix): rows = [] for _, r in df.iterrows(): # merge all available language question/answer pairs by stacking for lang in langs: qcol = f"{text_prefix}{lang}" acol = f"answer_{lang}" # If dataset uses question_ and answer_, use them; otherwise fall back to question_ and common 'answer' field. q = r.get(qcol, None) if q is None or str(q).strip() == "" or str(q).lower() == "nan": continue # pick answer_ if present else "answer" column if acol in df.columns and pd.notna(r.get(acol)): a = r.get(acol) else: a = r.get("answer", None) if a is None or str(a).strip() == "" or str(a).lower() == "nan": continue rows.append({"question": str(q).strip(), "answer": str(a).strip(), "language": lang}) return pd.DataFrame(rows) def main(): args = parse_args() os.makedirs(args.out_dir, exist_ok=True) # load JSONL to pandas print("Loading dataset:", args.input_jsonl) df_in = pd.read_json(args.input_jsonl, lines=True, dtype=str) # Build rows stacked across languages (question_, answer optional) langs = [l.strip() for l in args.langs.split(",") if l.strip()] print("Merging language columns (stack)... langs:", langs) rows_df = build_question_answer_rows(df_in, langs, args.text_prefix) if rows_df.empty: raise SystemExit("No question/answer rows found after merging languages. Check your columns.") print(f"Total rows extracted: {len(rows_df)}") # Save CSV (order matters) csv_path = Path(args.out_dir) / "muril_multilingual_dataset.csv" rows_df.to_csv(csv_path, index=False, encoding="utf-8") print("Saved merged CSV to:", csv_path) # Load model & tokenizer print("Loading tokenizer & model from:", args.model_dir, "device:", args.device) tokenizer = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True) model = AutoModel.from_pretrained(args.model_dir) model.to(args.device) model.eval() # Encode answers in batches answers = rows_df["answer"].astype(str).tolist() batch_size = int(args.batch_size) all_embs = [] with torch.inference_mode(): for i in tqdm(range(0, len(answers), batch_size), desc="Encoding"): batch_texts = answers[i:i+batch_size] encoded = tokenizer(batch_texts, padding=True, truncation=True, max_length=256, return_tensors="pt") input_ids = encoded["input_ids"].to(args.device) attention_mask = encoded["attention_mask"].to(args.device) out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) last_hidden = out.last_hidden_state # (B, L, H) pooled = mean_pooling(last_hidden, attention_mask) # (B, H) # L2-normalize embeddings (optional but recommended for cosine similarity) pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) all_embs.append(pooled.cpu()) all_embs = torch.cat(all_embs, dim=0) # (N, H) print("Embeddings shape:", all_embs.shape) embed_path = Path(args.out_dir) / "answer_embeddings.pt" torch.save(all_embs, embed_path) print("Saved embeddings to:", embed_path) print("Done. Artifacts in:", args.out_dir) if __name__ == "__main__": main()