Files
lora/train_specialist.py

217 lines
8.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Specialist LoRA trainer — parameterized for all adapters.
Same architecture as train_qwen35_27b.py (bt7274 persona) but configurable
per specialist via CLI args or environment variables.
Usage:
# Rust specialist
python train_specialist.py --name oxidizer --data data/oxidizer.jsonl --max-seq 8192
# TypeScript specialist
python train_specialist.py --name prism --data data/prism.jsonl --max-seq 8192
# TTS cleanup (smaller sequences, more epochs)
python train_specialist.py --name trace --data data/trace.jsonl \
--max-seq 2048 --epochs 5 --lr 1e-4
# All defaults
python train_specialist.py --name oxidizer
"""
import argparse
import os
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import torch
# ── Defaults ─────────────────────────────────────────────────────────
DEFAULTS = {
"model": "Qwen/Qwen3.5-27B",
"max_seq": 8192,
"rank": 16,
"alpha": 16,
"epochs": 3,
"batch": 1,
"grad_accum": 8,
"lr": 5e-5,
"warmup": 10,
"save_steps": 50,
"save_total_limit": 2,
}
# Per-adapter overrides
ADAPTER_OVERRIDES = {
"bt7274": {"max_seq": 4096, "lr": 1e-4, "data": "bt7274_v3.jsonl"},
"oxidizer": {"data": "data/oxidizer.jsonl"},
"serpent": {"data": "data/serpent.jsonl"},
"prism": {"data": "data/prism.jsonl"},
"forge": {"data": "data/forge.jsonl"},
"swiftblade": {"data": "data/swiftblade.jsonl"},
"trace": {"max_seq": 2048, "lr": 1e-4, "epochs": 5, "data": "data/trace.jsonl"},
}
def fix_tool_calls(messages):
"""Parse tool_call arguments from JSON strings to dicts for Qwen3.5 template."""
import json as _json
fixed = []
for msg in messages:
msg = dict(msg)
if msg.get("tool_calls"):
new_tcs = []
for tc in msg["tool_calls"]:
tc = dict(tc)
if "function" in tc:
fn = dict(tc["function"])
if isinstance(fn.get("arguments"), str):
try:
fn["arguments"] = _json.loads(fn["arguments"])
except (ValueError, TypeError):
fn["arguments"] = {"raw": fn["arguments"]}
tc["function"] = fn
new_tcs.append(tc)
msg["tool_calls"] = new_tcs
fixed.append(msg)
return fixed
def main():
parser = argparse.ArgumentParser(description="Train specialist LoRA adapter")
parser.add_argument("--name", required=True, help="Adapter name (oxidizer, serpent, prism, forge, swiftblade, trace)")
parser.add_argument("--model", default=None, help=f"Base model (default: {DEFAULTS['model']})")
parser.add_argument("--data", default=None, help="Training data JSONL path")
parser.add_argument("--out", default=None, help="Output directory (default: adapters/<name>)")
parser.add_argument("--max-seq", type=int, default=None, help=f"Max sequence length")
parser.add_argument("--rank", type=int, default=None, help=f"LoRA rank")
parser.add_argument("--alpha", type=int, default=None, help=f"LoRA alpha")
parser.add_argument("--epochs", type=int, default=None, help=f"Training epochs")
parser.add_argument("--batch", type=int, default=None, help=f"Batch size")
parser.add_argument("--grad-accum", type=int, default=None, help=f"Gradient accumulation steps")
parser.add_argument("--lr", type=float, default=None, help=f"Learning rate")
parser.add_argument("--warmup", type=int, default=None, help=f"Warmup steps")
parser.add_argument("--resume", default=None, help="Resume from checkpoint path")
args = parser.parse_args()
# Resolve config: CLI > adapter overrides > defaults
overrides = ADAPTER_OVERRIDES.get(args.name, {})
def resolve(key, cli_val):
if cli_val is not None:
return cli_val
if key in overrides:
return overrides[key]
return DEFAULTS[key]
model_name = resolve("model", args.model)
max_seq = resolve("max_seq", args.max_seq)
rank = resolve("rank", args.rank)
alpha = resolve("alpha", args.alpha)
epochs = resolve("epochs", args.epochs)
batch = resolve("batch", args.batch)
grad_accum = resolve("grad_accum", args.grad_accum)
lr = resolve("lr", args.lr)
warmup = resolve("warmup", args.warmup)
data_path = args.data or overrides.get("data", f"data/{args.name}.jsonl")
out_dir = args.out or f"adapters/{args.name}"
print(f"══ Specialist LoRA Training: {args.name} ══")
print(f"Base model: {model_name}")
print(f"Data: {data_path}")
print(f"Output: {out_dir}")
print(f"Max seq: {max_seq}")
print(f"LoRA: r={rank}, α={alpha}")
print(f"Training: {epochs} epochs, batch {batch}, grad_accum {grad_accum}")
print(f"LR: {lr}")
print(f"Warmup: {warmup} steps")
print()
# ── Load model ───────────────────────────────────────────────────
print("Loading model (bf16, no quantization)...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq,
load_in_4bit=False,
load_in_16bit=True,
full_finetuning=False,
dtype=torch.bfloat16,
)
# ── LoRA adapter ─────────────────────────────────────────────────
print("Applying LoRA...")
model = FastLanguageModel.get_peft_model(
model,
r=rank,
lora_alpha=alpha,
lora_dropout=0,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
max_seq_length=max_seq,
)
# ── Dataset ──────────────────────────────────────────────────────
print(f"Loading dataset: {data_path}")
ds = load_dataset("json", data_files=data_path, split="train")
def to_chatml(ex):
messages = fix_tool_calls(ex["messages"])
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
return {"text": text}
ds = ds.map(to_chatml)
steps = (len(ds) * epochs) // (batch * grad_accum)
print(f"Dataset: {len(ds)} examples")
print(f"Epochs: {epochs}, effective batch: {batch * grad_accum}")
print(f"Est. steps: {steps}")
# ── Train ────────────────────────────────────────────────────────
print("\nStarting training...")
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=ds,
args=SFTConfig(
output_dir=out_dir,
per_device_train_batch_size=batch,
gradient_accumulation_steps=grad_accum,
num_train_epochs=epochs,
learning_rate=lr,
bf16=True,
logging_steps=5,
save_steps=resolve("save_steps", None),
save_total_limit=resolve("save_total_limit", None),
warmup_steps=warmup,
optim="adamw_8bit",
seed=42,
report_to="none",
max_seq_length=max_seq,
dataset_num_proc=1,
),
)
if args.resume:
print(f"Resuming from checkpoint: {args.resume}")
trainer.train(resume_from_checkpoint=args.resume)
else:
trainer.train()
# ── Save ─────────────────────────────────────────────────────────
model.save_pretrained(out_dir)
tokenizer.save_pretrained(out_dir)
print(f"\n✓ Saved {args.name} adapter to {out_dir}/")
print(f" Transfer to sin: ~/models/loras/{args.name}/")
if __name__ == "__main__":
main()