add training scripts: memory, specialist, mining, smoke test
This commit is contained in:
@@ -0,0 +1,516 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Extract specialist training data from opencode session DB.
|
||||
|
||||
Classifies build-agent messages by programming language and outputs
|
||||
per-specialist JSONL files for LoRA training.
|
||||
|
||||
opencode DB schema:
|
||||
- session: id, agent, title, time_created, ...
|
||||
- message: id, session_id, data (JSON: role, finish, tokens, ...)
|
||||
- part: id, message_id, session_id, data (JSON: type, text/tool/state, ...)
|
||||
|
||||
Part types:
|
||||
- text: {type: "text", text: "..."}
|
||||
- tool: {type: "tool", tool: "read", callID: "...", state: {status, input, output, ...}}
|
||||
- step-start/step-finish: inference step boundaries
|
||||
- reasoning: chain-of-thought (skip for training)
|
||||
- patch: file diffs (skip — use tool output instead)
|
||||
- compaction: summary (skip)
|
||||
|
||||
Usage:
|
||||
python extract_specialists.py [--db PATH] [--outdir data/] [--min-turns 2]
|
||||
python extract_specialists.py --lang python --outdir data/ # single language
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sqlite3
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# ── Language classification signals ──────────────────────────────────
|
||||
|
||||
LANG_SIGNALS: dict[str, dict[str, list[str]]] = {
|
||||
"rust": {
|
||||
"extensions": [".rs"],
|
||||
"files": ["Cargo.toml", "Cargo.lock", "build.rs", "clippy.toml", "rustfmt.toml"],
|
||||
"commands": ["cargo ", "cargo build", "cargo test", "cargo clippy", "cargo fmt",
|
||||
"cargo add", "rustc ", "rustup "],
|
||||
"errors": ["error[E", "rustc --explain", "cannot find value", "expected struct",
|
||||
"borrow checker"],
|
||||
},
|
||||
"typescript": {
|
||||
"extensions": [".ts", ".tsx", ".mts", ".cts"],
|
||||
"files": ["tsconfig.json", "package.json", "bun.lockb", "pnpm-lock.yaml",
|
||||
"next.config", "vite.config", "astro.config"],
|
||||
"commands": ["npm ", "pnpm ", "bun ", "npx ", "tsc ", "vitest ", "jest ",
|
||||
"biome ", "eslint "],
|
||||
"errors": ["error TS", "TS2", "TS7", "Cannot find module", "Type '"],
|
||||
},
|
||||
"python": {
|
||||
"extensions": [".py", ".pyi"],
|
||||
"files": ["pyproject.toml", "setup.py", "setup.cfg", "requirements.txt",
|
||||
"ruff.toml", "mypy.ini", ".flake8", "noxfile.py", "tox.ini"],
|
||||
"commands": ["python ", "python3 ", "pip ", "uv ", "pytest ", "ruff ", "mypy ",
|
||||
"uvicorn ", "gunicorn "],
|
||||
"errors": ["Traceback (most recent", "SyntaxError", "ImportError",
|
||||
"TypeError", "ModuleNotFoundError"],
|
||||
},
|
||||
"ruby": {
|
||||
"extensions": [".rb", ".erb", ".haml", ".slim"],
|
||||
"files": ["Gemfile", "Gemfile.lock", "Rakefile", ".ruby-version",
|
||||
".rubocop.yml", ".standard.yml"],
|
||||
"commands": ["bundle ", "rails ", "rake ", "rspec ", "rubocop ",
|
||||
"standardrb ", "gem "],
|
||||
"errors": ["NoMethodError", "NameError", "ArgumentError",
|
||||
"ActiveRecord::", "undefined method"],
|
||||
},
|
||||
"swift": {
|
||||
"extensions": [".swift"],
|
||||
"files": ["Package.swift", "project.yml", ".xcodeproj", ".xcworkspace"],
|
||||
"commands": ["swift build", "swift test", "swift run", "xcodebuild ",
|
||||
"swift-format ", "swift package "],
|
||||
"errors": ["cannot convert value of type", "protocol conformance",
|
||||
"value of type", "has no member"],
|
||||
},
|
||||
}
|
||||
|
||||
# Adapter codenames
|
||||
LANG_TO_NAME = {
|
||||
"rust": "oxidizer",
|
||||
"typescript": "prism",
|
||||
"python": "serpent",
|
||||
"ruby": "forge",
|
||||
"swift": "swiftblade",
|
||||
}
|
||||
|
||||
# System prompts per specialist
|
||||
SYSTEM_PROMPTS: dict[str, str] = {}
|
||||
|
||||
|
||||
def load_system_prompts(agents_dir: Path) -> None:
|
||||
"""Load agent system prompts from markdown files."""
|
||||
mapping = {
|
||||
"rust": "build-rust.md",
|
||||
"typescript": "build-ts.md",
|
||||
"python": "build-python.md",
|
||||
"ruby": "build-ruby.md",
|
||||
"swift": "build-swift.md",
|
||||
}
|
||||
for lang, filename in mapping.items():
|
||||
path = agents_dir / filename
|
||||
if path.exists():
|
||||
SYSTEM_PROMPTS[lang] = path.read_text().strip()
|
||||
else:
|
||||
print(f" WARN: {path} not found, using default prompt for {lang}")
|
||||
SYSTEM_PROMPTS[lang] = f"You are a {lang} coding agent."
|
||||
|
||||
|
||||
def classify_text(content: str) -> dict[str, float]:
|
||||
"""Score text's relevance to each language. Returns {lang: score}."""
|
||||
scores: dict[str, float] = defaultdict(float)
|
||||
content_lower = content.lower()
|
||||
|
||||
for lang, signals in LANG_SIGNALS.items():
|
||||
for ext in signals["extensions"]:
|
||||
scores[lang] += content_lower.count(ext) * 3.0
|
||||
for f in signals["files"]:
|
||||
if f.lower() in content_lower:
|
||||
scores[lang] += 5.0
|
||||
for cmd in signals["commands"]:
|
||||
scores[lang] += content_lower.count(cmd.lower()) * 2.0
|
||||
for err in signals["errors"]:
|
||||
if err.lower() in content_lower:
|
||||
scores[lang] += 4.0
|
||||
|
||||
return dict(scores)
|
||||
|
||||
|
||||
def classify_conversation(all_text: str) -> str | None:
|
||||
"""Classify concatenated conversation text to a single language."""
|
||||
scores = classify_text(all_text)
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
sorted_langs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
if len(sorted_langs) == 0:
|
||||
return None
|
||||
|
||||
winner, winner_score = sorted_langs[0]
|
||||
if winner_score < 5.0:
|
||||
return None
|
||||
|
||||
if len(sorted_langs) > 1:
|
||||
runner_up_score = sorted_langs[1][1]
|
||||
if runner_up_score > 0 and winner_score / runner_up_score < 1.5:
|
||||
return None # Ambiguous
|
||||
|
||||
return winner
|
||||
|
||||
|
||||
# ── Tool call tools we care about for training ──────────────────────
|
||||
|
||||
TRAINING_TOOLS = {"bash", "read", "edit", "write", "glob", "grep", "todowrite", "question"}
|
||||
|
||||
# Max output length to include (truncate large tool outputs)
|
||||
# 8192 tokens ≈ ~32K chars. Budget: system ~2K, user ~2K, leaves ~28K for assistant+tools.
|
||||
# Each tool call+result pair: ~500–2000 chars. Cap output at 2000 to fit more exchanges.
|
||||
MAX_OUTPUT_LEN = 2000
|
||||
|
||||
|
||||
def truncate_output(output: str, max_len: int = MAX_OUTPUT_LEN) -> str:
|
||||
"""Truncate tool output to max_len chars."""
|
||||
if len(output) <= max_len:
|
||||
return output
|
||||
return output[:max_len] + f"\n... (truncated, {len(output)} total chars)"
|
||||
|
||||
|
||||
def extract_sessions(db_path: Path, target_lang: str | None = None) -> list[dict]:
|
||||
"""Extract build-agent sessions from opencode DB.
|
||||
|
||||
Returns list of {session_id, title, messages: [...], raw_text: str}
|
||||
where messages are in ChatML-like format suitable for training.
|
||||
"""
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
session_rows = conn.execute("""
|
||||
SELECT id, title, time_created
|
||||
FROM session
|
||||
WHERE agent = 'build' OR agent LIKE 'build-%'
|
||||
ORDER BY time_created
|
||||
""").fetchall()
|
||||
|
||||
print(f"Found {len(session_rows)} build sessions")
|
||||
|
||||
all_conversations: list[dict] = []
|
||||
|
||||
for s_id, s_title, s_created in session_rows:
|
||||
# Get messages for this session, ordered
|
||||
msg_rows = conn.execute("""
|
||||
SELECT m.id, json_extract(m.data, '$.role') as role,
|
||||
json_extract(m.data, '$.finish') as finish
|
||||
FROM message m
|
||||
WHERE m.session_id = ?
|
||||
ORDER BY m.time_created
|
||||
""", (s_id,)).fetchall()
|
||||
|
||||
if len(msg_rows) < 2:
|
||||
continue
|
||||
|
||||
# Get all parts for this session, grouped by message
|
||||
part_rows = conn.execute("""
|
||||
SELECT p.message_id,
|
||||
json_extract(p.data, '$.type') as ptype,
|
||||
p.data as pdata
|
||||
FROM part p
|
||||
WHERE p.session_id = ?
|
||||
ORDER BY p.time_created
|
||||
""", (s_id,)).fetchall()
|
||||
|
||||
# Group parts by message_id
|
||||
msg_parts: dict[str, list[tuple[str, str]]] = defaultdict(list)
|
||||
for p_msg_id, p_type, p_data in part_rows:
|
||||
msg_parts[p_msg_id].append((p_type, p_data))
|
||||
|
||||
# Build ChatML messages
|
||||
messages: list[dict[str, Any]] = []
|
||||
raw_texts: list[str] = [] # for classification
|
||||
|
||||
for m_id, m_role, m_finish in msg_rows:
|
||||
parts = msg_parts.get(m_id, [])
|
||||
|
||||
if m_role == "user":
|
||||
# Extract user text
|
||||
user_text = ""
|
||||
for ptype, pdata in parts:
|
||||
if ptype == "text":
|
||||
pd = json.loads(pdata)
|
||||
user_text += pd.get("text", "")
|
||||
if user_text.strip():
|
||||
messages.append({"role": "user", "content": user_text.strip()})
|
||||
raw_texts.append(user_text)
|
||||
|
||||
elif m_role == "assistant":
|
||||
# Collect text parts and tool calls
|
||||
asst_text = ""
|
||||
tool_calls: list[dict] = []
|
||||
tool_results: list[dict] = []
|
||||
|
||||
for ptype, pdata in parts:
|
||||
if ptype == "text":
|
||||
pd = json.loads(pdata)
|
||||
asst_text += pd.get("text", "")
|
||||
|
||||
elif ptype == "tool":
|
||||
pd = json.loads(pdata)
|
||||
tool_name = pd.get("tool", "")
|
||||
call_id = pd.get("callID", "")
|
||||
state = pd.get("state", {})
|
||||
|
||||
if tool_name not in TRAINING_TOOLS:
|
||||
continue
|
||||
if state.get("status") != "completed":
|
||||
continue
|
||||
|
||||
tool_input = state.get("input", {})
|
||||
tool_output = state.get("output", "")
|
||||
|
||||
# Build tool_call in OpenAI format
|
||||
tool_calls.append({
|
||||
"type": "function",
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": tool_input,
|
||||
},
|
||||
})
|
||||
|
||||
# Build tool result
|
||||
output_str = truncate_output(str(tool_output))
|
||||
tool_results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output_str,
|
||||
})
|
||||
|
||||
# Collect for classification
|
||||
raw_texts.append(json.dumps(tool_input))
|
||||
raw_texts.append(output_str)
|
||||
|
||||
# Emit assistant message with tool calls
|
||||
if tool_calls:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tool_calls,
|
||||
})
|
||||
messages.extend(tool_results)
|
||||
|
||||
# Emit text-only assistant message (after tools, or standalone)
|
||||
if asst_text.strip():
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": asst_text.strip(),
|
||||
})
|
||||
raw_texts.append(asst_text)
|
||||
|
||||
if len(messages) < 3: # need at least user + assistant + something
|
||||
continue
|
||||
|
||||
# Concatenate raw text for classification
|
||||
raw_combined = " ".join(raw_texts)
|
||||
|
||||
# Early classification filter if target_lang specified
|
||||
if target_lang:
|
||||
lang = classify_conversation(raw_combined)
|
||||
if lang != target_lang:
|
||||
continue
|
||||
|
||||
all_conversations.append({
|
||||
"session_id": s_id,
|
||||
"title": s_title,
|
||||
"messages": messages,
|
||||
"raw_text": raw_combined,
|
||||
})
|
||||
|
||||
conn.close()
|
||||
return all_conversations
|
||||
|
||||
|
||||
def window_conversations(
|
||||
conversations: list[dict],
|
||||
min_turns: int = 2,
|
||||
max_turns: int = 10,
|
||||
) -> list[dict]:
|
||||
"""Split long conversations into training windows.
|
||||
|
||||
Each window captures a coherent exchange: user question → assistant response
|
||||
including all tool calls and results within that exchange.
|
||||
"""
|
||||
windows: list[dict] = []
|
||||
|
||||
for conv in conversations:
|
||||
msgs = conv["messages"]
|
||||
|
||||
# Find user message indices
|
||||
user_indices = [i for i, m in enumerate(msgs) if m["role"] == "user"]
|
||||
|
||||
if len(user_indices) < min_turns:
|
||||
# Short enough to use as-is
|
||||
if len(user_indices) >= 1:
|
||||
windows.append({
|
||||
"session_id": conv["session_id"],
|
||||
"title": conv["title"],
|
||||
"messages": msgs,
|
||||
"raw_text": conv.get("raw_text", ""),
|
||||
})
|
||||
continue
|
||||
|
||||
# Window by user-turn boundaries
|
||||
for start in range(0, len(user_indices), max_turns):
|
||||
end = min(start + max_turns, len(user_indices))
|
||||
|
||||
first_msg = user_indices[start]
|
||||
# End at next user msg or end of conversation
|
||||
if end < len(user_indices):
|
||||
last_msg = user_indices[end]
|
||||
else:
|
||||
last_msg = len(msgs)
|
||||
|
||||
window_msgs = msgs[first_msg:last_msg]
|
||||
|
||||
# Skip windows that are too short
|
||||
user_count = sum(1 for m in window_msgs if m["role"] == "user")
|
||||
if user_count < 1:
|
||||
continue
|
||||
|
||||
windows.append({
|
||||
"session_id": conv["session_id"],
|
||||
"title": conv["title"],
|
||||
"messages": window_msgs,
|
||||
"raw_text": " ".join(
|
||||
m.get("content", "") or json.dumps(m.get("tool_calls", ""))
|
||||
for m in window_msgs
|
||||
),
|
||||
})
|
||||
|
||||
return windows
|
||||
|
||||
|
||||
def format_example(messages: list[dict], lang: str) -> dict:
|
||||
"""Format a conversation window as a training example with system prompt."""
|
||||
system_prompt = SYSTEM_PROMPTS.get(lang, f"You are a {lang} coding agent.")
|
||||
|
||||
# Clean up messages: ensure tool_call arguments are dicts
|
||||
cleaned = []
|
||||
for msg in messages:
|
||||
msg = dict(msg)
|
||||
if msg.get("tool_calls"):
|
||||
new_tcs = []
|
||||
for tc in msg["tool_calls"]:
|
||||
tc = dict(tc)
|
||||
if "function" in tc:
|
||||
fn = dict(tc["function"])
|
||||
if isinstance(fn.get("arguments"), str):
|
||||
try:
|
||||
fn["arguments"] = json.loads(fn["arguments"])
|
||||
except (ValueError, TypeError):
|
||||
fn["arguments"] = {"raw": fn["arguments"]}
|
||||
tc["function"] = fn
|
||||
new_tcs.append(tc)
|
||||
msg["tool_calls"] = new_tcs
|
||||
# Remove None content if no tool_calls
|
||||
if msg.get("content") is None and not msg.get("tool_calls"):
|
||||
continue
|
||||
cleaned.append(msg)
|
||||
|
||||
return {
|
||||
"messages": [{"role": "system", "content": system_prompt}] + cleaned,
|
||||
}
|
||||
|
||||
|
||||
def write_dataset(examples: list[dict], path: Path) -> None:
|
||||
"""Write examples to JSONL file."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
for ex in examples:
|
||||
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
||||
print(f" Wrote {len(examples)} examples → {path}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Extract specialist training data")
|
||||
parser.add_argument(
|
||||
"--db", type=Path,
|
||||
default=Path.home() / ".local/share/opencode/opencode.db",
|
||||
help="Path to opencode SQLite database",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--agents-dir", type=Path,
|
||||
default=Path.home() / ".config/opencode/agents",
|
||||
help="Path to agent system prompt directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outdir", type=Path, default=Path("data"),
|
||||
help="Output directory for JSONL files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang", type=str, default=None,
|
||||
help="Extract single language only (rust, typescript, python, ruby, swift)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-turns", type=int, default=1,
|
||||
help="Minimum user turns per training window",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-turns", type=int, default=10,
|
||||
help="Maximum user turns per training window",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("══ Specialist Data Extraction ══")
|
||||
print(f"DB: {args.db}")
|
||||
print(f"Agents: {args.agents_dir}")
|
||||
print(f"Output: {args.outdir}")
|
||||
if args.lang:
|
||||
print(f"Filter: {args.lang} only")
|
||||
print()
|
||||
|
||||
# Load system prompts
|
||||
load_system_prompts(args.agents_dir)
|
||||
print(f"Loaded {len(SYSTEM_PROMPTS)} system prompts")
|
||||
|
||||
# Extract sessions
|
||||
conversations = extract_sessions(args.db, target_lang=args.lang)
|
||||
print(f"Extracted {len(conversations)} conversations")
|
||||
|
||||
# Window into training examples
|
||||
windows = window_conversations(
|
||||
conversations, min_turns=args.min_turns, max_turns=args.max_turns,
|
||||
)
|
||||
print(f"Created {len(windows)} training windows")
|
||||
|
||||
# Classify and bucket
|
||||
buckets: dict[str, list[dict]] = defaultdict(list)
|
||||
unclassified = 0
|
||||
|
||||
for window in windows:
|
||||
if args.lang:
|
||||
lang = args.lang
|
||||
else:
|
||||
lang = classify_conversation(window.get("raw_text", ""))
|
||||
if lang:
|
||||
example = format_example(window["messages"], lang)
|
||||
buckets[lang].append(example)
|
||||
else:
|
||||
unclassified += 1
|
||||
|
||||
# Report
|
||||
print(f"\n── Classification Results ──")
|
||||
if not args.lang:
|
||||
print(f"Unclassified: {unclassified}")
|
||||
for lang, examples in sorted(buckets.items(), key=lambda x: -len(x[1])):
|
||||
name = LANG_TO_NAME.get(lang, lang)
|
||||
# Count tool calls and text-only
|
||||
tc_count = sum(
|
||||
1 for ex in examples
|
||||
if any(m.get("tool_calls") for m in ex["messages"])
|
||||
)
|
||||
print(f" {name} ({lang}): {len(examples)} examples ({tc_count} with tool calls)")
|
||||
|
||||
# Write per-language datasets
|
||||
print(f"\n── Writing Datasets ──")
|
||||
for lang, examples in buckets.items():
|
||||
name = LANG_TO_NAME.get(lang, lang)
|
||||
write_dataset(examples, args.outdir / f"{name}.jsonl")
|
||||
|
||||
print(f"\nDone. Review datasets in {args.outdir}/")
|
||||
print(f"Next steps:")
|
||||
print(f" 1. python mine_repos.py --repos repos.json (add git diff examples)")
|
||||
print(f" 2. Manual curation pass")
|
||||
print(f" 3. python train_specialist.py --name <adapter>")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user