201 lines
7.3 KiB
Python
201 lines
7.3 KiB
Python
"""BT-7274 LoRA v4 — Qwen3.5-27B, bf16 LoRA (NOT QLoRA).
|
|
|
|
Key differences from v3 train script:
|
|
- Uses BASE Qwen3.5 tokenizer (Hermes tool format, NOT Coder XML)
|
|
- Dataset includes <think> blocks (enable_thinking in template)
|
|
- Combined dataset: persona + agent tools + reformatted v3
|
|
- No custom chat_template override — base model template produces
|
|
Hermes-format tool calls that vLLM's hermes parser can decode
|
|
|
|
vLLM serving flags for v4:
|
|
--tool-call-parser hermes
|
|
--reasoning-parser deepseek_r1
|
|
--enable-reasoning (or --enable-thinking via Qwen3 alias)
|
|
|
|
Usage:
|
|
pip install --upgrade unsloth unsloth_zoo
|
|
python train_v4.py
|
|
"""
|
|
|
|
from unsloth import FastLanguageModel
|
|
from trl import SFTTrainer, SFTConfig
|
|
from datasets import load_dataset
|
|
import torch
|
|
import json
|
|
|
|
# ── Config ───────────────────────────────────────────────────────────
|
|
MODEL = "Qwen/Qwen3.5-27B"
|
|
MAX_SEQ = 8192 # bumped from 4096 — multi-turn conversations are longer now
|
|
RANK = 16
|
|
ALPHA = 16
|
|
DATA = "./bt7274_v4.jsonl"
|
|
OUT = "./bt7274-qwen35-27b-lora-v4"
|
|
EPOCHS = 3
|
|
BATCH = 1
|
|
GRAD_ACCUM = 8
|
|
LR = 5e-5 # lowered from 1e-4 — larger dataset benefits from gentler lr
|
|
WARMUP_RATIO = 0.05 # 5% warmup instead of fixed steps
|
|
|
|
# ── Load model (bf16, NOT 4-bit) ────────────────────────────────────
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
model_name=MODEL,
|
|
max_seq_length=MAX_SEQ,
|
|
load_in_4bit=False, # QLoRA not recommended for Qwen3.5
|
|
load_in_16bit=True, # bf16 LoRA
|
|
full_finetuning=False,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
# CRITICAL: Verify we're using the BASE tokenizer, not a LoRA override.
|
|
# The base Qwen3.5 template produces Hermes-format tool calls:
|
|
# <tool_call>{"name":"...","arguments":{...}}</tool_call>
|
|
# NOT the Coder XML format that v3 used.
|
|
print(f"Chat template source: {tokenizer.chat_template[:80] if tokenizer.chat_template else 'NONE'}...")
|
|
|
|
# ── LoRA adapter ────────────────────────────────────────────────────
|
|
model = FastLanguageModel.get_peft_model(
|
|
model,
|
|
r=RANK,
|
|
lora_alpha=ALPHA,
|
|
lora_dropout=0,
|
|
target_modules=[
|
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
"gate_proj", "up_proj", "down_proj",
|
|
],
|
|
bias="none",
|
|
use_gradient_checkpointing="unsloth",
|
|
random_state=42,
|
|
max_seq_length=MAX_SEQ,
|
|
)
|
|
|
|
# ── Dataset ─────────────────────────────────────────────────────────
|
|
|
|
|
|
def fix_tool_calls(messages):
|
|
"""Parse tool_call arguments from JSON strings to dicts for Qwen3.5 template."""
|
|
fixed = []
|
|
for msg in messages:
|
|
msg = dict(msg)
|
|
if msg.get("tool_calls"):
|
|
new_tcs = []
|
|
for tc in msg["tool_calls"]:
|
|
tc = dict(tc)
|
|
if "function" in tc:
|
|
fn = dict(tc["function"])
|
|
if isinstance(fn.get("arguments"), str):
|
|
try:
|
|
fn["arguments"] = json.loads(fn["arguments"])
|
|
except (ValueError, TypeError):
|
|
fn["arguments"] = {"raw": fn["arguments"]}
|
|
tc["function"] = fn
|
|
new_tcs.append(tc)
|
|
msg["tool_calls"] = new_tcs
|
|
fixed.append(msg)
|
|
return fixed
|
|
|
|
|
|
def load_and_format(path):
|
|
"""Load JSONL manually — pyarrow chokes on mixed tool_calls argument types."""
|
|
from datasets import Dataset
|
|
texts = []
|
|
skipped = 0
|
|
with open(path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
row = json.loads(line)
|
|
messages = fix_tool_calls(row["messages"])
|
|
try:
|
|
text = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=False,
|
|
enable_thinking=True,
|
|
)
|
|
except TypeError:
|
|
text = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=False,
|
|
)
|
|
if len(tokenizer.encode(text)) <= MAX_SEQ:
|
|
texts.append(text)
|
|
else:
|
|
skipped += 1
|
|
if skipped:
|
|
print(f"⚠ Filtered {skipped} examples exceeding {MAX_SEQ} tokens")
|
|
return Dataset.from_dict({"text": texts})
|
|
|
|
|
|
ds = load_and_format(DATA)
|
|
|
|
steps = (len(ds) * EPOCHS) // (BATCH * GRAD_ACCUM)
|
|
print(f"Dataset: {len(ds)} examples")
|
|
print(f"Epochs: {EPOCHS}, effective batch: {BATCH * GRAD_ACCUM}")
|
|
print(f"Estimated steps: {steps}")
|
|
print(f"LoRA: r={RANK}, alpha={ALPHA}")
|
|
print(f"Max seq: {MAX_SEQ}")
|
|
print(f"Model: {MODEL}")
|
|
print(f"Learning rate: {LR}")
|
|
print(f"Output: {OUT}")
|
|
|
|
# ── Train ───────────────────────────────────────────────────────────
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
train_dataset=ds,
|
|
args=SFTConfig(
|
|
output_dir=OUT,
|
|
per_device_train_batch_size=BATCH,
|
|
gradient_accumulation_steps=GRAD_ACCUM,
|
|
num_train_epochs=EPOCHS,
|
|
learning_rate=LR,
|
|
bf16=True,
|
|
logging_steps=5,
|
|
save_steps=100,
|
|
save_total_limit=2,
|
|
warmup_ratio=WARMUP_RATIO,
|
|
optim="adamw_torch",
|
|
seed=42,
|
|
report_to="none",
|
|
max_seq_length=MAX_SEQ,
|
|
dataset_num_proc=1,
|
|
lr_scheduler_type="cosine", # cosine decay for smoother convergence
|
|
weight_decay=0.01, # light regularization
|
|
),
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
# ── Save LoRA adapter ──────────────────────────────────────────────
|
|
# IMPORTANT: Do NOT save a custom chat_template.
|
|
# The base Qwen3.5 template is correct for Hermes format.
|
|
# v3's mistake was saving a Coder XML template with the adapter.
|
|
model.save_pretrained(OUT)
|
|
|
|
# Save tokenizer WITHOUT overriding the chat template
|
|
# This ensures the base model's template is used at inference time
|
|
tokenizer.save_pretrained(OUT)
|
|
|
|
# Verify no chat_template.jinja was saved (or if it was, it's the base one)
|
|
import os
|
|
template_path = os.path.join(OUT, "chat_template.jinja")
|
|
if os.path.exists(template_path):
|
|
with open(template_path) as f:
|
|
content = f.read()
|
|
if "function=" in content or "<parameter=" in content:
|
|
print("⚠ WARNING: Saved template contains Coder XML format!")
|
|
print(" This will cause hermes parser failures at inference.")
|
|
print(" Delete chat_template.jinja from the adapter directory.")
|
|
else:
|
|
print("✓ Saved template uses Hermes JSON format (correct)")
|
|
else:
|
|
print("✓ No chat_template.jinja saved — will use base model template")
|
|
|
|
print(f"\nSaved LoRA adapter to {OUT}/")
|
|
print(f"\nDeploy with:")
|
|
print(f" --lora-modules bt7274={OUT}")
|
|
print(f" --tool-call-parser hermes")
|
|
print(f" --reasoning-parser deepseek_r1")
|