feat: tts-norm LoRA — dataset generator + training script
gen_tts_dataset.py: 4960 synthetic examples, 22 categories (numbers, currencies, dates, times, temperatures, acronyms, NATO phonetic, URLs, markdown, etc). Bilingual EN/PL with explicit [lang] tag prefix. train_tts_norm.py: Unsloth LoRA training for Qwen2.5-7B-Instruct. Rank 16, 3 epochs, packing, max_seq 768. Trained on H100 in 20m38s, final loss 0.091. Adapter: 154MB.
This commit is contained in:
@@ -0,0 +1,173 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train TTS normalization LoRA on Qwen2.5-7B-Instruct using Unsloth.
|
||||
|
||||
Reads: tts_norm_dataset.jsonl (ShareGPT format)
|
||||
Output: tts-norm-lora/ adapter (vLLM-compatible)
|
||||
|
||||
Run on junkpile — RTX 2000 Ada 16GB.
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from unsloth import FastLanguageModel
|
||||
from unsloth.chat_templates import get_chat_template, standardize_sharegpt
|
||||
from trl import SFTTrainer
|
||||
from transformers import TrainingArguments
|
||||
from datasets import load_dataset
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# CONFIG
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
MODEL_NAME = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"
|
||||
DATASET_PATH = "tts_norm_dataset.jsonl"
|
||||
OUTPUT_DIR = "./tts-norm-lora"
|
||||
MAX_SEQ_LEN = 768 # TTS normalization is short text
|
||||
LORA_RANK = 16
|
||||
LORA_ALPHA = 16
|
||||
BATCH_SIZE = 2 # small GPU — use grad accumulation
|
||||
GRAD_ACCUM = 8 # effective batch = 16
|
||||
EPOCHS = 3
|
||||
LR = 2e-4
|
||||
WARMUP_STEPS = 30
|
||||
SAVE_STEPS = 100
|
||||
LOGGING_STEPS = 10
|
||||
SEED = 42
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# LOAD MODEL
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
print(f"Loading {MODEL_NAME}...")
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=MODEL_NAME,
|
||||
max_seq_length=MAX_SEQ_LEN,
|
||||
load_in_4bit=True,
|
||||
dtype=None, # auto-detect
|
||||
)
|
||||
|
||||
# Apply chat template
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="qwen-2.5",
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# PEFT CONFIG
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
print("Applying LoRA...")
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=LORA_RANK,
|
||||
lora_alpha=LORA_ALPHA,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
],
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=SEED,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# DATASET
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
print(f"Loading dataset from {DATASET_PATH}...")
|
||||
dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
|
||||
print(f" {len(dataset)} examples loaded")
|
||||
|
||||
# Standardize to ShareGPT format (handles from/value vs role/content)
|
||||
dataset = standardize_sharegpt(dataset)
|
||||
|
||||
# Pre-apply chat template via map — avoids formatting_func signature issues
|
||||
def apply_template(examples):
|
||||
"""Apply Qwen2.5 chat template to conversations."""
|
||||
convos = examples["conversations"]
|
||||
texts = []
|
||||
for convo in convos:
|
||||
text = tokenizer.apply_chat_template(
|
||||
convo,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
texts.append(text)
|
||||
return {"text": texts}
|
||||
|
||||
print("Applying chat template...")
|
||||
dataset = dataset.map(apply_template, batched=True, num_proc=2)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# TRAINER
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
print("Setting up trainer...")
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
args=TrainingArguments(
|
||||
output_dir=OUTPUT_DIR,
|
||||
per_device_train_batch_size=BATCH_SIZE,
|
||||
gradient_accumulation_steps=GRAD_ACCUM,
|
||||
num_train_epochs=EPOCHS,
|
||||
learning_rate=LR,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_steps=WARMUP_STEPS,
|
||||
fp16=not torch.cuda.is_bf16_supported(),
|
||||
bf16=torch.cuda.is_bf16_supported(),
|
||||
logging_steps=LOGGING_STEPS,
|
||||
save_steps=SAVE_STEPS,
|
||||
save_total_limit=3,
|
||||
seed=SEED,
|
||||
optim="adamw_8bit",
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=1.0,
|
||||
report_to="none",
|
||||
dataloader_num_workers=2,
|
||||
),
|
||||
max_seq_length=MAX_SEQ_LEN,
|
||||
dataset_num_proc=2,
|
||||
packing=True, # pack short examples for efficiency
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# TRAIN
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
print("Starting training...")
|
||||
stats = trainer.train()
|
||||
print(f"\nTraining complete!")
|
||||
print(f" Total steps: {stats.global_step}")
|
||||
print(f" Train loss: {stats.training_loss:.4f}")
|
||||
print(f" Runtime: {stats.metrics['train_runtime']:.0f}s")
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# SAVE ADAPTER
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
print(f"\nSaving adapter to {OUTPUT_DIR}...")
|
||||
model.save_pretrained(OUTPUT_DIR)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
|
||||
# Verify
|
||||
adapter_path = Path(OUTPUT_DIR) / "adapter_model.safetensors"
|
||||
if adapter_path.exists():
|
||||
size_mb = adapter_path.stat().st_size / (1024 * 1024)
|
||||
print(f" Adapter saved: {size_mb:.1f} MB")
|
||||
else:
|
||||
print(" WARNING: adapter_model.safetensors not found!")
|
||||
|
||||
config_path = Path(OUTPUT_DIR) / "adapter_config.json"
|
||||
if config_path.exists():
|
||||
print(f" Config saved: {config_path}")
|
||||
|
||||
print(f"\nDone. Serve with:")
|
||||
print(f" vllm serve Qwen/Qwen2.5-7B-Instruct \\")
|
||||
print(f" --enable-lora \\")
|
||||
print(f" --lora-modules tts-norm={os.path.abspath(OUTPUT_DIR)} \\")
|
||||
print(f" --max-lora-rank {LORA_RANK}")
|
||||
Reference in New Issue
Block a user