517 lines
18 KiB
Python
517 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
"""Extract specialist training data from opencode session DB.
|
||
|
||
Classifies build-agent messages by programming language and outputs
|
||
per-specialist JSONL files for LoRA training.
|
||
|
||
opencode DB schema:
|
||
- session: id, agent, title, time_created, ...
|
||
- message: id, session_id, data (JSON: role, finish, tokens, ...)
|
||
- part: id, message_id, session_id, data (JSON: type, text/tool/state, ...)
|
||
|
||
Part types:
|
||
- text: {type: "text", text: "..."}
|
||
- tool: {type: "tool", tool: "read", callID: "...", state: {status, input, output, ...}}
|
||
- step-start/step-finish: inference step boundaries
|
||
- reasoning: chain-of-thought (skip for training)
|
||
- patch: file diffs (skip — use tool output instead)
|
||
- compaction: summary (skip)
|
||
|
||
Usage:
|
||
python extract_specialists.py [--db PATH] [--outdir data/] [--min-turns 2]
|
||
python extract_specialists.py --lang python --outdir data/ # single language
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import sqlite3
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
# ── Language classification signals ──────────────────────────────────
|
||
|
||
LANG_SIGNALS: dict[str, dict[str, list[str]]] = {
|
||
"rust": {
|
||
"extensions": [".rs"],
|
||
"files": ["Cargo.toml", "Cargo.lock", "build.rs", "clippy.toml", "rustfmt.toml"],
|
||
"commands": ["cargo ", "cargo build", "cargo test", "cargo clippy", "cargo fmt",
|
||
"cargo add", "rustc ", "rustup "],
|
||
"errors": ["error[E", "rustc --explain", "cannot find value", "expected struct",
|
||
"borrow checker"],
|
||
},
|
||
"typescript": {
|
||
"extensions": [".ts", ".tsx", ".mts", ".cts"],
|
||
"files": ["tsconfig.json", "package.json", "bun.lockb", "pnpm-lock.yaml",
|
||
"next.config", "vite.config", "astro.config"],
|
||
"commands": ["npm ", "pnpm ", "bun ", "npx ", "tsc ", "vitest ", "jest ",
|
||
"biome ", "eslint "],
|
||
"errors": ["error TS", "TS2", "TS7", "Cannot find module", "Type '"],
|
||
},
|
||
"python": {
|
||
"extensions": [".py", ".pyi"],
|
||
"files": ["pyproject.toml", "setup.py", "setup.cfg", "requirements.txt",
|
||
"ruff.toml", "mypy.ini", ".flake8", "noxfile.py", "tox.ini"],
|
||
"commands": ["python ", "python3 ", "pip ", "uv ", "pytest ", "ruff ", "mypy ",
|
||
"uvicorn ", "gunicorn "],
|
||
"errors": ["Traceback (most recent", "SyntaxError", "ImportError",
|
||
"TypeError", "ModuleNotFoundError"],
|
||
},
|
||
"ruby": {
|
||
"extensions": [".rb", ".erb", ".haml", ".slim"],
|
||
"files": ["Gemfile", "Gemfile.lock", "Rakefile", ".ruby-version",
|
||
".rubocop.yml", ".standard.yml"],
|
||
"commands": ["bundle ", "rails ", "rake ", "rspec ", "rubocop ",
|
||
"standardrb ", "gem "],
|
||
"errors": ["NoMethodError", "NameError", "ArgumentError",
|
||
"ActiveRecord::", "undefined method"],
|
||
},
|
||
"swift": {
|
||
"extensions": [".swift"],
|
||
"files": ["Package.swift", "project.yml", ".xcodeproj", ".xcworkspace"],
|
||
"commands": ["swift build", "swift test", "swift run", "xcodebuild ",
|
||
"swift-format ", "swift package "],
|
||
"errors": ["cannot convert value of type", "protocol conformance",
|
||
"value of type", "has no member"],
|
||
},
|
||
}
|
||
|
||
# Adapter codenames
|
||
LANG_TO_NAME = {
|
||
"rust": "oxidizer",
|
||
"typescript": "prism",
|
||
"python": "serpent",
|
||
"ruby": "forge",
|
||
"swift": "swiftblade",
|
||
}
|
||
|
||
# System prompts per specialist
|
||
SYSTEM_PROMPTS: dict[str, str] = {}
|
||
|
||
|
||
def load_system_prompts(agents_dir: Path) -> None:
|
||
"""Load agent system prompts from markdown files."""
|
||
mapping = {
|
||
"rust": "build-rust.md",
|
||
"typescript": "build-ts.md",
|
||
"python": "build-python.md",
|
||
"ruby": "build-ruby.md",
|
||
"swift": "build-swift.md",
|
||
}
|
||
for lang, filename in mapping.items():
|
||
path = agents_dir / filename
|
||
if path.exists():
|
||
SYSTEM_PROMPTS[lang] = path.read_text().strip()
|
||
else:
|
||
print(f" WARN: {path} not found, using default prompt for {lang}")
|
||
SYSTEM_PROMPTS[lang] = f"You are a {lang} coding agent."
|
||
|
||
|
||
def classify_text(content: str) -> dict[str, float]:
|
||
"""Score text's relevance to each language. Returns {lang: score}."""
|
||
scores: dict[str, float] = defaultdict(float)
|
||
content_lower = content.lower()
|
||
|
||
for lang, signals in LANG_SIGNALS.items():
|
||
for ext in signals["extensions"]:
|
||
scores[lang] += content_lower.count(ext) * 3.0
|
||
for f in signals["files"]:
|
||
if f.lower() in content_lower:
|
||
scores[lang] += 5.0
|
||
for cmd in signals["commands"]:
|
||
scores[lang] += content_lower.count(cmd.lower()) * 2.0
|
||
for err in signals["errors"]:
|
||
if err.lower() in content_lower:
|
||
scores[lang] += 4.0
|
||
|
||
return dict(scores)
|
||
|
||
|
||
def classify_conversation(all_text: str) -> str | None:
|
||
"""Classify concatenated conversation text to a single language."""
|
||
scores = classify_text(all_text)
|
||
if not scores:
|
||
return None
|
||
|
||
sorted_langs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||
if len(sorted_langs) == 0:
|
||
return None
|
||
|
||
winner, winner_score = sorted_langs[0]
|
||
if winner_score < 5.0:
|
||
return None
|
||
|
||
if len(sorted_langs) > 1:
|
||
runner_up_score = sorted_langs[1][1]
|
||
if runner_up_score > 0 and winner_score / runner_up_score < 1.5:
|
||
return None # Ambiguous
|
||
|
||
return winner
|
||
|
||
|
||
# ── Tool call tools we care about for training ──────────────────────
|
||
|
||
TRAINING_TOOLS = {"bash", "read", "edit", "write", "glob", "grep", "todowrite", "question"}
|
||
|
||
# Max output length to include (truncate large tool outputs)
|
||
# 8192 tokens ≈ ~32K chars. Budget: system ~2K, user ~2K, leaves ~28K for assistant+tools.
|
||
# Each tool call+result pair: ~500–2000 chars. Cap output at 2000 to fit more exchanges.
|
||
MAX_OUTPUT_LEN = 2000
|
||
|
||
|
||
def truncate_output(output: str, max_len: int = MAX_OUTPUT_LEN) -> str:
|
||
"""Truncate tool output to max_len chars."""
|
||
if len(output) <= max_len:
|
||
return output
|
||
return output[:max_len] + f"\n... (truncated, {len(output)} total chars)"
|
||
|
||
|
||
def extract_sessions(db_path: Path, target_lang: str | None = None) -> list[dict]:
|
||
"""Extract build-agent sessions from opencode DB.
|
||
|
||
Returns list of {session_id, title, messages: [...], raw_text: str}
|
||
where messages are in ChatML-like format suitable for training.
|
||
"""
|
||
conn = sqlite3.connect(db_path)
|
||
|
||
session_rows = conn.execute("""
|
||
SELECT id, title, time_created
|
||
FROM session
|
||
WHERE agent = 'build' OR agent LIKE 'build-%'
|
||
ORDER BY time_created
|
||
""").fetchall()
|
||
|
||
print(f"Found {len(session_rows)} build sessions")
|
||
|
||
all_conversations: list[dict] = []
|
||
|
||
for s_id, s_title, s_created in session_rows:
|
||
# Get messages for this session, ordered
|
||
msg_rows = conn.execute("""
|
||
SELECT m.id, json_extract(m.data, '$.role') as role,
|
||
json_extract(m.data, '$.finish') as finish
|
||
FROM message m
|
||
WHERE m.session_id = ?
|
||
ORDER BY m.time_created
|
||
""", (s_id,)).fetchall()
|
||
|
||
if len(msg_rows) < 2:
|
||
continue
|
||
|
||
# Get all parts for this session, grouped by message
|
||
part_rows = conn.execute("""
|
||
SELECT p.message_id,
|
||
json_extract(p.data, '$.type') as ptype,
|
||
p.data as pdata
|
||
FROM part p
|
||
WHERE p.session_id = ?
|
||
ORDER BY p.time_created
|
||
""", (s_id,)).fetchall()
|
||
|
||
# Group parts by message_id
|
||
msg_parts: dict[str, list[tuple[str, str]]] = defaultdict(list)
|
||
for p_msg_id, p_type, p_data in part_rows:
|
||
msg_parts[p_msg_id].append((p_type, p_data))
|
||
|
||
# Build ChatML messages
|
||
messages: list[dict[str, Any]] = []
|
||
raw_texts: list[str] = [] # for classification
|
||
|
||
for m_id, m_role, m_finish in msg_rows:
|
||
parts = msg_parts.get(m_id, [])
|
||
|
||
if m_role == "user":
|
||
# Extract user text
|
||
user_text = ""
|
||
for ptype, pdata in parts:
|
||
if ptype == "text":
|
||
pd = json.loads(pdata)
|
||
user_text += pd.get("text", "")
|
||
if user_text.strip():
|
||
messages.append({"role": "user", "content": user_text.strip()})
|
||
raw_texts.append(user_text)
|
||
|
||
elif m_role == "assistant":
|
||
# Collect text parts and tool calls
|
||
asst_text = ""
|
||
tool_calls: list[dict] = []
|
||
tool_results: list[dict] = []
|
||
|
||
for ptype, pdata in parts:
|
||
if ptype == "text":
|
||
pd = json.loads(pdata)
|
||
asst_text += pd.get("text", "")
|
||
|
||
elif ptype == "tool":
|
||
pd = json.loads(pdata)
|
||
tool_name = pd.get("tool", "")
|
||
call_id = pd.get("callID", "")
|
||
state = pd.get("state", {})
|
||
|
||
if tool_name not in TRAINING_TOOLS:
|
||
continue
|
||
if state.get("status") != "completed":
|
||
continue
|
||
|
||
tool_input = state.get("input", {})
|
||
tool_output = state.get("output", "")
|
||
|
||
# Build tool_call in OpenAI format
|
||
tool_calls.append({
|
||
"type": "function",
|
||
"id": call_id,
|
||
"function": {
|
||
"name": tool_name,
|
||
"arguments": tool_input,
|
||
},
|
||
})
|
||
|
||
# Build tool result
|
||
output_str = truncate_output(str(tool_output))
|
||
tool_results.append({
|
||
"role": "tool",
|
||
"tool_call_id": call_id,
|
||
"content": output_str,
|
||
})
|
||
|
||
# Collect for classification
|
||
raw_texts.append(json.dumps(tool_input))
|
||
raw_texts.append(output_str)
|
||
|
||
# Emit assistant message with tool calls
|
||
if tool_calls:
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": tool_calls,
|
||
})
|
||
messages.extend(tool_results)
|
||
|
||
# Emit text-only assistant message (after tools, or standalone)
|
||
if asst_text.strip():
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": asst_text.strip(),
|
||
})
|
||
raw_texts.append(asst_text)
|
||
|
||
if len(messages) < 3: # need at least user + assistant + something
|
||
continue
|
||
|
||
# Concatenate raw text for classification
|
||
raw_combined = " ".join(raw_texts)
|
||
|
||
# Early classification filter if target_lang specified
|
||
if target_lang:
|
||
lang = classify_conversation(raw_combined)
|
||
if lang != target_lang:
|
||
continue
|
||
|
||
all_conversations.append({
|
||
"session_id": s_id,
|
||
"title": s_title,
|
||
"messages": messages,
|
||
"raw_text": raw_combined,
|
||
})
|
||
|
||
conn.close()
|
||
return all_conversations
|
||
|
||
|
||
def window_conversations(
|
||
conversations: list[dict],
|
||
min_turns: int = 2,
|
||
max_turns: int = 10,
|
||
) -> list[dict]:
|
||
"""Split long conversations into training windows.
|
||
|
||
Each window captures a coherent exchange: user question → assistant response
|
||
including all tool calls and results within that exchange.
|
||
"""
|
||
windows: list[dict] = []
|
||
|
||
for conv in conversations:
|
||
msgs = conv["messages"]
|
||
|
||
# Find user message indices
|
||
user_indices = [i for i, m in enumerate(msgs) if m["role"] == "user"]
|
||
|
||
if len(user_indices) < min_turns:
|
||
# Short enough to use as-is
|
||
if len(user_indices) >= 1:
|
||
windows.append({
|
||
"session_id": conv["session_id"],
|
||
"title": conv["title"],
|
||
"messages": msgs,
|
||
"raw_text": conv.get("raw_text", ""),
|
||
})
|
||
continue
|
||
|
||
# Window by user-turn boundaries
|
||
for start in range(0, len(user_indices), max_turns):
|
||
end = min(start + max_turns, len(user_indices))
|
||
|
||
first_msg = user_indices[start]
|
||
# End at next user msg or end of conversation
|
||
if end < len(user_indices):
|
||
last_msg = user_indices[end]
|
||
else:
|
||
last_msg = len(msgs)
|
||
|
||
window_msgs = msgs[first_msg:last_msg]
|
||
|
||
# Skip windows that are too short
|
||
user_count = sum(1 for m in window_msgs if m["role"] == "user")
|
||
if user_count < 1:
|
||
continue
|
||
|
||
windows.append({
|
||
"session_id": conv["session_id"],
|
||
"title": conv["title"],
|
||
"messages": window_msgs,
|
||
"raw_text": " ".join(
|
||
m.get("content", "") or json.dumps(m.get("tool_calls", ""))
|
||
for m in window_msgs
|
||
),
|
||
})
|
||
|
||
return windows
|
||
|
||
|
||
def format_example(messages: list[dict], lang: str) -> dict:
|
||
"""Format a conversation window as a training example with system prompt."""
|
||
system_prompt = SYSTEM_PROMPTS.get(lang, f"You are a {lang} coding agent.")
|
||
|
||
# Clean up messages: ensure tool_call arguments are dicts
|
||
cleaned = []
|
||
for msg in messages:
|
||
msg = dict(msg)
|
||
if msg.get("tool_calls"):
|
||
new_tcs = []
|
||
for tc in msg["tool_calls"]:
|
||
tc = dict(tc)
|
||
if "function" in tc:
|
||
fn = dict(tc["function"])
|
||
if isinstance(fn.get("arguments"), str):
|
||
try:
|
||
fn["arguments"] = json.loads(fn["arguments"])
|
||
except (ValueError, TypeError):
|
||
fn["arguments"] = {"raw": fn["arguments"]}
|
||
tc["function"] = fn
|
||
new_tcs.append(tc)
|
||
msg["tool_calls"] = new_tcs
|
||
# Remove None content if no tool_calls
|
||
if msg.get("content") is None and not msg.get("tool_calls"):
|
||
continue
|
||
cleaned.append(msg)
|
||
|
||
return {
|
||
"messages": [{"role": "system", "content": system_prompt}] + cleaned,
|
||
}
|
||
|
||
|
||
def write_dataset(examples: list[dict], path: Path) -> None:
|
||
"""Write examples to JSONL file."""
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(path, "w") as f:
|
||
for ex in examples:
|
||
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
||
print(f" Wrote {len(examples)} examples → {path}")
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="Extract specialist training data")
|
||
parser.add_argument(
|
||
"--db", type=Path,
|
||
default=Path.home() / ".local/share/opencode/opencode.db",
|
||
help="Path to opencode SQLite database",
|
||
)
|
||
parser.add_argument(
|
||
"--agents-dir", type=Path,
|
||
default=Path.home() / ".config/opencode/agents",
|
||
help="Path to agent system prompt directory",
|
||
)
|
||
parser.add_argument(
|
||
"--outdir", type=Path, default=Path("data"),
|
||
help="Output directory for JSONL files",
|
||
)
|
||
parser.add_argument(
|
||
"--lang", type=str, default=None,
|
||
help="Extract single language only (rust, typescript, python, ruby, swift)",
|
||
)
|
||
parser.add_argument(
|
||
"--min-turns", type=int, default=1,
|
||
help="Minimum user turns per training window",
|
||
)
|
||
parser.add_argument(
|
||
"--max-turns", type=int, default=10,
|
||
help="Maximum user turns per training window",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
print("══ Specialist Data Extraction ══")
|
||
print(f"DB: {args.db}")
|
||
print(f"Agents: {args.agents_dir}")
|
||
print(f"Output: {args.outdir}")
|
||
if args.lang:
|
||
print(f"Filter: {args.lang} only")
|
||
print()
|
||
|
||
# Load system prompts
|
||
load_system_prompts(args.agents_dir)
|
||
print(f"Loaded {len(SYSTEM_PROMPTS)} system prompts")
|
||
|
||
# Extract sessions
|
||
conversations = extract_sessions(args.db, target_lang=args.lang)
|
||
print(f"Extracted {len(conversations)} conversations")
|
||
|
||
# Window into training examples
|
||
windows = window_conversations(
|
||
conversations, min_turns=args.min_turns, max_turns=args.max_turns,
|
||
)
|
||
print(f"Created {len(windows)} training windows")
|
||
|
||
# Classify and bucket
|
||
buckets: dict[str, list[dict]] = defaultdict(list)
|
||
unclassified = 0
|
||
|
||
for window in windows:
|
||
if args.lang:
|
||
lang = args.lang
|
||
else:
|
||
lang = classify_conversation(window.get("raw_text", ""))
|
||
if lang:
|
||
example = format_example(window["messages"], lang)
|
||
buckets[lang].append(example)
|
||
else:
|
||
unclassified += 1
|
||
|
||
# Report
|
||
print(f"\n── Classification Results ──")
|
||
if not args.lang:
|
||
print(f"Unclassified: {unclassified}")
|
||
for lang, examples in sorted(buckets.items(), key=lambda x: -len(x[1])):
|
||
name = LANG_TO_NAME.get(lang, lang)
|
||
# Count tool calls and text-only
|
||
tc_count = sum(
|
||
1 for ex in examples
|
||
if any(m.get("tool_calls") for m in ex["messages"])
|
||
)
|
||
print(f" {name} ({lang}): {len(examples)} examples ({tc_count} with tool calls)")
|
||
|
||
# Write per-language datasets
|
||
print(f"\n── Writing Datasets ──")
|
||
for lang, examples in buckets.items():
|
||
name = LANG_TO_NAME.get(lang, lang)
|
||
write_dataset(examples, args.outdir / f"{name}.jsonl")
|
||
|
||
print(f"\nDone. Review datasets in {args.outdir}/")
|
||
print(f"Next steps:")
|
||
print(f" 1. python mine_repos.py --repos repos.json (add git diff examples)")
|
||
print(f" 2. Manual curation pass")
|
||
print(f" 3. python train_specialist.py --name <adapter>")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|