training-scripts / train_sft_demo.py
passagereptile455's picture
Upload train_sft_demo.py with huggingface_hub
c5c3c89 verified
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "datasets", "transformers", "torch", "accelerate"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import os
# Load a small dataset
dataset = load_dataset("trl-lib/Capybara", split="train[:500]")
# Setup trainer
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
peft_config=LoraConfig(r=16, lora_alpha=32, target_modules="all-linear"),
args=SFTConfig(
output_dir="qwen-demo-sft",
max_steps=100,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
logging_steps=10,
push_to_hub=True,
hub_model_id="passagereptile455/qwen-demo-sft",
hub_private_repo=True,
)
)
trainer.train()
trainer.push_to_hub()
print("Training complete! Model pushed to Hub.")