"""CLI application and commands for tsr.""" from __future__ import annotations import json import sys from importlib.metadata import version from pathlib import Path from typing import Annotated, Any import click import typer from rich.console import Console from rich.table import Table from tensors.api import ( download_model, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version, search_civitai, ) from tensors.config import ( COMFYUI_DEFAULT_CFG, COMFYUI_DEFAULT_HEIGHT, COMFYUI_DEFAULT_SAMPLER, COMFYUI_DEFAULT_SCHEDULER, COMFYUI_DEFAULT_STEPS, COMFYUI_DEFAULT_WIDTH, CONFIG_FILE, MODEL_FAMILY_DEFAULTS, VALID_PATH_TYPES, BaseModel, CommercialUse, ModelType, NsfwLevel, Period, Provider, SortOrder, detect_model_family, get_default_output_path, get_model_paths, load_api_key, load_config, save_config, ) from tensors.db import DB_PATH, Database from tensors.display import ( _format_size, display_civitai_data, display_file_info, display_hf_model_info, display_hf_search_results, display_local_metadata, display_model_info, display_search_results, ) from tensors.hf import ( download_all_safetensors, download_hf_safetensor, get_hf_model, list_safetensor_files, search_hf_models, ) from tensors.remote import ( remote_download, remote_download_status, remote_generate, remote_get_image, remote_models, remote_search, ) from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata # Key masking threshold MIN_KEY_LENGTH_FOR_MASKING = 8 # Display truncation limits MAX_QUEUE_DISPLAY = 10 MAX_MODEL_LIST_DISPLAY = 20 MAX_PROMPT_ID_DISPLAY = 36 # File extensions that force YAML parsing for `tsr generate --input `. _YAML_INPUT_EXTENSIONS = frozenset({".yml", ".yaml"}) def _parse_generate_input(value: str) -> dict[str, Any]: """Parse a ``--input`` argument into a dict of generation params. Accepts either: * a path to a ``.json`` / ``.yml`` / ``.yaml`` file, * a raw JSON object string (``{"prompt": ...}``), * a raw YAML document string (anything else that doesn't start with ``{``). File extension wins when reading from disk. For inline strings we try JSON first (current behaviour) and fall back to YAML so existing callers keep working without surprises. Raises ``typer.Exit(1)`` with a rich error on every failure path so callers don't need to repeat the diagnostics. """ import yaml # noqa: PLC0415 — keep yaml a soft import path # ---- locate source text + decide format ---- path = Path(value) if path.is_file(): text = path.read_text() suffix = path.suffix.lower() if suffix in _YAML_INPUT_EXTENSIONS: fmt = "yaml" elif suffix == ".json": fmt = "json" else: # Unknown extension: sniff the content. Leading '{' or '[' → JSON. fmt = "json" if text.lstrip().startswith(("{", "[")) else "yaml" elif value.lstrip().startswith("{"): text = value fmt = "json" else: # Last-resort: treat as inline YAML. This is unusual but lets the user # pass ``--input 'prompt: foo\nmodel: bar.safetensors'`` without quoting # a JSON object on the shell. text = value fmt = "yaml" # ---- parse ---- parsed: Any try: parsed = json.loads(text) if fmt == "json" else yaml.safe_load(text) except json.JSONDecodeError as e: console.print(f"[red]Invalid JSON input:[/red] {e}") raise typer.Exit(1) from e except yaml.YAMLError as e: console.print(f"[red]Invalid YAML input:[/red] {e}") raise typer.Exit(1) from e if not isinstance(parsed, dict): console.print(f"[red]{fmt.upper()} input must be a mapping/object[/red]") raise typer.Exit(1) return parsed def _cache_model_quietly(model_data: dict[str, Any]) -> None: """Cache model data to database without output.""" try: with Database() as db: db.init_schema() db.cache_model(model_data) except Exception: pass # Silently ignore cache failures def _cache_models_quietly(models: list[dict[str, Any]]) -> None: """Cache multiple models to database without output.""" if not models: return try: with Database() as db: db.init_schema() for model_data in models: db.cache_model(model_data) except Exception: pass # Silently ignore cache failures def _version_callback(value: bool) -> None: if value: print(f"tsr {version('tensors')}") raise typer.Exit app = typer.Typer( name="tsr", help="Read safetensor metadata, search and download CivitAI models.", no_args_is_help=True, ) @app.callback() def _main( _version: Annotated[ bool, typer.Option("--version", "-V", callback=_version_callback, is_eager=True, help="Show version and exit."), ] = False, ) -> None: """Read safetensor metadata, search and download CivitAI models.""" console = Console() @app.command() def info( file: Annotated[Path, typer.Argument(help="Path to the safetensor file")], meta: Annotated[list[str] | None, typer.Option("--meta", "-m", help="Show specific metadata key(s) in full")] = None, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, skip_civitai: Annotated[bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, save_to: Annotated[Path | None, typer.Option("--save-to", help="Save metadata to directory")] = None, ) -> None: """Read safetensor metadata and fetch CivitAI info.""" file_path = file.resolve() if not file_path.exists(): console.print(f"[red]Error: File not found: {file_path}[/red]") raise typer.Exit(1) if file_path.suffix.lower() not in (".safetensors", ".sft"): console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]") try: local_metadata = read_safetensor_metadata(file_path) if meta: display_local_metadata(local_metadata, console, keys_filter=meta) return console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}") sha256_hash = compute_sha256(file_path, console) civitai_data = None if not skip_civitai: key = api_key or load_api_key() civitai_data = fetch_civitai_by_hash(sha256_hash, key, console) if json_output: _output_info_json(file_path, sha256_hash, local_metadata, civitai_data) else: display_file_info(file_path, local_metadata, sha256_hash, console) display_local_metadata(local_metadata, console) display_civitai_data(civitai_data, console) if save_to: _save_metadata(save_to, file_path, sha256_hash, local_metadata, civitai_data) except ValueError as e: console.print(f"[red]Error reading safetensor: {e}[/red]") raise typer.Exit(1) from e def _output_info_json( file_path: Path, sha256_hash: str, local_metadata: dict[str, Any], civitai_data: dict[str, Any] | None, ) -> None: """Output info command result as JSON.""" output = { "file": str(file_path), "sha256": sha256_hash, "header_size": local_metadata["header_size"], "tensor_count": local_metadata["tensor_count"], "metadata": local_metadata["metadata"], "civitai": civitai_data, } console.print_json(data=output) def _save_metadata( save_to: Path, file_path: Path, sha256_hash: str, local_metadata: dict[str, Any], civitai_data: dict[str, Any] | None, ) -> None: """Save metadata to directory.""" output_dir = save_to.resolve() if not output_dir.exists() or not output_dir.is_dir(): console.print(f"[red]Error: Invalid directory: {output_dir}[/red]") raise typer.Exit(1) base_name = get_base_name(file_path) json_path = output_dir / f"{base_name}.json" sha_path = output_dir / f"{base_name}.sha256" output = { "file": str(file_path), "sha256": sha256_hash, "header_size": local_metadata["header_size"], "tensor_count": local_metadata["tensor_count"], "metadata": local_metadata["metadata"], "civitai": civitai_data, } json_path.write_text(json.dumps(output, indent=2)) sha_path.write_text(f"{sha256_hash} {file_path.name}\n") console.print() console.print(f"[green]Saved:[/green] {json_path}") console.print(f"[green]Saved:[/green] {sha_path}") @app.command() def search( query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None, provider: Annotated[Provider, typer.Option("--provider", "-P", help="Search provider")] = Provider.all, model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter (CivitAI)")] = None, base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter (CivitAI)")] = None, sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads, limit: Annotated[int, typer.Option("-n", "--limit", help="Max results per provider")] = 20, period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period (CivitAI)")] = None, tag: Annotated[str | None, typer.Option("--tag", help="Filter by tag")] = None, username: Annotated[str | None, typer.Option("-u", "--user", "-a", "--author", help="Filter by creator/author")] = None, page: Annotated[int | None, typer.Option("--page", help="Page number (CivitAI)")] = None, nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level (CivitAI)")] = None, sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content (CivitAI)")] = False, commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter (CivitAI)")] = None, pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, ) -> None: """Search models on CivitAI and/or Hugging Face. Examples: tsr search "flux" # Search both providers tsr search "anime" -P civitai # CivitAI only tsr search "llama" -P hf # Hugging Face only tsr search -t lora -b pony # CivitAI LoRAs for Pony tsr search -a stabilityai -P hf # HF by author tsr search --sfw -P civitai # CivitAI SFW only tsr search "pony" --remote junkpile # Search via remote server """ # Remote mode: delegate to remote tensors server if remote: civitai_results = remote_search( remote, query=query, model_type=model_type.to_api() if model_type else None, base_model=base.to_api() if base else None, sort=sort.value, limit=limit, page=page, nsfw=nsfw.value if nsfw else None, sfw=sfw, console=console, ) if not civitai_results: console.print("[red]Remote search failed.[/red]") raise typer.Exit(1) if json_output: console.print_json(data={"civitai": civitai_results}) else: display_search_results(civitai_results, console) return key = api_key or load_api_key() civitai_results: dict[str, Any] | None = None hf_results: list[dict[str, Any]] | None = None # Search CivitAI if provider in (Provider.civitai, Provider.all): nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw civitai_results = search_civitai( query=query, model_type=model_type, base_model=base, sort=sort, limit=limit, api_key=key, console=console if provider == Provider.civitai else None, period=period, nsfw=nsfw_filter, tag=tag, username=username, page=page, commercial_use=commercial, ) if civitai_results: _cache_models_quietly(civitai_results.get("items", [])) # Search Hugging Face if provider in (Provider.hf, Provider.all): tags = [tag] if tag else None hf_results = search_hf_models( query=query, author=username, tags=tags, pipeline_tag=pipeline, sort="downloads" if sort == SortOrder.downloads else "likes" if sort == SortOrder.rating else "created_at", limit=limit, console=console if provider == Provider.hf else None, ) # Output results if json_output: output: dict[str, Any] = {} if civitai_results: output["civitai"] = civitai_results if hf_results: output["huggingface"] = hf_results console.print_json(data=output) return # Display based on provider if provider == Provider.civitai: if not civitai_results: console.print("[red]CivitAI search failed.[/red]") raise typer.Exit(1) display_search_results(civitai_results, console) elif provider == Provider.hf: if hf_results is None: console.print("[red]Hugging Face search failed.[/red]") raise typer.Exit(1) display_hf_search_results(hf_results, console) else: # Both providers if civitai_results and civitai_results.get("items"): console.print("\n[bold cyan]═══ CivitAI Results ═══[/bold cyan]") display_search_results(civitai_results, console) if hf_results: console.print("\n[bold cyan]═══ Hugging Face Results ═══[/bold cyan]") display_hf_search_results(hf_results, console) if not (civitai_results and civitai_results.get("items")) and not hf_results: console.print("[yellow]No results found on either provider.[/yellow]") @app.command() def get( id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")], version: Annotated[bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID")] = False, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, no_cache: Annotated[bool, typer.Option("--no-cache", help="Don't cache to local database")] = False, ) -> None: """Fetch model information from CivitAI by model ID or version ID.""" key = api_key or load_api_key() if version: version_data = fetch_civitai_model_version(id_value, key, console) if not version_data: console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]") raise typer.Exit(1) # Auto-cache version data (need to fetch full model for complete cache) if not no_cache: model_id = version_data.get("modelId") if model_id: model_data = fetch_civitai_model(model_id, key) if model_data: _cache_model_quietly(model_data) if json_output: console.print_json(data=version_data) else: display_civitai_data(version_data, console) else: model_data = fetch_civitai_model(id_value, key, console) if not model_data: console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]") raise typer.Exit(1) # Auto-cache model data if not no_cache: _cache_model_quietly(model_data) if json_output: console.print_json(data=model_data) else: display_model_info(model_data, console) def _resolve_by_hash(hash_val: str, api_key: str | None) -> int | None: """Resolve version ID from SHA256 hash.""" console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]") civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console) if not civitai_data: console.print("[red]Error: Model not found on CivitAI for this hash.[/red]") return None vid: int | None = civitai_data.get("id") if vid: console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}") return vid def _resolve_by_model_id(model_id: int, api_key: str | None) -> int | None: """Resolve latest version ID from model ID.""" console.print(f"[cyan]Looking up model {model_id}...[/cyan]") model_data = fetch_civitai_model(model_id, api_key, console) if not model_data: console.print(f"[red]Error: Model {model_id} not found.[/red]") return None versions = model_data.get("modelVersions", []) if not versions: console.print("[red]Error: Model has no versions.[/red]") return None latest = versions[0] latest_vid: int | None = latest.get("id") if latest_vid: console.print(f"[green]Found latest:[/green] {latest.get('name', 'N/A')} (ID: {latest_vid})") return latest_vid def _resolve_version_id( version_id: int | None, hash_val: str | None, model_id: int | None, api_key: str | None, ) -> int | None: """Resolve version ID from direct ID, hash, or model ID.""" if version_id: return version_id if hash_val: return _resolve_by_hash(hash_val, api_key) if model_id: return _resolve_by_model_id(model_id, api_key) return None def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None: """Prepare output directory for download.""" if output is None: output_dir = get_default_output_path(model_type_str) if output_dir is None: console.print(f"[red]Error: No default path for type '{model_type_str}'. Use --output to specify.[/red]") return None console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]") else: output_dir = output.resolve() if not output_dir.exists(): console.print(f"[cyan]Creating directory: {output_dir}[/cyan]") output_dir.mkdir(parents=True, exist_ok=True) elif not output_dir.is_dir(): console.print(f"[red]Error: Not a directory: {output_dir}[/red]") return None return output_dir def _poll_remote_download(remote_name: str, download_id: str) -> None: """Poll a remote download for completion with a progress bar.""" import time # noqa: PLC0415 from rich.progress import BarColumn, DownloadColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn # noqa: PLC0415 status: dict[str, Any] | None = None with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), DownloadColumn(), console=console, ) as progress: task = progress.add_task("[cyan]Downloading...", total=100) while True: status = remote_download_status(remote_name, download_id) if not status: break dl_status = status.get("status", "") pct = status.get("progress", 0) progress.update(task, completed=pct, description=f"[cyan]{dl_status.title()}...") if dl_status in ("completed", "failed"): break time.sleep(1) if status and status.get("status") == "completed": console.print(f"[green]Download complete:[/green] {status.get('path', 'N/A')}") elif status and status.get("status") == "failed": console.print(f"[red]Download failed:[/red] {status.get('error', 'Unknown error')}") def _download_remote( remote_name: str, version_id: int | None, model_id: int | None, hash_val: str | None, output: Path | None, ) -> None: """Handle remote download flow.""" if not version_id and not model_id: if hash_val: console.print("[yellow]Remote download does not support --hash. Use --version-id or --model-id.[/yellow]") else: console.print("[red]Error: Must specify --version-id or --model-id for remote download[/red]") raise typer.Exit(1) console.print("[dim]Starting download on remote server...[/dim]") result = remote_download( remote_name, version_id=version_id, model_id=model_id, output_dir=str(output) if output else None, console=console, ) if not result: raise typer.Exit(1) console.print(f"[green]Download started:[/green] {result.get('model_name', 'N/A')}") console.print(f"[dim]Version: {result.get('version_name', 'N/A')}[/dim]") console.print(f"[dim]Destination: {result.get('destination', 'N/A')}[/dim]") download_id = result.get("download_id") if download_id: _poll_remote_download(remote_name, download_id) @app.command("dl") def download( version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None, model_id: Annotated[int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)")] = None, hash_val: Annotated[str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up")] = None, output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, ) -> None: """Download a model from CivitAI. When --remote is specified, the download happens on the remote server. Examples: tsr dl -v 12345 # Download by version ID tsr dl -m 67890 # Download latest version of model tsr dl -v 12345 --remote junkpile # Download on remote server """ if remote: _download_remote(remote, version_id, model_id, hash_val, output) return key = api_key or load_api_key() resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) if not resolved_version_id: if not version_id and not hash_val and not model_id: console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") raise typer.Exit(1) console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") version_info = fetch_civitai_model_version(resolved_version_id, key, console) if not version_info: console.print("[red]Error: Could not fetch model version info.[/red]") raise typer.Exit(1) model_type_str: str | None = version_info.get("model", {}).get("type") output_dir = _prepare_download_dir(output, model_type_str) if not output_dir: raise typer.Exit(1) files: list[dict[str, Any]] = version_info.get("files", []) primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) if not primary_file: console.print("[red]Error: No files found for this version.[/red]") raise typer.Exit(1) filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") dest_path = output_dir / filename _display_download_info(version_info, filename, primary_file, dest_path) success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume) if not success: raise typer.Exit(1) # Add downloaded file to database and link to CivitAI _add_downloaded_file_to_db(dest_path, version_info) def _add_downloaded_file_to_db(dest_path: Path, version_info: dict[str, Any]) -> None: """Add a downloaded file to the database, link to CivitAI, and cache full model data. Args: dest_path: Path to the downloaded file version_info: CivitAI version info response """ console.print("[dim]Adding to database...[/dim]") api_key = load_api_key() with Database() as db: db.init_schema() result = db.register_downloaded_file(dest_path, version_info, api_key=api_key, console=console) if result["error"]: console.print(f"[yellow]Warning: Could not add to database: {result['error']}[/yellow]") return console.print(f"[green]Added to database (id={result['file_id']})[/green]") if result["linked"]: civitai_version_id = version_info.get("id") civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id") console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]") if result["cached"]: console.print("[green]Cached model metadata[/green]") def _display_download_info( version_info: dict[str, Any], filename: str, primary_file: dict[str, Any], dest_path: Path, ) -> None: """Display download info table.""" table = Table(title="Model Download", show_header=True, header_style="bold magenta") table.add_column("Property", style="cyan") table.add_column("Value", style="green") table.add_row("Version", version_info.get("name", "N/A")) table.add_row("Base Model", version_info.get("baseModel", "N/A")) table.add_row("File", filename) table.add_row("Size", _format_size(primary_file.get("sizeKB", 0))) table.add_row("Destination", str(dest_path)) console.print() console.print(table) console.print() @app.command() def config( show: Annotated[bool, typer.Option("--show", help="Show current config")] = False, set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None, set_path: Annotated[str | None, typer.Option("--set-path", help="Set model path (TYPE=PATH)")] = None, ) -> None: """Manage configuration.""" if set_key: cfg = load_config() if "api" not in cfg: cfg["api"] = {} cfg["api"]["civitai_key"] = set_key save_config(cfg) console.print(f"[green]API key saved to {CONFIG_FILE}[/green]") return if set_path: # Parse TYPE=PATH format if "=" not in set_path: console.print("[red]Error: Use format TYPE=PATH (e.g., checkpoints=/opt/models/checkpoints)[/red]") raise typer.Exit(1) path_type, path_value = set_path.split("=", 1) path_type = path_type.lower().strip() if path_type not in VALID_PATH_TYPES: console.print(f"[red]Error: Invalid type '{path_type}'. Valid: {', '.join(VALID_PATH_TYPES)}[/red]") raise typer.Exit(1) cfg = load_config() if "paths" not in cfg: cfg["paths"] = {} cfg["paths"][path_type] = path_value.strip() save_config(cfg) console.print(f"[green]Path for {path_type} set to: {path_value}[/green]") return if show or (not set_key and not set_path): console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}") console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}") key = load_api_key() if key: masked = key[:4] + "..." + key[-4:] if len(key) > MIN_KEY_LENGTH_FOR_MASKING else "***" console.print(f"[bold]API key:[/bold] {masked}") else: console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") console.print() console.print("[bold]Model paths:[/bold]") paths = get_model_paths() # Group by unique paths to show cleanly shown_paths: dict[str, list[str]] = {} for model_type, path in paths.items(): path_str = str(path) if path_str not in shown_paths: shown_paths[path_str] = [] shown_paths[path_str].append(model_type) cfg = load_config() configured_paths = cfg.get("paths", {}) for path_str, types in sorted(shown_paths.items(), key=lambda x: x[0]): is_custom = any(path_str == configured_paths.get(k) for k in VALID_PATH_TYPES) marker = " [green](custom)[/green]" if is_custom else " [dim](default)[/dim]" console.print(f" {', '.join(sorted(types))}: {path_str}{marker}") console.print() console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") console.print("[dim]Set paths with: tsr config --set-path checkpoints=/path/to/models[/dim]") @app.command() def serve( host: Annotated[str, typer.Option(help="Listen address.")] = "127.0.0.1", port: Annotated[int, typer.Option(help="Listen port.")] = 51200, log_level: Annotated[str, typer.Option(help="Log level.")] = "info", ) -> None: """Start the tensors server (gallery and CivitAI management).""" try: import uvicorn # noqa: PLC0415 from tensors.server import create_app # noqa: PLC0415 except ImportError: console.print("[red]Missing server dependencies. Install with:[/red]") console.print(" pip install tensors[server]") raise typer.Exit(1) from None uvicorn.run(create_app(), host=host, port=port, log_level=log_level) # ============================================================================= # Top-Level Generate Command # ============================================================================= @app.command(context_settings={"allow_extra_args": False}) def generate( # noqa: PLR0915 ctx: typer.Context, prompt: Annotated[str | None, typer.Argument(help="Positive prompt text", show_default=False)] = None, model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None, width: Annotated[int | None, typer.Option("-W", "--width", help="Image width (auto from checkpoint)")] = None, height: Annotated[int | None, typer.Option("-H", "--height", help="Image height (auto from checkpoint)")] = None, steps: Annotated[int | None, typer.Option("--steps", help="Sampling steps (auto from checkpoint)")] = None, cfg: Annotated[float | None, typer.Option("--cfg", help="CFG scale (auto from checkpoint)")] = None, guidance: Annotated[ float | None, typer.Option( "--guidance", "-g", help="FluxGuidance value (Flux only; default 3.5). Ignored for non-Flux models.", ), ] = None, seed: Annotated[int, typer.Option("--seed", "-s", help="Random seed (-1 for random)")] = -1, sampler: Annotated[str | None, typer.Option("--sampler", help="Sampler name (auto from checkpoint)")] = None, scheduler: Annotated[str | None, typer.Option("--scheduler", help="Scheduler name (auto from checkpoint)")] = None, vae: Annotated[str | None, typer.Option("--vae", help="VAE model name (auto from checkpoint)")] = None, orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "square", lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "", count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1, rating: Annotated[ str | None, typer.Option("--rating", "-R", help="Content rating: safe, questionable, explicit (Pony/Illustrious)") ] = None, no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False, no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False, character: Annotated[ str | None, typer.Option("-C", "--character", help="Saved character name (loaded from ~/.local/share/tensors/characters/)"), ] = None, character_prompt: Annotated[ str | None, typer.Option("--character-prompt", help='Inline character fragment, comma-separated (e.g. "blond hair, blue eyes")'), ] = None, scene: Annotated[ str | None, typer.Option("-S", "--scene", help="Saved scene name (loaded from ~/.local/share/tensors/scenes/)"), ] = None, scene_prompt: Annotated[ str | None, typer.Option( "--scene-prompt", help='Inline scene fragment, comma-separated (e.g. "luxury penthouse, volumetric lighting")' ), ] = None, family: Annotated[ str | None, typer.Option( "--family", "-F", help=( "Override detected model family " "(pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo, " "sd15, sd15_lcm, flux, flux_schnell, flux_unet, flux2_klein, zimage)" ), ), ] = None, output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None, remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, json_input: Annotated[ str | None, typer.Option("--input", "-I", help="JSON or YAML params (file path or inline; keys match CLI options)"), ] = None, parallel_queue: Annotated[ int, typer.Option( "--parallel-queue", "-P", help=( "Concurrent ComfyUI submissions (default 1). When >1 with --count N, " "splits the request into N independent jobs (batch_size=1 each) with " "incrementing seeds, executed P-at-a-time via thread pool. The GPU " "still processes one prompt at a time, but HTTP queue / init / " "download phases pipeline for a ~5-15%% speedup. Per-task output " "interleaves; final summary lists all saved files. Ignored when " "--count is 1." ), ), ] = 1, ) -> None: """Generate an image using text-to-image. Auto-detects optimal sampler, scheduler, CFG, resolution, and VAE from the checkpoint model family. All auto-detected values can be overridden with explicit flags. Calls ComfyUI directly when local, or the remote tensors API when --remote is given. Accepts --input with a JSON or YAML object whose keys match CLI option names. Files ending in ``.yml`` / ``.yaml`` are parsed as YAML; ``.json`` (or any other extension whose contents start with ``{``/``[``) as JSON. Inline strings starting with ``{`` are JSON, everything else is YAML. CLI flags override --input values. With --count > 1, images are generated as a single ComfyUI batch by default (one workflow, sequential on GPU). Use --parallel-queue N to instead split into N independent batch_size=1 jobs queued in parallel, each with its own seed — useful for overlapping the HTTP/download phase across requests. Examples: tsr generate "a cat on a windowsill" tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait tsr generate "cyberpunk city" -o output.png --count 4 tsr generate "landscape" --remote junkpile tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}' tsr generate --input scene.yml tsr generate "raw prompt" --no-quality --no-negative tsr generate "city" -c 8 -P 4 -o out.png # 8 distinct seeds, 4 in flight """ if parallel_queue < 1: console.print("[red]--parallel-queue must be >= 1[/red]") raise typer.Exit(1) if parallel_queue > 1 and json_output: # _run_generation short-circuits the disk-save when json_output=True # (it dumps JSON and returns). For the parallel fanout to actually save # files, each task must take the non-JSON path. We render our own JSON # at the end, so the per-task --json is incompatible. console.print( "[red]--json is not supported with --parallel-queue > 1 (would skip the file-save step). Drop one or the other.[/red]" ) raise typer.Exit(1) # ---- --input merging (JSON or YAML) ---- if json_input is not None: ji = _parse_generate_input(json_input) # Map source keys to parameter names (handle aliases) key_map = {"negative_prompt": "negative", "lora_name": "lora"} mapped: dict[str, Any] = {} for k, v in ji.items(): mapped[key_map.get(k, k)] = v # Determine which CLI params the user explicitly set click_ctx = ctx._context if hasattr(ctx, "_context") else ctx explicit = ( { p.name for p in click_ctx.command.params if click_ctx.get_parameter_source(p.name) == click.core.ParameterSource.COMMANDLINE } if hasattr(click_ctx, "get_parameter_source") else set() ) # Apply JSON values for anything not explicitly set on CLI if "prompt" in mapped and ("prompt" not in explicit and prompt is None): prompt = mapped["prompt"] if "model" in mapped and "model" not in explicit: model = mapped["model"] if "width" in mapped and "width" not in explicit: width = int(mapped["width"]) if "height" in mapped and "height" not in explicit: height = int(mapped["height"]) if "steps" in mapped and "steps" not in explicit: steps = int(mapped["steps"]) if "cfg" in mapped and "cfg" not in explicit: cfg = float(mapped["cfg"]) if "seed" in mapped and "seed" not in explicit: seed = int(mapped["seed"]) if "sampler" in mapped and "sampler" not in explicit: sampler = mapped["sampler"] if "scheduler" in mapped and "scheduler" not in explicit: scheduler = mapped["scheduler"] if "vae" in mapped and "vae" not in explicit: vae = mapped["vae"] if "lora" in mapped and "lora" not in explicit: lora = mapped["lora"] if "lora_strength" in mapped and "lora_strength" not in explicit: lora_strength = float(mapped["lora_strength"]) if "negative" in mapped and "negative" not in explicit: negative = mapped["negative"] if "output" in mapped and "output" not in explicit: output = Path(mapped["output"]) if "remote" in mapped and "remote" not in explicit: remote = mapped["remote"] if "count" in mapped and "count" not in explicit: count = int(mapped["count"]) if "orientation" in mapped and "orientation" not in explicit: orientation = mapped["orientation"] if "no_quality" in mapped and "no_quality" not in explicit: no_quality = bool(mapped["no_quality"]) if "no_negative" in mapped and "no_negative" not in explicit: no_negative = bool(mapped["no_negative"]) if "character" in mapped and "character" not in explicit: # Accept either a saved-name string or an already-resolved list/tuple # (templates may carry the resolved list inline). For lists we stage # them into character_prompt by joining with commas so the existing # CLI splitting/dedup path applies uniformly. val = mapped["character"] if isinstance(val, str): character = val elif isinstance(val, (list, tuple)): character_prompt = ", ".join(str(x) for x in val if str(x).strip()) if "character_prompt" in mapped and "character_prompt" not in explicit: cp_val = mapped["character_prompt"] character_prompt = cp_val if isinstance(cp_val, str) else ", ".join(str(x) for x in cp_val if str(x).strip()) if "scene" in mapped and "scene" not in explicit: sv = mapped["scene"] if isinstance(sv, str): scene = sv elif isinstance(sv, (list, tuple)): scene_prompt = ", ".join(str(x) for x in sv if str(x).strip()) if "scene_prompt" in mapped and "scene_prompt" not in explicit: sp_val = mapped["scene_prompt"] scene_prompt = sp_val if isinstance(sp_val, str) else ", ".join(str(x) for x in sp_val if str(x).strip()) if "rating" in mapped and "rating" not in explicit: rating = mapped["rating"] if "parallel_queue" in mapped and "parallel_queue" not in explicit: parallel_queue = int(mapped["parallel_queue"]) has_content = bool(prompt or character or character_prompt or scene or scene_prompt) if not has_content: console.print("[red]Prompt (or character/scene) is required[/red]") raise typer.Exit(1) # Effective parallelism is bounded by count — running 4 threads for 1 image # is silly. count=1 always goes through the sequential path regardless of -P. effective_parallel = min(parallel_queue, count) if count > 1 else 1 if effective_parallel <= 1: # Sequential path: single _run_generation call with batch_size=count. # Unchanged from pre-parallel behavior — preserves existing output naming, # JSON shape, and log lines exactly. _run_generation( prompt=prompt, model=model, width=width, height=height, steps=steps, cfg=cfg, guidance=guidance, seed=seed, sampler=sampler, scheduler=scheduler, vae=vae, orientation=orientation, lora=lora, lora_strength=lora_strength, negative=negative, count=count, rating=rating, no_quality=no_quality, no_negative=no_negative, character=character, character_prompt=character_prompt, scene=scene, scene_prompt=scene_prompt, family=family, output=output, remote=remote, json_output=json_output, ) return # ---- Parallel fanout path ---- # Split count into `count` independent jobs (batch_size=1), executed # `effective_parallel` at a time. Each job gets a distinct seed and a # distinct output path so writes don't clobber each other. import random as _rng # noqa: PLC0415 import time as _time # noqa: PLC0415 from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415 # Resolve bare model/lora names ONCE in the parent before fanout. Each # parallel _run_generation call silences its own console (json_output=True) # which also skips the validation/resolution step in that path. Doing it # here means each task receives a canonical filename and ComfyUI's strict # loaders accept the request first try. if model and not remote: # Detect family for the right loader bucket (checkpoints vs diffusion_models). # Mirrors the lookup _run_generation does on entry. from tensors.db import Database # noqa: PLC0415 _base_model: str | None = None try: with Database() as _db: _db.init_schema() _base_model = _db.get_base_model_by_filename(model) except Exception: pass _detected = detect_model_family(model, _base_model) _fam = family or _detected try: model, lora = _validate_model_available(model, _fam, lora) except typer.Exit: raise # surface the same error path as sequential # Seed strategy: # --seed >= 0 → use as base, increment per job (reproducible series) # --seed == -1 → pick a fresh random seed PER JOB so parallel runs aren't # accidentally correlated (each thread gets variety) seeds = [seed + i for i in range(count)] if seed >= 0 else [_rng.randint(0, 2**32 - 1) for _ in range(count)] # Output paths: mirror the existing `count > 1` naming convention from # _run_generation (stem_NNN.ext). When --output is omitted, leave per-task # output as None — _run_generation will skip the disk write and the user # gets only the console listing of generated image refs. out_paths: list[Path | None] = [] for i in range(count): if output is None: out_paths.append(None) else: out_paths.append(output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}") if not json_output: console.print( f"[dim]Parallel queue: {effective_parallel} concurrent submissions x {count} images (output may interleave)[/dim]" ) common_kwargs: dict[str, Any] = { "prompt": prompt, "model": model, "width": width, "height": height, "steps": steps, "cfg": cfg, "guidance": guidance, "sampler": sampler, "scheduler": scheduler, "vae": vae, "orientation": orientation, "lora": lora, "lora_strength": lora_strength, "negative": negative, "count": 1, # each task generates exactly one image "rating": rating, "no_quality": no_quality, "no_negative": no_negative, "character": character, "character_prompt": character_prompt, "scene": scene, "scene_prompt": scene_prompt, "family": family, "remote": remote, # NOTE: json_output stays False so _run_generation's disk-save path runs. # Setting True would short-circuit before saving files. Per-task console # chatter is the trade-off; the final summary still shows clean per-task # status lines. "json_output": False, } def _run_one(idx: int) -> dict[str, Any]: """Run a single batch_size=1 job. Returns a result dict (success captured).""" start = _time.perf_counter() result: dict[str, Any] = { "index": idx, "seed": seeds[idx], "output": str(out_paths[idx]) if out_paths[idx] is not None else None, "duration_sec": 0.0, "success": False, "error": None, } try: _run_generation(seed=seeds[idx], output=out_paths[idx], **common_kwargs) result["duration_sec"] = round(_time.perf_counter() - start, 2) result["success"] = True except typer.Exit as ex: result["duration_sec"] = round(_time.perf_counter() - start, 2) result["error"] = f"generate exited with code {ex.exit_code}" except Exception as ex: result["duration_sec"] = round(_time.perf_counter() - start, 2) result["error"] = str(ex) return result fan_results: list[dict[str, Any]] = [] with ThreadPoolExecutor(max_workers=effective_parallel) as pool: futures = {pool.submit(_run_one, i): i for i in range(count)} for completed, fut in enumerate(as_completed(futures), start=1): try: res = fut.result() except Exception as ex: # Defensive — _run_one already swallows, but if the executor itself # raises (e.g. pickling failure) we still want a well-formed result # in the manifest rather than a crash. res = { "index": futures[fut], "seed": seeds[futures[fut]], "output": str(out_paths[futures[fut]]) if out_paths[futures[fut]] is not None else None, "duration_sec": 0.0, "success": False, "error": f"executor exception: {ex}", } fan_results.append(res) if not json_output: if res["success"]: where = res["output"] or "(no --output set)" console.print( f"[green]\\[{completed}/{count}] seed={res['seed']} ok in {res['duration_sec']:.1f}s → {where}[/green]" ) else: console.print(f"[red]\\[{completed}/{count}] seed={res['seed']} FAIL: {res['error']}[/red]") # Reorder by original index so JSON output / final summary list is stable. fan_results.sort(key=lambda r: r["index"]) successful = sum(1 for r in fan_results if r["success"]) if json_output: console.print_json( data={ "success": successful == count, "count": count, "parallel_queue": effective_parallel, "results": fan_results, } ) return console.print("[bold green]Generation complete![/bold green]") console.print(f"[dim]Generated {successful}/{count} images at parallelism={effective_parallel}[/dim]") if successful < count: raise typer.Exit(1) # Map model family → which ComfyUI loader directory the checkpoint must live in. # Used by _validate_model_available() to query the right slot from get_loaded_models(). _FAMILY_TO_LOADER_BUCKET: dict[str, str] = { "flux_unet": "diffusion_models", "flux2_klein": "diffusion_models", } # Extensions tried in order when the user passes a bare name (no suffix) and # the exact lookup misses. Safetensors first since that's the modern default. _MODEL_EXTENSIONS: tuple[str, ...] = (".safetensors", ".ckpt", ".gguf", ".pt", ".bin") def _resolve_with_extension(name: str, available: list[str]) -> str | None: """If `name` is missing an extension but exactly one suffixed variant exists in `available`, return that variant. Otherwise return None. Examples: _resolve_with_extension("lust_v10", ["lust_v10.safetensors", ...]) -> "lust_v10.safetensors" _resolve_with_extension("lust_v10", ["lust_v10.safetensors", "lust_v10.ckpt"]) -> None # ambiguous _resolve_with_extension("lust_v10.safetensors", [...]) -> None # already has ext """ if "." in name: # user already provided an extension; don't second-guess return None candidates = [f"{name}{ext}" for ext in _MODEL_EXTENSIONS if f"{name}{ext}" in available] if len(candidates) == 1: return candidates[0] return None # zero matches → real miss; multiple matches → ambiguous, force user to disambiguate def _validate_model_available(model: str, family: str | None, lora: str | None) -> tuple[str, str | None]: """Verify model + LoRA exist on the live ComfyUI host before queueing. Returns (resolved_model, resolved_lora) so the caller can substitute the canonical filename when the user passed a bare name without extension (e.g. -m lust_v10 → lust_v10.safetensors). Otherwise returns the inputs unchanged. Fails fast with typer.Exit(1) and a "did you mean" suggestion when the requested file isn't loaded. Bucket lookup respects family: - flux_unet / flux2_klein → diffusion_models/ (UNETLoader) - everything else → checkpoints/ (CheckpointLoaderSimple) Network failures are non-fatal — we'd rather forward to ComfyUI and let its 400 surface than block on a stale comfyui endpoint. """ from difflib import get_close_matches # noqa: PLC0415 from tensors.comfyui import get_loaded_models # noqa: PLC0415 try: loaded = get_loaded_models(console=None) except Exception: return model, lora # network down — let ComfyUI itself handle it if not loaded: return model, lora bucket = _FAMILY_TO_LOADER_BUCKET.get(family or "", "checkpoints") available = loaded.get(bucket, []) if model not in available: # Try implicit-extension resolution before failing: bare names like # `lust_v10` should silently resolve to `lust_v10.safetensors` when # there's an unambiguous match. resolved = _resolve_with_extension(model, available) if resolved is not None: console.print(f"[dim]Resolved '{model}' → '{resolved}'[/dim]") model = resolved else: console.print(f"[red]Model '{model}' not available on ComfyUI host[/red]") console.print(f"[dim](looked in {bucket}/ — {len(available)} entries)[/dim]") matches = get_close_matches(model, available, n=3, cutoff=0.5) if matches: console.print("[yellow]Did you mean:[/yellow]") for m in matches: console.print(f" [cyan]{m}[/cyan]") else: console.print(f"[dim]Run `tsr models` to see what's installed in {bucket}/.[/dim]") # Suggest symlink fix if the file exists in checkpoints/ but family wants diffusion_models/ if bucket == "diffusion_models" and ( model in loaded.get("checkpoints", []) or _resolve_with_extension(model, loaded.get("checkpoints", [])) is not None ): console.print( f"[yellow]Hint:[/yellow] '{model}' is in checkpoints/ but UNet-only " "checkpoints need to be in diffusion_models/. On the ComfyUI host: " f"[cyan]ln -s ../checkpoints/{model} " f"/models/diffusion_models/{model}[/cyan]" ) raise typer.Exit(1) if lora is not None: loras_available = loaded.get("loras", []) if lora not in loras_available: resolved_lora = _resolve_with_extension(lora, loras_available) if resolved_lora is not None: console.print(f"[dim]Resolved LoRA '{lora}' → '{resolved_lora}'[/dim]") lora = resolved_lora else: console.print(f"[red]LoRA '{lora}' not available on ComfyUI host[/red]") matches = get_close_matches(lora, loras_available, n=3, cutoff=0.5) if matches: console.print("[yellow]Did you mean:[/yellow]") for m in matches: console.print(f" [cyan]{m}[/cyan]") raise typer.Exit(1) return model, lora def _run_generation( # noqa: PLR0915 *, prompt: str | None = None, model: str | None = None, width: int | None = None, height: int | None = None, steps: int | None = None, cfg: float | None = None, guidance: float | None = None, seed: int = -1, sampler: str | None = None, scheduler: str | None = None, vae: str | None = None, orientation: str = "square", lora: str | None = None, lora_strength: float = 0.8, negative: str = "", count: int = 1, rating: str | None = None, no_quality: bool = False, no_negative: bool = False, character: str | None = None, character_prompt: str | None = None, scene: str | None = None, scene_prompt: str | None = None, family: str | None = None, output: Path | None = None, remote: str | None = None, json_output: bool = False, ) -> None: """Core generation routine shared by `generate` and `style-sweep`. All parameters are fully resolved (no CLI/JSON merging here). Raises typer.Exit on failure. Prints to console unless json_output is True (then prints JSON). """ import random as rng # noqa: PLC0415 # ---- Detect model family and enhance prompt/negative ---- family_defaults: dict[str, Any] = {} model_family: str | None = None base_model_str: str | None = None if model: try: with Database() as db: db.init_schema() base_model_str = db.get_base_model_by_filename(model) except Exception: pass detected_family = detect_model_family(model, base_model_str) model_family = family or detected_family if model_family: family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {}) if not json_output: if family and detected_family and family != detected_family: console.print(f"[dim]Model family: {model_family} (override; detected: {detected_family})[/dim]") elif family: console.print(f"[dim]Model family: {model_family} (override)[/dim]") else: console.print(f"[dim]Detected model family: {model_family}[/dim]") # ---- Validate the requested model exists on the target host ---- # Catches mismatches between local intent ("v5Hardcore") and what's actually # available remotely ("v11Softcore"), and offers a fuzzy "did you mean" hint # instead of forwarding the request to ComfyUI for a generic 400 rejection. # Skipped in --json mode and for remote dispatches (server already validates). if model and not json_output and not remote: # Returns possibly-rewritten names so bare inputs like `-m lust_v10` # silently resolve to the canonical `lust_v10.safetensors` filename # before being forwarded to ComfyUI's strict CLIPLoader / UNETLoader. model, lora = _validate_model_available(model, model_family, lora) # Build enhanced prompt with quality prefix (no automatic LoRA trigger injection) prompt_parts: list[str] = [] # Add quality prefix based on model family if not no_quality and family_defaults.get("quality_prefix"): prompt_parts.append(family_defaults["quality_prefix"]) # Resolve character (named lookup + inline --character-prompt, merged + deduped) character_elements: list[str] = [] if character or character_prompt: from tensors.characters import resolve_character # noqa: PLC0415 try: character_elements = resolve_character(character=character, character_prompt=character_prompt) except FileNotFoundError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e except ValueError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e if character_elements: prompt_parts.extend(character_elements) if not json_output: origin = f"'{character}'" if character else "inline" console.print( f"[dim]Character ({origin}, {len(character_elements)} elements): {', '.join(character_elements)}[/dim]" ) # Resolve scene (named lookup + inline --scene-prompt, merged + deduped). # Scene sits between character (who) and rating/user prompt (what's happening) # so the natural reading order is: quality → character → scene → rating → user. scene_elements: list[str] = [] if scene or scene_prompt: from tensors.scenes import resolve_scene # noqa: PLC0415 try: scene_elements = resolve_scene(scene=scene, scene_prompt=scene_prompt) except FileNotFoundError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e except ValueError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e if scene_elements: prompt_parts.extend(scene_elements) if not json_output: origin = f"'{scene}'" if scene else "inline" console.print(f"[dim]Scene ({origin}, {len(scene_elements)} elements): {', '.join(scene_elements)}[/dim]") # Add rating tag based on model family (Pony/Illustrious) if rating: from tensors.config import get_rating_tag # noqa: PLC0415 rating_tag = get_rating_tag(model_family, rating.lower()) if rating_tag: prompt_parts.append(rating_tag) if not json_output: console.print(f"[dim]Rating tag: {rating_tag}[/dim]") elif not json_output: console.print(f"[dim]Rating '{rating}' not applicable for {model_family or 'unknown'} family[/dim]") # Add user prompt if prompt: prompt_parts.append(prompt) enhanced_prompt = ", ".join(prompt_parts) if prompt_parts else "" # Build enhanced negative prompt enhanced_negative = negative if not no_negative and family_defaults.get("negative_prompt"): family_negative = family_defaults["negative_prompt"] enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative): if enhanced_prompt != prompt: truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004 console.print(f"[dim]Enhanced prompt: {truncated}[/dim]") if enhanced_negative != negative: truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004 console.print(f"[dim]Enhanced negative: {truncated}[/dim]") # ---- Resolve preset defaults for None params (both remote and local need these) ---- from tensors.config import resolve_orientation # noqa: PLC0415 from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415 # Use already-detected family_defaults from DB lookup above (not filename guessing) if family_defaults: res_w, res_h = resolve_orientation(model_family, orientation) if width is None: width = res_w if height is None: height = res_h if steps is None: steps = family_defaults.get("steps", 20) if cfg is None: cfg = family_defaults.get("cfg", 7.0) if sampler is None: sampler = family_defaults.get("sampler", "euler") if scheduler is None: scheduler = family_defaults.get("scheduler", "normal") if vae is None: vae = family_defaults.get("vae") # Fallback to global defaults when no model family was detected if width is None: width = COMFYUI_DEFAULT_WIDTH if height is None: height = COMFYUI_DEFAULT_HEIGHT if steps is None: steps = COMFYUI_DEFAULT_STEPS if cfg is None: cfg = COMFYUI_DEFAULT_CFG if sampler is None: sampler = COMFYUI_DEFAULT_SAMPLER if scheduler is None: scheduler = COMFYUI_DEFAULT_SCHEDULER # ---- Determine base seed ---- base_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1) # Resolve remote (explicit flag, or default from config) remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None) all_results: list[dict[str, Any]] = [] all_saved: list[Path] = [] if remote_url: # ---- Remote mode: HTTP call to tensors server ---- if not json_output: console.print(f"[dim]Remote: {remote_url}[/dim]") result = remote_generate( remote or remote_url, enhanced_prompt, negative_prompt=enhanced_negative, model=model, width=width, height=height, steps=steps, cfg=cfg, seed=base_seed, sampler=sampler, scheduler=scheduler, vae=vae, lora_name=lora, lora_strength=lora_strength, guidance=guidance, console=console, ) if not result: if not json_output: console.print("[red]Generation failed[/red]") raise typer.Exit(1) if json_output: console.print_json(data=result) return if not result.get("success"): console.print("[red]Generation failed[/red]") errors = result.get("errors", {}) for node_id, err in errors.items(): console.print(f" [yellow]Node {node_id}:[/yellow] {err}") raise typer.Exit(1) images = result.get("images", []) console.print(f"[green]Generated {len(images)} image(s)[/green]") console.print(f"[dim]Prompt ID: {result.get('prompt_id', 'N/A')}[/dim]") # Download and save images if --output specified if output and images: for i, img_name in enumerate(images): img_data = remote_get_image(remote or remote_url, img_name) if img_data: save_path = output if len(images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}" save_path.write_bytes(img_data) console.print(f"[green]Saved:[/green] {save_path}") else: console.print(f"[yellow]Could not download image: {img_name}[/yellow]") elif images: for img_name in images: console.print(f" [dim]{img_name}[/dim]") else: # ---- Local mode: direct library call ---- from tensors.comfyui import generate_image, get_image # noqa: PLC0415 result_local = generate_image( prompt=enhanced_prompt, negative_prompt=enhanced_negative, model=model, width=width, height=height, steps=steps, cfg=cfg, seed=base_seed, sampler=sampler, scheduler=scheduler, console=console if not json_output else None, lora_name=lora, lora_strength=lora_strength, batch_size=count, vae=vae, orientation=orientation, guidance=guidance, ) if not result_local: if json_output: all_results.append({"success": False, "index": 0, "errors": {"generation": "Failed to generate"}}) else: console.print("[red]Generation failed[/red]") raise typer.Exit(1) elif not result_local.success: if json_output: all_results.append({"success": False, "index": 0, "errors": result_local.node_errors}) else: console.print("[red]Generation failed[/red]") for node_id, errors in result_local.node_errors.items(): console.print(f" [yellow]Node {node_id}:[/yellow] {errors}") raise typer.Exit(1) else: # Save all output images for i, img_path in enumerate(result_local.images): saved_path: Path | None = None if output: img_data = get_image(str(img_path)) if img_data: save_path = output if count == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}" save_path.write_bytes(img_data) saved_path = save_path all_saved.append(save_path) if not json_output: console.print(f"[green]Saved:[/green] {save_path}") elif not json_output: console.print(f"[yellow]Could not download image: {img_path}[/yellow]") all_results.append( { "success": True, "index": i, "prompt_id": result_local.prompt_id, "image": str(img_path), "saved": str(saved_path) if saved_path else None, } ) if json_output: console.print_json( data={ "success": all(r.get("success", False) for r in all_results), "count": len(all_results), "results": all_results, } ) return console.print("[bold green]Generation complete![/bold green]") if count > 1: successful = sum(1 for r in all_results if r.get("success", False)) console.print(f"[dim]Generated {successful}/{count} images[/dim]") if all_saved: console.print(f"[dim]Saved to: {all_saved[0].parent}/[/dim]") elif all_results and all_results[0].get("prompt_id"): console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]") # ============================================================================= # Style Sweep # ============================================================================= # Keys that style-sweep templates accept (mirror of `generate --input` keys, plus # two sweep-specific keys: output_dir and styles). _STYLE_SWEEP_TEMPLATE_KEYS = { "prompt", "model", "width", "height", "steps", "cfg", "guidance", "seed", "sampler", "scheduler", "vae", "lora", "lora_strength", "negative", "negative_prompt", "orientation", "no_quality", "no_negative", "character", "character_prompt", "scene", "scene_prompt", "rating", "family", "remote", # sweep-specific "output_dir", "styles", } def _load_json_file_or_inline(value: str | list | dict, *, what: str) -> Any: """Load JSON from a file path or accept already-parsed inline data. `value` may be a path string, a JSON string, or an already-parsed list/dict (e.g. when read out of a template). Raises typer.Exit on failure. """ if isinstance(value, (list, dict)): return value if not isinstance(value, str): console.print(f"[red]Invalid {what} value (expected path, JSON string, or inline data)[/red]") raise typer.Exit(1) path = Path(value) if path.is_file(): try: return json.loads(path.read_text()) except json.JSONDecodeError as e: console.print(f"[red]Invalid JSON in {what} file {path}:[/red] {e}") raise typer.Exit(1) from e stripped = value.lstrip() if stripped.startswith(("{", "[")): try: return json.loads(value) except json.JSONDecodeError as e: console.print(f"[red]Invalid inline JSON for {what}:[/red] {e}") raise typer.Exit(1) from e console.print(f"[red]{what.capitalize()} is neither a readable file nor inline JSON:[/red] {value}") raise typer.Exit(1) def _normalize_styles(styles_data: Any) -> list[dict[str, str]]: """Coerce styles data into a flat list of {slug, suffix} dicts.""" if isinstance(styles_data, dict): entries = styles_data.get("styles") if entries is None: console.print("[red]Styles object missing 'styles' key[/red]") raise typer.Exit(1) elif isinstance(styles_data, list): entries = styles_data else: console.print("[red]Styles data must be an object with 'styles' key or a list[/red]") raise typer.Exit(1) if not isinstance(entries, list) or not entries: console.print("[red]Styles list is empty or not a list[/red]") raise typer.Exit(1) normalized: list[dict[str, str]] = [] for i, entry in enumerate(entries): if not isinstance(entry, dict): console.print(f"[red]Style entry #{i} is not an object[/red]") raise typer.Exit(1) slug = entry.get("slug") suffix = entry.get("suffix") if not slug or not isinstance(slug, str): console.print(f"[red]Style entry #{i} missing/invalid 'slug'[/red]") raise typer.Exit(1) if suffix is None or not isinstance(suffix, str): console.print(f"[red]Style entry #{i} ({slug}) missing/invalid 'suffix'[/red]") raise typer.Exit(1) normalized.append({"slug": slug, "suffix": suffix}) return normalized @app.command(name="style-sweep") def style_sweep( # noqa: PLR0915 template: Annotated[ Path | None, typer.Option("--template", "-t", help="Path to template JSON (mirrors `generate --input` keys + output_dir/styles)"), ] = None, styles: Annotated[ str | None, typer.Option("--styles", help="Styles source: path to JSON or inline JSON list/object (overrides template's styles)"), ] = None, output_dir: Annotated[ Path | None, typer.Option("--output-dir", help="Override output directory from template"), ] = None, limit: Annotated[ int | None, typer.Option("--limit", help="Stop after N styles (applied after --style filter)"), ] = None, style_filter: Annotated[ list[str] | None, typer.Option("--style", "-S", help="Only run the named slug(s); repeatable for multiple"), ] = None, list_styles: Annotated[ bool, typer.Option("--list", "-L", help="Print resolved styles list and exit; no generation"), ] = False, skip_existing: Annotated[ bool, typer.Option("--skip-existing/--no-skip-existing", help="Skip styles whose output file already exists"), ] = True, dry_run: Annotated[ bool, typer.Option("--dry-run", help="Print planned prompts/paths without invoking generate"), ] = False, continue_on_error: Annotated[ bool, typer.Option("--continue-on-error/--abort-on-error", help="Keep going after individual style failures"), ] = True, remote: Annotated[ str | None, typer.Option("-r", "--remote", help="Remote server name or URL (overrides template)"), ] = None, parallel_queue: Annotated[ int, typer.Option( "--parallel-queue", "-P", help=( "Concurrent ComfyUI submissions (default 1). Values >1 submit N " "prompts to ComfyUI's HTTP queue in parallel; the GPU still " "processes one at a time, but HTTP/init/download overhead is " "pipelined for a ~5-15%% speedup. Per-task console output will " "interleave; use the manifest for accurate per-slug timing." ), ), ] = 1, ) -> None: """Sweep a base prompt across a list of style suffixes, one image per style. Loads a template JSON with the base prompt + generation params, plus a styles JSON listing {slug, suffix} entries. For each style, composes "{prompt}, {suffix}" and renders to {output_dir}/{slug}.png. Writes a manifest at {output_dir}/_sweep.json with per-style results. With --list, just prints the resolved styles list (template optional in that case if --styles is provided directly). Examples: tsr style-sweep --template woman-black-dress.json tsr style-sweep -t template.json --styles styles.json --limit 3 tsr style-sweep -t template.json --dry-run tsr style-sweep -t template.json --remote junkpile tsr style-sweep -t template.json --list tsr style-sweep --styles styles.json --list tsr style-sweep -t template.json -S 38-manara -S 40-elder-kurtzman tsr style-sweep -t template.json -P 4 # 4 concurrent submissions """ # ---- Validate required inputs ---- # Template is required for generation, but optional when --list is paired # with an explicit --styles source. if template is None and not (list_styles and styles is not None): console.print("[red]--template is required (or use --list with --styles to inspect a styles file)[/red]") raise typer.Exit(1) if parallel_queue < 1: console.print("[red]--parallel-queue must be >= 1[/red]") raise typer.Exit(1) # ---- Load template (if provided) ---- tpl_data: dict[str, Any] = {} if template is not None: if not template.is_file(): console.print(f"[red]Template file not found:[/red] {template}") raise typer.Exit(1) try: tpl_data = json.loads(template.read_text()) except json.JSONDecodeError as e: console.print(f"[red]Invalid JSON in template {template}:[/red] {e}") raise typer.Exit(1) from e if not isinstance(tpl_data, dict): console.print("[red]Template JSON must be an object[/red]") raise typer.Exit(1) # Warn on unknown keys (don't error — forward-compat) unknown = {k for k in tpl_data if not k.startswith("_") and k not in _STYLE_SWEEP_TEMPLATE_KEYS} if unknown: console.print(f"[yellow]Unknown template keys ignored:[/yellow] {sorted(unknown)}") # base_prompt is optional if character or scene fields are provided base_prompt = tpl_data.get("prompt") if template is not None else None has_content = bool( base_prompt or tpl_data.get("character") or tpl_data.get("character_prompt") or tpl_data.get("scene") or tpl_data.get("scene_prompt") ) if not list_styles and not has_content: console.print("[red]Template missing required 'prompt', 'character', or 'scene'[/red]") raise typer.Exit(1) # ---- Resolve styles source ---- # Relative paths inside the template are resolved against the template's # directory (so templates can ship next to their styles files). tpl_dir = template.resolve().parent if template is not None else None def _resolve_relative_to_template(val: str) -> str: if tpl_dir is None: return val p = Path(val) if not p.is_absolute() and not p.exists(): alt = tpl_dir / p if alt.exists(): return str(alt) return val styles_source: Any styles_origin: str if styles is not None: styles_origin = styles styles_source = _load_json_file_or_inline(styles, what="styles") elif "styles" in tpl_data: tpl_styles = tpl_data["styles"] if isinstance(tpl_styles, list): styles_origin = "" styles_source = tpl_styles else: resolved = _resolve_relative_to_template(tpl_styles) styles_origin = resolved styles_source = _load_json_file_or_inline(resolved, what="styles") else: console.print("[red]No styles specified (use --styles or set 'styles' in template)[/red]") raise typer.Exit(1) style_entries = _normalize_styles(styles_source) # ---- Apply --style filter (exact slug match) ---- if style_filter: available = [e["slug"] for e in style_entries] wanted = list(style_filter) unknown_slugs = [s for s in wanted if s not in available] if unknown_slugs: console.print(f"[red]Unknown style slug(s):[/red] {', '.join(unknown_slugs)}") console.print(f"[dim]Available slugs ({len(available)}):[/dim] {', '.join(available)}") raise typer.Exit(1) # Preserve order of the original styles list, but only keep wanted slugs wanted_set = set(wanted) style_entries = [e for e in style_entries if e["slug"] in wanted_set] # ---- Apply --limit (after filter) ---- if limit is not None: if limit < 0: console.print("[red]--limit must be >= 0[/red]") raise typer.Exit(1) style_entries = style_entries[:limit] # ---- --list short-circuit: print and exit ---- if list_styles: _print_styles_list(styles_origin, style_entries) return # ---- Resolve output directory ---- out_dir: Path if output_dir is not None: out_dir = output_dir elif "output_dir" in tpl_data: out_dir = Path(tpl_data["output_dir"]) else: console.print("[red]No output_dir specified (use --output-dir or set 'output_dir' in template)[/red]") raise typer.Exit(1) if not dry_run: out_dir.mkdir(parents=True, exist_ok=True) # ---- Resolve generate params from template ---- def _t(key: str, *, cast: Any = None, default: Any = None) -> Any: val = tpl_data.get(key, default) if val is None or cast is None: return val try: return cast(val) except (TypeError, ValueError): return val # Accept both "negative" and "negative_prompt" keys negative_val = tpl_data.get("negative", tpl_data.get("negative_prompt", "")) or "" gen_remote = remote if remote is not None else tpl_data.get("remote") # ---- Execute sweep ---- import time # noqa: PLC0415 total = len(style_entries) console.print(f"[bold]Style sweep:[/bold] {total} styles → {out_dir}") console.print(f"[dim]Template: {template}[/dim]") console.print(f"[dim]Styles: {styles_origin}[/dim]") if dry_run: console.print("[yellow]DRY RUN — no generation calls will be made[/yellow]") results: list[dict[str, Any]] = [] failed_slugs: list[str] = [] # Pre-compute per-style work items and short-circuit skip/dry-run cases # synchronously (no point pipelining no-ops). Only real generation tasks # go through the executor path. pending_tasks: list[tuple[int, dict[str, str], dict[str, Any], Path]] = [] for i, entry in enumerate(style_entries, start=1): slug = entry["slug"] suffix = entry["suffix"] composed_prompt = f"{base_prompt}, {suffix}" out_path = out_dir / f"{slug}.png" result: dict[str, Any] = { "slug": slug, "prompt": composed_prompt, "output": str(out_path), "seed": _t("seed", cast=int, default=-1), "duration_sec": 0.0, "success": False, "error": None, } if skip_existing and out_path.exists(): console.print(f"[dim]\\[{i}/{total}] {slug} skip (exists)[/dim]") result["success"] = True result["skipped"] = True results.append(result) continue if dry_run: console.print(f"\\[{i}/{total}] {slug}") console.print(f" [dim]prompt:[/dim] {composed_prompt}") console.print(f" [dim]output:[/dim] {out_path}") result["success"] = True result["dry_run"] = True results.append(result) continue pending_tasks.append((i, entry, result, out_path)) # Character / scene resolution: templates may carry either a name string # (look up at run-time) or an inline list of resolved elements (e.g. produced # by `tsr template -C ... -S ...`). Lists are joined into the *_prompt arg # so _run_generation sees a uniform CSV string and skips the disk lookup. def _split_fragment(name_val: Any, prompt_val: Any) -> tuple[str | None, str | None]: name_out: str | None = None inline_out: str | None = None if isinstance(name_val, str): name_out = name_val elif isinstance(name_val, (list, tuple)): inline_out = ", ".join(str(x) for x in name_val if str(x).strip()) if prompt_val is not None: inline_out = prompt_val if isinstance(prompt_val, str) else ", ".join(str(x) for x in prompt_val if str(x).strip()) return name_out, inline_out char_name, char_inline = _split_fragment(tpl_data.get("character"), tpl_data.get("character_prompt")) scene_name, scene_inline = _split_fragment(tpl_data.get("scene"), tpl_data.get("scene_prompt")) # Common kwargs for every _run_generation call — extracted from the # template once, reused across sequential and parallel paths. base_gen_kwargs: dict[str, Any] = { "model": _t("model"), "width": _t("width", cast=int), "height": _t("height", cast=int), "steps": _t("steps", cast=int), "cfg": _t("cfg", cast=float), "guidance": _t("guidance", cast=float), "seed": _t("seed", cast=int, default=-1), "sampler": _t("sampler"), "scheduler": _t("scheduler"), "vae": _t("vae"), "orientation": _t("orientation", default="square"), "lora": _t("lora"), "lora_strength": _t("lora_strength", cast=float, default=0.8), "negative": negative_val, "count": 1, "rating": _t("rating"), "no_quality": bool(_t("no_quality", default=False)), "no_negative": bool(_t("no_negative", default=False)), "character": char_name, "character_prompt": char_inline, "scene": scene_name, "scene_prompt": scene_inline, "family": _t("family"), "remote": gen_remote, "json_output": False, } def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]: """Run a single style. Returns the result dict (success or error captured).""" _idx, _entry_in, res, opath = task composed = res["prompt"] start = time.perf_counter() try: _run_generation(prompt=composed, output=opath, **base_gen_kwargs) res["duration_sec"] = round(time.perf_counter() - start, 2) res["success"] = True except typer.Exit as ex: res["duration_sec"] = round(time.perf_counter() - start, 2) res["error"] = f"generate exited with code {ex.exit_code}" except Exception as ex: res["duration_sec"] = round(time.perf_counter() - start, 2) res["error"] = str(ex) return res if parallel_queue == 1: # Sequential path — preserves the original "ok in Xs" / "FAIL" lines # exactly so existing log-scraping stays valid. for task in pending_tasks: idx, _entry, result, _out_path = task slug = result["slug"] res = _run_one(task) if res["success"]: console.print(f"[green]\\[{idx}/{total}] {slug} ok in {res['duration_sec']:.1f}s[/green]") else: failed_slugs.append(slug) console.print(f"[red]\\[{idx}/{total}] {slug} FAIL: {res['error']}[/red]") if not continue_on_error: results.append(res) _write_sweep_manifest(out_dir, template, styles_origin, results) raise typer.Exit(1) results.append(res) else: # Parallel path — N concurrent ComfyUI submissions. The GPU still # processes one prompt at a time, but the HTTP queueing, websocket # polling, image download, and disk write phases overlap with the # next prompt's submission. Net effect: 5-15%% speedup vs sequential. # Per-task console output WILL interleave (each _run_generation # prints its own progress); use the manifest for clean per-slug # timing data. from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415 console.print(f"[dim]Parallel queue: {parallel_queue} concurrent submissions (output may interleave)[/dim]") # abort-on-error is incompatible with parallelism — we can't reliably # stop in-flight workers without losing their state. Warn and continue. if not continue_on_error: console.print( "[yellow]Note: --abort-on-error is ignored when --parallel-queue > 1; in-flight tasks always complete[/yellow]" ) with ThreadPoolExecutor(max_workers=parallel_queue) as pool: futures = {pool.submit(_run_one, task): task for task in pending_tasks} for completed, fut in enumerate(as_completed(futures), start=1): task = futures[fut] idx, _entry, _res, _out_path = task # idx used in log message below try: res = fut.result() except Exception as ex: # Pathological — _run_one is supposed to catch everything. # Re-build a result dict so the manifest is still well-formed. res = { "slug": task[2]["slug"], "prompt": task[2]["prompt"], "output": task[2]["output"], "seed": task[2]["seed"], "duration_sec": 0.0, "success": False, "error": f"executor exception: {ex}", } if res["success"]: console.print( f"[green]\\[{completed}/{len(pending_tasks)}] " f"{res['slug']} ok in {res['duration_sec']:.1f}s " f"(submit #{idx})[/green]" ) else: failed_slugs.append(res["slug"]) console.print(f"[red]\\[{completed}/{len(pending_tasks)}] {res['slug']} FAIL: {res['error']}[/red]") results.append(res) # Reorder results to match the original styles list order so the manifest # is human-readable. Skipped/dry-run entries already in `results` keep # their position from the pre-loop walk. slug_order = {e["slug"]: i for i, e in enumerate(style_entries)} results.sort(key=lambda r: slug_order.get(r["slug"], 1_000_000)) # ---- Manifest ---- if not dry_run: manifest_path = _write_sweep_manifest(out_dir, template, styles_origin, results) console.print(f"[dim]Manifest: {manifest_path}[/dim]") # ---- Summary ---- successful = sum(1 for r in results if r.get("success")) console.print(f"[bold]Sweep complete:[/bold] {successful}/{len(results)} ok") if failed_slugs: console.print(f"[red]Failed slugs ({len(failed_slugs)}):[/red] {', '.join(failed_slugs)}") raise typer.Exit(1) def _write_sweep_manifest( out_dir: Path, template_path: Path, styles_origin: str, results: list[dict[str, Any]], ) -> Path: """Write the per-sweep manifest JSON. Returns the path.""" manifest_path = out_dir / "_sweep.json" manifest: dict[str, Any] = { "template": str(template_path), "styles_source": styles_origin, "results": results, } manifest_path.write_text(json.dumps(manifest, indent=2) + "\n") return manifest_path def _print_styles_list(styles_origin: str, entries: list[dict[str, str]]) -> None: """Render the resolved styles as a two-column table. Suffixes truncated to ~80 chars.""" max_suffix = 80 console.print(f"[bold]Styles:[/bold] {styles_origin} ({len(entries)} entries)") table = Table(show_header=True, header_style="bold", box=None, pad_edge=False) table.add_column("SLUG", style="cyan", no_wrap=True) table.add_column("SUFFIX", overflow="fold") for entry in entries: suffix = entry["suffix"] if len(suffix) > max_suffix: suffix = suffix[: max_suffix - 1].rstrip() + "…" table.add_row(entry["slug"], suffix) console.print(table) # ============================================================================= # Template Dump # ============================================================================= @app.command() def template( # noqa: PLR0915 model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")], lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "square", rating: Annotated[str | None, typer.Option("--rating", "-R", help="Content rating: safe, questionable, explicit")] = None, character: Annotated[ str | None, typer.Option("-C", "--character", help="Saved character name (resolved into the `character` list field)"), ] = None, character_prompt: Annotated[ str | None, typer.Option( "--character-prompt", help="Inline character fragment, comma-separated (merged with --character into `character`)", ), ] = None, scene: Annotated[ str | None, typer.Option("-S", "--scene", help="Saved scene name (resolved into the `scene` list field)"), ] = None, scene_prompt: Annotated[ str | None, typer.Option( "--scene-prompt", help="Inline scene fragment, comma-separated (merged with --scene into `scene`)", ), ] = None, output: Annotated[Path | None, typer.Option("-o", "--output", help="Save template to file")] = None, ) -> None: """Dump a JSON generation template with resolved defaults for a model. Outputs a ready-to-use JSON object with all parameters auto-resolved from the checkpoint family. Pipe to 'tsr generate --input' or save to a file for reuse. ``--character`` / ``--character-prompt`` append a ``character`` list to the template; ``--scene`` / ``--scene-prompt`` append a ``scene`` list (named elements first, inline elements appended, deduped within each list). Examples: tsr template -m ponyDiffusionV6XL_v6StartWithThisOne.safetensors tsr template -m beautifulRealistic_v7.safetensors -O portrait tsr template -m waiIllustriousSDXL_v160.safetensors -l "Elvira iIlluLoRA.safetensors" tsr template -m ponyRealism_V22.safetensors -o pony_preset.json tsr template -m flux1-dev.safetensors -C cassie_cage -S penthouse tsr template -m flux1-dev.safetensors --character-prompt "blond hair, blue eyes" \\ --scene-prompt "luxury penthouse, volumetric lighting, Canon R5" tsr generate --input "$(tsr template -m ponyRealism_V22.safetensors)" "a portrait" """ from tensors.config import get_model_generation_defaults, resolve_orientation # noqa: PLC0415 # Look up base_model from DB for accurate family detection base_model_str: str | None = None try: with Database() as db: db.init_schema() base_model_str = db.get_base_model_by_filename(model) except Exception: pass family = detect_model_family(model, base_model_str) defaults = get_model_generation_defaults(model, base_model_str) res_w, res_h = resolve_orientation(family, orientation) # Build template tpl: dict[str, Any] = { "prompt": "", "negative_prompt": defaults.get("negative_prompt", ""), "model": model, "width": res_w, "height": res_h, "steps": defaults.get("steps"), "cfg": defaults.get("cfg"), "sampler": defaults.get("sampler"), "scheduler": defaults.get("scheduler"), "vae": defaults.get("vae"), "orientation": orientation, "seed": -1, "count": 1, } # Add quality prefix if the family has one quality_prefix = defaults.get("quality_prefix", "") if quality_prefix: tpl["quality_prefix"] = quality_prefix # Add rating tag if specified if rating: from tensors.config import get_rating_tag # noqa: PLC0415 rating_tag = get_rating_tag(family, rating.lower()) if rating_tag: tpl["rating"] = rating tpl["rating_tag"] = rating_tag # Add LoRA info if lora: tpl["lora"] = lora tpl["lora_strength"] = lora_strength # Resolve character / scene into flat lists embedded in the template. When # the template is later fed to `tsr generate --input`, _run_generation will # treat the lists under `character` / `scene` as inline elements (no # re-lookup needed). The `_*_name` fields are informational only. if character or character_prompt: from tensors.characters import resolve_character # noqa: PLC0415 try: resolved = resolve_character(character=character, character_prompt=character_prompt) except FileNotFoundError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e except ValueError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e if resolved: tpl["character"] = resolved if character: tpl["_character_name"] = character if scene or scene_prompt: from tensors.scenes import resolve_scene # noqa: PLC0415 try: resolved_scene = resolve_scene(scene=scene, scene_prompt=scene_prompt) except FileNotFoundError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e except ValueError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e if resolved_scene: tpl["scene"] = resolved_scene if scene: tpl["_scene_name"] = scene # Add metadata (not used by generate, but informational) tpl["_family"] = family or "unknown" if base_model_str: tpl["_base_model"] = base_model_str json_str = json.dumps(tpl, indent=2) if output: output.write_text(json_str + "\n") console.print(f"[green]Saved template:[/green] {output}") else: console.print(json_str) # ============================================================================= # Top-Level Models Command # ============================================================================= @app.command() def models( model_type: Annotated[str | None, typer.Option("-t", "--type", help="Filter by type (checkpoints, loras, vae)")] = None, remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """List available models from ComfyUI. Shows checkpoints, LoRAs, VAEs, and other model types loaded in ComfyUI. Uses --remote to query a remote tensors server instead of local ComfyUI. Examples: tsr models tsr models -t checkpoints tsr models --remote junkpile tsr models --json """ from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415 remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None) if remote_url: if not json_output: console.print(f"[dim]Remote: {remote_url}[/dim]") result = remote_models(remote or remote_url, console=console) else: from tensors.comfyui import get_loaded_models # noqa: PLC0415 result = get_loaded_models(console=console if not json_output else None) if not result: console.print("[red]Error: Could not fetch models[/red]") raise typer.Exit(1) # Filter by type if requested if model_type: key = model_type.lower() filtered = {k: v for k, v in result.items() if k.lower() == key} if not filtered: console.print(f"[yellow]No models found for type '{model_type}'[/yellow]") console.print(f"[dim]Available types: {', '.join(sorted(result.keys()))}[/dim]") raise typer.Exit(1) result = filtered if json_output: console.print_json(data=result) return console.print("[bold cyan]Available Models[/bold cyan]") for mtype, model_list in sorted(result.items()): console.print() console.print(f"[bold]{mtype}:[/bold] ({len(model_list)})") for name in model_list[:MAX_MODEL_LIST_DISPLAY]: console.print(f" {name}") if len(model_list) > MAX_MODEL_LIST_DISPLAY: console.print(f" ... and {len(model_list) - MAX_MODEL_LIST_DISPLAY} more") # ============================================================================= # Database Commands # ============================================================================= db_app = typer.Typer( name="db", help="Manage local models database and CivitAI cache.", no_args_is_help=True, ) app.add_typer(db_app, name="db") @db_app.command("scan") def db_scan( directory: Annotated[Path, typer.Argument(help="Directory to scan for safetensor files")], json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Scan directory for safetensor files and add to database.""" path = directory.resolve() if not path.exists() or not path.is_dir(): console.print(f"[red]Error: Directory not found: {path}[/red]") raise typer.Exit(1) with Database() as db: db.init_schema() console.print(f"[cyan]Scanning {path}...[/cyan]") results = db.scan_directory(path, console if not json_output else None) if json_output: console.print_json(data=results) else: console.print(f"[green]Scanned {len(results)} file(s)[/green]") for f in results: console.print(f" • {f['file_path']}") @db_app.command("link") def db_link( api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Link unlinked local files to CivitAI by hash lookup.""" key = api_key or load_api_key() with Database() as db: db.init_schema() unlinked = db.get_unlinked_files() if not unlinked: console.print("[green]All files already linked.[/green]") return console.print(f"[cyan]Found {len(unlinked)} unlinked file(s)[/cyan]") linked: list[dict[str, Any]] = [] for file_info in unlinked: sha256 = file_info["sha256"] console.print(f"[dim]Looking up {sha256[:16]}...[/dim]") civitai_data = fetch_civitai_by_hash(sha256, key, console if not json_output else None) if civitai_data: version_id: int = civitai_data.get("id", 0) model_id: int = civitai_data.get("modelId", 0) if version_id and model_id: db.link_file_to_civitai(file_info["id"], model_id, version_id) linked.append( { "file": file_info["file_path"], "model_id": model_id, "version_id": version_id, "name": civitai_data.get("name", ""), } ) if not json_output: console.print(f" [green]✓[/green] {civitai_data.get('name', 'N/A')}") if json_output: console.print_json(data=linked) else: console.print(f"[green]Linked {len(linked)} file(s)[/green]") # Cache model data for newly linked files if linked: _cache_linked_models(db, key, linked, json_output) def _cache_linked_models( db: Database, api_key: str | None, linked: list[dict[str, Any]], json_output: bool, ) -> None: """Fetch and cache full model data for linked files. Args: db: Database instance (already initialized) api_key: CivitAI API key linked: List of linked file info dicts with model_id json_output: Whether to suppress console output """ # Collect unique model IDs model_ids: set[int] = {item["model_id"] for item in linked if item.get("model_id")} # Find which models are not yet cached uncached_ids: list[int] = [] for model_id in model_ids: if db.get_model(model_id) is None: uncached_ids.append(model_id) if not uncached_ids: return if not json_output: console.print(f"[cyan]Caching {len(uncached_ids)} model(s)...[/cyan]") cached: list[dict[str, Any]] = [] for model_id in uncached_ids: model_data = fetch_civitai_model(model_id, api_key, console if not json_output else None) if model_data: db.cache_model(model_data) cached.append({"model_id": model_id, "name": model_data.get("name", "")}) if not json_output: console.print(f" [green]✓[/green] Cached: {model_data.get('name', 'N/A')}") if not json_output and cached: console.print(f"[green]Cached {len(cached)} model(s)[/green]") @db_app.command("cache") def db_cache( model_id: Annotated[int, typer.Argument(help="CivitAI model ID to cache")], api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Fetch and cache full CivitAI model data.""" key = api_key or load_api_key() model_data = fetch_civitai_model(model_id, key, console if not json_output else None) if not model_data: console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]") raise typer.Exit(1) with Database() as db: db.init_schema() internal_id = db.cache_model(model_data) if json_output: console.print_json(data={"model_id": model_id, "internal_id": internal_id, "name": model_data.get("name")}) else: console.print(f"[green]Cached:[/green] {model_data.get('name')} (internal ID: {internal_id})") @db_app.command("list") def db_list( model_type: Annotated[ str | None, typer.Option("-t", "--type", help="Filter by model type (Checkpoint, LORA, VAE, etc.)") ] = None, base: Annotated[ str | None, typer.Option("-b", "--base", help="Filter by base model (Pony, Illustrious, SDXL 1.0, SD 1.5, etc.)") ] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """List local files with CivitAI info. Examples: tsr db list # All local files tsr db list -t Checkpoint # Only checkpoints tsr db list -t LORA # Only LoRAs tsr db list -t Checkpoint -b Pony # Pony checkpoints only tsr db list -b "SDXL 1.0" # All SDXL 1.0 models """ with Database() as db: db.init_schema() files = db.list_local_files() # Apply filters (case-insensitive substring match) if model_type: mt_lower = model_type.lower() files = [f for f in files if (f.get("model_type") or "").lower() == mt_lower] if base: base_lower = base.lower() files = [f for f in files if base_lower in (f.get("base_model") or "").lower()] if json_output: console.print_json(data=files) return if not files: console.print("[yellow]No files found. Try 'tsr db scan' or adjust filters.[/yellow]") return title = "Local Files" if model_type or base: parts = [] if model_type: parts.append(model_type) if base: parts.append(base) title = f"Local Files ({', '.join(parts)})" table = Table(title=title, show_header=True, header_style="bold magenta") table.add_column("Path", style="cyan", max_width=50) table.add_column("Model", style="green") table.add_column("Version", style="white") table.add_column("Type", style="yellow") table.add_column("Base", style="dim") for f in files: path = Path(f["file_path"]).name model = f.get("model_name") or "[dim]unlinked[/dim]" version = f.get("version_name") or "" ft = f.get("model_type") or "" base_model = f.get("base_model") or "" table.add_row(path, model, version, ft, base_model) console.print(table) @db_app.command("search") def db_search( query: Annotated[str | None, typer.Argument(help="Search query")] = None, model_type: Annotated[str | None, typer.Option("-t", "--type", help="Model type filter")] = None, base_model: Annotated[str | None, typer.Option("-b", "--base", help="Base model filter")] = None, limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Search cached models offline.""" with Database() as db: db.init_schema() results = db.search_models(query=query, model_type=model_type, base_model=base_model, limit=limit) if json_output: console.print_json(data=results) return if not results: console.print("[yellow]No models found.[/yellow]") return table = Table(title="Cached Models", show_header=True, header_style="bold magenta") table.add_column("ID", style="dim") table.add_column("Name", style="cyan") table.add_column("Type", style="yellow") table.add_column("Base", style="green") table.add_column("Creator", style="dim") table.add_column("Downloads", justify="right") for m in results: table.add_row( str(m.get("civitai_id", "")), m.get("name", ""), m.get("type", ""), m.get("base_model", ""), m.get("creator", ""), str(m.get("download_count", 0)), ) console.print(table) @db_app.command("triggers") def db_triggers( file: Annotated[Path, typer.Argument(help="Path to safetensor file")], json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Show trigger words for a LoRA file.""" file_path = file.resolve() if not file_path.exists(): console.print(f"[red]Error: File not found: {file_path}[/red]") raise typer.Exit(1) with Database() as db: db.init_schema() triggers = db.get_triggers(str(file_path)) if json_output: console.print_json(data=triggers) return if not triggers: console.print("[yellow]No trigger words found. File may not be linked to CivitAI.[/yellow]") console.print("[dim]Run 'tsr db link' to link files to CivitAI.[/dim]") return console.print(f"[bold]Trigger words for {file_path.name}:[/bold]") for word in triggers: console.print(f" • {word}") @db_app.command("stats") def db_stats( json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Show database statistics.""" with Database() as db: db.init_schema() stats = db.get_stats() if json_output: console.print_json(data={"db_path": str(DB_PATH), "stats": stats}) return table = Table(title="Database Statistics", show_header=True, header_style="bold magenta") table.add_column("Table", style="cyan") table.add_column("Count", style="green", justify="right") for table_name, count in stats.items(): table.add_row(table_name, str(count)) console.print(f"[dim]Database: {DB_PATH}[/dim]") console.print(table) # ============================================================================= # Hugging Face Commands # ============================================================================= hf_app = typer.Typer(name="hf", help="Hugging Face Hub commands for safetensor files.") app.add_typer(hf_app) @hf_app.command("get") def hf_get( model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")], json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Get Hugging Face model info and list safetensor files.""" model = get_hf_model(model_id, console=console) if not model: raise typer.Exit(1) if json_output: console.print_json(data=model) return display_hf_model_info(model, console) @hf_app.command("files") def hf_files( model_id: Annotated[str, typer.Argument(help="Model ID")], json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """List safetensor files in a Hugging Face model.""" files = list_safetensor_files(model_id, console=console) if json_output: console.print_json(data=files) return if not files: console.print("[yellow]No safetensor files found.[/yellow]") return console.print(f"[bold]Safetensor files in {model_id}:[/bold]") for i, f in enumerate(files, 1): console.print(f" {i}. {f}") @hf_app.command("dl") def hf_download( model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")], filename: Annotated[str | None, typer.Option("-f", "--file", help="Specific file to download")] = None, output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, all_files: Annotated[bool, typer.Option("--all", "-a", help="Download all safetensor files")] = False, ) -> None: """Download safetensor files from Hugging Face. Examples: tsr hf dl stabilityai/stable-diffusion-xl-base-1.0 -f sd_xl_base_1.0.safetensors tsr hf dl author/model --all """ output_dir = output or Path.cwd() if all_files: downloaded = download_all_safetensors(model_id, output_dir, console=console) if downloaded: console.print(f"[green]Downloaded {len(downloaded)} files[/green]") else: console.print("[red]No files downloaded[/red]") raise typer.Exit(1) return if not filename: # List files and prompt or show help files = list_safetensor_files(model_id, console=console) if not files: console.print("[red]No safetensor files found in model[/red]") raise typer.Exit(1) if len(files) == 1: filename = files[0] console.print(f"[dim]Downloading only safetensor file: {filename}[/dim]") else: console.print("[yellow]Multiple safetensor files found. Specify one with -f or use --all:[/yellow]") for i, f in enumerate(files, 1): console.print(f" {i}. {f}") raise typer.Exit(1) result = download_hf_safetensor(model_id, filename, output_dir, console=console) if not result: raise typer.Exit(1) # ============================================================================= # Fragment Commands (character + scene) # ============================================================================= # Characters and scenes are named, comma-split prompt fragments stored as YAML # lists in ~/.local/share/tensors//.yml. They are injected into the # positive prompt by `tsr generate --character ` / `--scene ` (or # inline via `--character-prompt "..."` / `--scene-prompt`). # # Both subcommand groups share the underlying tensors.fragments.FragmentLibrary, # but the CLI commands are spelled out per-kind to keep Typer's signature # introspection happy under `from __future__ import annotations` (closures # referencing per-kind labels break typer's eval_str=True resolution). # ---- character ---- character_app = typer.Typer( name="character", help="Manage saved character prompts (~/.local/share/tensors/characters/).", no_args_is_help=True, ) app.add_typer(character_app) @character_app.command("save") def character_save( elements: Annotated[str, typer.Argument(help='Comma-separated prompt elements (e.g. "blond hair, blue eyes")')], name: Annotated[str, typer.Option("-o", "--output", help="Character name (used as filename)")], json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Save a character as a YAML list of prompt elements. Examples: tsr character save -o cassie_cage "blond hair, broad chin, skin imperfections" tsr character save -o elvira "long black hair, pale skin, gothic dress" """ from tensors.characters import parse_elements, save_character # noqa: PLC0415 parsed = parse_elements(elements) if not parsed: console.print("[red]No usable elements after splitting on commas[/red]") raise typer.Exit(1) try: path = save_character(name, parsed) except ValueError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e if json_output: console.print_json(data={"name": name, "path": str(path), "elements": parsed, "kind": "characters"}) return console.print(f"[green]Saved character '{name}' ({len(parsed)} elements):[/green] {path}") for elem in parsed: console.print(f" • {elem}") @character_app.command("list") def character_list( json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """List saved characters.""" from tensors.characters import CHARACTERS_DIR, list_characters # noqa: PLC0415 names = list_characters() if json_output: console.print_json(data={"dir": str(CHARACTERS_DIR), "characters": names}) return if not names: console.print(f"[yellow]No characters saved in {CHARACTERS_DIR}.[/yellow]") console.print('[dim]Create one with: tsr character save -o "elem1, elem2"[/dim]') return console.print(f"[bold]Characters[/bold] ({len(names)}) [dim]in {CHARACTERS_DIR}[/dim]") for n in names: console.print(f" • {n}") @character_app.command("show") def character_show( name: Annotated[str, typer.Argument(help="Character name")], json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Show a character's elements.""" from tensors.characters import character_path, load_character # noqa: PLC0415 try: elements = load_character(name) except FileNotFoundError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e except ValueError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e if json_output: console.print_json(data={"name": name, "path": str(character_path(name)), "elements": elements}) return console.print(f"[bold]{name}[/bold] [dim]({character_path(name)})[/dim]") for elem in elements: console.print(f" • {elem}") @character_app.command("delete") def character_delete( name: Annotated[str, typer.Argument(help="Character name")], json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Delete a saved character.""" from tensors.characters import delete_character # noqa: PLC0415 try: deleted = delete_character(name) except ValueError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e if json_output: console.print_json(data={"name": name, "deleted": deleted}) return if deleted: console.print(f"[green]Deleted character '{name}'[/green]") else: console.print(f"[yellow]Character '{name}' does not exist[/yellow]") raise typer.Exit(1) # ---- scene ---- scene_app = typer.Typer( name="scene", help="Manage saved scene prompts (~/.local/share/tensors/scenes/).", no_args_is_help=True, ) app.add_typer(scene_app) @scene_app.command("save") def scene_save( elements: Annotated[ str, typer.Argument(help='Comma-separated prompt elements (e.g. "luxury penthouse, volumetric lighting")') ], name: Annotated[str, typer.Option("-o", "--output", help="Scene name (used as filename)")], json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Save a scene as a YAML list of prompt elements. Examples: tsr scene save -o penthouse "luxury penthouse, volumetric lighting, Canon R5, 85mm" tsr scene save -o forest "deep forest, dappled sunlight, moss-covered rocks" """ from tensors.fragments import parse_elements # noqa: PLC0415 from tensors.scenes import save_scene # noqa: PLC0415 parsed = parse_elements(elements) if not parsed: console.print("[red]No usable elements after splitting on commas[/red]") raise typer.Exit(1) try: path = save_scene(name, parsed) except ValueError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e if json_output: console.print_json(data={"name": name, "path": str(path), "elements": parsed, "kind": "scenes"}) return console.print(f"[green]Saved scene '{name}' ({len(parsed)} elements):[/green] {path}") for elem in parsed: console.print(f" • {elem}") @scene_app.command("extract") def scene_extract( model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")], api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, ) -> None: """Extract example prompts from a model's CivitAI showcase and save as scenes.""" from pathlib import Path # noqa: PLC0415 from tensors.api import fetch_civitai_model_version # noqa: PLC0415 from tensors.config import load_api_key # noqa: PLC0415 from tensors.db import Database # noqa: PLC0415 from tensors.fragments import parse_elements # noqa: PLC0415 from tensors.scenes import save_scene # noqa: PLC0415 with Database() as db: files = db.list_local_files() target_file = None for f in files: file_path = Path(f["file_path"]) if model in (file_path.name, file_path.stem): target_file = f break if not target_file: console.print(f"[red]Model '{model}' not found in local database. Run 'tsr db scan' first.[/red]") raise typer.Exit(1) vid = target_file["civitai_version_id"] if not vid: console.print(f"[red]Model '{model}' is not linked to CivitAI. Run 'tsr db link' first.[/red]") raise typer.Exit(1) console.print(f"[cyan]Fetching showcase images for version ID {vid}...[/cyan]") data = fetch_civitai_model_version(vid, api_key or load_api_key(), console=console) if not data: console.print("[red]Failed to fetch model data from CivitAI.[/red]") raise typer.Exit(1) images = data.get("images", []) seen_prompts = set() idx = 1 base_name = Path(target_file["file_path"]).stem for img in images: meta = img.get("meta", {}) prompt = meta.get("prompt") if not prompt: continue normalized = prompt.lower().strip() if normalized not in seen_prompts: seen_prompts.add(normalized) parsed = parse_elements(prompt) if parsed: scene_name = f"{base_name}_{idx:02d}" path = save_scene(scene_name, parsed) console.print(f"[green]Saved {scene_name} ({len(parsed)} elements):[/green] {path}") idx += 1 if idx == 1: console.print("[yellow]No example prompts found in showcase images.[/yellow]") @scene_app.command("list") def scene_list( json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """List saved scenes.""" from tensors.scenes import SCENES_DIR, list_scenes # noqa: PLC0415 names = list_scenes() if json_output: console.print_json(data={"dir": str(SCENES_DIR), "scenes": names}) return if not names: console.print(f"[yellow]No scenes saved in {SCENES_DIR}.[/yellow]") console.print('[dim]Create one with: tsr scene save -o "elem1, elem2"[/dim]') return console.print(f"[bold]Scenes[/bold] ({len(names)}) [dim]in {SCENES_DIR}[/dim]") for n in names: console.print(f" • {n}") @scene_app.command("show") def scene_show( name: Annotated[str, typer.Argument(help="Scene name")], json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Show a scene's elements.""" from tensors.scenes import load_scene, scene_path # noqa: PLC0415 try: elements = load_scene(name) except FileNotFoundError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e except ValueError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e if json_output: console.print_json(data={"name": name, "path": str(scene_path(name)), "elements": elements}) return console.print(f"[bold]{name}[/bold] [dim]({scene_path(name)})[/dim]") for elem in elements: console.print(f" • {elem}") @scene_app.command("delete") def scene_delete( name: Annotated[str, typer.Argument(help="Scene name")], json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Delete a saved scene.""" from tensors.scenes import delete_scene # noqa: PLC0415 try: deleted = delete_scene(name) except ValueError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e if json_output: console.print_json(data={"name": name, "deleted": deleted}) return if deleted: console.print(f"[green]Deleted scene '{name}'[/green]") else: console.print(f"[yellow]Scene '{name}' does not exist[/yellow]") raise typer.Exit(1) # ---- templates ---- templates_app = typer.Typer( name="templates", help="Bulk-extract, list, and run generation templates derived from CivitAI showcase data.", no_args_is_help=True, ) app.add_typer(templates_app) @templates_app.command("extract") def templates_extract( # noqa: PLR0915 model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")], orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait", no_overrides: Annotated[ bool, typer.Option( "--no-overrides", help="Skip auto-derived params from showcase image meta; use family defaults only", ), ] = False, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, limit: Annotated[int, typer.Option("--limit", "-L", help="Max templates to write (0 = all unique prompts)")] = 0, overwrite: Annotated[bool, typer.Option("--overwrite", help="Overwrite existing template files (default: skip)")] = False, do_generate: Annotated[ bool, typer.Option("--generate", help="After writing, run `tsr generate --input` for each emitted template"), ] = False, output_dir: Annotated[ Path | None, typer.Option("--output-dir", help="Where to write generated images when --generate (default: ComfyUI output dir)"), ] = None, dry_run: Annotated[bool, typer.Option("--dry-run", help="Print what would be done; write nothing")] = False, ) -> None: """Bulk-extract templates from a model's CivitAI showcase. Pulls showcase images, deduplicates prompts, derives recommended generation params (sampler / scheduler / steps / cfg / guidance) from the *mode* of the showcase image metadata, and writes one JSON template per unique prompt to ``~/.local/share/tensors/templates//.json``. Each emitted template is ready to feed straight to ``tsr generate --input``. Examples: tsr templates extract lust_v10.safetensors tsr templates extract bodySliderFitness_v10 -O portrait --generate tsr templates extract ultrasenseInfinity_v10 --dry-run tsr templates extract getphat_v5 --no-overrides # use tsr family defaults only """ import subprocess # noqa: PLC0415 from tensors.api import fetch_civitai_model_version # noqa: PLC0415 from tensors.config import ( # noqa: PLC0415 detect_model_family, get_model_generation_defaults, load_api_key, resolve_orientation, ) from tensors.fragments import parse_elements # noqa: PLC0415 from tensors.templates import ( # noqa: PLC0415 build_template, derive_overrides_from_images, save_template, template_path, ) with Database() as db: files = db.list_local_files() target_file = None for f in files: file_path = Path(f["file_path"]) if model in (file_path.name, file_path.stem): target_file = f break if not target_file: console.print(f"[red]Model '{model}' not found in local database. Run 'tsr db scan' first.[/red]") raise typer.Exit(1) vid = target_file["civitai_version_id"] if not vid: console.print(f"[red]Model '{model}' is not linked to CivitAI. Run 'tsr db link' first.[/red]") raise typer.Exit(1) model_stem = Path(target_file["file_path"]).stem model_filename = Path(target_file["file_path"]).name base_model_str = target_file.get("base_model") console.print(f"[cyan]Fetching showcase for {model_stem} (version {vid})...[/cyan]") data = fetch_civitai_model_version(vid, api_key or load_api_key(), console=console) if not data: console.print("[red]Failed to fetch CivitAI data.[/red]") raise typer.Exit(1) images = data.get("images", []) enriched = sum(1 for img in images if img.get("meta")) overrides = {} if no_overrides else derive_overrides_from_images(images) if overrides: console.print(f"[cyan]Derived overrides from {enriched} enriched image(s):[/cyan] {overrides}") elif not no_overrides: console.print("[yellow]No usable param meta in showcase; using family defaults only.[/yellow]") family = detect_model_family(model_filename, base_model_str) defaults = get_model_generation_defaults(model_filename, base_model_str) res_w, res_h = resolve_orientation(family, orientation) seen_prompts: set[str] = set() emitted: list[Path] = [] skipped_existing = 0 skipped_no_prompt = 0 for img in images: meta = img.get("meta") or {} prompt = meta.get("prompt") if not prompt: skipped_no_prompt += 1 continue normalized = prompt.lower().strip() if normalized in seen_prompts: continue seen_prompts.add(normalized) scene_elements = parse_elements(prompt) if not scene_elements: continue idx = len(emitted) + skipped_existing + 1 name = f"{model_stem}_{idx:02d}" out_path = template_path(model_stem, name) if out_path.is_file() and not overwrite: console.print(f"[yellow]Skip (exists, use --overwrite to replace):[/yellow] {out_path}") skipped_existing += 1 continue tpl = build_template( model_filename=model_filename, family=family, defaults=defaults, base_model_str=base_model_str, width=res_w, height=res_h, orientation=orientation, scene_elements=scene_elements, scene_name=name, overrides=overrides, ) if dry_run: console.print(f"[dim](dry-run) Would write:[/dim] {out_path}") emitted.append(out_path) else: saved = save_template(model_stem, name, tpl) console.print(f"[green]Saved ({len(scene_elements)} scene elements):[/green] {saved}") emitted.append(saved) if limit and len(emitted) >= limit: break console.print( f"\n[bold]Extract summary:[/bold] emitted={len(emitted)} skipped_existing={skipped_existing} " f"images_no_prompt={skipped_no_prompt}" ) if not emitted: return if do_generate: if dry_run: console.print("[yellow]--dry-run is set; skipping --generate phase.[/yellow]") return console.print(f"\n[cyan]Generating images for {len(emitted)} template(s)...[/cyan]") for tpl_path in emitted: out_arg = [] if output_dir: output_dir.mkdir(parents=True, exist_ok=True) out_arg = ["-o", str(output_dir / f"{tpl_path.stem}.png")] cmd = ["tsr", "generate", "--input", str(tpl_path), *out_arg] console.print(f"\n[cyan]$ {' '.join(cmd)}[/cyan]") subprocess.run(cmd, check=False) @templates_app.command("list") def templates_list( model: Annotated[str | None, typer.Argument(help="Filter by model stem (optional)")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """List saved templates, grouped by model.""" from tensors.templates import TEMPLATES_DIR, list_templates # noqa: PLC0415 items = list_templates(model) if json_output: console.print_json(data={"dir": str(TEMPLATES_DIR), "templates": [{"model": m, "name": n} for m, n in items]}) return if not items: scope = f" for model '{model}'" if model else "" console.print(f"[yellow]No templates{scope} in {TEMPLATES_DIR}.[/yellow]") console.print("[dim]Create some with: tsr templates extract [/dim]") return cur_model = None for m, n in items: if m != cur_model: console.print(f"\n[cyan]{m}[/cyan]") cur_model = m console.print(f" {n}") @templates_app.command("show") def templates_show( model: Annotated[str, typer.Argument(help="Model stem (directory name under templates/)")], name: Annotated[str, typer.Argument(help="Template name (filename without .json)")], ) -> None: """Print a saved template as JSON.""" from tensors.templates import load_template # noqa: PLC0415 try: data = load_template(model, name) except FileNotFoundError as e: console.print(f"[red]{e}[/red]") raise typer.Exit(1) from e console.print_json(data=data) @templates_app.command("delete") def templates_delete( model: Annotated[str, typer.Argument(help="Model stem")], name: Annotated[str, typer.Argument(help="Template name")], yes: Annotated[bool, typer.Option("--yes", "-y", help="Skip confirmation")] = False, ) -> None: """Delete a saved template.""" from tensors.templates import delete_template, template_path # noqa: PLC0415 path = template_path(model, name) if not path.is_file(): console.print(f"[yellow]Template not found: {path}[/yellow]") raise typer.Exit(1) if not yes: typer.confirm(f"Delete {path}?", abort=True) if delete_template(model, name): console.print(f"[green]Deleted:[/green] {path}") # ============================================================================= # ComfyUI Commands # ============================================================================= comfy_app = typer.Typer(name="comfy", help="ComfyUI integration for image generation.") app.add_typer(comfy_app) @comfy_app.command("status") def comfy_status( url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Show ComfyUI system status (GPU, RAM, queue).""" from tensors.comfyui import get_queue_status, get_system_stats # noqa: PLC0415 stats = get_system_stats(url=url, console=console if not json_output else None) if not stats: console.print("[red]Error: Could not connect to ComfyUI[/red]") raise typer.Exit(1) queue = get_queue_status(url=url) if json_output: output = {"system": stats, "queue": queue} console.print_json(data=output) return # Display system stats console.print("[bold cyan]ComfyUI System Status[/bold cyan]") console.print() # System info system_info = stats.get("system", {}) console.print(f"[bold]OS:[/bold] {system_info.get('os', 'N/A')}") console.print(f"[bold]Python:[/bold] {system_info.get('python_version', 'N/A')}") console.print(f"[bold]PyTorch:[/bold] {system_info.get('pytorch_version', 'N/A')}") # GPU info devices = stats.get("devices", []) if devices: console.print() console.print("[bold]GPU Devices:[/bold]") for i, device in enumerate(devices): name = device.get("name", "Unknown") vram_total = device.get("vram_total", 0) vram_free = device.get("vram_free", 0) vram_used = vram_total - vram_free vram_pct = (vram_used / vram_total * 100) if vram_total > 0 else 0 console.print(f" [{i}] {name}") console.print(f" VRAM: {vram_used / 1024**3:.1f} / {vram_total / 1024**3:.1f} GB ({vram_pct:.0f}%)") # Queue info if queue: running = len(queue.get("queue_running", [])) pending = len(queue.get("queue_pending", [])) console.print() console.print(f"[bold]Queue:[/bold] {running} running, {pending} pending") @comfy_app.command("queue") def comfy_queue( url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None, clear: Annotated[bool, typer.Option("--clear", "-c", help="Clear the queue")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Show or clear the ComfyUI queue.""" from tensors.comfyui import clear_queue as do_clear_queue # noqa: PLC0415 from tensors.comfyui import get_queue_status # noqa: PLC0415 if clear: success = do_clear_queue(url=url, console=console) if not success: raise typer.Exit(1) return queue = get_queue_status(url=url, console=console if not json_output else None) if not queue: console.print("[red]Error: Could not connect to ComfyUI[/red]") raise typer.Exit(1) if json_output: console.print_json(data=queue) return running = queue.get("queue_running", []) pending = queue.get("queue_pending", []) console.print("[bold cyan]ComfyUI Queue[/bold cyan]") console.print() console.print(f"[bold]Running:[/bold] {len(running)}") console.print(f"[bold]Pending:[/bold] {len(pending)}") if running: console.print() console.print("[bold]Running Jobs:[/bold]") for job in running: prompt_id = job[1] if len(job) > 1 else "unknown" console.print(f" • {prompt_id}") if pending: console.print() console.print("[bold]Pending Jobs:[/bold]") for job in pending[:MAX_QUEUE_DISPLAY]: prompt_id = job[1] if len(job) > 1 else "unknown" console.print(f" • {prompt_id}") if len(pending) > MAX_QUEUE_DISPLAY: console.print(f" ... and {len(pending) - MAX_QUEUE_DISPLAY} more") @comfy_app.command("models") def comfy_models( url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """List available models in ComfyUI.""" from tensors.comfyui import get_loaded_models # noqa: PLC0415 models = get_loaded_models(url=url, console=console if not json_output else None) if not models: console.print("[red]Error: Could not fetch models from ComfyUI[/red]") raise typer.Exit(1) if json_output: console.print_json(data=models) return console.print("[bold cyan]ComfyUI Available Models[/bold cyan]") for model_type, model_list in sorted(models.items()): console.print() console.print(f"[bold]{model_type}:[/bold] ({len(model_list)})") for name in model_list[:MAX_MODEL_LIST_DISPLAY]: console.print(f" • {name}") if len(model_list) > MAX_MODEL_LIST_DISPLAY: console.print(f" ... and {len(model_list) - MAX_MODEL_LIST_DISPLAY} more") @comfy_app.command("history") def comfy_history( prompt_id: Annotated[str | None, typer.Argument(help="Specific prompt ID to view")] = None, url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None, limit: Annotated[int, typer.Option("-n", "--limit", help="Max history items")] = 20, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """View ComfyUI generation history.""" from tensors.comfyui import get_history # noqa: PLC0415 history = get_history(url=url, prompt_id=prompt_id, max_items=limit, console=console if not json_output else None) if history is None: console.print("[red]Error: Could not fetch history from ComfyUI[/red]") raise typer.Exit(1) if json_output: console.print_json(data=history) return if not history: console.print("[yellow]No history found.[/yellow]") return if prompt_id: # Show single entry details if prompt_id not in history: console.print(f"[yellow]Prompt {prompt_id} not found in history.[/yellow]") return entry = history[prompt_id] console.print(f"[bold cyan]Prompt: {prompt_id}[/bold cyan]") console.print() status = entry.get("status", {}) console.print(f"[bold]Status:[/bold] {status.get('status_str', 'unknown')}") outputs = entry.get("outputs", {}) if outputs: console.print() console.print("[bold]Outputs:[/bold]") for node_id, output in outputs.items(): if "images" in output: for img in output["images"]: console.print(f" [{node_id}] {img.get('filename', 'unknown')}") else: # Show list of history entries console.print("[bold cyan]ComfyUI History[/bold cyan]") console.print() table = Table(show_header=True, header_style="bold magenta") table.add_column("Prompt ID", style="cyan", max_width=40) table.add_column("Status", style="green") table.add_column("Images", justify="right") for pid, entry in list(history.items())[:limit]: status = entry.get("status", {}).get("status_str", "unknown") outputs = entry.get("outputs", {}) image_count = sum(len(o.get("images", [])) for o in outputs.values()) display_pid = pid[:MAX_PROMPT_ID_DISPLAY] + "..." if len(pid) > MAX_PROMPT_ID_DISPLAY else pid table.add_row(display_pid, status, str(image_count)) console.print(table) @comfy_app.command("generate", deprecated=True) def comfy_generate( prompt: Annotated[str, typer.Argument(help="Positive prompt text")], model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None, negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "", width: Annotated[int | None, typer.Option("-W", "--width")] = None, height: Annotated[int | None, typer.Option("-H", "--height")] = None, steps: Annotated[int | None, typer.Option("--steps")] = None, cfg: Annotated[float | None, typer.Option("--cfg")] = None, seed: Annotated[int, typer.Option("--seed", "-s")] = -1, sampler: Annotated[str | None, typer.Option("--sampler")] = None, scheduler: Annotated[str | None, typer.Option("--scheduler")] = None, orientation: Annotated[str, typer.Option("-O", "--orientation")] = "square", output: Annotated[Path | None, typer.Option("-o", "--output")] = None, count: Annotated[int, typer.Option("-c", "--count")] = 1, lora: Annotated[str | None, typer.Option("-l", "--lora")] = None, lora_strength: Annotated[float, typer.Option("--lora-strength")] = 0.8, no_quality: Annotated[bool, typer.Option("--no-quality")] = False, no_negative: Annotated[bool, typer.Option("--no-negative")] = False, json_output: Annotated[bool, typer.Option("--json", "-j")] = False, ) -> None: """[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command.""" console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]") # Delegate to the unified generate command via context invocation ctx = typer.Context(generate) generate( ctx=ctx, prompt=prompt, model=model, width=width, height=height, steps=steps, cfg=cfg, seed=seed, sampler=sampler, scheduler=scheduler, vae=None, orientation=orientation, lora=lora, lora_strength=lora_strength, negative=negative, count=count, no_quality=no_quality, no_negative=no_negative, rating=None, family=None, guidance=None, output=output, remote=None, json_output=json_output, json_input=None, ) @comfy_app.command("run") def comfy_run( workflow_file: Annotated[Path, typer.Argument(help="Path to workflow JSON file")], url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Run an arbitrary ComfyUI workflow from a JSON file. The workflow should be in ComfyUI API format (exported via "Save (API Format)"). """ from tensors.comfyui import run_workflow # noqa: PLC0415 if not workflow_file.exists(): console.print(f"[red]Error: Workflow file not found: {workflow_file}[/red]") raise typer.Exit(1) result = run_workflow( workflow=workflow_file, url=url, console=console if not json_output else None, ) if not result: console.print("[red]Failed to queue workflow[/red]") raise typer.Exit(1) if not result.success: if json_output: console.print_json(data={"success": False, "prompt_id": result.prompt_id, "errors": result.node_errors}) else: console.print("[red]Workflow execution failed[/red]") for node_id, errors in result.node_errors.items(): console.print(f" [yellow]Node {node_id}:[/yellow] {errors}") raise typer.Exit(1) if json_output: console.print_json(data={"success": True, "prompt_id": result.prompt_id, "outputs": result.outputs}) return console.print("[bold green]Workflow complete![/bold green]") console.print(f"[dim]Prompt ID: {result.prompt_id}[/dim]") # Show output images for _node_id, output in result.outputs.items(): if "images" in output: for img in output["images"]: console.print(f" [green]Image:[/green] {img.get('filename', 'unknown')}") def main() -> int: """Main entry point.""" # Handle legacy invocation: tsr -> tsr info known_commands = ( "info", "search", "get", "dl", "download", "generate", "models", "config", "serve", "db", "hf", "comfy", ) if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): arg = sys.argv[1] if arg not in known_commands and (arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()): sys.argv = [sys.argv[0], "info", *sys.argv[1:]] app() return 0 if __name__ == "__main__": sys.exit(main())