Files
lora/extract_specialists.py

517 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Extract specialist training data from opencode session DB.
Classifies build-agent messages by programming language and outputs
per-specialist JSONL files for LoRA training.
opencode DB schema:
- session: id, agent, title, time_created, ...
- message: id, session_id, data (JSON: role, finish, tokens, ...)
- part: id, message_id, session_id, data (JSON: type, text/tool/state, ...)
Part types:
- text: {type: "text", text: "..."}
- tool: {type: "tool", tool: "read", callID: "...", state: {status, input, output, ...}}
- step-start/step-finish: inference step boundaries
- reasoning: chain-of-thought (skip for training)
- patch: file diffs (skip — use tool output instead)
- compaction: summary (skip)
Usage:
python extract_specialists.py [--db PATH] [--outdir data/] [--min-turns 2]
python extract_specialists.py --lang python --outdir data/ # single language
"""
import argparse
import json
import sqlite3
from collections import defaultdict
from pathlib import Path
from typing import Any
# ── Language classification signals ──────────────────────────────────
LANG_SIGNALS: dict[str, dict[str, list[str]]] = {
"rust": {
"extensions": [".rs"],
"files": ["Cargo.toml", "Cargo.lock", "build.rs", "clippy.toml", "rustfmt.toml"],
"commands": ["cargo ", "cargo build", "cargo test", "cargo clippy", "cargo fmt",
"cargo add", "rustc ", "rustup "],
"errors": ["error[E", "rustc --explain", "cannot find value", "expected struct",
"borrow checker"],
},
"typescript": {
"extensions": [".ts", ".tsx", ".mts", ".cts"],
"files": ["tsconfig.json", "package.json", "bun.lockb", "pnpm-lock.yaml",
"next.config", "vite.config", "astro.config"],
"commands": ["npm ", "pnpm ", "bun ", "npx ", "tsc ", "vitest ", "jest ",
"biome ", "eslint "],
"errors": ["error TS", "TS2", "TS7", "Cannot find module", "Type '"],
},
"python": {
"extensions": [".py", ".pyi"],
"files": ["pyproject.toml", "setup.py", "setup.cfg", "requirements.txt",
"ruff.toml", "mypy.ini", ".flake8", "noxfile.py", "tox.ini"],
"commands": ["python ", "python3 ", "pip ", "uv ", "pytest ", "ruff ", "mypy ",
"uvicorn ", "gunicorn "],
"errors": ["Traceback (most recent", "SyntaxError", "ImportError",
"TypeError", "ModuleNotFoundError"],
},
"ruby": {
"extensions": [".rb", ".erb", ".haml", ".slim"],
"files": ["Gemfile", "Gemfile.lock", "Rakefile", ".ruby-version",
".rubocop.yml", ".standard.yml"],
"commands": ["bundle ", "rails ", "rake ", "rspec ", "rubocop ",
"standardrb ", "gem "],
"errors": ["NoMethodError", "NameError", "ArgumentError",
"ActiveRecord::", "undefined method"],
},
"swift": {
"extensions": [".swift"],
"files": ["Package.swift", "project.yml", ".xcodeproj", ".xcworkspace"],
"commands": ["swift build", "swift test", "swift run", "xcodebuild ",
"swift-format ", "swift package "],
"errors": ["cannot convert value of type", "protocol conformance",
"value of type", "has no member"],
},
}
# Adapter codenames
LANG_TO_NAME = {
"rust": "oxidizer",
"typescript": "prism",
"python": "serpent",
"ruby": "forge",
"swift": "swiftblade",
}
# System prompts per specialist
SYSTEM_PROMPTS: dict[str, str] = {}
def load_system_prompts(agents_dir: Path) -> None:
"""Load agent system prompts from markdown files."""
mapping = {
"rust": "build-rust.md",
"typescript": "build-ts.md",
"python": "build-python.md",
"ruby": "build-ruby.md",
"swift": "build-swift.md",
}
for lang, filename in mapping.items():
path = agents_dir / filename
if path.exists():
SYSTEM_PROMPTS[lang] = path.read_text().strip()
else:
print(f" WARN: {path} not found, using default prompt for {lang}")
SYSTEM_PROMPTS[lang] = f"You are a {lang} coding agent."
def classify_text(content: str) -> dict[str, float]:
"""Score text's relevance to each language. Returns {lang: score}."""
scores: dict[str, float] = defaultdict(float)
content_lower = content.lower()
for lang, signals in LANG_SIGNALS.items():
for ext in signals["extensions"]:
scores[lang] += content_lower.count(ext) * 3.0
for f in signals["files"]:
if f.lower() in content_lower:
scores[lang] += 5.0
for cmd in signals["commands"]:
scores[lang] += content_lower.count(cmd.lower()) * 2.0
for err in signals["errors"]:
if err.lower() in content_lower:
scores[lang] += 4.0
return dict(scores)
def classify_conversation(all_text: str) -> str | None:
"""Classify concatenated conversation text to a single language."""
scores = classify_text(all_text)
if not scores:
return None
sorted_langs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
if len(sorted_langs) == 0:
return None
winner, winner_score = sorted_langs[0]
if winner_score < 5.0:
return None
if len(sorted_langs) > 1:
runner_up_score = sorted_langs[1][1]
if runner_up_score > 0 and winner_score / runner_up_score < 1.5:
return None # Ambiguous
return winner
# ── Tool call tools we care about for training ──────────────────────
TRAINING_TOOLS = {"bash", "read", "edit", "write", "glob", "grep", "todowrite", "question"}
# Max output length to include (truncate large tool outputs)
# 8192 tokens ≈ ~32K chars. Budget: system ~2K, user ~2K, leaves ~28K for assistant+tools.
# Each tool call+result pair: ~5002000 chars. Cap output at 2000 to fit more exchanges.
MAX_OUTPUT_LEN = 2000
def truncate_output(output: str, max_len: int = MAX_OUTPUT_LEN) -> str:
"""Truncate tool output to max_len chars."""
if len(output) <= max_len:
return output
return output[:max_len] + f"\n... (truncated, {len(output)} total chars)"
def extract_sessions(db_path: Path, target_lang: str | None = None) -> list[dict]:
"""Extract build-agent sessions from opencode DB.
Returns list of {session_id, title, messages: [...], raw_text: str}
where messages are in ChatML-like format suitable for training.
"""
conn = sqlite3.connect(db_path)
session_rows = conn.execute("""
SELECT id, title, time_created
FROM session
WHERE agent = 'build' OR agent LIKE 'build-%'
ORDER BY time_created
""").fetchall()
print(f"Found {len(session_rows)} build sessions")
all_conversations: list[dict] = []
for s_id, s_title, s_created in session_rows:
# Get messages for this session, ordered
msg_rows = conn.execute("""
SELECT m.id, json_extract(m.data, '$.role') as role,
json_extract(m.data, '$.finish') as finish
FROM message m
WHERE m.session_id = ?
ORDER BY m.time_created
""", (s_id,)).fetchall()
if len(msg_rows) < 2:
continue
# Get all parts for this session, grouped by message
part_rows = conn.execute("""
SELECT p.message_id,
json_extract(p.data, '$.type') as ptype,
p.data as pdata
FROM part p
WHERE p.session_id = ?
ORDER BY p.time_created
""", (s_id,)).fetchall()
# Group parts by message_id
msg_parts: dict[str, list[tuple[str, str]]] = defaultdict(list)
for p_msg_id, p_type, p_data in part_rows:
msg_parts[p_msg_id].append((p_type, p_data))
# Build ChatML messages
messages: list[dict[str, Any]] = []
raw_texts: list[str] = [] # for classification
for m_id, m_role, m_finish in msg_rows:
parts = msg_parts.get(m_id, [])
if m_role == "user":
# Extract user text
user_text = ""
for ptype, pdata in parts:
if ptype == "text":
pd = json.loads(pdata)
user_text += pd.get("text", "")
if user_text.strip():
messages.append({"role": "user", "content": user_text.strip()})
raw_texts.append(user_text)
elif m_role == "assistant":
# Collect text parts and tool calls
asst_text = ""
tool_calls: list[dict] = []
tool_results: list[dict] = []
for ptype, pdata in parts:
if ptype == "text":
pd = json.loads(pdata)
asst_text += pd.get("text", "")
elif ptype == "tool":
pd = json.loads(pdata)
tool_name = pd.get("tool", "")
call_id = pd.get("callID", "")
state = pd.get("state", {})
if tool_name not in TRAINING_TOOLS:
continue
if state.get("status") != "completed":
continue
tool_input = state.get("input", {})
tool_output = state.get("output", "")
# Build tool_call in OpenAI format
tool_calls.append({
"type": "function",
"id": call_id,
"function": {
"name": tool_name,
"arguments": tool_input,
},
})
# Build tool result
output_str = truncate_output(str(tool_output))
tool_results.append({
"role": "tool",
"tool_call_id": call_id,
"content": output_str,
})
# Collect for classification
raw_texts.append(json.dumps(tool_input))
raw_texts.append(output_str)
# Emit assistant message with tool calls
if tool_calls:
messages.append({
"role": "assistant",
"content": None,
"tool_calls": tool_calls,
})
messages.extend(tool_results)
# Emit text-only assistant message (after tools, or standalone)
if asst_text.strip():
messages.append({
"role": "assistant",
"content": asst_text.strip(),
})
raw_texts.append(asst_text)
if len(messages) < 3: # need at least user + assistant + something
continue
# Concatenate raw text for classification
raw_combined = " ".join(raw_texts)
# Early classification filter if target_lang specified
if target_lang:
lang = classify_conversation(raw_combined)
if lang != target_lang:
continue
all_conversations.append({
"session_id": s_id,
"title": s_title,
"messages": messages,
"raw_text": raw_combined,
})
conn.close()
return all_conversations
def window_conversations(
conversations: list[dict],
min_turns: int = 2,
max_turns: int = 10,
) -> list[dict]:
"""Split long conversations into training windows.
Each window captures a coherent exchange: user question → assistant response
including all tool calls and results within that exchange.
"""
windows: list[dict] = []
for conv in conversations:
msgs = conv["messages"]
# Find user message indices
user_indices = [i for i, m in enumerate(msgs) if m["role"] == "user"]
if len(user_indices) < min_turns:
# Short enough to use as-is
if len(user_indices) >= 1:
windows.append({
"session_id": conv["session_id"],
"title": conv["title"],
"messages": msgs,
"raw_text": conv.get("raw_text", ""),
})
continue
# Window by user-turn boundaries
for start in range(0, len(user_indices), max_turns):
end = min(start + max_turns, len(user_indices))
first_msg = user_indices[start]
# End at next user msg or end of conversation
if end < len(user_indices):
last_msg = user_indices[end]
else:
last_msg = len(msgs)
window_msgs = msgs[first_msg:last_msg]
# Skip windows that are too short
user_count = sum(1 for m in window_msgs if m["role"] == "user")
if user_count < 1:
continue
windows.append({
"session_id": conv["session_id"],
"title": conv["title"],
"messages": window_msgs,
"raw_text": " ".join(
m.get("content", "") or json.dumps(m.get("tool_calls", ""))
for m in window_msgs
),
})
return windows
def format_example(messages: list[dict], lang: str) -> dict:
"""Format a conversation window as a training example with system prompt."""
system_prompt = SYSTEM_PROMPTS.get(lang, f"You are a {lang} coding agent.")
# Clean up messages: ensure tool_call arguments are dicts
cleaned = []
for msg in messages:
msg = dict(msg)
if msg.get("tool_calls"):
new_tcs = []
for tc in msg["tool_calls"]:
tc = dict(tc)
if "function" in tc:
fn = dict(tc["function"])
if isinstance(fn.get("arguments"), str):
try:
fn["arguments"] = json.loads(fn["arguments"])
except (ValueError, TypeError):
fn["arguments"] = {"raw": fn["arguments"]}
tc["function"] = fn
new_tcs.append(tc)
msg["tool_calls"] = new_tcs
# Remove None content if no tool_calls
if msg.get("content") is None and not msg.get("tool_calls"):
continue
cleaned.append(msg)
return {
"messages": [{"role": "system", "content": system_prompt}] + cleaned,
}
def write_dataset(examples: list[dict], path: Path) -> None:
"""Write examples to JSONL file."""
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
for ex in examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
print(f" Wrote {len(examples)} examples → {path}")
def main() -> None:
parser = argparse.ArgumentParser(description="Extract specialist training data")
parser.add_argument(
"--db", type=Path,
default=Path.home() / ".local/share/opencode/opencode.db",
help="Path to opencode SQLite database",
)
parser.add_argument(
"--agents-dir", type=Path,
default=Path.home() / ".config/opencode/agents",
help="Path to agent system prompt directory",
)
parser.add_argument(
"--outdir", type=Path, default=Path("data"),
help="Output directory for JSONL files",
)
parser.add_argument(
"--lang", type=str, default=None,
help="Extract single language only (rust, typescript, python, ruby, swift)",
)
parser.add_argument(
"--min-turns", type=int, default=1,
help="Minimum user turns per training window",
)
parser.add_argument(
"--max-turns", type=int, default=10,
help="Maximum user turns per training window",
)
args = parser.parse_args()
print("══ Specialist Data Extraction ══")
print(f"DB: {args.db}")
print(f"Agents: {args.agents_dir}")
print(f"Output: {args.outdir}")
if args.lang:
print(f"Filter: {args.lang} only")
print()
# Load system prompts
load_system_prompts(args.agents_dir)
print(f"Loaded {len(SYSTEM_PROMPTS)} system prompts")
# Extract sessions
conversations = extract_sessions(args.db, target_lang=args.lang)
print(f"Extracted {len(conversations)} conversations")
# Window into training examples
windows = window_conversations(
conversations, min_turns=args.min_turns, max_turns=args.max_turns,
)
print(f"Created {len(windows)} training windows")
# Classify and bucket
buckets: dict[str, list[dict]] = defaultdict(list)
unclassified = 0
for window in windows:
if args.lang:
lang = args.lang
else:
lang = classify_conversation(window.get("raw_text", ""))
if lang:
example = format_example(window["messages"], lang)
buckets[lang].append(example)
else:
unclassified += 1
# Report
print(f"\n── Classification Results ──")
if not args.lang:
print(f"Unclassified: {unclassified}")
for lang, examples in sorted(buckets.items(), key=lambda x: -len(x[1])):
name = LANG_TO_NAME.get(lang, lang)
# Count tool calls and text-only
tc_count = sum(
1 for ex in examples
if any(m.get("tool_calls") for m in ex["messages"])
)
print(f" {name} ({lang}): {len(examples)} examples ({tc_count} with tool calls)")
# Write per-language datasets
print(f"\n── Writing Datasets ──")
for lang, examples in buckets.items():
name = LANG_TO_NAME.get(lang, lang)
write_dataset(examples, args.outdir / f"{name}.jsonl")
print(f"\nDone. Review datasets in {args.outdir}/")
print(f"Next steps:")
print(f" 1. python mine_repos.py --repos repos.json (add git diff examples)")
print(f" 2. Manual curation pass")
print(f" 3. python train_specialist.py --name <adapter>")
if __name__ == "__main__":
main()