Files
lora/train_v4.py
T

201 lines
7.3 KiB
Python

"""BT-7274 LoRA v4 — Qwen3.5-27B, bf16 LoRA (NOT QLoRA).
Key differences from v3 train script:
- Uses BASE Qwen3.5 tokenizer (Hermes tool format, NOT Coder XML)
- Dataset includes <think> blocks (enable_thinking in template)
- Combined dataset: persona + agent tools + reformatted v3
- No custom chat_template override — base model template produces
Hermes-format tool calls that vLLM's hermes parser can decode
vLLM serving flags for v4:
--tool-call-parser hermes
--reasoning-parser deepseek_r1
--enable-reasoning (or --enable-thinking via Qwen3 alias)
Usage:
pip install --upgrade unsloth unsloth_zoo
python train_v4.py
"""
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import torch
import json
# ── Config ───────────────────────────────────────────────────────────
MODEL = "Qwen/Qwen3.5-27B"
MAX_SEQ = 8192 # bumped from 4096 — multi-turn conversations are longer now
RANK = 16
ALPHA = 16
DATA = "./bt7274_v4.jsonl"
OUT = "./bt7274-qwen35-27b-lora-v4"
EPOCHS = 3
BATCH = 1
GRAD_ACCUM = 8
LR = 5e-5 # lowered from 1e-4 — larger dataset benefits from gentler lr
WARMUP_RATIO = 0.05 # 5% warmup instead of fixed steps
# ── Load model (bf16, NOT 4-bit) ────────────────────────────────────
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL,
max_seq_length=MAX_SEQ,
load_in_4bit=False, # QLoRA not recommended for Qwen3.5
load_in_16bit=True, # bf16 LoRA
full_finetuning=False,
dtype=torch.bfloat16,
)
# CRITICAL: Verify we're using the BASE tokenizer, not a LoRA override.
# The base Qwen3.5 template produces Hermes-format tool calls:
# <tool_call>{"name":"...","arguments":{...}}</tool_call>
# NOT the Coder XML format that v3 used.
print(f"Chat template source: {tokenizer.chat_template[:80] if tokenizer.chat_template else 'NONE'}...")
# ── LoRA adapter ────────────────────────────────────────────────────
model = FastLanguageModel.get_peft_model(
model,
r=RANK,
lora_alpha=ALPHA,
lora_dropout=0,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
max_seq_length=MAX_SEQ,
)
# ── Dataset ─────────────────────────────────────────────────────────
def fix_tool_calls(messages):
"""Parse tool_call arguments from JSON strings to dicts for Qwen3.5 template."""
fixed = []
for msg in messages:
msg = dict(msg)
if msg.get("tool_calls"):
new_tcs = []
for tc in msg["tool_calls"]:
tc = dict(tc)
if "function" in tc:
fn = dict(tc["function"])
if isinstance(fn.get("arguments"), str):
try:
fn["arguments"] = json.loads(fn["arguments"])
except (ValueError, TypeError):
fn["arguments"] = {"raw": fn["arguments"]}
tc["function"] = fn
new_tcs.append(tc)
msg["tool_calls"] = new_tcs
fixed.append(msg)
return fixed
def load_and_format(path):
"""Load JSONL manually — pyarrow chokes on mixed tool_calls argument types."""
from datasets import Dataset
texts = []
skipped = 0
with open(path) as f:
for line in f:
line = line.strip()
if not line:
continue
row = json.loads(line)
messages = fix_tool_calls(row["messages"])
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
enable_thinking=True,
)
except TypeError:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
if len(tokenizer.encode(text)) <= MAX_SEQ:
texts.append(text)
else:
skipped += 1
if skipped:
print(f"⚠ Filtered {skipped} examples exceeding {MAX_SEQ} tokens")
return Dataset.from_dict({"text": texts})
ds = load_and_format(DATA)
steps = (len(ds) * EPOCHS) // (BATCH * GRAD_ACCUM)
print(f"Dataset: {len(ds)} examples")
print(f"Epochs: {EPOCHS}, effective batch: {BATCH * GRAD_ACCUM}")
print(f"Estimated steps: {steps}")
print(f"LoRA: r={RANK}, alpha={ALPHA}")
print(f"Max seq: {MAX_SEQ}")
print(f"Model: {MODEL}")
print(f"Learning rate: {LR}")
print(f"Output: {OUT}")
# ── Train ───────────────────────────────────────────────────────────
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=ds,
args=SFTConfig(
output_dir=OUT,
per_device_train_batch_size=BATCH,
gradient_accumulation_steps=GRAD_ACCUM,
num_train_epochs=EPOCHS,
learning_rate=LR,
bf16=True,
logging_steps=5,
save_steps=100,
save_total_limit=2,
warmup_ratio=WARMUP_RATIO,
optim="adamw_torch",
seed=42,
report_to="none",
max_seq_length=MAX_SEQ,
dataset_num_proc=1,
lr_scheduler_type="cosine", # cosine decay for smoother convergence
weight_decay=0.01, # light regularization
),
)
trainer.train()
# ── Save LoRA adapter ──────────────────────────────────────────────
# IMPORTANT: Do NOT save a custom chat_template.
# The base Qwen3.5 template is correct for Hermes format.
# v3's mistake was saving a Coder XML template with the adapter.
model.save_pretrained(OUT)
# Save tokenizer WITHOUT overriding the chat template
# This ensures the base model's template is used at inference time
tokenizer.save_pretrained(OUT)
# Verify no chat_template.jinja was saved (or if it was, it's the base one)
import os
template_path = os.path.join(OUT, "chat_template.jinja")
if os.path.exists(template_path):
with open(template_path) as f:
content = f.read()
if "function=" in content or "<parameter=" in content:
print("⚠ WARNING: Saved template contains Coder XML format!")
print(" This will cause hermes parser failures at inference.")
print(" Delete chat_template.jinja from the adapter directory.")
else:
print("✓ Saved template uses Hermes JSON format (correct)")
else:
print("✓ No chat_template.jinja saved — will use base model template")
print(f"\nSaved LoRA adapter to {OUT}/")
print(f"\nDeploy with:")
print(f" --lora-modules bt7274={OUT}")
print(f" --tool-call-parser hermes")
print(f" --reasoning-parser deepseek_r1")