fix: manual JSONL loader — pyarrow chokes on mixed tool_calls types

This commit is contained in:
marauder-actual
2026-05-26 14:20:37 +02:00
parent a562d753de
commit 5388df0075
+33 -28
View File
@@ -69,7 +69,6 @@ model = FastLanguageModel.get_peft_model(
)
# ── Dataset ─────────────────────────────────────────────────────────
ds = load_dataset("json", data_files=DATA, split="train")
def fix_tool_calls(messages):
@@ -95,35 +94,41 @@ def fix_tool_calls(messages):
return fixed
def to_chatml(ex):
"""Apply Qwen3.5 base chat template with thinking enabled."""
messages = fix_tool_calls(ex["messages"])
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
# Enable thinking mode so <think> blocks are properly formatted
enable_thinking=True,
)
except TypeError:
# Fallback if enable_thinking not supported in this template version
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
return {"text": text}
def load_and_format(path):
"""Load JSONL manually — pyarrow chokes on mixed tool_calls argument types."""
from datasets import Dataset
texts = []
skipped = 0
with open(path) as f:
for line in f:
line = line.strip()
if not line:
continue
row = json.loads(line)
messages = fix_tool_calls(row["messages"])
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
enable_thinking=True,
)
except TypeError:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
if len(tokenizer.encode(text)) <= MAX_SEQ:
texts.append(text)
else:
skipped += 1
if skipped:
print(f"⚠ Filtered {skipped} examples exceeding {MAX_SEQ} tokens")
return Dataset.from_dict({"text": texts})
ds = ds.map(to_chatml)
# Filter out examples that exceed max sequence length
orig_len = len(ds)
ds = ds.filter(lambda ex: len(tokenizer.encode(ex["text"])) <= MAX_SEQ)
filtered = orig_len - len(ds)
if filtered > 0:
print(f"⚠ Filtered {filtered} examples exceeding {MAX_SEQ} tokens")
ds = load_and_format(DATA)
steps = (len(ds) * EPOCHS) // (BATCH * GRAD_ACCUM)
print(f"Dataset: {len(ds)} examples")