From aa25d31ca941d64016040b23d44d2b7a571c3dbb Mon Sep 17 00:00:00 2001 From: aladac Date: Mon, 18 May 2026 21:04:21 +0200 Subject: [PATCH] feat(templates): add bulk template extraction from CivitAI showcase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the `tsr templates` subapp with extract / list / show / delete commands. The headline command is: tsr templates extract [--generate] which pulls a model's CivitAI showcase, deduplicates prompts, derives recommended generation params (sampler / scheduler / steps / cfg / guidance) from the *mode* of the showcase image metadata, and writes one JSON template per unique prompt to ~/.local/share/tensors/templates//.json. Each emitted template has the same shape as `tsr template -m ` output and feeds straight into `tsr generate --input`. `--generate` chains the generation step end-to-end. New module `tensors/templates.py` carries the storage + extraction helpers (`save_template`, `load_template`, `list_templates`, `build_template`, `param_from_civitai_meta`, `derive_overrides_from_images`) plus the CivitAI A1111 → ComfyUI sampler/scheduler name normalization tables. Replaces the external generate_templates.py wrapper script that was maintaining a hand-curated MODEL_OVERRIDES dict — overrides now derive automatically from each model's own showcase data. --- tensors/cli.py | 256 +++++++++++++++++++++++++++++++++++++++ tensors/templates.py | 281 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 537 insertions(+) create mode 100644 tensors/templates.py diff --git a/tensors/cli.py b/tensors/cli.py index a342d19..81215c6 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -2909,6 +2909,262 @@ def scene_delete( raise typer.Exit(1) +# ---- templates ---- + +templates_app = typer.Typer( + name="templates", + help="Bulk-extract, list, and run generation templates derived from CivitAI showcase data.", + no_args_is_help=True, +) +app.add_typer(templates_app) + + +@templates_app.command("extract") +def templates_extract( + model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")], + orientation: Annotated[ + str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape") + ] = "portrait", + no_overrides: Annotated[ + bool, + typer.Option( + "--no-overrides", + help="Skip auto-derived params from showcase image meta; use family defaults only", + ), + ] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + limit: Annotated[ + int, typer.Option("--limit", "-L", help="Max templates to write (0 = all unique prompts)") + ] = 0, + overwrite: Annotated[ + bool, typer.Option("--overwrite", help="Overwrite existing template files (default: skip)") + ] = False, + do_generate: Annotated[ + bool, + typer.Option("--generate", help="After writing, run `tsr generate --input` for each emitted template"), + ] = False, + output_dir: Annotated[ + Path | None, + typer.Option( + "--output-dir", help="Where to write generated images when --generate (default: ComfyUI output dir)" + ), + ] = None, + dry_run: Annotated[ + bool, typer.Option("--dry-run", help="Print what would be done; write nothing") + ] = False, +) -> None: + """Bulk-extract templates from a model's CivitAI showcase. + + Pulls showcase images, deduplicates prompts, derives recommended generation + params (sampler / scheduler / steps / cfg / guidance) from the *mode* of the + showcase image metadata, and writes one JSON template per unique prompt to + ``~/.local/share/tensors/templates//.json``. + + Each emitted template is ready to feed straight to ``tsr generate --input``. + + Examples: + tsr templates extract lust_v10.safetensors + tsr templates extract bodySliderFitness_v10 -O portrait --generate + tsr templates extract ultrasenseInfinity_v10 --dry-run + tsr templates extract getphat_v5 --no-overrides # use tsr family defaults only + """ + import subprocess # noqa: PLC0415 + + from tensors.api import fetch_civitai_model_version # noqa: PLC0415 + from tensors.config import ( # noqa: PLC0415 + detect_model_family, + get_model_generation_defaults, + load_api_key, + resolve_orientation, + ) + from tensors.fragments import parse_elements # noqa: PLC0415 + from tensors.templates import ( # noqa: PLC0415 + build_template, + derive_overrides_from_images, + save_template, + template_path, + ) + + with Database() as db: + files = db.list_local_files() + + target_file = None + for f in files: + file_path = Path(f["file_path"]) + if model in (file_path.name, file_path.stem): + target_file = f + break + + if not target_file: + console.print(f"[red]Model '{model}' not found in local database. Run 'tsr db scan' first.[/red]") + raise typer.Exit(1) + + vid = target_file["civitai_version_id"] + if not vid: + console.print(f"[red]Model '{model}' is not linked to CivitAI. Run 'tsr db link' first.[/red]") + raise typer.Exit(1) + + model_stem = Path(target_file["file_path"]).stem + model_filename = Path(target_file["file_path"]).name + base_model_str = target_file.get("base_model") + + console.print(f"[cyan]Fetching showcase for {model_stem} (version {vid})...[/cyan]") + data = fetch_civitai_model_version(vid, api_key or load_api_key(), console=console) + if not data: + console.print("[red]Failed to fetch CivitAI data.[/red]") + raise typer.Exit(1) + + images = data.get("images", []) + enriched = sum(1 for img in images if img.get("meta")) + + overrides = {} if no_overrides else derive_overrides_from_images(images) + if overrides: + console.print(f"[cyan]Derived overrides from {enriched} enriched image(s):[/cyan] {overrides}") + elif not no_overrides: + console.print("[yellow]No usable param meta in showcase; using family defaults only.[/yellow]") + + family = detect_model_family(model_filename, base_model_str) + defaults = get_model_generation_defaults(model_filename, base_model_str) + res_w, res_h = resolve_orientation(family, orientation) + + seen_prompts: set[str] = set() + emitted: list[Path] = [] + skipped_existing = 0 + skipped_no_prompt = 0 + + for img in images: + meta = img.get("meta") or {} + prompt = meta.get("prompt") + if not prompt: + skipped_no_prompt += 1 + continue + normalized = prompt.lower().strip() + if normalized in seen_prompts: + continue + seen_prompts.add(normalized) + + scene_elements = parse_elements(prompt) + if not scene_elements: + continue + + idx = len(emitted) + skipped_existing + 1 + name = f"{model_stem}_{idx:02d}" + out_path = template_path(model_stem, name) + + if out_path.is_file() and not overwrite: + console.print(f"[yellow]Skip (exists, use --overwrite to replace):[/yellow] {out_path}") + skipped_existing += 1 + continue + + tpl = build_template( + model_filename=model_filename, + family=family, + defaults=defaults, + base_model_str=base_model_str, + width=res_w, + height=res_h, + orientation=orientation, + scene_elements=scene_elements, + scene_name=name, + overrides=overrides, + ) + + if dry_run: + console.print(f"[dim](dry-run) Would write:[/dim] {out_path}") + emitted.append(out_path) + else: + saved = save_template(model_stem, name, tpl) + console.print(f"[green]Saved ({len(scene_elements)} scene elements):[/green] {saved}") + emitted.append(saved) + + if limit and len(emitted) >= limit: + break + + console.print( + f"\n[bold]Extract summary:[/bold] emitted={len(emitted)} skipped_existing={skipped_existing} " + f"images_no_prompt={skipped_no_prompt}" + ) + + if not emitted: + return + + if do_generate: + if dry_run: + console.print("[yellow]--dry-run is set; skipping --generate phase.[/yellow]") + return + console.print(f"\n[cyan]Generating images for {len(emitted)} template(s)...[/cyan]") + for tpl_path in emitted: + out_arg = [] + if output_dir: + output_dir.mkdir(parents=True, exist_ok=True) + out_arg = ["-o", str(output_dir / f"{tpl_path.stem}.png")] + cmd = ["tsr", "generate", "--input", str(tpl_path), *out_arg] + console.print(f"\n[cyan]$ {' '.join(cmd)}[/cyan]") + subprocess.run(cmd, check=False) + + +@templates_app.command("list") +def templates_list( + model: Annotated[str | None, typer.Argument(help="Filter by model stem (optional)")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """List saved templates, grouped by model.""" + from tensors.templates import TEMPLATES_DIR, list_templates # noqa: PLC0415 + + items = list_templates(model) + if json_output: + console.print_json( + data={"dir": str(TEMPLATES_DIR), "templates": [{"model": m, "name": n} for m, n in items]} + ) + return + if not items: + scope = f" for model '{model}'" if model else "" + console.print(f"[yellow]No templates{scope} in {TEMPLATES_DIR}.[/yellow]") + console.print("[dim]Create some with: tsr templates extract [/dim]") + return + cur_model = None + for m, n in items: + if m != cur_model: + console.print(f"\n[cyan]{m}[/cyan]") + cur_model = m + console.print(f" {n}") + + +@templates_app.command("show") +def templates_show( + model: Annotated[str, typer.Argument(help="Model stem (directory name under templates/)")], + name: Annotated[str, typer.Argument(help="Template name (filename without .json)")], +) -> None: + """Print a saved template as JSON.""" + from tensors.templates import load_template # noqa: PLC0415 + + try: + data = load_template(model, name) + except FileNotFoundError as e: + console.print(f"[red]{e}[/red]") + raise typer.Exit(1) from e + console.print_json(data=data) + + +@templates_app.command("delete") +def templates_delete( + model: Annotated[str, typer.Argument(help="Model stem")], + name: Annotated[str, typer.Argument(help="Template name")], + yes: Annotated[bool, typer.Option("--yes", "-y", help="Skip confirmation")] = False, +) -> None: + """Delete a saved template.""" + from tensors.templates import delete_template, template_path # noqa: PLC0415 + + path = template_path(model, name) + if not path.is_file(): + console.print(f"[yellow]Template not found: {path}[/yellow]") + raise typer.Exit(1) + if not yes: + typer.confirm(f"Delete {path}?", abort=True) + if delete_template(model, name): + console.print(f"[green]Deleted:[/green] {path}") + + # ============================================================================= # ComfyUI Commands # ============================================================================= diff --git a/tensors/templates.py b/tensors/templates.py new file mode 100644 index 0000000..4759d50 --- /dev/null +++ b/tensors/templates.py @@ -0,0 +1,281 @@ +"""Template library: full generation configs derived from models + scenes. + +A *template* is a complete `tsr generate --input` payload for a specific model +and prompt: dimensions, sampler, scheduler, steps, cfg, guidance, vae, and the +scene/character lists. Templates extend scenes (which are just lists of prompt +elements) with all the family-resolved generation params, so they can be fed +straight to `tsr generate` without re-resolving anything. + +Each template lives at +``~/.local/share/tensors/templates//.json`` +and uses the same JSON shape as ``tsr template -m `` standalone output, +so the on-disk format and the ad-hoc one-shot template format are identical. + +Module-level :data:`TEMPLATES_DIR` is read via :func:`globals` on every call so +tests can monkeypatch it without re-importing. +""" + +from __future__ import annotations + +import json +import re +from collections import Counter +from pathlib import Path # noqa: TC003 # used in runtime return annotations exposed to typer +from typing import Any + +from tensors.config import DATA_DIR + +__all__ = [ + "META_KEY_MAP", + "SAMPLER_NORMALIZE", + "SCHEDULER_NORMALIZE", + "TEMPLATES_DIR", + "build_template", + "delete_template", + "derive_overrides_from_images", + "list_templates", + "load_template", + "param_from_civitai_meta", + "save_template", + "template_dir_for", + "template_path", +] + +# Default storage location. Tests may monkeypatch this; every helper below +# dereferences via globals() so overrides are picked up live. +TEMPLATES_DIR = DATA_DIR / "templates" + +# Restrict template + model names to the same safe subset used by FragmentLibrary +# so they can't escape the storage dir via path traversal. +_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$") + +# CivitAI A1111-style image meta → tsr template key mapping. +# Each entry maps a source key in the `meta` dict of a CivitAI image to a +# (tsr_key, converter) pair. The converter is either a callable (applied to the +# raw value) or a literal sentinel string ("sampler" / "scheduler") that +# triggers the corresponding normalize-and-translate path below. +META_KEY_MAP: dict[str, tuple[str, Any]] = { + "sampler": ("sampler", "sampler"), + "Schedule type": ("scheduler", "scheduler"), + "steps": ("steps", int), + "cfgScale": ("cfg", float), + "Distilled CFG Scale": ("guidance", float), +} + +# A1111 / CivitAI sampler labels → ComfyUI canonical sampler names. +# Lookup is case-folded; unknown labels fall through with whitespace replaced +# by underscores (so e.g. "DPM++ 2M Karras" we don't know about still passes +# through as "dpm++_2m_karras" — wrong but loud). +SAMPLER_NORMALIZE: dict[str, str] = { + "euler": "euler", + "euler a": "euler_ancestral", + "euler ancestral": "euler_ancestral", + "dpm++ 2m": "dpmpp_2m", + "dpm++ 2m karras": "dpmpp_2m", + "dpm++ 2m sde": "dpmpp_2m_sde", + "dpm++ 2m sde karras": "dpmpp_2m_sde", + "dpm++ 3m sde": "dpmpp_3m_sde", + "dpm++ 3m sde karras": "dpmpp_3m_sde", + "dpm++ sde": "dpmpp_sde", + "dpm++ sde karras": "dpmpp_sde", + "dpm++ 2s a": "dpmpp_2s_ancestral", + "dpm++ 2s ancestral": "dpmpp_2s_ancestral", + "heun": "heun", + "ddim": "ddim", + "lms": "lms", + "unipc": "uni_pc", + "uni_pc": "uni_pc", + "lcm": "lcm", + "dpmpp_2m": "dpmpp_2m", + "dpmpp_2m_sde": "dpmpp_2m_sde", +} + +SCHEDULER_NORMALIZE: dict[str, str] = { + "simple": "simple", + "normal": "normal", + "karras": "karras", + "sgm uniform": "sgm_uniform", + "sgm_uniform": "sgm_uniform", + "beta": "beta", + "ddim_uniform": "ddim_uniform", + "exponential": "exponential", +} + + +def _validate_name(name: str, kind: str = "template") -> None: + if not name or not _NAME_RE.match(name): + raise ValueError(f"Invalid {kind} name {name!r}: only letters, digits, '.', '_', '-' allowed") + + +def _root() -> Path: + """Return the live TEMPLATES_DIR (allows tests to monkeypatch the module attr).""" + root: Path = globals()["TEMPLATES_DIR"] + return root + + +def template_dir_for(model_stem: str) -> Path: + """Return the per-model template directory (without ensuring existence).""" + _validate_name(model_stem, "model") + return _root() / model_stem + + +def template_path(model_stem: str, name: str) -> Path: + """Return the on-disk path for a template (without ensuring existence).""" + _validate_name(name) + return template_dir_for(model_stem) / f"{name}.json" + + +def save_template(model_stem: str, name: str, data: dict[str, Any]) -> Path: + """Persist a template dict to disk as JSON and return its path.""" + path = template_path(model_stem, name) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2, ensure_ascii=False) + "\n") + return path + + +def load_template(model_stem: str, name: str) -> dict[str, Any]: + """Load a template dict. Raises FileNotFoundError if missing.""" + path = template_path(model_stem, name) + if not path.is_file(): + raise FileNotFoundError(f"Template {model_stem}/{name!r} not found at {path}") + data: dict[str, Any] = json.loads(path.read_text()) + return data + + +def list_templates(model_stem: str | None = None) -> list[tuple[str, str]]: + """List saved templates as (model_stem, template_name) pairs, sorted. + + With ``model_stem`` set, restrict to that one model's subdirectory. + """ + root = _root() + if not root.is_dir(): + return [] + if model_stem is not None: + _validate_name(model_stem, "model") + sub = root / model_stem + if not sub.is_dir(): + return [] + return [(model_stem, p.stem) for p in sorted(sub.glob("*.json")) if p.is_file()] + out: list[tuple[str, str]] = [] + for d in sorted(p for p in root.iterdir() if p.is_dir()): + for p in sorted(d.glob("*.json")): + if p.is_file(): + out.append((d.name, p.stem)) + return out + + +def delete_template(model_stem: str, name: str) -> bool: + """Delete a template file. Returns True on success, False if missing.""" + path = template_path(model_stem, name) + if not path.is_file(): + return False + path.unlink() + return True + + +def param_from_civitai_meta(meta: dict[str, Any]) -> dict[str, Any]: + """Extract tsr-canonical generation params from a CivitAI image meta dict. + + Returns only keys that were present in the input and successfully converted. + Unknown sampler / scheduler labels are still surfaced (with whitespace → + underscores) rather than silently dropped — the calling layer can decide + whether to use or ignore them. + """ + out: dict[str, Any] = {} + for src_key, (dst_key, converter) in META_KEY_MAP.items(): + if src_key not in meta: + continue + raw = meta[src_key] + try: + if converter == "sampler": + normalized = str(raw).strip().lower() + out[dst_key] = SAMPLER_NORMALIZE.get(normalized, normalized.replace(" ", "_")) + elif converter == "scheduler": + normalized = str(raw).strip().lower() + out[dst_key] = SCHEDULER_NORMALIZE.get(normalized, normalized.replace(" ", "_")) + else: + if isinstance(raw, str): + # CivitAI emits "Distilled CFG Scale" as a string with comma + # or dot decimal separator depending on locale; normalize. + raw = raw.strip().replace(",", ".") + out[dst_key] = converter(raw) + except (ValueError, TypeError): + # Conversion failed for this image; skip the key, keep going. + continue + return out + + +def build_template( + *, + model_filename: str, + family: str | None, + defaults: dict[str, Any], + base_model_str: str | None, + width: int, + height: int, + orientation: str, + scene_elements: list[str], + scene_name: str, + overrides: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Assemble a complete generation template dict. + + Shape mirrors ``tsr template -m `` output exactly so the result is + a drop-in for ``tsr generate --input``. Showcase-derived ``overrides`` win + over family-resolved ``defaults``. + """ + tpl: dict[str, Any] = { + "prompt": "", + "negative_prompt": defaults.get("negative_prompt", ""), + "model": model_filename, + "width": width, + "height": height, + "steps": defaults.get("steps"), + "cfg": defaults.get("cfg"), + "sampler": defaults.get("sampler"), + "scheduler": defaults.get("scheduler"), + "vae": defaults.get("vae"), + "orientation": orientation, + "seed": -1, + "count": 1, + } + # Flux models carry an explicit guidance dial; default to tsr's own 3.5 + # when no override supplies one. + if (family or "").startswith("flux"): + tpl["guidance"] = 3.5 + quality_prefix = defaults.get("quality_prefix", "") + if quality_prefix: + tpl["quality_prefix"] = quality_prefix + # Showcase-derived overrides win over family defaults. + if overrides: + for k, v in overrides.items(): + tpl[k] = v + tpl["scene"] = scene_elements + tpl["_scene_name"] = scene_name + tpl["_family"] = family or "unknown" + if base_model_str: + tpl["_base_model"] = base_model_str + return tpl + + +def derive_overrides_from_images(images: list[dict[str, Any]]) -> dict[str, Any]: + """Mode-of-each-param across showcase images that carry generation meta. + + Returns a dict of ``{tsr_key: most_common_value}`` suitable for merging + on top of a base template. Skips images without a ``meta`` dict and any + image meta whose param-extraction yielded nothing. + """ + counters: dict[str, Counter[Any]] = {} + for img in images: + meta = img.get("meta") or {} + if not meta: + continue + params = param_from_civitai_meta(meta) + for k, v in params.items(): + counters.setdefault(k, Counter())[v] += 1 + + overrides: dict[str, Any] = {} + for k, ctr in counters.items(): + if ctr: + overrides[k] = ctr.most_common(1)[0][0] + return overrides