217 lines
8.3 KiB
Python
217 lines
8.3 KiB
Python
"""Specialist LoRA trainer — parameterized for all adapters.
|
||
|
||
Same architecture as train_qwen35_27b.py (bt7274 persona) but configurable
|
||
per specialist via CLI args or environment variables.
|
||
|
||
Usage:
|
||
# Rust specialist
|
||
python train_specialist.py --name oxidizer --data data/oxidizer.jsonl --max-seq 8192
|
||
|
||
# TypeScript specialist
|
||
python train_specialist.py --name prism --data data/prism.jsonl --max-seq 8192
|
||
|
||
# TTS cleanup (smaller sequences, more epochs)
|
||
python train_specialist.py --name trace --data data/trace.jsonl \
|
||
--max-seq 2048 --epochs 5 --lr 1e-4
|
||
|
||
# All defaults
|
||
python train_specialist.py --name oxidizer
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
|
||
from unsloth import FastLanguageModel
|
||
from trl import SFTTrainer, SFTConfig
|
||
from datasets import load_dataset
|
||
import torch
|
||
|
||
# ── Defaults ─────────────────────────────────────────────────────────
|
||
|
||
DEFAULTS = {
|
||
"model": "Qwen/Qwen3.5-27B",
|
||
"max_seq": 8192,
|
||
"rank": 16,
|
||
"alpha": 16,
|
||
"epochs": 3,
|
||
"batch": 1,
|
||
"grad_accum": 8,
|
||
"lr": 5e-5,
|
||
"warmup": 10,
|
||
"save_steps": 50,
|
||
"save_total_limit": 2,
|
||
}
|
||
|
||
# Per-adapter overrides
|
||
ADAPTER_OVERRIDES = {
|
||
"bt7274": {"max_seq": 4096, "lr": 1e-4, "data": "bt7274_v3.jsonl"},
|
||
"oxidizer": {"data": "data/oxidizer.jsonl"},
|
||
"serpent": {"data": "data/serpent.jsonl"},
|
||
"prism": {"data": "data/prism.jsonl"},
|
||
"forge": {"data": "data/forge.jsonl"},
|
||
"swiftblade": {"data": "data/swiftblade.jsonl"},
|
||
"trace": {"max_seq": 2048, "lr": 1e-4, "epochs": 5, "data": "data/trace.jsonl"},
|
||
}
|
||
|
||
|
||
def fix_tool_calls(messages):
|
||
"""Parse tool_call arguments from JSON strings to dicts for Qwen3.5 template."""
|
||
import json as _json
|
||
fixed = []
|
||
for msg in messages:
|
||
msg = dict(msg)
|
||
if msg.get("tool_calls"):
|
||
new_tcs = []
|
||
for tc in msg["tool_calls"]:
|
||
tc = dict(tc)
|
||
if "function" in tc:
|
||
fn = dict(tc["function"])
|
||
if isinstance(fn.get("arguments"), str):
|
||
try:
|
||
fn["arguments"] = _json.loads(fn["arguments"])
|
||
except (ValueError, TypeError):
|
||
fn["arguments"] = {"raw": fn["arguments"]}
|
||
tc["function"] = fn
|
||
new_tcs.append(tc)
|
||
msg["tool_calls"] = new_tcs
|
||
fixed.append(msg)
|
||
return fixed
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="Train specialist LoRA adapter")
|
||
parser.add_argument("--name", required=True, help="Adapter name (oxidizer, serpent, prism, forge, swiftblade, trace)")
|
||
parser.add_argument("--model", default=None, help=f"Base model (default: {DEFAULTS['model']})")
|
||
parser.add_argument("--data", default=None, help="Training data JSONL path")
|
||
parser.add_argument("--out", default=None, help="Output directory (default: adapters/<name>)")
|
||
parser.add_argument("--max-seq", type=int, default=None, help=f"Max sequence length")
|
||
parser.add_argument("--rank", type=int, default=None, help=f"LoRA rank")
|
||
parser.add_argument("--alpha", type=int, default=None, help=f"LoRA alpha")
|
||
parser.add_argument("--epochs", type=int, default=None, help=f"Training epochs")
|
||
parser.add_argument("--batch", type=int, default=None, help=f"Batch size")
|
||
parser.add_argument("--grad-accum", type=int, default=None, help=f"Gradient accumulation steps")
|
||
parser.add_argument("--lr", type=float, default=None, help=f"Learning rate")
|
||
parser.add_argument("--warmup", type=int, default=None, help=f"Warmup steps")
|
||
parser.add_argument("--resume", default=None, help="Resume from checkpoint path")
|
||
args = parser.parse_args()
|
||
|
||
# Resolve config: CLI > adapter overrides > defaults
|
||
overrides = ADAPTER_OVERRIDES.get(args.name, {})
|
||
|
||
def resolve(key, cli_val):
|
||
if cli_val is not None:
|
||
return cli_val
|
||
if key in overrides:
|
||
return overrides[key]
|
||
return DEFAULTS[key]
|
||
|
||
model_name = resolve("model", args.model)
|
||
max_seq = resolve("max_seq", args.max_seq)
|
||
rank = resolve("rank", args.rank)
|
||
alpha = resolve("alpha", args.alpha)
|
||
epochs = resolve("epochs", args.epochs)
|
||
batch = resolve("batch", args.batch)
|
||
grad_accum = resolve("grad_accum", args.grad_accum)
|
||
lr = resolve("lr", args.lr)
|
||
warmup = resolve("warmup", args.warmup)
|
||
data_path = args.data or overrides.get("data", f"data/{args.name}.jsonl")
|
||
out_dir = args.out or f"adapters/{args.name}"
|
||
|
||
print(f"══ Specialist LoRA Training: {args.name} ══")
|
||
print(f"Base model: {model_name}")
|
||
print(f"Data: {data_path}")
|
||
print(f"Output: {out_dir}")
|
||
print(f"Max seq: {max_seq}")
|
||
print(f"LoRA: r={rank}, α={alpha}")
|
||
print(f"Training: {epochs} epochs, batch {batch}, grad_accum {grad_accum}")
|
||
print(f"LR: {lr}")
|
||
print(f"Warmup: {warmup} steps")
|
||
print()
|
||
|
||
# ── Load model ───────────────────────────────────────────────────
|
||
print("Loading model (bf16, no quantization)...")
|
||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||
model_name=model_name,
|
||
max_seq_length=max_seq,
|
||
load_in_4bit=False,
|
||
load_in_16bit=True,
|
||
full_finetuning=False,
|
||
dtype=torch.bfloat16,
|
||
)
|
||
|
||
# ── LoRA adapter ─────────────────────────────────────────────────
|
||
print("Applying LoRA...")
|
||
model = FastLanguageModel.get_peft_model(
|
||
model,
|
||
r=rank,
|
||
lora_alpha=alpha,
|
||
lora_dropout=0,
|
||
target_modules=[
|
||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||
"gate_proj", "up_proj", "down_proj",
|
||
],
|
||
bias="none",
|
||
use_gradient_checkpointing="unsloth",
|
||
random_state=42,
|
||
max_seq_length=max_seq,
|
||
)
|
||
|
||
# ── Dataset ──────────────────────────────────────────────────────
|
||
print(f"Loading dataset: {data_path}")
|
||
ds = load_dataset("json", data_files=data_path, split="train")
|
||
|
||
def to_chatml(ex):
|
||
messages = fix_tool_calls(ex["messages"])
|
||
text = tokenizer.apply_chat_template(
|
||
messages, tokenize=False, add_generation_prompt=False
|
||
)
|
||
return {"text": text}
|
||
|
||
ds = ds.map(to_chatml)
|
||
|
||
steps = (len(ds) * epochs) // (batch * grad_accum)
|
||
print(f"Dataset: {len(ds)} examples")
|
||
print(f"Epochs: {epochs}, effective batch: {batch * grad_accum}")
|
||
print(f"Est. steps: {steps}")
|
||
|
||
# ── Train ────────────────────────────────────────────────────────
|
||
print("\nStarting training...")
|
||
trainer = SFTTrainer(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
train_dataset=ds,
|
||
args=SFTConfig(
|
||
output_dir=out_dir,
|
||
per_device_train_batch_size=batch,
|
||
gradient_accumulation_steps=grad_accum,
|
||
num_train_epochs=epochs,
|
||
learning_rate=lr,
|
||
bf16=True,
|
||
logging_steps=5,
|
||
save_steps=resolve("save_steps", None),
|
||
save_total_limit=resolve("save_total_limit", None),
|
||
warmup_steps=warmup,
|
||
optim="adamw_8bit",
|
||
seed=42,
|
||
report_to="none",
|
||
max_seq_length=max_seq,
|
||
dataset_num_proc=1,
|
||
),
|
||
)
|
||
|
||
if args.resume:
|
||
print(f"Resuming from checkpoint: {args.resume}")
|
||
trainer.train(resume_from_checkpoint=args.resume)
|
||
else:
|
||
trainer.train()
|
||
|
||
# ── Save ─────────────────────────────────────────────────────────
|
||
model.save_pretrained(out_dir)
|
||
tokenizer.save_pretrained(out_dir)
|
||
print(f"\n✓ Saved {args.name} adapter to {out_dir}/")
|
||
print(f" Transfer to sin: ~/models/loras/{args.name}/")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|