fix: manual JSONL loader — pyarrow chokes on mixed tool_calls types

This commit is contained in:
marauder-actual
2026-05-26 14:20:37 +02:00
parent a562d753de
commit 5388df0075
+20 -15
View File
@@ -69,7 +69,6 @@ model = FastLanguageModel.get_peft_model(
) )
# ── Dataset ───────────────────────────────────────────────────────── # ── Dataset ─────────────────────────────────────────────────────────
ds = load_dataset("json", data_files=DATA, split="train")
def fix_tool_calls(messages): def fix_tool_calls(messages):
@@ -95,35 +94,41 @@ def fix_tool_calls(messages):
return fixed return fixed
def to_chatml(ex): def load_and_format(path):
"""Apply Qwen3.5 base chat template with thinking enabled.""" """Load JSONL manually — pyarrow chokes on mixed tool_calls argument types."""
messages = fix_tool_calls(ex["messages"]) from datasets import Dataset
texts = []
skipped = 0
with open(path) as f:
for line in f:
line = line.strip()
if not line:
continue
row = json.loads(line)
messages = fix_tool_calls(row["messages"])
try: try:
text = tokenizer.apply_chat_template( text = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=False, add_generation_prompt=False,
# Enable thinking mode so <think> blocks are properly formatted
enable_thinking=True, enable_thinking=True,
) )
except TypeError: except TypeError:
# Fallback if enable_thinking not supported in this template version
text = tokenizer.apply_chat_template( text = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=False, add_generation_prompt=False,
) )
return {"text": text} if len(tokenizer.encode(text)) <= MAX_SEQ:
texts.append(text)
else:
skipped += 1
if skipped:
print(f"⚠ Filtered {skipped} examples exceeding {MAX_SEQ} tokens")
return Dataset.from_dict({"text": texts})
ds = ds.map(to_chatml) ds = load_and_format(DATA)
# Filter out examples that exceed max sequence length
orig_len = len(ds)
ds = ds.filter(lambda ex: len(tokenizer.encode(ex["text"])) <= MAX_SEQ)
filtered = orig_len - len(ds)
if filtered > 0:
print(f"⚠ Filtered {filtered} examples exceeding {MAX_SEQ} tokens")
steps = (len(ds) * EPOCHS) // (BATCH * GRAD_ACCUM) steps = (len(ds) * EPOCHS) // (BATCH * GRAD_ACCUM)
print(f"Dataset: {len(ds)} examples") print(f"Dataset: {len(ds)} examples")