diff --git a/train_v4.py b/train_v4.py index c94cb0c..1bed668 100644 --- a/train_v4.py +++ b/train_v4.py @@ -69,7 +69,6 @@ model = FastLanguageModel.get_peft_model( ) # ── Dataset ───────────────────────────────────────────────────────── -ds = load_dataset("json", data_files=DATA, split="train") def fix_tool_calls(messages): @@ -95,35 +94,41 @@ def fix_tool_calls(messages): return fixed -def to_chatml(ex): - """Apply Qwen3.5 base chat template with thinking enabled.""" - messages = fix_tool_calls(ex["messages"]) - try: - text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - # Enable thinking mode so blocks are properly formatted - enable_thinking=True, - ) - except TypeError: - # Fallback if enable_thinking not supported in this template version - text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - ) - return {"text": text} +def load_and_format(path): + """Load JSONL manually — pyarrow chokes on mixed tool_calls argument types.""" + from datasets import Dataset + texts = [] + skipped = 0 + with open(path) as f: + for line in f: + line = line.strip() + if not line: + continue + row = json.loads(line) + messages = fix_tool_calls(row["messages"]) + try: + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + enable_thinking=True, + ) + except TypeError: + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + if len(tokenizer.encode(text)) <= MAX_SEQ: + texts.append(text) + else: + skipped += 1 + if skipped: + print(f"⚠ Filtered {skipped} examples exceeding {MAX_SEQ} tokens") + return Dataset.from_dict({"text": texts}) -ds = ds.map(to_chatml) - -# Filter out examples that exceed max sequence length -orig_len = len(ds) -ds = ds.filter(lambda ex: len(tokenizer.encode(ex["text"])) <= MAX_SEQ) -filtered = orig_len - len(ds) -if filtered > 0: - print(f"⚠ Filtered {filtered} examples exceeding {MAX_SEQ} tokens") +ds = load_and_format(DATA) steps = (len(ds) * EPOCHS) // (BATCH * GRAD_ACCUM) print(f"Dataset: {len(ds)} examples")