feat(templates): add bulk template extraction from CivitAI showcase

Adds the `tsr templates` subapp with extract / list / show / delete
commands. The headline command is:

    tsr templates extract <model> [--generate]

which pulls a model's CivitAI showcase, deduplicates prompts, derives
recommended generation params (sampler / scheduler / steps / cfg /
guidance) from the *mode* of the showcase image metadata, and writes one
JSON template per unique prompt to
~/.local/share/tensors/templates/<model_stem>/<scene_name>.json.

Each emitted template has the same shape as `tsr template -m <model>`
output and feeds straight into `tsr generate --input`. `--generate`
chains the generation step end-to-end.

New module `tensors/templates.py` carries the storage + extraction
helpers (`save_template`, `load_template`, `list_templates`,
`build_template`, `param_from_civitai_meta`,
`derive_overrides_from_images`) plus the CivitAI A1111 → ComfyUI
sampler/scheduler name normalization tables.

Replaces the external generate_templates.py wrapper script that was
maintaining a hand-curated MODEL_OVERRIDES dict — overrides now derive
automatically from each model's own showcase data.
This commit is contained in:
2026-05-18 21:04:21 +02:00
parent 2cbef237df
commit aa25d31ca9
2 changed files with 537 additions and 0 deletions
+256
View File
@@ -2909,6 +2909,262 @@ def scene_delete(
raise typer.Exit(1)
# ---- templates ----
templates_app = typer.Typer(
name="templates",
help="Bulk-extract, list, and run generation templates derived from CivitAI showcase data.",
no_args_is_help=True,
)
app.add_typer(templates_app)
@templates_app.command("extract")
def templates_extract(
model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")],
orientation: Annotated[
str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")
] = "portrait",
no_overrides: Annotated[
bool,
typer.Option(
"--no-overrides",
help="Skip auto-derived params from showcase image meta; use family defaults only",
),
] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
limit: Annotated[
int, typer.Option("--limit", "-L", help="Max templates to write (0 = all unique prompts)")
] = 0,
overwrite: Annotated[
bool, typer.Option("--overwrite", help="Overwrite existing template files (default: skip)")
] = False,
do_generate: Annotated[
bool,
typer.Option("--generate", help="After writing, run `tsr generate --input` for each emitted template"),
] = False,
output_dir: Annotated[
Path | None,
typer.Option(
"--output-dir", help="Where to write generated images when --generate (default: ComfyUI output dir)"
),
] = None,
dry_run: Annotated[
bool, typer.Option("--dry-run", help="Print what would be done; write nothing")
] = False,
) -> None:
"""Bulk-extract templates from a model's CivitAI showcase.
Pulls showcase images, deduplicates prompts, derives recommended generation
params (sampler / scheduler / steps / cfg / guidance) from the *mode* of the
showcase image metadata, and writes one JSON template per unique prompt to
``~/.local/share/tensors/templates/<model_stem>/<scene_name>.json``.
Each emitted template is ready to feed straight to ``tsr generate --input``.
Examples:
tsr templates extract lust_v10.safetensors
tsr templates extract bodySliderFitness_v10 -O portrait --generate
tsr templates extract ultrasenseInfinity_v10 --dry-run
tsr templates extract getphat_v5 --no-overrides # use tsr family defaults only
"""
import subprocess # noqa: PLC0415
from tensors.api import fetch_civitai_model_version # noqa: PLC0415
from tensors.config import ( # noqa: PLC0415
detect_model_family,
get_model_generation_defaults,
load_api_key,
resolve_orientation,
)
from tensors.fragments import parse_elements # noqa: PLC0415
from tensors.templates import ( # noqa: PLC0415
build_template,
derive_overrides_from_images,
save_template,
template_path,
)
with Database() as db:
files = db.list_local_files()
target_file = None
for f in files:
file_path = Path(f["file_path"])
if model in (file_path.name, file_path.stem):
target_file = f
break
if not target_file:
console.print(f"[red]Model '{model}' not found in local database. Run 'tsr db scan' first.[/red]")
raise typer.Exit(1)
vid = target_file["civitai_version_id"]
if not vid:
console.print(f"[red]Model '{model}' is not linked to CivitAI. Run 'tsr db link' first.[/red]")
raise typer.Exit(1)
model_stem = Path(target_file["file_path"]).stem
model_filename = Path(target_file["file_path"]).name
base_model_str = target_file.get("base_model")
console.print(f"[cyan]Fetching showcase for {model_stem} (version {vid})...[/cyan]")
data = fetch_civitai_model_version(vid, api_key or load_api_key(), console=console)
if not data:
console.print("[red]Failed to fetch CivitAI data.[/red]")
raise typer.Exit(1)
images = data.get("images", [])
enriched = sum(1 for img in images if img.get("meta"))
overrides = {} if no_overrides else derive_overrides_from_images(images)
if overrides:
console.print(f"[cyan]Derived overrides from {enriched} enriched image(s):[/cyan] {overrides}")
elif not no_overrides:
console.print("[yellow]No usable param meta in showcase; using family defaults only.[/yellow]")
family = detect_model_family(model_filename, base_model_str)
defaults = get_model_generation_defaults(model_filename, base_model_str)
res_w, res_h = resolve_orientation(family, orientation)
seen_prompts: set[str] = set()
emitted: list[Path] = []
skipped_existing = 0
skipped_no_prompt = 0
for img in images:
meta = img.get("meta") or {}
prompt = meta.get("prompt")
if not prompt:
skipped_no_prompt += 1
continue
normalized = prompt.lower().strip()
if normalized in seen_prompts:
continue
seen_prompts.add(normalized)
scene_elements = parse_elements(prompt)
if not scene_elements:
continue
idx = len(emitted) + skipped_existing + 1
name = f"{model_stem}_{idx:02d}"
out_path = template_path(model_stem, name)
if out_path.is_file() and not overwrite:
console.print(f"[yellow]Skip (exists, use --overwrite to replace):[/yellow] {out_path}")
skipped_existing += 1
continue
tpl = build_template(
model_filename=model_filename,
family=family,
defaults=defaults,
base_model_str=base_model_str,
width=res_w,
height=res_h,
orientation=orientation,
scene_elements=scene_elements,
scene_name=name,
overrides=overrides,
)
if dry_run:
console.print(f"[dim](dry-run) Would write:[/dim] {out_path}")
emitted.append(out_path)
else:
saved = save_template(model_stem, name, tpl)
console.print(f"[green]Saved ({len(scene_elements)} scene elements):[/green] {saved}")
emitted.append(saved)
if limit and len(emitted) >= limit:
break
console.print(
f"\n[bold]Extract summary:[/bold] emitted={len(emitted)} skipped_existing={skipped_existing} "
f"images_no_prompt={skipped_no_prompt}"
)
if not emitted:
return
if do_generate:
if dry_run:
console.print("[yellow]--dry-run is set; skipping --generate phase.[/yellow]")
return
console.print(f"\n[cyan]Generating images for {len(emitted)} template(s)...[/cyan]")
for tpl_path in emitted:
out_arg = []
if output_dir:
output_dir.mkdir(parents=True, exist_ok=True)
out_arg = ["-o", str(output_dir / f"{tpl_path.stem}.png")]
cmd = ["tsr", "generate", "--input", str(tpl_path), *out_arg]
console.print(f"\n[cyan]$ {' '.join(cmd)}[/cyan]")
subprocess.run(cmd, check=False)
@templates_app.command("list")
def templates_list(
model: Annotated[str | None, typer.Argument(help="Filter by model stem (optional)")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List saved templates, grouped by model."""
from tensors.templates import TEMPLATES_DIR, list_templates # noqa: PLC0415
items = list_templates(model)
if json_output:
console.print_json(
data={"dir": str(TEMPLATES_DIR), "templates": [{"model": m, "name": n} for m, n in items]}
)
return
if not items:
scope = f" for model '{model}'" if model else ""
console.print(f"[yellow]No templates{scope} in {TEMPLATES_DIR}.[/yellow]")
console.print("[dim]Create some with: tsr templates extract <model>[/dim]")
return
cur_model = None
for m, n in items:
if m != cur_model:
console.print(f"\n[cyan]{m}[/cyan]")
cur_model = m
console.print(f" {n}")
@templates_app.command("show")
def templates_show(
model: Annotated[str, typer.Argument(help="Model stem (directory name under templates/)")],
name: Annotated[str, typer.Argument(help="Template name (filename without .json)")],
) -> None:
"""Print a saved template as JSON."""
from tensors.templates import load_template # noqa: PLC0415
try:
data = load_template(model, name)
except FileNotFoundError as e:
console.print(f"[red]{e}[/red]")
raise typer.Exit(1) from e
console.print_json(data=data)
@templates_app.command("delete")
def templates_delete(
model: Annotated[str, typer.Argument(help="Model stem")],
name: Annotated[str, typer.Argument(help="Template name")],
yes: Annotated[bool, typer.Option("--yes", "-y", help="Skip confirmation")] = False,
) -> None:
"""Delete a saved template."""
from tensors.templates import delete_template, template_path # noqa: PLC0415
path = template_path(model, name)
if not path.is_file():
console.print(f"[yellow]Template not found: {path}[/yellow]")
raise typer.Exit(1)
if not yes:
typer.confirm(f"Delete {path}?", abort=True)
if delete_template(model, name):
console.print(f"[green]Deleted:[/green] {path}")
# =============================================================================
# ComfyUI Commands
# =============================================================================
+281
View File
@@ -0,0 +1,281 @@
"""Template library: full generation configs derived from models + scenes.
A *template* is a complete `tsr generate --input` payload for a specific model
and prompt: dimensions, sampler, scheduler, steps, cfg, guidance, vae, and the
scene/character lists. Templates extend scenes (which are just lists of prompt
elements) with all the family-resolved generation params, so they can be fed
straight to `tsr generate` without re-resolving anything.
Each template lives at
``~/.local/share/tensors/templates/<model_stem>/<name>.json``
and uses the same JSON shape as ``tsr template -m <model>`` standalone output,
so the on-disk format and the ad-hoc one-shot template format are identical.
Module-level :data:`TEMPLATES_DIR` is read via :func:`globals` on every call so
tests can monkeypatch it without re-importing.
"""
from __future__ import annotations
import json
import re
from collections import Counter
from pathlib import Path # noqa: TC003 # used in runtime return annotations exposed to typer
from typing import Any
from tensors.config import DATA_DIR
__all__ = [
"META_KEY_MAP",
"SAMPLER_NORMALIZE",
"SCHEDULER_NORMALIZE",
"TEMPLATES_DIR",
"build_template",
"delete_template",
"derive_overrides_from_images",
"list_templates",
"load_template",
"param_from_civitai_meta",
"save_template",
"template_dir_for",
"template_path",
]
# Default storage location. Tests may monkeypatch this; every helper below
# dereferences via globals() so overrides are picked up live.
TEMPLATES_DIR = DATA_DIR / "templates"
# Restrict template + model names to the same safe subset used by FragmentLibrary
# so they can't escape the storage dir via path traversal.
_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
# CivitAI A1111-style image meta → tsr template key mapping.
# Each entry maps a source key in the `meta` dict of a CivitAI image to a
# (tsr_key, converter) pair. The converter is either a callable (applied to the
# raw value) or a literal sentinel string ("sampler" / "scheduler") that
# triggers the corresponding normalize-and-translate path below.
META_KEY_MAP: dict[str, tuple[str, Any]] = {
"sampler": ("sampler", "sampler"),
"Schedule type": ("scheduler", "scheduler"),
"steps": ("steps", int),
"cfgScale": ("cfg", float),
"Distilled CFG Scale": ("guidance", float),
}
# A1111 / CivitAI sampler labels → ComfyUI canonical sampler names.
# Lookup is case-folded; unknown labels fall through with whitespace replaced
# by underscores (so e.g. "DPM++ 2M Karras" we don't know about still passes
# through as "dpm++_2m_karras" — wrong but loud).
SAMPLER_NORMALIZE: dict[str, str] = {
"euler": "euler",
"euler a": "euler_ancestral",
"euler ancestral": "euler_ancestral",
"dpm++ 2m": "dpmpp_2m",
"dpm++ 2m karras": "dpmpp_2m",
"dpm++ 2m sde": "dpmpp_2m_sde",
"dpm++ 2m sde karras": "dpmpp_2m_sde",
"dpm++ 3m sde": "dpmpp_3m_sde",
"dpm++ 3m sde karras": "dpmpp_3m_sde",
"dpm++ sde": "dpmpp_sde",
"dpm++ sde karras": "dpmpp_sde",
"dpm++ 2s a": "dpmpp_2s_ancestral",
"dpm++ 2s ancestral": "dpmpp_2s_ancestral",
"heun": "heun",
"ddim": "ddim",
"lms": "lms",
"unipc": "uni_pc",
"uni_pc": "uni_pc",
"lcm": "lcm",
"dpmpp_2m": "dpmpp_2m",
"dpmpp_2m_sde": "dpmpp_2m_sde",
}
SCHEDULER_NORMALIZE: dict[str, str] = {
"simple": "simple",
"normal": "normal",
"karras": "karras",
"sgm uniform": "sgm_uniform",
"sgm_uniform": "sgm_uniform",
"beta": "beta",
"ddim_uniform": "ddim_uniform",
"exponential": "exponential",
}
def _validate_name(name: str, kind: str = "template") -> None:
if not name or not _NAME_RE.match(name):
raise ValueError(f"Invalid {kind} name {name!r}: only letters, digits, '.', '_', '-' allowed")
def _root() -> Path:
"""Return the live TEMPLATES_DIR (allows tests to monkeypatch the module attr)."""
root: Path = globals()["TEMPLATES_DIR"]
return root
def template_dir_for(model_stem: str) -> Path:
"""Return the per-model template directory (without ensuring existence)."""
_validate_name(model_stem, "model")
return _root() / model_stem
def template_path(model_stem: str, name: str) -> Path:
"""Return the on-disk path for a template (without ensuring existence)."""
_validate_name(name)
return template_dir_for(model_stem) / f"{name}.json"
def save_template(model_stem: str, name: str, data: dict[str, Any]) -> Path:
"""Persist a template dict to disk as JSON and return its path."""
path = template_path(model_stem, name)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(data, indent=2, ensure_ascii=False) + "\n")
return path
def load_template(model_stem: str, name: str) -> dict[str, Any]:
"""Load a template dict. Raises FileNotFoundError if missing."""
path = template_path(model_stem, name)
if not path.is_file():
raise FileNotFoundError(f"Template {model_stem}/{name!r} not found at {path}")
data: dict[str, Any] = json.loads(path.read_text())
return data
def list_templates(model_stem: str | None = None) -> list[tuple[str, str]]:
"""List saved templates as (model_stem, template_name) pairs, sorted.
With ``model_stem`` set, restrict to that one model's subdirectory.
"""
root = _root()
if not root.is_dir():
return []
if model_stem is not None:
_validate_name(model_stem, "model")
sub = root / model_stem
if not sub.is_dir():
return []
return [(model_stem, p.stem) for p in sorted(sub.glob("*.json")) if p.is_file()]
out: list[tuple[str, str]] = []
for d in sorted(p for p in root.iterdir() if p.is_dir()):
for p in sorted(d.glob("*.json")):
if p.is_file():
out.append((d.name, p.stem))
return out
def delete_template(model_stem: str, name: str) -> bool:
"""Delete a template file. Returns True on success, False if missing."""
path = template_path(model_stem, name)
if not path.is_file():
return False
path.unlink()
return True
def param_from_civitai_meta(meta: dict[str, Any]) -> dict[str, Any]:
"""Extract tsr-canonical generation params from a CivitAI image meta dict.
Returns only keys that were present in the input and successfully converted.
Unknown sampler / scheduler labels are still surfaced (with whitespace →
underscores) rather than silently dropped — the calling layer can decide
whether to use or ignore them.
"""
out: dict[str, Any] = {}
for src_key, (dst_key, converter) in META_KEY_MAP.items():
if src_key not in meta:
continue
raw = meta[src_key]
try:
if converter == "sampler":
normalized = str(raw).strip().lower()
out[dst_key] = SAMPLER_NORMALIZE.get(normalized, normalized.replace(" ", "_"))
elif converter == "scheduler":
normalized = str(raw).strip().lower()
out[dst_key] = SCHEDULER_NORMALIZE.get(normalized, normalized.replace(" ", "_"))
else:
if isinstance(raw, str):
# CivitAI emits "Distilled CFG Scale" as a string with comma
# or dot decimal separator depending on locale; normalize.
raw = raw.strip().replace(",", ".")
out[dst_key] = converter(raw)
except (ValueError, TypeError):
# Conversion failed for this image; skip the key, keep going.
continue
return out
def build_template(
*,
model_filename: str,
family: str | None,
defaults: dict[str, Any],
base_model_str: str | None,
width: int,
height: int,
orientation: str,
scene_elements: list[str],
scene_name: str,
overrides: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Assemble a complete generation template dict.
Shape mirrors ``tsr template -m <model>`` output exactly so the result is
a drop-in for ``tsr generate --input``. Showcase-derived ``overrides`` win
over family-resolved ``defaults``.
"""
tpl: dict[str, Any] = {
"prompt": "",
"negative_prompt": defaults.get("negative_prompt", ""),
"model": model_filename,
"width": width,
"height": height,
"steps": defaults.get("steps"),
"cfg": defaults.get("cfg"),
"sampler": defaults.get("sampler"),
"scheduler": defaults.get("scheduler"),
"vae": defaults.get("vae"),
"orientation": orientation,
"seed": -1,
"count": 1,
}
# Flux models carry an explicit guidance dial; default to tsr's own 3.5
# when no override supplies one.
if (family or "").startswith("flux"):
tpl["guidance"] = 3.5
quality_prefix = defaults.get("quality_prefix", "")
if quality_prefix:
tpl["quality_prefix"] = quality_prefix
# Showcase-derived overrides win over family defaults.
if overrides:
for k, v in overrides.items():
tpl[k] = v
tpl["scene"] = scene_elements
tpl["_scene_name"] = scene_name
tpl["_family"] = family or "unknown"
if base_model_str:
tpl["_base_model"] = base_model_str
return tpl
def derive_overrides_from_images(images: list[dict[str, Any]]) -> dict[str, Any]:
"""Mode-of-each-param across showcase images that carry generation meta.
Returns a dict of ``{tsr_key: most_common_value}`` suitable for merging
on top of a base template. Skips images without a ``meta`` dict and any
image meta whose param-extraction yielded nothing.
"""
counters: dict[str, Counter[Any]] = {}
for img in images:
meta = img.get("meta") or {}
if not meta:
continue
params = param_from_civitai_meta(meta)
for k, v in params.items():
counters.setdefault(k, Counter())[v] += 1
overrides: dict[str, Any] = {}
for k, ctr in counters.items():
if ctr:
overrides[k] = ctr.most_common(1)[0][0]
return overrides