Files
lora/train_memory_lora.py

172 lines
6.7 KiB
Python

#!/usr/bin/env python3
"""Train BT-7274 memory LoRA on Qwen2.5-7B-Instruct using Unsloth.
100 curated EEMS memories — knowledge injection.
Run on junkpile (RTX 2000 Ada 16GB).
Prerequisites:
1. Stop vLLM: systemctl --user stop vllm-poc
2. Activate: source ~/lora-train/bin/activate
3. Run: python3 train_memory_lora.py
4. Restart: systemctl --user start vllm-poc
"""
import os
import torch
from pathlib import Path
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template, standardize_sharegpt
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset
# ──────────────────────────────────────────────────────────────
# CONFIG
# ──────────────────────────────────────────────────────────────
MODEL_NAME = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"
DATASET_PATH = "bt7274_memory_100.jsonl"
OUTPUT_DIR = "./bt7274-memory-lora"
MAX_SEQ_LEN = 2048 # memories avg ~1500 chars, some up to 7K
LORA_RANK = 16
LORA_ALPHA = 16
BATCH_SIZE = 1 # 16GB GPU + longer seqs — play safe
GRAD_ACCUM = 8 # effective batch = 8
EPOCHS = 5 # small dataset — more epochs to converge
LR = 2e-4
WARMUP_STEPS = 5
SAVE_STEPS = 50
LOGGING_STEPS = 5
SEED = 42
# ──────────────────────────────────────────────────────────────
# LOAD MODEL
# ──────────────────────────────────────────────────────────────
print(f"Loading {MODEL_NAME}...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MAX_SEQ_LEN,
load_in_4bit=True,
dtype=None,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
)
# ──────────────────────────────────────────────────────────────
# PEFT CONFIG
# ──────────────────────────────────────────────────────────────
print("Applying LoRA...")
model = FastLanguageModel.get_peft_model(
model,
r=LORA_RANK,
lora_alpha=LORA_ALPHA,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=SEED,
)
# ──────────────────────────────────────────────────────────────
# DATASET
# ──────────────────────────────────────────────────────────────
print(f"Loading dataset from {DATASET_PATH}...")
dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
print(f" {len(dataset)} examples loaded")
dataset = standardize_sharegpt(dataset)
def apply_template(examples):
"""Apply Qwen2.5 chat template to conversations."""
convos = examples["conversations"]
texts = []
for convo in convos:
text = tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False,
)
texts.append(text)
return {"text": texts}
print("Applying chat template...")
dataset = dataset.map(apply_template, batched=True, num_proc=2)
# ──────────────────────────────────────────────────────────────
# TRAINER
# ──────────────────────────────────────────────────────────────
print("Setting up trainer...")
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
args=TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
num_train_epochs=EPOCHS,
learning_rate=LR,
lr_scheduler_type="cosine",
warmup_steps=WARMUP_STEPS,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=LOGGING_STEPS,
save_steps=SAVE_STEPS,
save_total_limit=2,
seed=SEED,
optim="adamw_8bit",
weight_decay=0.01,
max_grad_norm=1.0,
report_to="none",
dataloader_num_workers=2,
),
max_seq_length=MAX_SEQ_LEN,
dataset_num_proc=2,
packing=True,
)
# ──────────────────────────────────────────────────────────────
# TRAIN
# ──────────────────────────────────────────────────────────────
print("Starting training...")
stats = trainer.train()
print(f"\nTraining complete!")
print(f" Total steps: {stats.global_step}")
print(f" Train loss: {stats.training_loss:.4f}")
print(f" Runtime: {stats.metrics['train_runtime']:.0f}s")
# ──────────────────────────────────────────────────────────────
# SAVE ADAPTER
# ──────────────────────────────────────────────────────────────
print(f"\nSaving adapter to {OUTPUT_DIR}...")
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
adapter_path = Path(OUTPUT_DIR) / "adapter_model.safetensors"
if adapter_path.exists():
size_mb = adapter_path.stat().st_size / (1024 * 1024)
print(f" Adapter saved: {size_mb:.1f} MB")
else:
print(" WARNING: adapter_model.safetensors not found!")
print(f"\nDone. To serve with vLLM:")
print(f" Update vllm-poc.service to add:")
print(f" --enable-lora \\")
print(f" --lora-modules bt7274-memory={os.path.abspath(OUTPUT_DIR)} \\")
print(f" --max-lora-rank {LORA_RANK}")