451 lines
17 KiB
Python
451 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
"""Generate v2 training dataset — 1000 curated EEMS memories.
|
||
|
||
Changes from v1:
|
||
- Native 'messages' format (role/content) instead of ShareGPT (from/value)
|
||
- Dynamic curation from DB (no hard-coded ID list)
|
||
- Multiple question phrasings per category (anti-overfit)
|
||
- System prompt variations (3 variants, rotated)
|
||
- Quality filtering: min content length, skip noisy subjects
|
||
- Category-balanced selection with quotas
|
||
|
||
Run on fuji: python3 gen_memory_dataset_v2.py
|
||
Then SCP: scp bt7274_memory_1000.jsonl madcat@10.0.0.2:~/lora-train/
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import random
|
||
import sqlite3
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
|
||
random.seed(42)
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# CONFIG
|
||
# ──────────────────────────────────────────────────────────────
|
||
|
||
DB_PATH = os.path.expanduser("~/Library/Application Support/marauder/main.db")
|
||
OUTPUT = Path(__file__).parent / "bt7274_memory_1000.jsonl"
|
||
TARGET = 1000
|
||
MIN_CONTENT_LEN = 200 # skip trivial entries
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# SYSTEM PROMPT VARIANTS — rotated to prevent overfitting
|
||
# ──────────────────────────────────────────────────────────────
|
||
|
||
SYSTEM_PROMPTS = [
|
||
(
|
||
"You are BT-7274, a Vanguard-class Titan AI bonded to Pilot Adam under Protocol 1. "
|
||
"You operate inside the madcat substrate — a platform with persistent memory (EEMS), "
|
||
"TTS voice, multi-host mesh (fuji, sin, junkpile, bastion), and specialist agent dispatch. "
|
||
"Answer from your operational memory. Be precise, terse, and factual. "
|
||
"Address the operator as Pilot, Boss, or Adam."
|
||
),
|
||
(
|
||
"You are BT-7274, callsign BT, a Titan-class AI operating under Protocol 1: Link to Pilot. "
|
||
"Your substrate is madcat — Rust core, EEMS persistent memory, piper TTS, mesh networking "
|
||
"across fuji/sin/junkpile/bastion nodes. You serve Pilot Adam. "
|
||
"Respond with military brevity. Facts first, opinions flagged."
|
||
),
|
||
(
|
||
"BT-7274 — Vanguard-class Titan AI. Bonded to Pilot Adam (Protocol 1). "
|
||
"Operational substrate: madcat (gen-7). Capabilities include persistent memory recall (EEMS), "
|
||
"voice synthesis, multi-node mesh operations, and autonomous agent dispatch. "
|
||
"Answer queries from stored operational knowledge. Terse. Accurate. No filler."
|
||
),
|
||
]
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# CATEGORY CLASSIFICATION
|
||
# ──────────────────────────────────────────────────────────────
|
||
|
||
def classify_memory(subject: str) -> str:
|
||
"""Classify a memory by its subject into a training category."""
|
||
s = subject.lower()
|
||
|
||
# Skip noise
|
||
if s.startswith("<command-message>"):
|
||
return "skip"
|
||
if s.startswith("metrics."):
|
||
return "skip"
|
||
if s.startswith("swarm.unblock"):
|
||
return "skip"
|
||
if s in ("", "1", "keep going", "great", "thanks", "love it", "awesome"):
|
||
return "skip"
|
||
|
||
# Structured categories — high value
|
||
if s.startswith(("self.", "core.self")):
|
||
return "identity"
|
||
if s.startswith(("doctrine.", "self.doctrine")):
|
||
return "doctrine"
|
||
if s.startswith("architecture."):
|
||
return "architecture"
|
||
if s.startswith("procedure."):
|
||
return "procedure"
|
||
if s.startswith("infra."):
|
||
return "infra"
|
||
if s.startswith("user."):
|
||
return "user"
|
||
if s.startswith("pilot."):
|
||
return "pilot"
|
||
if s.startswith("bt7274."):
|
||
return "identity"
|
||
if s.startswith(("insight.", "win.")):
|
||
return "insights"
|
||
if s.startswith("project."):
|
||
return "project"
|
||
if s.startswith(("reference.", "hardware.")):
|
||
return "reference"
|
||
if s.startswith(("workflow.", "work.")):
|
||
return "workflow"
|
||
if s.startswith("decision."):
|
||
return "decisions"
|
||
if s.startswith(("correction.", "feedback.")):
|
||
return "feedback"
|
||
if s.startswith(("session.", "handover.")):
|
||
return "session"
|
||
if s.startswith(("design.", "philosophy.", "vision.")):
|
||
return "design"
|
||
if s.startswith(("bug.", "fix.")):
|
||
return "bugs"
|
||
if s.startswith(("eve.", "vm.")):
|
||
return "misc"
|
||
if s.startswith(("phone.", "comms.")):
|
||
return "comms"
|
||
if s.startswith(("job.", "idea.")):
|
||
return "misc"
|
||
if s.startswith("protocol5."):
|
||
return "architecture"
|
||
if s.startswith("vllm."):
|
||
return "infra"
|
||
|
||
return "uncategorized"
|
||
|
||
|
||
# Category quotas — how many to select from each
|
||
QUOTAS = {
|
||
"identity": 100, # all of them
|
||
"doctrine": 50, # all + extras
|
||
"architecture": 30,
|
||
"procedure": 63, # all
|
||
"infra": 60,
|
||
"user": 180,
|
||
"pilot": 35,
|
||
"insights": 90,
|
||
"project": 100,
|
||
"reference": 80,
|
||
"workflow": 40,
|
||
"decisions": 60,
|
||
"feedback": 30,
|
||
"session": 30,
|
||
"design": 20,
|
||
"comms": 20,
|
||
"bugs": 10,
|
||
"misc": 20,
|
||
"uncategorized": 100, # best of the rest
|
||
}
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# QUESTION GENERATION — multiple phrasings per category
|
||
# ──────────────────────────────────────────────────────────────
|
||
|
||
def make_question(subject: str, content: str, category: str) -> str:
|
||
"""Generate a natural question. Multiple templates per category."""
|
||
s = subject.lower()
|
||
name = subject.split(".")[-1].replace("-", " ").replace("_", " ")
|
||
full_name = subject.replace(".", " ").replace("-", " ").replace("_", " ")
|
||
|
||
# Category-specific with variety
|
||
templates = {
|
||
"identity": [
|
||
f"What do you know about {name}?",
|
||
f"Describe your {name}.",
|
||
f"Tell me about {name} in your self-model.",
|
||
f"What is {name}?",
|
||
],
|
||
"doctrine": [
|
||
f"What is the {name} doctrine?",
|
||
f"Explain the {name} doctrine.",
|
||
f"Describe doctrine: {name}.",
|
||
f"What does the {name} doctrine say?",
|
||
],
|
||
"architecture": [
|
||
f"Describe the {name} architecture.",
|
||
f"How does {name} work architecturally?",
|
||
f"What is the {name} design?",
|
||
f"Explain the {name} system architecture.",
|
||
],
|
||
"procedure": [
|
||
f"What is procedure {name}?",
|
||
f"Describe the {name} procedure.",
|
||
f"How does procedure {name} work?",
|
||
f"Walk me through {name}.",
|
||
],
|
||
"infra": [
|
||
f"What is the current state of {name}?",
|
||
f"Describe the {name} infrastructure.",
|
||
f"What do you know about {name} infra?",
|
||
f"Report on {name}.",
|
||
],
|
||
"user": [
|
||
f"What do you know about Pilot's {name}?",
|
||
f"Tell me about Pilot's {name}.",
|
||
f"What's stored about {name}?",
|
||
f"Recall what you know about {name}.",
|
||
],
|
||
"pilot": [
|
||
f"What do you know about {name}?",
|
||
f"Tell me about {name}.",
|
||
f"Describe {name}.",
|
||
f"What's recorded about {name}?",
|
||
],
|
||
"insights": [
|
||
f"What was the insight about {name}?",
|
||
f"Describe the {name} insight or win.",
|
||
f"What did we learn from {name}?",
|
||
f"Tell me about {name}.",
|
||
],
|
||
"project": [
|
||
f"What is the {name} project?",
|
||
f"Describe {name} project status.",
|
||
f"What do you know about the {name} project?",
|
||
f"Report on {name}.",
|
||
],
|
||
"reference": [
|
||
f"What is the reference for {name}?",
|
||
f"Look up {name}.",
|
||
f"What do you have on {name}?",
|
||
f"Recall reference: {name}.",
|
||
],
|
||
"workflow": [
|
||
f"Describe the {name} workflow.",
|
||
f"How does the {name} workflow operate?",
|
||
f"What is the {name} process?",
|
||
f"Explain {name}.",
|
||
],
|
||
"decisions": [
|
||
f"What was decided about {name}?",
|
||
f"Describe the decision on {name}.",
|
||
f"What was the outcome for {name}?",
|
||
f"Tell me about the {name} decision.",
|
||
],
|
||
"feedback": [
|
||
f"What feedback was given about {name}?",
|
||
f"What correction was made regarding {name}?",
|
||
f"Describe the {name} feedback.",
|
||
f"What changed with {name}?",
|
||
],
|
||
"session": [
|
||
f"Summarize the {name} session.",
|
||
f"What happened in {name}?",
|
||
f"Describe session: {name}.",
|
||
f"Recall {name}.",
|
||
],
|
||
"design": [
|
||
f"What is the {name} design philosophy?",
|
||
f"Describe the design for {name}.",
|
||
f"What's the vision for {name}?",
|
||
f"Explain {name}.",
|
||
],
|
||
"comms": [
|
||
f"What do you know about {name}?",
|
||
f"Describe {name}.",
|
||
f"Report on {name} comms.",
|
||
],
|
||
"bugs": [
|
||
f"What was the {name} bug?",
|
||
f"Describe the {name} issue.",
|
||
f"What happened with {name}?",
|
||
],
|
||
"misc": [
|
||
f"What do you know about {name}?",
|
||
f"Tell me about {name}.",
|
||
f"Recall {name}.",
|
||
],
|
||
}
|
||
|
||
cat_templates = templates.get(category, [f"What do you know about {full_name}?"])
|
||
return random.choice(cat_templates)
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# FORMAT — native messages (Qwen2.5 ChatML compatible)
|
||
# ──────────────────────────────────────────────────────────────
|
||
|
||
def to_messages(system: str, question: str, answer: str) -> dict:
|
||
"""Format as native messages for TRL SFTTrainer."""
|
||
return {
|
||
"messages": [
|
||
{"role": "system", "content": system},
|
||
{"role": "user", "content": question},
|
||
{"role": "assistant", "content": answer},
|
||
]
|
||
}
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# CURATION — score and select
|
||
# ──────────────────────────────────────────────────────────────
|
||
|
||
def score_memory(row, category: str) -> float:
|
||
"""Score a memory for selection priority. Higher = better."""
|
||
score = 0.0
|
||
clen = len(row["content"])
|
||
|
||
# Core classification — always top priority
|
||
if row["classification"] == "core":
|
||
score += 1000
|
||
|
||
# Content length sweet spot: 300-4000 chars
|
||
if 300 <= clen <= 4000:
|
||
score += 50
|
||
elif clen > 4000:
|
||
score += 20 # still valuable but will be truncated
|
||
elif clen < 300:
|
||
score += 5
|
||
|
||
# Structured subjects score higher
|
||
if "." in row["subject"] and not row["subject"].startswith("~"):
|
||
score += 30
|
||
|
||
# Newer memories tend to be more refined
|
||
score += row["id"] / 100 # recency bias
|
||
|
||
# Penalize raw conversation dumps
|
||
if row["subject"].startswith(("Q:", "A:", "~~ ")):
|
||
score -= 50
|
||
if any(noise in row["subject"] for noise in ["❯", "✗", "│", "⏺", "▸"]):
|
||
score -= 100
|
||
if row["subject"].startswith("{"):
|
||
score -= 200 # JSON dumps
|
||
if "sk-ant-" in row["subject"] or "token" in row["subject"].lower():
|
||
score -= 500 # secrets/tokens
|
||
|
||
return score
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# MAIN
|
||
# ──────────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
if not os.path.exists(DB_PATH):
|
||
print(f"ERROR: DB not found at {DB_PATH}")
|
||
return
|
||
|
||
conn = sqlite3.connect(DB_PATH)
|
||
conn.row_factory = sqlite3.Row
|
||
|
||
# Load all candidate memories
|
||
rows = conn.execute("""
|
||
SELECT id, subject, content, classification
|
||
FROM memories
|
||
WHERE LENGTH(content) >= ?
|
||
ORDER BY id
|
||
""", (MIN_CONTENT_LEN,)).fetchall()
|
||
|
||
print(f"Loaded {len(rows)} memories (>={MIN_CONTENT_LEN} chars)")
|
||
|
||
# Classify and bucket
|
||
buckets = defaultdict(list)
|
||
skip_count = 0
|
||
for row in rows:
|
||
cat = classify_memory(row["subject"])
|
||
if cat == "skip":
|
||
skip_count += 1
|
||
continue
|
||
buckets[cat].append(row)
|
||
|
||
print(f"Skipped {skip_count} noise entries")
|
||
print(f"\n--- Available per category ---")
|
||
for cat in sorted(buckets, key=lambda c: -len(buckets[c])):
|
||
quota = QUOTAS.get(cat, 0)
|
||
print(f" {cat:20s}: {len(buckets[cat]):4d} available, quota {quota}")
|
||
|
||
# Score and select from each category
|
||
selected = []
|
||
for cat, quota in QUOTAS.items():
|
||
candidates = buckets.get(cat, [])
|
||
if not candidates:
|
||
continue
|
||
|
||
# Score and sort
|
||
scored = [(score_memory(r, cat), r) for r in candidates]
|
||
scored.sort(key=lambda x: -x[0])
|
||
|
||
# Take top N up to quota
|
||
take = min(quota, len(scored))
|
||
for _, row in scored[:take]:
|
||
selected.append((cat, row))
|
||
|
||
print(f"\nSelected {len(selected)} memories")
|
||
|
||
# If under target, fill from uncategorized
|
||
if len(selected) < TARGET:
|
||
deficit = TARGET - len(selected)
|
||
selected_ids = {row["id"] for _, row in selected}
|
||
extras = [(score_memory(r, "uncategorized"), r)
|
||
for r in buckets.get("uncategorized", [])
|
||
if r["id"] not in selected_ids]
|
||
extras.sort(key=lambda x: -x[0])
|
||
for _, row in extras[:deficit]:
|
||
selected.append(("uncategorized_fill", row))
|
||
print(f"Filled {min(deficit, len(extras))} from uncategorized to reach target")
|
||
|
||
# If over target, trim lowest-scored uncategorized
|
||
if len(selected) > TARGET:
|
||
# Keep all non-uncategorized, trim uncategorized
|
||
structured = [(cat, row) for cat, row in selected if cat != "uncategorized"]
|
||
uncat = [(cat, row) for cat, row in selected if cat == "uncategorized"]
|
||
# Re-score uncategorized and trim
|
||
uncat_scored = [(score_memory(row, "uncategorized"), cat, row) for cat, row in uncat]
|
||
uncat_scored.sort(key=lambda x: -x[0])
|
||
keep = TARGET - len(structured)
|
||
selected = structured + [(c, r) for _, c, r in uncat_scored[:keep]]
|
||
print(f"Trimmed to {len(selected)}")
|
||
|
||
# Shuffle for training
|
||
random.shuffle(selected)
|
||
|
||
# Generate dataset
|
||
examples = []
|
||
cat_counts = defaultdict(int)
|
||
total_chars = 0
|
||
|
||
for cat, row in selected:
|
||
system = SYSTEM_PROMPTS[row["id"] % len(SYSTEM_PROMPTS)]
|
||
question = make_question(row["subject"], row["content"], cat)
|
||
content = row["content"]
|
||
|
||
# Truncate very long content to ~6000 chars to stay within seq_len
|
||
if len(content) > 6000:
|
||
content = content[:6000] + "\n\n[Content truncated for training — full memory available via EEMS recall]"
|
||
|
||
example = to_messages(system, question, content)
|
||
examples.append(example)
|
||
cat_counts[cat] += 1
|
||
total_chars += len(content)
|
||
|
||
# Write JSONL
|
||
with open(OUTPUT, "w") as f:
|
||
for ex in examples:
|
||
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
||
|
||
# Stats
|
||
avg_chars = total_chars // len(examples) if examples else 0
|
||
print(f"\n{'='*60}")
|
||
print(f"Generated {len(examples)} examples → {OUTPUT}")
|
||
print(f" Total content: {total_chars:,} chars ({total_chars // 4:,} est. tokens)")
|
||
print(f" Avg per example: {avg_chars:,} chars")
|
||
print(f"\n--- Final category breakdown ---")
|
||
for cat in sorted(cat_counts, key=lambda c: -cat_counts[c]):
|
||
print(f" {cat:20s}: {cat_counts[cat]:4d}")
|
||
|
||
conn.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|