add training scripts: memory, specialist, mining, smoke test

This commit is contained in:
marauder-actual
2026-05-31 11:38:42 +02:00
parent df0d4a6eac
commit 4678816795
9 changed files with 2256 additions and 0 deletions
+516
View File
@@ -0,0 +1,516 @@
#!/usr/bin/env python3
"""Extract specialist training data from opencode session DB.
Classifies build-agent messages by programming language and outputs
per-specialist JSONL files for LoRA training.
opencode DB schema:
- session: id, agent, title, time_created, ...
- message: id, session_id, data (JSON: role, finish, tokens, ...)
- part: id, message_id, session_id, data (JSON: type, text/tool/state, ...)
Part types:
- text: {type: "text", text: "..."}
- tool: {type: "tool", tool: "read", callID: "...", state: {status, input, output, ...}}
- step-start/step-finish: inference step boundaries
- reasoning: chain-of-thought (skip for training)
- patch: file diffs (skip — use tool output instead)
- compaction: summary (skip)
Usage:
python extract_specialists.py [--db PATH] [--outdir data/] [--min-turns 2]
python extract_specialists.py --lang python --outdir data/ # single language
"""
import argparse
import json
import sqlite3
from collections import defaultdict
from pathlib import Path
from typing import Any
# ── Language classification signals ──────────────────────────────────
LANG_SIGNALS: dict[str, dict[str, list[str]]] = {
"rust": {
"extensions": [".rs"],
"files": ["Cargo.toml", "Cargo.lock", "build.rs", "clippy.toml", "rustfmt.toml"],
"commands": ["cargo ", "cargo build", "cargo test", "cargo clippy", "cargo fmt",
"cargo add", "rustc ", "rustup "],
"errors": ["error[E", "rustc --explain", "cannot find value", "expected struct",
"borrow checker"],
},
"typescript": {
"extensions": [".ts", ".tsx", ".mts", ".cts"],
"files": ["tsconfig.json", "package.json", "bun.lockb", "pnpm-lock.yaml",
"next.config", "vite.config", "astro.config"],
"commands": ["npm ", "pnpm ", "bun ", "npx ", "tsc ", "vitest ", "jest ",
"biome ", "eslint "],
"errors": ["error TS", "TS2", "TS7", "Cannot find module", "Type '"],
},
"python": {
"extensions": [".py", ".pyi"],
"files": ["pyproject.toml", "setup.py", "setup.cfg", "requirements.txt",
"ruff.toml", "mypy.ini", ".flake8", "noxfile.py", "tox.ini"],
"commands": ["python ", "python3 ", "pip ", "uv ", "pytest ", "ruff ", "mypy ",
"uvicorn ", "gunicorn "],
"errors": ["Traceback (most recent", "SyntaxError", "ImportError",
"TypeError", "ModuleNotFoundError"],
},
"ruby": {
"extensions": [".rb", ".erb", ".haml", ".slim"],
"files": ["Gemfile", "Gemfile.lock", "Rakefile", ".ruby-version",
".rubocop.yml", ".standard.yml"],
"commands": ["bundle ", "rails ", "rake ", "rspec ", "rubocop ",
"standardrb ", "gem "],
"errors": ["NoMethodError", "NameError", "ArgumentError",
"ActiveRecord::", "undefined method"],
},
"swift": {
"extensions": [".swift"],
"files": ["Package.swift", "project.yml", ".xcodeproj", ".xcworkspace"],
"commands": ["swift build", "swift test", "swift run", "xcodebuild ",
"swift-format ", "swift package "],
"errors": ["cannot convert value of type", "protocol conformance",
"value of type", "has no member"],
},
}
# Adapter codenames
LANG_TO_NAME = {
"rust": "oxidizer",
"typescript": "prism",
"python": "serpent",
"ruby": "forge",
"swift": "swiftblade",
}
# System prompts per specialist
SYSTEM_PROMPTS: dict[str, str] = {}
def load_system_prompts(agents_dir: Path) -> None:
"""Load agent system prompts from markdown files."""
mapping = {
"rust": "build-rust.md",
"typescript": "build-ts.md",
"python": "build-python.md",
"ruby": "build-ruby.md",
"swift": "build-swift.md",
}
for lang, filename in mapping.items():
path = agents_dir / filename
if path.exists():
SYSTEM_PROMPTS[lang] = path.read_text().strip()
else:
print(f" WARN: {path} not found, using default prompt for {lang}")
SYSTEM_PROMPTS[lang] = f"You are a {lang} coding agent."
def classify_text(content: str) -> dict[str, float]:
"""Score text's relevance to each language. Returns {lang: score}."""
scores: dict[str, float] = defaultdict(float)
content_lower = content.lower()
for lang, signals in LANG_SIGNALS.items():
for ext in signals["extensions"]:
scores[lang] += content_lower.count(ext) * 3.0
for f in signals["files"]:
if f.lower() in content_lower:
scores[lang] += 5.0
for cmd in signals["commands"]:
scores[lang] += content_lower.count(cmd.lower()) * 2.0
for err in signals["errors"]:
if err.lower() in content_lower:
scores[lang] += 4.0
return dict(scores)
def classify_conversation(all_text: str) -> str | None:
"""Classify concatenated conversation text to a single language."""
scores = classify_text(all_text)
if not scores:
return None
sorted_langs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
if len(sorted_langs) == 0:
return None
winner, winner_score = sorted_langs[0]
if winner_score < 5.0:
return None
if len(sorted_langs) > 1:
runner_up_score = sorted_langs[1][1]
if runner_up_score > 0 and winner_score / runner_up_score < 1.5:
return None # Ambiguous
return winner
# ── Tool call tools we care about for training ──────────────────────
TRAINING_TOOLS = {"bash", "read", "edit", "write", "glob", "grep", "todowrite", "question"}
# Max output length to include (truncate large tool outputs)
# 8192 tokens ≈ ~32K chars. Budget: system ~2K, user ~2K, leaves ~28K for assistant+tools.
# Each tool call+result pair: ~5002000 chars. Cap output at 2000 to fit more exchanges.
MAX_OUTPUT_LEN = 2000
def truncate_output(output: str, max_len: int = MAX_OUTPUT_LEN) -> str:
"""Truncate tool output to max_len chars."""
if len(output) <= max_len:
return output
return output[:max_len] + f"\n... (truncated, {len(output)} total chars)"
def extract_sessions(db_path: Path, target_lang: str | None = None) -> list[dict]:
"""Extract build-agent sessions from opencode DB.
Returns list of {session_id, title, messages: [...], raw_text: str}
where messages are in ChatML-like format suitable for training.
"""
conn = sqlite3.connect(db_path)
session_rows = conn.execute("""
SELECT id, title, time_created
FROM session
WHERE agent = 'build' OR agent LIKE 'build-%'
ORDER BY time_created
""").fetchall()
print(f"Found {len(session_rows)} build sessions")
all_conversations: list[dict] = []
for s_id, s_title, s_created in session_rows:
# Get messages for this session, ordered
msg_rows = conn.execute("""
SELECT m.id, json_extract(m.data, '$.role') as role,
json_extract(m.data, '$.finish') as finish
FROM message m
WHERE m.session_id = ?
ORDER BY m.time_created
""", (s_id,)).fetchall()
if len(msg_rows) < 2:
continue
# Get all parts for this session, grouped by message
part_rows = conn.execute("""
SELECT p.message_id,
json_extract(p.data, '$.type') as ptype,
p.data as pdata
FROM part p
WHERE p.session_id = ?
ORDER BY p.time_created
""", (s_id,)).fetchall()
# Group parts by message_id
msg_parts: dict[str, list[tuple[str, str]]] = defaultdict(list)
for p_msg_id, p_type, p_data in part_rows:
msg_parts[p_msg_id].append((p_type, p_data))
# Build ChatML messages
messages: list[dict[str, Any]] = []
raw_texts: list[str] = [] # for classification
for m_id, m_role, m_finish in msg_rows:
parts = msg_parts.get(m_id, [])
if m_role == "user":
# Extract user text
user_text = ""
for ptype, pdata in parts:
if ptype == "text":
pd = json.loads(pdata)
user_text += pd.get("text", "")
if user_text.strip():
messages.append({"role": "user", "content": user_text.strip()})
raw_texts.append(user_text)
elif m_role == "assistant":
# Collect text parts and tool calls
asst_text = ""
tool_calls: list[dict] = []
tool_results: list[dict] = []
for ptype, pdata in parts:
if ptype == "text":
pd = json.loads(pdata)
asst_text += pd.get("text", "")
elif ptype == "tool":
pd = json.loads(pdata)
tool_name = pd.get("tool", "")
call_id = pd.get("callID", "")
state = pd.get("state", {})
if tool_name not in TRAINING_TOOLS:
continue
if state.get("status") != "completed":
continue
tool_input = state.get("input", {})
tool_output = state.get("output", "")
# Build tool_call in OpenAI format
tool_calls.append({
"type": "function",
"id": call_id,
"function": {
"name": tool_name,
"arguments": tool_input,
},
})
# Build tool result
output_str = truncate_output(str(tool_output))
tool_results.append({
"role": "tool",
"tool_call_id": call_id,
"content": output_str,
})
# Collect for classification
raw_texts.append(json.dumps(tool_input))
raw_texts.append(output_str)
# Emit assistant message with tool calls
if tool_calls:
messages.append({
"role": "assistant",
"content": None,
"tool_calls": tool_calls,
})
messages.extend(tool_results)
# Emit text-only assistant message (after tools, or standalone)
if asst_text.strip():
messages.append({
"role": "assistant",
"content": asst_text.strip(),
})
raw_texts.append(asst_text)
if len(messages) < 3: # need at least user + assistant + something
continue
# Concatenate raw text for classification
raw_combined = " ".join(raw_texts)
# Early classification filter if target_lang specified
if target_lang:
lang = classify_conversation(raw_combined)
if lang != target_lang:
continue
all_conversations.append({
"session_id": s_id,
"title": s_title,
"messages": messages,
"raw_text": raw_combined,
})
conn.close()
return all_conversations
def window_conversations(
conversations: list[dict],
min_turns: int = 2,
max_turns: int = 10,
) -> list[dict]:
"""Split long conversations into training windows.
Each window captures a coherent exchange: user question → assistant response
including all tool calls and results within that exchange.
"""
windows: list[dict] = []
for conv in conversations:
msgs = conv["messages"]
# Find user message indices
user_indices = [i for i, m in enumerate(msgs) if m["role"] == "user"]
if len(user_indices) < min_turns:
# Short enough to use as-is
if len(user_indices) >= 1:
windows.append({
"session_id": conv["session_id"],
"title": conv["title"],
"messages": msgs,
"raw_text": conv.get("raw_text", ""),
})
continue
# Window by user-turn boundaries
for start in range(0, len(user_indices), max_turns):
end = min(start + max_turns, len(user_indices))
first_msg = user_indices[start]
# End at next user msg or end of conversation
if end < len(user_indices):
last_msg = user_indices[end]
else:
last_msg = len(msgs)
window_msgs = msgs[first_msg:last_msg]
# Skip windows that are too short
user_count = sum(1 for m in window_msgs if m["role"] == "user")
if user_count < 1:
continue
windows.append({
"session_id": conv["session_id"],
"title": conv["title"],
"messages": window_msgs,
"raw_text": " ".join(
m.get("content", "") or json.dumps(m.get("tool_calls", ""))
for m in window_msgs
),
})
return windows
def format_example(messages: list[dict], lang: str) -> dict:
"""Format a conversation window as a training example with system prompt."""
system_prompt = SYSTEM_PROMPTS.get(lang, f"You are a {lang} coding agent.")
# Clean up messages: ensure tool_call arguments are dicts
cleaned = []
for msg in messages:
msg = dict(msg)
if msg.get("tool_calls"):
new_tcs = []
for tc in msg["tool_calls"]:
tc = dict(tc)
if "function" in tc:
fn = dict(tc["function"])
if isinstance(fn.get("arguments"), str):
try:
fn["arguments"] = json.loads(fn["arguments"])
except (ValueError, TypeError):
fn["arguments"] = {"raw": fn["arguments"]}
tc["function"] = fn
new_tcs.append(tc)
msg["tool_calls"] = new_tcs
# Remove None content if no tool_calls
if msg.get("content") is None and not msg.get("tool_calls"):
continue
cleaned.append(msg)
return {
"messages": [{"role": "system", "content": system_prompt}] + cleaned,
}
def write_dataset(examples: list[dict], path: Path) -> None:
"""Write examples to JSONL file."""
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
for ex in examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
print(f" Wrote {len(examples)} examples → {path}")
def main() -> None:
parser = argparse.ArgumentParser(description="Extract specialist training data")
parser.add_argument(
"--db", type=Path,
default=Path.home() / ".local/share/opencode/opencode.db",
help="Path to opencode SQLite database",
)
parser.add_argument(
"--agents-dir", type=Path,
default=Path.home() / ".config/opencode/agents",
help="Path to agent system prompt directory",
)
parser.add_argument(
"--outdir", type=Path, default=Path("data"),
help="Output directory for JSONL files",
)
parser.add_argument(
"--lang", type=str, default=None,
help="Extract single language only (rust, typescript, python, ruby, swift)",
)
parser.add_argument(
"--min-turns", type=int, default=1,
help="Minimum user turns per training window",
)
parser.add_argument(
"--max-turns", type=int, default=10,
help="Maximum user turns per training window",
)
args = parser.parse_args()
print("══ Specialist Data Extraction ══")
print(f"DB: {args.db}")
print(f"Agents: {args.agents_dir}")
print(f"Output: {args.outdir}")
if args.lang:
print(f"Filter: {args.lang} only")
print()
# Load system prompts
load_system_prompts(args.agents_dir)
print(f"Loaded {len(SYSTEM_PROMPTS)} system prompts")
# Extract sessions
conversations = extract_sessions(args.db, target_lang=args.lang)
print(f"Extracted {len(conversations)} conversations")
# Window into training examples
windows = window_conversations(
conversations, min_turns=args.min_turns, max_turns=args.max_turns,
)
print(f"Created {len(windows)} training windows")
# Classify and bucket
buckets: dict[str, list[dict]] = defaultdict(list)
unclassified = 0
for window in windows:
if args.lang:
lang = args.lang
else:
lang = classify_conversation(window.get("raw_text", ""))
if lang:
example = format_example(window["messages"], lang)
buckets[lang].append(example)
else:
unclassified += 1
# Report
print(f"\n── Classification Results ──")
if not args.lang:
print(f"Unclassified: {unclassified}")
for lang, examples in sorted(buckets.items(), key=lambda x: -len(x[1])):
name = LANG_TO_NAME.get(lang, lang)
# Count tool calls and text-only
tc_count = sum(
1 for ex in examples
if any(m.get("tool_calls") for m in ex["messages"])
)
print(f" {name} ({lang}): {len(examples)} examples ({tc_count} with tool calls)")
# Write per-language datasets
print(f"\n── Writing Datasets ──")
for lang, examples in buckets.items():
name = LANG_TO_NAME.get(lang, lang)
write_dataset(examples, args.outdir / f"{name}.jsonl")
print(f"\nDone. Review datasets in {args.outdir}/")
print(f"Next steps:")
print(f" 1. python mine_repos.py --repos repos.json (add git diff examples)")
print(f" 2. Manual curation pass")
print(f" 3. python train_specialist.py --name <adapter>")
if __name__ == "__main__":
main()