fix: manual JSONL loader — pyarrow chokes on mixed tool_calls types

This commit is contained in:
marauder-actual
2026-05-26 14:20:37 +02:00
parent a562d753de
commit 5388df0075
+33 -28
View File
@@ -69,7 +69,6 @@ model = FastLanguageModel.get_peft_model(
) )
# ── Dataset ───────────────────────────────────────────────────────── # ── Dataset ─────────────────────────────────────────────────────────
ds = load_dataset("json", data_files=DATA, split="train")
def fix_tool_calls(messages): def fix_tool_calls(messages):
@@ -95,35 +94,41 @@ def fix_tool_calls(messages):
return fixed return fixed
def to_chatml(ex): def load_and_format(path):
"""Apply Qwen3.5 base chat template with thinking enabled.""" """Load JSONL manually — pyarrow chokes on mixed tool_calls argument types."""
messages = fix_tool_calls(ex["messages"]) from datasets import Dataset
try: texts = []
text = tokenizer.apply_chat_template( skipped = 0
messages, with open(path) as f:
tokenize=False, for line in f:
add_generation_prompt=False, line = line.strip()
# Enable thinking mode so <think> blocks are properly formatted if not line:
enable_thinking=True, continue
) row = json.loads(line)
except TypeError: messages = fix_tool_calls(row["messages"])
# Fallback if enable_thinking not supported in this template version try:
text = tokenizer.apply_chat_template( text = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=False, add_generation_prompt=False,
) enable_thinking=True,
return {"text": text} )
except TypeError:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
if len(tokenizer.encode(text)) <= MAX_SEQ:
texts.append(text)
else:
skipped += 1
if skipped:
print(f"⚠ Filtered {skipped} examples exceeding {MAX_SEQ} tokens")
return Dataset.from_dict({"text": texts})
ds = ds.map(to_chatml) ds = load_and_format(DATA)
# Filter out examples that exceed max sequence length
orig_len = len(ds)
ds = ds.filter(lambda ex: len(tokenizer.encode(ex["text"])) <= MAX_SEQ)
filtered = orig_len - len(ds)
if filtered > 0:
print(f"⚠ Filtered {filtered} examples exceeding {MAX_SEQ} tokens")
steps = (len(ds) * EPOCHS) // (BATCH * GRAD_ACCUM) steps = (len(ds) * EPOCHS) // (BATCH * GRAD_ACCUM)
print(f"Dataset: {len(ds)} examples") print(f"Dataset: {len(ds)} examples")