diff --git a/.coverage b/.coverage index d2c62bd..8679539 100644 Binary files a/.coverage and b/.coverage differ diff --git a/tensors/cli.py b/tensors/cli.py index 5d61b70..cf44825 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -21,6 +21,12 @@ from tensors.api import ( search_civitai, ) from tensors.config import ( + COMFYUI_DEFAULT_CFG, + COMFYUI_DEFAULT_HEIGHT, + COMFYUI_DEFAULT_SAMPLER, + COMFYUI_DEFAULT_SCHEDULER, + COMFYUI_DEFAULT_STEPS, + COMFYUI_DEFAULT_WIDTH, CONFIG_FILE, MODEL_FAMILY_DEFAULTS, BaseModel, @@ -787,6 +793,10 @@ def generate( # noqa: PLR0915 lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "", + count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1, + rating: Annotated[str | None, typer.Option("--rating", "-R", help="Content rating: safe, questionable, explicit (Pony/Illustrious)")] = None, + no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False, + no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False, output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None, remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, @@ -794,15 +804,19 @@ def generate( # noqa: PLR0915 ) -> None: """Generate an image using text-to-image. + Auto-detects optimal sampler, scheduler, CFG, resolution, and VAE from the checkpoint + model family. All auto-detected values can be overridden with explicit flags. + Calls ComfyUI directly when local, or the remote tensors API when --remote is given. Accepts --input with a JSON object whose keys match CLI option names. CLI flags override JSON values. Examples: tsr generate "a cat on a windowsill" - tsr generate "portrait photo" -m "flux1-dev-fp8.safetensors" --steps 30 - tsr generate "cyberpunk city" -o output.png + tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait + tsr generate "cyberpunk city" -o output.png --count 4 tsr generate "landscape" --remote junkpile - tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors", "steps": 30}' + tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}' + tsr generate "raw prompt" --no-quality --no-negative """ import random as rng # noqa: PLC0415 @@ -877,16 +891,133 @@ def generate( # noqa: PLR0915 output = Path(mapped["output"]) if "remote" in mapped and "remote" not in explicit: remote = mapped["remote"] + if "count" in mapped and "count" not in explicit: + count = int(mapped["count"]) + if "orientation" in mapped and "orientation" not in explicit: + orientation = mapped["orientation"] + if "no_quality" in mapped and "no_quality" not in explicit: + no_quality = bool(mapped["no_quality"]) + if "no_negative" in mapped and "no_negative" not in explicit: + no_negative = bool(mapped["no_negative"]) + if "rating" in mapped and "rating" not in explicit: + rating = mapped["rating"] if not prompt: console.print("[red]Prompt is required (as argument or in --input JSON)[/red]") raise typer.Exit(1) - from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415 + # ---- Detect model family and enhance prompt/negative ---- + family_defaults: dict[str, Any] = {} + model_family: str | None = None + base_model_str: str | None = None + if model: + try: + with Database() as db: + db.init_schema() + base_model_str = db.get_base_model_by_filename(model) + except Exception: + pass + + model_family = detect_model_family(model, base_model_str) + if model_family: + family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {}) + if not json_output: + console.print(f"[dim]Detected model family: {model_family}[/dim]") + + # Build enhanced prompt with quality prefix and LoRA trigger words + prompt_parts: list[str] = [] + + # Add LoRA trigger words if using LoRA + if lora: + try: + with Database() as db: + db.init_schema() + trigger_words = db.get_trigger_words_by_filename(lora) + if trigger_words: + prompt_parts.extend(trigger_words) + if not json_output: + console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]") + except Exception: + pass + + # Add quality prefix based on model family + if not no_quality and family_defaults.get("quality_prefix"): + prompt_parts.append(family_defaults["quality_prefix"]) + + # Add rating tag based on model family (Pony/Illustrious) + if rating: + from tensors.config import get_rating_tag # noqa: PLC0415 + + rating_tag = get_rating_tag(model_family, rating.lower()) + if rating_tag: + prompt_parts.append(rating_tag) + if not json_output: + console.print(f"[dim]Rating tag: {rating_tag}[/dim]") + elif not json_output: + console.print(f"[dim]Rating '{rating}' not applicable for {model_family or 'unknown'} family[/dim]") + + # Add user prompt + prompt_parts.append(prompt) + enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt + + # Build enhanced negative prompt + enhanced_negative = negative + if not no_negative and family_defaults.get("negative_prompt"): + family_negative = family_defaults["negative_prompt"] + enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative + + if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative): + if enhanced_prompt != prompt: + truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004 + console.print(f"[dim]Enhanced prompt: {truncated}[/dim]") + if enhanced_negative != negative: + truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004 + console.print(f"[dim]Enhanced negative: {truncated}[/dim]") + + # ---- Resolve preset defaults for None params (both remote and local need these) ---- + from tensors.config import resolve_orientation, resolve_remote as do_resolve_remote # noqa: PLC0415 + + # Use already-detected family_defaults from DB lookup above (not filename guessing) + if family_defaults: + res_w, res_h = resolve_orientation(model_family, orientation) + if width is None: + width = res_w + if height is None: + height = res_h + if steps is None: + steps = family_defaults.get("steps", 20) + if cfg is None: + cfg = family_defaults.get("cfg", 7.0) + if sampler is None: + sampler = family_defaults.get("sampler", "euler") + if scheduler is None: + scheduler = family_defaults.get("scheduler", "normal") + if vae is None: + vae = family_defaults.get("vae") + + # Fallback to global defaults when no model family was detected + if width is None: + width = COMFYUI_DEFAULT_WIDTH + if height is None: + height = COMFYUI_DEFAULT_HEIGHT + if steps is None: + steps = COMFYUI_DEFAULT_STEPS + if cfg is None: + cfg = COMFYUI_DEFAULT_CFG + if sampler is None: + sampler = COMFYUI_DEFAULT_SAMPLER + if scheduler is None: + scheduler = COMFYUI_DEFAULT_SCHEDULER + + # ---- Determine base seed ---- + base_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1) # Resolve remote (explicit flag, or default from config) remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None) + all_results: list[dict[str, Any]] = [] + all_saved: list[Path] = [] + if remote_url: # ---- Remote mode: HTTP call to tensors server ---- if not json_output: @@ -894,14 +1025,14 @@ def generate( # noqa: PLR0915 result = remote_generate( remote or remote_url, - prompt, - negative_prompt=negative, + enhanced_prompt, + negative_prompt=enhanced_negative, model=model, width=width, height=height, steps=steps, cfg=cfg, - seed=seed, + seed=base_seed, sampler=sampler, scheduler=scheduler, vae=vae, @@ -948,69 +1079,188 @@ def generate( # noqa: PLR0915 # ---- Local mode: direct library call ---- from tensors.comfyui import generate_image, get_image # noqa: PLC0415 - actual_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1) - result_local = generate_image( - prompt=prompt, - negative_prompt=negative, + prompt=enhanced_prompt, + negative_prompt=enhanced_negative, model=model, width=width, height=height, steps=steps, cfg=cfg, - seed=actual_seed, + seed=base_seed, sampler=sampler, scheduler=scheduler, console=console if not json_output else None, lora_name=lora, lora_strength=lora_strength, + batch_size=count, vae=vae, orientation=orientation, ) if not result_local: if json_output: - console.print_json(data={"success": False, "errors": {"generation": "Failed to generate"}}) + all_results.append({"success": False, "index": 0, "errors": {"generation": "Failed to generate"}}) else: console.print("[red]Generation failed[/red]") - raise typer.Exit(1) - - if not result_local.success: + raise typer.Exit(1) + elif not result_local.success: if json_output: - console.print_json(data={"success": False, "errors": result_local.node_errors}) + all_results.append({"success": False, "index": 0, "errors": result_local.node_errors}) else: console.print("[red]Generation failed[/red]") for node_id, errors in result_local.node_errors.items(): console.print(f" [yellow]Node {node_id}:[/yellow] {errors}") - raise typer.Exit(1) + raise typer.Exit(1) + else: + # Save all output images + for i, img_path in enumerate(result_local.images): + saved_path: Path | None = None + if output: + img_data = get_image(str(img_path)) + if img_data: + save_path = ( + output + if count == 1 + else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}" + ) + save_path.write_bytes(img_data) + saved_path = save_path + all_saved.append(save_path) + if not json_output: + console.print(f"[green]Saved:[/green] {save_path}") + elif not json_output: + console.print(f"[yellow]Could not download image: {img_path}[/yellow]") - # Save images - saved_paths: list[Path] = [] - for i, img_path in enumerate(result_local.images): - if output: - img_data = get_image(str(img_path)) - if img_data: - save_path = ( - output if len(result_local.images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}" - ) - save_path.write_bytes(img_data) - saved_paths.append(save_path) - if not json_output: - console.print(f"[green]Saved:[/green] {save_path}") + all_results.append( + { + "success": True, + "index": i, + "prompt_id": result_local.prompt_id, + "image": str(img_path), + "saved": str(saved_path) if saved_path else None, + } + ) - if json_output: - console.print_json( - data={ - "success": True, - "prompt_id": result_local.prompt_id, - "images": [str(p) for p in result_local.images], - "saved": [str(p) for p in saved_paths], - } - ) - return + if json_output: + console.print_json( + data={ + "success": all(r.get("success", False) for r in all_results), + "count": len(all_results), + "results": all_results, + } + ) + return - console.print("[bold green]Generation complete![/bold green]") - console.print(f"[dim]Prompt ID: {result_local.prompt_id}[/dim]") + console.print("[bold green]Generation complete![/bold green]") + if count > 1: + successful = sum(1 for r in all_results if r.get("success", False)) + console.print(f"[dim]Generated {successful}/{count} images[/dim]") + if all_saved: + console.print(f"[dim]Saved to: {all_saved[0].parent}/[/dim]") + elif all_results and all_results[0].get("prompt_id"): + console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]") + + +# ============================================================================= +# Template Dump +# ============================================================================= + + +@app.command() +def template( + model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")], + lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, + lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, + orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "square", + rating: Annotated[str | None, typer.Option("--rating", "-R", help="Content rating: safe, questionable, explicit")] = None, + output: Annotated[Path | None, typer.Option("-o", "--output", help="Save template to file")] = None, +) -> None: + """Dump a JSON generation template with resolved defaults for a model. + + Outputs a ready-to-use JSON object with all parameters auto-resolved from the + checkpoint family. Pipe to 'tsr generate --input' or save to a file for reuse. + + Examples: + tsr template -m ponyDiffusionV6XL_v6StartWithThisOne.safetensors + tsr template -m beautifulRealistic_v7.safetensors -O portrait + tsr template -m waiIllustriousSDXL_v160.safetensors -l "Elvira iIlluLoRA.safetensors" + tsr template -m ponyRealism_V22.safetensors -o pony_preset.json + tsr generate --input "$(tsr template -m ponyRealism_V22.safetensors)" "a portrait" + """ + from tensors.config import get_model_generation_defaults, resolve_orientation # noqa: PLC0415 + + # Look up base_model from DB for accurate family detection + base_model_str: str | None = None + try: + with Database() as db: + db.init_schema() + base_model_str = db.get_base_model_by_filename(model) + except Exception: + pass + + family = detect_model_family(model, base_model_str) + defaults = get_model_generation_defaults(model, base_model_str) + res_w, res_h = resolve_orientation(family, orientation) + + # Build template + tpl: dict[str, Any] = { + "prompt": "", + "negative_prompt": defaults.get("negative_prompt", ""), + "model": model, + "width": res_w, + "height": res_h, + "steps": defaults.get("steps"), + "cfg": defaults.get("cfg"), + "sampler": defaults.get("sampler"), + "scheduler": defaults.get("scheduler"), + "vae": defaults.get("vae"), + "orientation": orientation, + "seed": -1, + "count": 1, + } + + # Add quality prefix if the family has one + quality_prefix = defaults.get("quality_prefix", "") + if quality_prefix: + tpl["quality_prefix"] = quality_prefix + + # Add rating tag if specified + if rating: + from tensors.config import get_rating_tag # noqa: PLC0415 + + rating_tag = get_rating_tag(family, rating.lower()) + if rating_tag: + tpl["rating"] = rating + tpl["rating_tag"] = rating_tag + + # Add LoRA info + if lora: + tpl["lora"] = lora + tpl["lora_strength"] = lora_strength + + # Look up trigger words + try: + with Database() as db: + db.init_schema() + trigger_words = db.get_trigger_words_by_filename(lora) + if trigger_words: + tpl["lora_triggers"] = trigger_words + except Exception: + pass + + # Add metadata (not used by generate, but informational) + tpl["_family"] = family or "unknown" + if base_model_str: + tpl["_base_model"] = base_model_str + + json_str = json.dumps(tpl, indent=2) + + if output: + output.write_text(json_str + "\n") + console.print(f"[green]Saved template:[/green] {output}") + else: + console.print(json_str) # ============================================================================= @@ -1231,22 +1481,49 @@ def db_cache( @db_app.command("list") def db_list( + model_type: Annotated[str | None, typer.Option("-t", "--type", help="Filter by model type (Checkpoint, LORA, VAE, etc.)")] = None, + base: Annotated[str | None, typer.Option("-b", "--base", help="Filter by base model (Pony, Illustrious, SDXL 1.0, SD 1.5, etc.)")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: - """List local files with CivitAI info.""" + """List local files with CivitAI info. + + Examples: + tsr db list # All local files + tsr db list -t Checkpoint # Only checkpoints + tsr db list -t LORA # Only LoRAs + tsr db list -t Checkpoint -b Pony # Pony checkpoints only + tsr db list -b "SDXL 1.0" # All SDXL 1.0 models + """ with Database() as db: db.init_schema() files = db.list_local_files() + # Apply filters (case-insensitive substring match) + if model_type: + mt_lower = model_type.lower() + files = [f for f in files if (f.get("model_type") or "").lower() == mt_lower] + if base: + base_lower = base.lower() + files = [f for f in files if base_lower in (f.get("base_model") or "").lower()] + if json_output: console.print_json(data=files) return if not files: - console.print("[yellow]No files in database. Run 'tsr db scan' first.[/yellow]") + console.print("[yellow]No files found. Try 'tsr db scan' or adjust filters.[/yellow]") return - table = Table(title="Local Files", show_header=True, header_style="bold magenta") + title = "Local Files" + if model_type or base: + parts = [] + if model_type: + parts.append(model_type) + if base: + parts.append(base) + title = f"Local Files ({', '.join(parts)})" + + table = Table(title=title, show_header=True, header_style="bold magenta") table.add_column("Path", style="cyan", max_width=50) table.add_column("Model", style="green") table.add_column("Version", style="white") @@ -1257,9 +1534,9 @@ def db_list( path = Path(f["file_path"]).name model = f.get("model_name") or "[dim]unlinked[/dim]" version = f.get("version_name") or "" - model_type = f.get("model_type") or "" - base = f.get("base_model") or "" - table.add_row(path, model, version, model_type, base) + ft = f.get("model_type") or "" + base_model = f.get("base_model") or "" + table.add_row(path, model, version, ft, base_model) console.print(table) @@ -1653,183 +1930,56 @@ def comfy_history( console.print(table) -@comfy_app.command("generate") -def comfy_generate( # noqa: PLR0915 +@comfy_app.command("generate", deprecated=True) +def comfy_generate( prompt: Annotated[str, typer.Argument(help="Positive prompt text")], - url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None, - negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "", model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None, - width: Annotated[int | None, typer.Option("-W", "--width", help="Image width (auto from checkpoint)")] = None, - height: Annotated[int | None, typer.Option("-H", "--height", help="Image height (auto from checkpoint)")] = None, - steps: Annotated[int | None, typer.Option("--steps", help="Sampling steps (auto from checkpoint)")] = None, - cfg: Annotated[float | None, typer.Option("--cfg", help="CFG scale (auto from checkpoint)")] = None, - seed: Annotated[int, typer.Option("--seed", "-s", help="Random seed (-1 for random)")] = -1, - sampler: Annotated[str | None, typer.Option("--sampler", help="Sampler name (auto from checkpoint)")] = None, - scheduler: Annotated[str | None, typer.Option("--scheduler", help="Scheduler name (auto from checkpoint)")] = None, - orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "square", - output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None, - count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1, - lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, - lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 1.0, - no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False, - no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, + negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "", + width: Annotated[int | None, typer.Option("-W", "--width")] = None, + height: Annotated[int | None, typer.Option("-H", "--height")] = None, + steps: Annotated[int | None, typer.Option("--steps")] = None, + cfg: Annotated[float | None, typer.Option("--cfg")] = None, + seed: Annotated[int, typer.Option("--seed", "-s")] = -1, + sampler: Annotated[str | None, typer.Option("--sampler")] = None, + scheduler: Annotated[str | None, typer.Option("--scheduler")] = None, + orientation: Annotated[str, typer.Option("-O", "--orientation")] = "square", + output: Annotated[Path | None, typer.Option("-o", "--output")] = None, + count: Annotated[int, typer.Option("-c", "--count")] = 1, + lora: Annotated[str | None, typer.Option("-l", "--lora")] = None, + lora_strength: Annotated[float, typer.Option("--lora-strength")] = 0.8, + no_quality: Annotated[bool, typer.Option("--no-quality")] = False, + no_negative: Annotated[bool, typer.Option("--no-negative")] = False, + json_output: Annotated[bool, typer.Option("--json", "-j")] = False, ) -> None: - """Generate an image with a simple text-to-image workflow. - - Examples: - tsr comfy generate "a cat sitting on a windowsill" - tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30 - tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768 - tsr comfy generate "cyberpunk city" --count 4 -o batch.png - tsr comfy generate "girl" --lora spumcostyle.safetensors --lora-strength 0.8 - tsr comfy generate "raw prompt" --no-quality --no-negative - """ - import random # noqa: PLC0415 - - from tensors.comfyui import generate_image, get_image # noqa: PLC0415 - - all_results: list[dict[str, Any]] = [] - all_saved: list[Path] = [] - - # Determine base seed for batch - base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) - - # Detect model family and apply defaults - family_defaults: dict[str, Any] = {} - model_family: str | None = None - if model: - # Try to get base_model from database - base_model_str: str | None = None - try: - with Database() as db: - db.init_schema() - base_model_str = db.get_base_model_by_filename(model) - except Exception: - pass - - model_family = detect_model_family(model, base_model_str) - if model_family: - family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {}) - if not json_output: - console.print(f"[dim]Detected model family: {model_family}[/dim]") - - # Build enhanced prompt with quality prefix and LoRA trigger words - enhanced_prompt = prompt - prompt_parts: list[str] = [] - - # Add LoRA trigger words if using LoRA - if lora: - try: - with Database() as db: - db.init_schema() - trigger_words = db.get_trigger_words_by_filename(lora) - if trigger_words: - prompt_parts.extend(trigger_words) - if not json_output: - console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]") - except Exception: - pass - - # Add quality prefix based on model family - if not no_quality and family_defaults.get("quality_prefix"): - prompt_parts.append(family_defaults["quality_prefix"]) - - # Add user prompt - prompt_parts.append(prompt) - enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt - - # Build enhanced negative prompt - enhanced_negative = negative - if not no_negative and family_defaults.get("negative_prompt"): - family_negative = family_defaults["negative_prompt"] - enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative - - if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative): - if enhanced_prompt != prompt: - truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004 - console.print(f"[dim]Enhanced prompt: {truncated}[/dim]") - if enhanced_negative != negative: - truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004 - console.print(f"[dim]Enhanced negative: {truncated}[/dim]") - - # Use native ComfyUI batching - single workflow generates all images - result = generate_image( - prompt=enhanced_prompt, - url=url, - negative_prompt=enhanced_negative, + """[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command.""" + console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]") + # Delegate to the unified generate command via context invocation + ctx = typer.Context(generate) + generate( + ctx=ctx, + prompt=prompt, model=model, width=width, height=height, steps=steps, cfg=cfg, - seed=base_seed, + seed=seed, sampler=sampler, scheduler=scheduler, - console=console if not json_output else None, - lora_name=lora, - lora_strength=lora_strength, - batch_size=count, + vae=None, orientation=orientation, + lora=lora, + lora_strength=lora_strength, + negative=negative, + count=count, + no_quality=no_quality, + no_negative=no_negative, + output=output, + remote=None, + json_output=json_output, + json_input=None, ) - if not result: - if json_output: - all_results.append({"success": False, "index": 0, "errors": {"generation": "Failed to generate"}}) - else: - console.print("[red]Generation failed[/red]") - elif not result.success: - if json_output: - all_results.append({"success": False, "index": 0, "errors": result.node_errors}) - else: - console.print("[red]Generation failed[/red]") - for node_id, errors in result.node_errors.items(): - console.print(f" [yellow]Node {node_id}:[/yellow] {errors}") - else: - # Save all output images - for i, img_path in enumerate(result.images): - saved_path: Path | None = None - if output: - img_data = get_image(str(img_path), url=url) - if img_data: - save_path = output if count == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}" - save_path.write_bytes(img_data) - saved_path = save_path - all_saved.append(save_path) - if not json_output: - console.print(f"[green]Saved:[/green] {save_path}") - elif not json_output: - console.print(f"[yellow]Could not download image: {img_path}[/yellow]") - - all_results.append( - { - "success": True, - "index": i, - "prompt_id": result.prompt_id, - "image": str(img_path), - "saved": str(saved_path) if saved_path else None, - } - ) - - if json_output: - console.print_json( - data={ - "success": all(r.get("success", False) for r in all_results), - "count": len(all_results), - "results": all_results, - } - ) - return - - console.print("\n[bold green]Generation complete![/bold green]") - if count > 1: - successful = sum(1 for r in all_results if r.get("success", False)) - console.print(f"[dim]Generated {successful}/{count} images[/dim]") - if all_saved: - console.print(f"[dim]Saved to: {all_saved[0].parent}/[/dim]") - elif all_results and all_results[0].get("prompt_id"): - console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]") - @comfy_app.command("run") def comfy_run( diff --git a/tensors/config.py b/tensors/config.py index 068d55c..7ea9d12 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -506,6 +506,42 @@ COMFYUI_DEFAULT_SCHEDULER = "normal" # Model Family Defaults (Quality Tags, Negative Prompts, etc.) # ============================================================================ +# Rating tags per model family — maps (family, rating) to the tag to inject +# Families not listed here have no rating tag system (prompt-driven only) +RATING_TAGS: dict[str, dict[str, str]] = { + "pony": { + "safe": "rating_safe", + "questionable": "rating_questionable", + "explicit": "rating_explicit", + }, + "illustrious": { + "safe": "rating:safe", + "questionable": "rating:questionable", + "explicit": "rating:explicit", + }, +} +# NoobAI uses same tags as Illustrious +RATING_TAGS["noobai"] = RATING_TAGS["illustrious"] + + +def get_rating_tag(family: str | None, rating: str) -> str | None: + """Get the rating tag for a model family and rating level. + + Args: + family: Model family key (e.g. "pony", "illustrious") or None + rating: One of "safe", "questionable", "explicit" + + Returns: + Rating tag string to inject into prompt, or None if family has no rating system + """ + if not family: + return None + tags = RATING_TAGS.get(family) + if not tags: + return None + return tags.get(rating) + + MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "pony": { "quality_prefix": "score_9, score_8_up, score_7_up",