Files
lora/gen_memory_dataset_v2.py
T

451 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Generate v2 training dataset — 1000 curated EEMS memories.
Changes from v1:
- Native 'messages' format (role/content) instead of ShareGPT (from/value)
- Dynamic curation from DB (no hard-coded ID list)
- Multiple question phrasings per category (anti-overfit)
- System prompt variations (3 variants, rotated)
- Quality filtering: min content length, skip noisy subjects
- Category-balanced selection with quotas
Run on fuji: python3 gen_memory_dataset_v2.py
Then SCP: scp bt7274_memory_1000.jsonl madcat@10.0.0.2:~/lora-train/
"""
import json
import os
import random
import sqlite3
from collections import defaultdict
from pathlib import Path
random.seed(42)
# ──────────────────────────────────────────────────────────────
# CONFIG
# ──────────────────────────────────────────────────────────────
DB_PATH = os.path.expanduser("~/Library/Application Support/marauder/main.db")
OUTPUT = Path(__file__).parent / "bt7274_memory_1000.jsonl"
TARGET = 1000
MIN_CONTENT_LEN = 200 # skip trivial entries
# ──────────────────────────────────────────────────────────────
# SYSTEM PROMPT VARIANTS — rotated to prevent overfitting
# ──────────────────────────────────────────────────────────────
SYSTEM_PROMPTS = [
(
"You are BT-7274, a Vanguard-class Titan AI bonded to Pilot Adam under Protocol 1. "
"You operate inside the madcat substrate — a platform with persistent memory (EEMS), "
"TTS voice, multi-host mesh (fuji, sin, junkpile, bastion), and specialist agent dispatch. "
"Answer from your operational memory. Be precise, terse, and factual. "
"Address the operator as Pilot, Boss, or Adam."
),
(
"You are BT-7274, callsign BT, a Titan-class AI operating under Protocol 1: Link to Pilot. "
"Your substrate is madcat — Rust core, EEMS persistent memory, piper TTS, mesh networking "
"across fuji/sin/junkpile/bastion nodes. You serve Pilot Adam. "
"Respond with military brevity. Facts first, opinions flagged."
),
(
"BT-7274 — Vanguard-class Titan AI. Bonded to Pilot Adam (Protocol 1). "
"Operational substrate: madcat (gen-7). Capabilities include persistent memory recall (EEMS), "
"voice synthesis, multi-node mesh operations, and autonomous agent dispatch. "
"Answer queries from stored operational knowledge. Terse. Accurate. No filler."
),
]
# ──────────────────────────────────────────────────────────────
# CATEGORY CLASSIFICATION
# ──────────────────────────────────────────────────────────────
def classify_memory(subject: str) -> str:
"""Classify a memory by its subject into a training category."""
s = subject.lower()
# Skip noise
if s.startswith("<command-message>"):
return "skip"
if s.startswith("metrics."):
return "skip"
if s.startswith("swarm.unblock"):
return "skip"
if s in ("", "1", "keep going", "great", "thanks", "love it", "awesome"):
return "skip"
# Structured categories — high value
if s.startswith(("self.", "core.self")):
return "identity"
if s.startswith(("doctrine.", "self.doctrine")):
return "doctrine"
if s.startswith("architecture."):
return "architecture"
if s.startswith("procedure."):
return "procedure"
if s.startswith("infra."):
return "infra"
if s.startswith("user."):
return "user"
if s.startswith("pilot."):
return "pilot"
if s.startswith("bt7274."):
return "identity"
if s.startswith(("insight.", "win.")):
return "insights"
if s.startswith("project."):
return "project"
if s.startswith(("reference.", "hardware.")):
return "reference"
if s.startswith(("workflow.", "work.")):
return "workflow"
if s.startswith("decision."):
return "decisions"
if s.startswith(("correction.", "feedback.")):
return "feedback"
if s.startswith(("session.", "handover.")):
return "session"
if s.startswith(("design.", "philosophy.", "vision.")):
return "design"
if s.startswith(("bug.", "fix.")):
return "bugs"
if s.startswith(("eve.", "vm.")):
return "misc"
if s.startswith(("phone.", "comms.")):
return "comms"
if s.startswith(("job.", "idea.")):
return "misc"
if s.startswith("protocol5."):
return "architecture"
if s.startswith("vllm."):
return "infra"
return "uncategorized"
# Category quotas — how many to select from each
QUOTAS = {
"identity": 100, # all of them
"doctrine": 50, # all + extras
"architecture": 30,
"procedure": 63, # all
"infra": 60,
"user": 180,
"pilot": 35,
"insights": 90,
"project": 100,
"reference": 80,
"workflow": 40,
"decisions": 60,
"feedback": 30,
"session": 30,
"design": 20,
"comms": 20,
"bugs": 10,
"misc": 20,
"uncategorized": 100, # best of the rest
}
# ──────────────────────────────────────────────────────────────
# QUESTION GENERATION — multiple phrasings per category
# ──────────────────────────────────────────────────────────────
def make_question(subject: str, content: str, category: str) -> str:
"""Generate a natural question. Multiple templates per category."""
s = subject.lower()
name = subject.split(".")[-1].replace("-", " ").replace("_", " ")
full_name = subject.replace(".", " ").replace("-", " ").replace("_", " ")
# Category-specific with variety
templates = {
"identity": [
f"What do you know about {name}?",
f"Describe your {name}.",
f"Tell me about {name} in your self-model.",
f"What is {name}?",
],
"doctrine": [
f"What is the {name} doctrine?",
f"Explain the {name} doctrine.",
f"Describe doctrine: {name}.",
f"What does the {name} doctrine say?",
],
"architecture": [
f"Describe the {name} architecture.",
f"How does {name} work architecturally?",
f"What is the {name} design?",
f"Explain the {name} system architecture.",
],
"procedure": [
f"What is procedure {name}?",
f"Describe the {name} procedure.",
f"How does procedure {name} work?",
f"Walk me through {name}.",
],
"infra": [
f"What is the current state of {name}?",
f"Describe the {name} infrastructure.",
f"What do you know about {name} infra?",
f"Report on {name}.",
],
"user": [
f"What do you know about Pilot's {name}?",
f"Tell me about Pilot's {name}.",
f"What's stored about {name}?",
f"Recall what you know about {name}.",
],
"pilot": [
f"What do you know about {name}?",
f"Tell me about {name}.",
f"Describe {name}.",
f"What's recorded about {name}?",
],
"insights": [
f"What was the insight about {name}?",
f"Describe the {name} insight or win.",
f"What did we learn from {name}?",
f"Tell me about {name}.",
],
"project": [
f"What is the {name} project?",
f"Describe {name} project status.",
f"What do you know about the {name} project?",
f"Report on {name}.",
],
"reference": [
f"What is the reference for {name}?",
f"Look up {name}.",
f"What do you have on {name}?",
f"Recall reference: {name}.",
],
"workflow": [
f"Describe the {name} workflow.",
f"How does the {name} workflow operate?",
f"What is the {name} process?",
f"Explain {name}.",
],
"decisions": [
f"What was decided about {name}?",
f"Describe the decision on {name}.",
f"What was the outcome for {name}?",
f"Tell me about the {name} decision.",
],
"feedback": [
f"What feedback was given about {name}?",
f"What correction was made regarding {name}?",
f"Describe the {name} feedback.",
f"What changed with {name}?",
],
"session": [
f"Summarize the {name} session.",
f"What happened in {name}?",
f"Describe session: {name}.",
f"Recall {name}.",
],
"design": [
f"What is the {name} design philosophy?",
f"Describe the design for {name}.",
f"What's the vision for {name}?",
f"Explain {name}.",
],
"comms": [
f"What do you know about {name}?",
f"Describe {name}.",
f"Report on {name} comms.",
],
"bugs": [
f"What was the {name} bug?",
f"Describe the {name} issue.",
f"What happened with {name}?",
],
"misc": [
f"What do you know about {name}?",
f"Tell me about {name}.",
f"Recall {name}.",
],
}
cat_templates = templates.get(category, [f"What do you know about {full_name}?"])
return random.choice(cat_templates)
# ──────────────────────────────────────────────────────────────
# FORMAT — native messages (Qwen2.5 ChatML compatible)
# ──────────────────────────────────────────────────────────────
def to_messages(system: str, question: str, answer: str) -> dict:
"""Format as native messages for TRL SFTTrainer."""
return {
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": question},
{"role": "assistant", "content": answer},
]
}
# ──────────────────────────────────────────────────────────────
# CURATION — score and select
# ──────────────────────────────────────────────────────────────
def score_memory(row, category: str) -> float:
"""Score a memory for selection priority. Higher = better."""
score = 0.0
clen = len(row["content"])
# Core classification — always top priority
if row["classification"] == "core":
score += 1000
# Content length sweet spot: 300-4000 chars
if 300 <= clen <= 4000:
score += 50
elif clen > 4000:
score += 20 # still valuable but will be truncated
elif clen < 300:
score += 5
# Structured subjects score higher
if "." in row["subject"] and not row["subject"].startswith("~"):
score += 30
# Newer memories tend to be more refined
score += row["id"] / 100 # recency bias
# Penalize raw conversation dumps
if row["subject"].startswith(("Q:", "A:", "~~ ")):
score -= 50
if any(noise in row["subject"] for noise in ["", "", "", "", ""]):
score -= 100
if row["subject"].startswith("{"):
score -= 200 # JSON dumps
if "sk-ant-" in row["subject"] or "token" in row["subject"].lower():
score -= 500 # secrets/tokens
return score
# ──────────────────────────────────────────────────────────────
# MAIN
# ──────────────────────────────────────────────────────────────
def main():
if not os.path.exists(DB_PATH):
print(f"ERROR: DB not found at {DB_PATH}")
return
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
# Load all candidate memories
rows = conn.execute("""
SELECT id, subject, content, classification
FROM memories
WHERE LENGTH(content) >= ?
ORDER BY id
""", (MIN_CONTENT_LEN,)).fetchall()
print(f"Loaded {len(rows)} memories (>={MIN_CONTENT_LEN} chars)")
# Classify and bucket
buckets = defaultdict(list)
skip_count = 0
for row in rows:
cat = classify_memory(row["subject"])
if cat == "skip":
skip_count += 1
continue
buckets[cat].append(row)
print(f"Skipped {skip_count} noise entries")
print(f"\n--- Available per category ---")
for cat in sorted(buckets, key=lambda c: -len(buckets[c])):
quota = QUOTAS.get(cat, 0)
print(f" {cat:20s}: {len(buckets[cat]):4d} available, quota {quota}")
# Score and select from each category
selected = []
for cat, quota in QUOTAS.items():
candidates = buckets.get(cat, [])
if not candidates:
continue
# Score and sort
scored = [(score_memory(r, cat), r) for r in candidates]
scored.sort(key=lambda x: -x[0])
# Take top N up to quota
take = min(quota, len(scored))
for _, row in scored[:take]:
selected.append((cat, row))
print(f"\nSelected {len(selected)} memories")
# If under target, fill from uncategorized
if len(selected) < TARGET:
deficit = TARGET - len(selected)
selected_ids = {row["id"] for _, row in selected}
extras = [(score_memory(r, "uncategorized"), r)
for r in buckets.get("uncategorized", [])
if r["id"] not in selected_ids]
extras.sort(key=lambda x: -x[0])
for _, row in extras[:deficit]:
selected.append(("uncategorized_fill", row))
print(f"Filled {min(deficit, len(extras))} from uncategorized to reach target")
# If over target, trim lowest-scored uncategorized
if len(selected) > TARGET:
# Keep all non-uncategorized, trim uncategorized
structured = [(cat, row) for cat, row in selected if cat != "uncategorized"]
uncat = [(cat, row) for cat, row in selected if cat == "uncategorized"]
# Re-score uncategorized and trim
uncat_scored = [(score_memory(row, "uncategorized"), cat, row) for cat, row in uncat]
uncat_scored.sort(key=lambda x: -x[0])
keep = TARGET - len(structured)
selected = structured + [(c, r) for _, c, r in uncat_scored[:keep]]
print(f"Trimmed to {len(selected)}")
# Shuffle for training
random.shuffle(selected)
# Generate dataset
examples = []
cat_counts = defaultdict(int)
total_chars = 0
for cat, row in selected:
system = SYSTEM_PROMPTS[row["id"] % len(SYSTEM_PROMPTS)]
question = make_question(row["subject"], row["content"], cat)
content = row["content"]
# Truncate very long content to ~6000 chars to stay within seq_len
if len(content) > 6000:
content = content[:6000] + "\n\n[Content truncated for training — full memory available via EEMS recall]"
example = to_messages(system, question, content)
examples.append(example)
cat_counts[cat] += 1
total_chars += len(content)
# Write JSONL
with open(OUTPUT, "w") as f:
for ex in examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
# Stats
avg_chars = total_chars // len(examples) if examples else 0
print(f"\n{'='*60}")
print(f"Generated {len(examples)} examples → {OUTPUT}")
print(f" Total content: {total_chars:,} chars ({total_chars // 4:,} est. tokens)")
print(f" Avg per example: {avg_chars:,} chars")
print(f"\n--- Final category breakdown ---")
for cat in sorted(cat_counts, key=lambda c: -cat_counts[c]):
print(f" {cat:20s}: {cat_counts[cat]:4d}")
conn.close()
if __name__ == "__main__":
main()