| #!/bin/env python | |
| # This script extracts the "encoder-only" part from the full t5-xl model | |
| from transformers import T5ForConditionalGeneration, T5EncoderModel | |
| src_model_name = "google/t5-v1_1-xl" | |
| dst_dir = "./t5-v1_1-xl-encoder-only" | |
| full_model = T5ForConditionalGeneration.from_pretrained(src_model_name) | |
| # Initialize empty encoder-only model (inherits config, so tokenizer stays compatible) | |
| encoder_model = T5EncoderModel(full_model.config) | |
| # Get the full state dict, then ditch the parts we dont need | |
| state_dict = full_model.state_dict() | |
| encoder_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("decoder.") and not k.startswith("lm_head.")} | |
| encoder_model.load_state_dict(encoder_state_dict) | |
| encoder_model.save_pretrained(dst_dir) | |
| print(f"Encoder-only model saved to {dst_dir}") | |