fix null params when no model family detected in tsr generate

Co-Authored-By: marauder-os <marauder@saiden.dev>
This commit is contained in:
2026-04-24 19:02:16 +02:00
parent 55358e7b5a
commit 329b0d849e
3 changed files with 399 additions and 213 deletions
BIN
View File
Binary file not shown.
+346 -196
View File
@@ -21,6 +21,12 @@ from tensors.api import (
search_civitai, search_civitai,
) )
from tensors.config import ( from tensors.config import (
COMFYUI_DEFAULT_CFG,
COMFYUI_DEFAULT_HEIGHT,
COMFYUI_DEFAULT_SAMPLER,
COMFYUI_DEFAULT_SCHEDULER,
COMFYUI_DEFAULT_STEPS,
COMFYUI_DEFAULT_WIDTH,
CONFIG_FILE, CONFIG_FILE,
MODEL_FAMILY_DEFAULTS, MODEL_FAMILY_DEFAULTS,
BaseModel, BaseModel,
@@ -787,6 +793,10 @@ def generate( # noqa: PLR0915
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "", negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "",
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1,
rating: Annotated[str | None, typer.Option("--rating", "-R", help="Content rating: safe, questionable, explicit (Pony/Illustrious)")] = None,
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False,
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False,
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None, output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None,
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
@@ -794,15 +804,19 @@ def generate( # noqa: PLR0915
) -> None: ) -> None:
"""Generate an image using text-to-image. """Generate an image using text-to-image.
Auto-detects optimal sampler, scheduler, CFG, resolution, and VAE from the checkpoint
model family. All auto-detected values can be overridden with explicit flags.
Calls ComfyUI directly when local, or the remote tensors API when --remote is given. Calls ComfyUI directly when local, or the remote tensors API when --remote is given.
Accepts --input with a JSON object whose keys match CLI option names. CLI flags override JSON values. Accepts --input with a JSON object whose keys match CLI option names. CLI flags override JSON values.
Examples: Examples:
tsr generate "a cat on a windowsill" tsr generate "a cat on a windowsill"
tsr generate "portrait photo" -m "flux1-dev-fp8.safetensors" --steps 30 tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait
tsr generate "cyberpunk city" -o output.png tsr generate "cyberpunk city" -o output.png --count 4
tsr generate "landscape" --remote junkpile tsr generate "landscape" --remote junkpile
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors", "steps": 30}' tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}'
tsr generate "raw prompt" --no-quality --no-negative
""" """
import random as rng # noqa: PLC0415 import random as rng # noqa: PLC0415
@@ -877,16 +891,133 @@ def generate( # noqa: PLR0915
output = Path(mapped["output"]) output = Path(mapped["output"])
if "remote" in mapped and "remote" not in explicit: if "remote" in mapped and "remote" not in explicit:
remote = mapped["remote"] remote = mapped["remote"]
if "count" in mapped and "count" not in explicit:
count = int(mapped["count"])
if "orientation" in mapped and "orientation" not in explicit:
orientation = mapped["orientation"]
if "no_quality" in mapped and "no_quality" not in explicit:
no_quality = bool(mapped["no_quality"])
if "no_negative" in mapped and "no_negative" not in explicit:
no_negative = bool(mapped["no_negative"])
if "rating" in mapped and "rating" not in explicit:
rating = mapped["rating"]
if not prompt: if not prompt:
console.print("[red]Prompt is required (as argument or in --input JSON)[/red]") console.print("[red]Prompt is required (as argument or in --input JSON)[/red]")
raise typer.Exit(1) raise typer.Exit(1)
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415 # ---- Detect model family and enhance prompt/negative ----
family_defaults: dict[str, Any] = {}
model_family: str | None = None
base_model_str: str | None = None
if model:
try:
with Database() as db:
db.init_schema()
base_model_str = db.get_base_model_by_filename(model)
except Exception:
pass
model_family = detect_model_family(model, base_model_str)
if model_family:
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
if not json_output:
console.print(f"[dim]Detected model family: {model_family}[/dim]")
# Build enhanced prompt with quality prefix and LoRA trigger words
prompt_parts: list[str] = []
# Add LoRA trigger words if using LoRA
if lora:
try:
with Database() as db:
db.init_schema()
trigger_words = db.get_trigger_words_by_filename(lora)
if trigger_words:
prompt_parts.extend(trigger_words)
if not json_output:
console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]")
except Exception:
pass
# Add quality prefix based on model family
if not no_quality and family_defaults.get("quality_prefix"):
prompt_parts.append(family_defaults["quality_prefix"])
# Add rating tag based on model family (Pony/Illustrious)
if rating:
from tensors.config import get_rating_tag # noqa: PLC0415
rating_tag = get_rating_tag(model_family, rating.lower())
if rating_tag:
prompt_parts.append(rating_tag)
if not json_output:
console.print(f"[dim]Rating tag: {rating_tag}[/dim]")
elif not json_output:
console.print(f"[dim]Rating '{rating}' not applicable for {model_family or 'unknown'} family[/dim]")
# Add user prompt
prompt_parts.append(prompt)
enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt
# Build enhanced negative prompt
enhanced_negative = negative
if not no_negative and family_defaults.get("negative_prompt"):
family_negative = family_defaults["negative_prompt"]
enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative
if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative):
if enhanced_prompt != prompt:
truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004
console.print(f"[dim]Enhanced prompt: {truncated}[/dim]")
if enhanced_negative != negative:
truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004
console.print(f"[dim]Enhanced negative: {truncated}[/dim]")
# ---- Resolve preset defaults for None params (both remote and local need these) ----
from tensors.config import resolve_orientation, resolve_remote as do_resolve_remote # noqa: PLC0415
# Use already-detected family_defaults from DB lookup above (not filename guessing)
if family_defaults:
res_w, res_h = resolve_orientation(model_family, orientation)
if width is None:
width = res_w
if height is None:
height = res_h
if steps is None:
steps = family_defaults.get("steps", 20)
if cfg is None:
cfg = family_defaults.get("cfg", 7.0)
if sampler is None:
sampler = family_defaults.get("sampler", "euler")
if scheduler is None:
scheduler = family_defaults.get("scheduler", "normal")
if vae is None:
vae = family_defaults.get("vae")
# Fallback to global defaults when no model family was detected
if width is None:
width = COMFYUI_DEFAULT_WIDTH
if height is None:
height = COMFYUI_DEFAULT_HEIGHT
if steps is None:
steps = COMFYUI_DEFAULT_STEPS
if cfg is None:
cfg = COMFYUI_DEFAULT_CFG
if sampler is None:
sampler = COMFYUI_DEFAULT_SAMPLER
if scheduler is None:
scheduler = COMFYUI_DEFAULT_SCHEDULER
# ---- Determine base seed ----
base_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1)
# Resolve remote (explicit flag, or default from config) # Resolve remote (explicit flag, or default from config)
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None) remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
all_results: list[dict[str, Any]] = []
all_saved: list[Path] = []
if remote_url: if remote_url:
# ---- Remote mode: HTTP call to tensors server ---- # ---- Remote mode: HTTP call to tensors server ----
if not json_output: if not json_output:
@@ -894,14 +1025,14 @@ def generate( # noqa: PLR0915
result = remote_generate( result = remote_generate(
remote or remote_url, remote or remote_url,
prompt, enhanced_prompt,
negative_prompt=negative, negative_prompt=enhanced_negative,
model=model, model=model,
width=width, width=width,
height=height, height=height,
steps=steps, steps=steps,
cfg=cfg, cfg=cfg,
seed=seed, seed=base_seed,
sampler=sampler, sampler=sampler,
scheduler=scheduler, scheduler=scheduler,
vae=vae, vae=vae,
@@ -948,69 +1079,188 @@ def generate( # noqa: PLR0915
# ---- Local mode: direct library call ---- # ---- Local mode: direct library call ----
from tensors.comfyui import generate_image, get_image # noqa: PLC0415 from tensors.comfyui import generate_image, get_image # noqa: PLC0415
actual_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1)
result_local = generate_image( result_local = generate_image(
prompt=prompt, prompt=enhanced_prompt,
negative_prompt=negative, negative_prompt=enhanced_negative,
model=model, model=model,
width=width, width=width,
height=height, height=height,
steps=steps, steps=steps,
cfg=cfg, cfg=cfg,
seed=actual_seed, seed=base_seed,
sampler=sampler, sampler=sampler,
scheduler=scheduler, scheduler=scheduler,
console=console if not json_output else None, console=console if not json_output else None,
lora_name=lora, lora_name=lora,
lora_strength=lora_strength, lora_strength=lora_strength,
batch_size=count,
vae=vae, vae=vae,
orientation=orientation, orientation=orientation,
) )
if not result_local: if not result_local:
if json_output: if json_output:
console.print_json(data={"success": False, "errors": {"generation": "Failed to generate"}}) all_results.append({"success": False, "index": 0, "errors": {"generation": "Failed to generate"}})
else: else:
console.print("[red]Generation failed[/red]") console.print("[red]Generation failed[/red]")
raise typer.Exit(1) raise typer.Exit(1)
elif not result_local.success:
if not result_local.success:
if json_output: if json_output:
console.print_json(data={"success": False, "errors": result_local.node_errors}) all_results.append({"success": False, "index": 0, "errors": result_local.node_errors})
else: else:
console.print("[red]Generation failed[/red]") console.print("[red]Generation failed[/red]")
for node_id, errors in result_local.node_errors.items(): for node_id, errors in result_local.node_errors.items():
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}") console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
raise typer.Exit(1) raise typer.Exit(1)
else:
# Save images # Save all output images
saved_paths: list[Path] = []
for i, img_path in enumerate(result_local.images): for i, img_path in enumerate(result_local.images):
saved_path: Path | None = None
if output: if output:
img_data = get_image(str(img_path)) img_data = get_image(str(img_path))
if img_data: if img_data:
save_path = ( save_path = (
output if len(result_local.images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}" output
if count == 1
else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
) )
save_path.write_bytes(img_data) save_path.write_bytes(img_data)
saved_paths.append(save_path) saved_path = save_path
all_saved.append(save_path)
if not json_output: if not json_output:
console.print(f"[green]Saved:[/green] {save_path}") console.print(f"[green]Saved:[/green] {save_path}")
elif not json_output:
console.print(f"[yellow]Could not download image: {img_path}[/yellow]")
all_results.append(
{
"success": True,
"index": i,
"prompt_id": result_local.prompt_id,
"image": str(img_path),
"saved": str(saved_path) if saved_path else None,
}
)
if json_output: if json_output:
console.print_json( console.print_json(
data={ data={
"success": True, "success": all(r.get("success", False) for r in all_results),
"prompt_id": result_local.prompt_id, "count": len(all_results),
"images": [str(p) for p in result_local.images], "results": all_results,
"saved": [str(p) for p in saved_paths],
} }
) )
return return
console.print("[bold green]Generation complete![/bold green]") console.print("[bold green]Generation complete![/bold green]")
console.print(f"[dim]Prompt ID: {result_local.prompt_id}[/dim]") if count > 1:
successful = sum(1 for r in all_results if r.get("success", False))
console.print(f"[dim]Generated {successful}/{count} images[/dim]")
if all_saved:
console.print(f"[dim]Saved to: {all_saved[0].parent}/[/dim]")
elif all_results and all_results[0].get("prompt_id"):
console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]")
# =============================================================================
# Template Dump
# =============================================================================
@app.command()
def template(
model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")],
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "square",
rating: Annotated[str | None, typer.Option("--rating", "-R", help="Content rating: safe, questionable, explicit")] = None,
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save template to file")] = None,
) -> None:
"""Dump a JSON generation template with resolved defaults for a model.
Outputs a ready-to-use JSON object with all parameters auto-resolved from the
checkpoint family. Pipe to 'tsr generate --input' or save to a file for reuse.
Examples:
tsr template -m ponyDiffusionV6XL_v6StartWithThisOne.safetensors
tsr template -m beautifulRealistic_v7.safetensors -O portrait
tsr template -m waiIllustriousSDXL_v160.safetensors -l "Elvira iIlluLoRA.safetensors"
tsr template -m ponyRealism_V22.safetensors -o pony_preset.json
tsr generate --input "$(tsr template -m ponyRealism_V22.safetensors)" "a portrait"
"""
from tensors.config import get_model_generation_defaults, resolve_orientation # noqa: PLC0415
# Look up base_model from DB for accurate family detection
base_model_str: str | None = None
try:
with Database() as db:
db.init_schema()
base_model_str = db.get_base_model_by_filename(model)
except Exception:
pass
family = detect_model_family(model, base_model_str)
defaults = get_model_generation_defaults(model, base_model_str)
res_w, res_h = resolve_orientation(family, orientation)
# Build template
tpl: dict[str, Any] = {
"prompt": "",
"negative_prompt": defaults.get("negative_prompt", ""),
"model": model,
"width": res_w,
"height": res_h,
"steps": defaults.get("steps"),
"cfg": defaults.get("cfg"),
"sampler": defaults.get("sampler"),
"scheduler": defaults.get("scheduler"),
"vae": defaults.get("vae"),
"orientation": orientation,
"seed": -1,
"count": 1,
}
# Add quality prefix if the family has one
quality_prefix = defaults.get("quality_prefix", "")
if quality_prefix:
tpl["quality_prefix"] = quality_prefix
# Add rating tag if specified
if rating:
from tensors.config import get_rating_tag # noqa: PLC0415
rating_tag = get_rating_tag(family, rating.lower())
if rating_tag:
tpl["rating"] = rating
tpl["rating_tag"] = rating_tag
# Add LoRA info
if lora:
tpl["lora"] = lora
tpl["lora_strength"] = lora_strength
# Look up trigger words
try:
with Database() as db:
db.init_schema()
trigger_words = db.get_trigger_words_by_filename(lora)
if trigger_words:
tpl["lora_triggers"] = trigger_words
except Exception:
pass
# Add metadata (not used by generate, but informational)
tpl["_family"] = family or "unknown"
if base_model_str:
tpl["_base_model"] = base_model_str
json_str = json.dumps(tpl, indent=2)
if output:
output.write_text(json_str + "\n")
console.print(f"[green]Saved template:[/green] {output}")
else:
console.print(json_str)
# ============================================================================= # =============================================================================
@@ -1231,22 +1481,49 @@ def db_cache(
@db_app.command("list") @db_app.command("list")
def db_list( def db_list(
model_type: Annotated[str | None, typer.Option("-t", "--type", help="Filter by model type (Checkpoint, LORA, VAE, etc.)")] = None,
base: Annotated[str | None, typer.Option("-b", "--base", help="Filter by base model (Pony, Illustrious, SDXL 1.0, SD 1.5, etc.)")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None: ) -> None:
"""List local files with CivitAI info.""" """List local files with CivitAI info.
Examples:
tsr db list # All local files
tsr db list -t Checkpoint # Only checkpoints
tsr db list -t LORA # Only LoRAs
tsr db list -t Checkpoint -b Pony # Pony checkpoints only
tsr db list -b "SDXL 1.0" # All SDXL 1.0 models
"""
with Database() as db: with Database() as db:
db.init_schema() db.init_schema()
files = db.list_local_files() files = db.list_local_files()
# Apply filters (case-insensitive substring match)
if model_type:
mt_lower = model_type.lower()
files = [f for f in files if (f.get("model_type") or "").lower() == mt_lower]
if base:
base_lower = base.lower()
files = [f for f in files if base_lower in (f.get("base_model") or "").lower()]
if json_output: if json_output:
console.print_json(data=files) console.print_json(data=files)
return return
if not files: if not files:
console.print("[yellow]No files in database. Run 'tsr db scan' first.[/yellow]") console.print("[yellow]No files found. Try 'tsr db scan' or adjust filters.[/yellow]")
return return
table = Table(title="Local Files", show_header=True, header_style="bold magenta") title = "Local Files"
if model_type or base:
parts = []
if model_type:
parts.append(model_type)
if base:
parts.append(base)
title = f"Local Files ({', '.join(parts)})"
table = Table(title=title, show_header=True, header_style="bold magenta")
table.add_column("Path", style="cyan", max_width=50) table.add_column("Path", style="cyan", max_width=50)
table.add_column("Model", style="green") table.add_column("Model", style="green")
table.add_column("Version", style="white") table.add_column("Version", style="white")
@@ -1257,9 +1534,9 @@ def db_list(
path = Path(f["file_path"]).name path = Path(f["file_path"]).name
model = f.get("model_name") or "[dim]unlinked[/dim]" model = f.get("model_name") or "[dim]unlinked[/dim]"
version = f.get("version_name") or "" version = f.get("version_name") or ""
model_type = f.get("model_type") or "" ft = f.get("model_type") or ""
base = f.get("base_model") or "" base_model = f.get("base_model") or ""
table.add_row(path, model, version, model_type, base) table.add_row(path, model, version, ft, base_model)
console.print(table) console.print(table)
@@ -1653,183 +1930,56 @@ def comfy_history(
console.print(table) console.print(table)
@comfy_app.command("generate") @comfy_app.command("generate", deprecated=True)
def comfy_generate( # noqa: PLR0915 def comfy_generate(
prompt: Annotated[str, typer.Argument(help="Positive prompt text")], prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None, model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
width: Annotated[int | None, typer.Option("-W", "--width", help="Image width (auto from checkpoint)")] = None, negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
height: Annotated[int | None, typer.Option("-H", "--height", help="Image height (auto from checkpoint)")] = None, width: Annotated[int | None, typer.Option("-W", "--width")] = None,
steps: Annotated[int | None, typer.Option("--steps", help="Sampling steps (auto from checkpoint)")] = None, height: Annotated[int | None, typer.Option("-H", "--height")] = None,
cfg: Annotated[float | None, typer.Option("--cfg", help="CFG scale (auto from checkpoint)")] = None, steps: Annotated[int | None, typer.Option("--steps")] = None,
seed: Annotated[int, typer.Option("--seed", "-s", help="Random seed (-1 for random)")] = -1, cfg: Annotated[float | None, typer.Option("--cfg")] = None,
sampler: Annotated[str | None, typer.Option("--sampler", help="Sampler name (auto from checkpoint)")] = None, seed: Annotated[int, typer.Option("--seed", "-s")] = -1,
scheduler: Annotated[str | None, typer.Option("--scheduler", help="Scheduler name (auto from checkpoint)")] = None, sampler: Annotated[str | None, typer.Option("--sampler")] = None,
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "square", scheduler: Annotated[str | None, typer.Option("--scheduler")] = None,
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None, orientation: Annotated[str, typer.Option("-O", "--orientation")] = "square",
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1, output: Annotated[Path | None, typer.Option("-o", "--output")] = None,
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, count: Annotated[int, typer.Option("-c", "--count")] = 1,
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 1.0, lora: Annotated[str | None, typer.Option("-l", "--lora")] = None,
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False, lora_strength: Annotated[float, typer.Option("--lora-strength")] = 0.8,
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False, no_quality: Annotated[bool, typer.Option("--no-quality")] = False,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, no_negative: Annotated[bool, typer.Option("--no-negative")] = False,
json_output: Annotated[bool, typer.Option("--json", "-j")] = False,
) -> None: ) -> None:
"""Generate an image with a simple text-to-image workflow. """[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command."""
console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]")
Examples: # Delegate to the unified generate command via context invocation
tsr comfy generate "a cat sitting on a windowsill" ctx = typer.Context(generate)
tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30 generate(
tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768 ctx=ctx,
tsr comfy generate "cyberpunk city" --count 4 -o batch.png prompt=prompt,
tsr comfy generate "girl" --lora spumcostyle.safetensors --lora-strength 0.8
tsr comfy generate "raw prompt" --no-quality --no-negative
"""
import random # noqa: PLC0415
from tensors.comfyui import generate_image, get_image # noqa: PLC0415
all_results: list[dict[str, Any]] = []
all_saved: list[Path] = []
# Determine base seed for batch
base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
# Detect model family and apply defaults
family_defaults: dict[str, Any] = {}
model_family: str | None = None
if model:
# Try to get base_model from database
base_model_str: str | None = None
try:
with Database() as db:
db.init_schema()
base_model_str = db.get_base_model_by_filename(model)
except Exception:
pass
model_family = detect_model_family(model, base_model_str)
if model_family:
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
if not json_output:
console.print(f"[dim]Detected model family: {model_family}[/dim]")
# Build enhanced prompt with quality prefix and LoRA trigger words
enhanced_prompt = prompt
prompt_parts: list[str] = []
# Add LoRA trigger words if using LoRA
if lora:
try:
with Database() as db:
db.init_schema()
trigger_words = db.get_trigger_words_by_filename(lora)
if trigger_words:
prompt_parts.extend(trigger_words)
if not json_output:
console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]")
except Exception:
pass
# Add quality prefix based on model family
if not no_quality and family_defaults.get("quality_prefix"):
prompt_parts.append(family_defaults["quality_prefix"])
# Add user prompt
prompt_parts.append(prompt)
enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt
# Build enhanced negative prompt
enhanced_negative = negative
if not no_negative and family_defaults.get("negative_prompt"):
family_negative = family_defaults["negative_prompt"]
enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative
if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative):
if enhanced_prompt != prompt:
truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004
console.print(f"[dim]Enhanced prompt: {truncated}[/dim]")
if enhanced_negative != negative:
truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004
console.print(f"[dim]Enhanced negative: {truncated}[/dim]")
# Use native ComfyUI batching - single workflow generates all images
result = generate_image(
prompt=enhanced_prompt,
url=url,
negative_prompt=enhanced_negative,
model=model, model=model,
width=width, width=width,
height=height, height=height,
steps=steps, steps=steps,
cfg=cfg, cfg=cfg,
seed=base_seed, seed=seed,
sampler=sampler, sampler=sampler,
scheduler=scheduler, scheduler=scheduler,
console=console if not json_output else None, vae=None,
lora_name=lora,
lora_strength=lora_strength,
batch_size=count,
orientation=orientation, orientation=orientation,
lora=lora,
lora_strength=lora_strength,
negative=negative,
count=count,
no_quality=no_quality,
no_negative=no_negative,
output=output,
remote=None,
json_output=json_output,
json_input=None,
) )
if not result:
if json_output:
all_results.append({"success": False, "index": 0, "errors": {"generation": "Failed to generate"}})
else:
console.print("[red]Generation failed[/red]")
elif not result.success:
if json_output:
all_results.append({"success": False, "index": 0, "errors": result.node_errors})
else:
console.print("[red]Generation failed[/red]")
for node_id, errors in result.node_errors.items():
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
else:
# Save all output images
for i, img_path in enumerate(result.images):
saved_path: Path | None = None
if output:
img_data = get_image(str(img_path), url=url)
if img_data:
save_path = output if count == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
save_path.write_bytes(img_data)
saved_path = save_path
all_saved.append(save_path)
if not json_output:
console.print(f"[green]Saved:[/green] {save_path}")
elif not json_output:
console.print(f"[yellow]Could not download image: {img_path}[/yellow]")
all_results.append(
{
"success": True,
"index": i,
"prompt_id": result.prompt_id,
"image": str(img_path),
"saved": str(saved_path) if saved_path else None,
}
)
if json_output:
console.print_json(
data={
"success": all(r.get("success", False) for r in all_results),
"count": len(all_results),
"results": all_results,
}
)
return
console.print("\n[bold green]Generation complete![/bold green]")
if count > 1:
successful = sum(1 for r in all_results if r.get("success", False))
console.print(f"[dim]Generated {successful}/{count} images[/dim]")
if all_saved:
console.print(f"[dim]Saved to: {all_saved[0].parent}/[/dim]")
elif all_results and all_results[0].get("prompt_id"):
console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]")
@comfy_app.command("run") @comfy_app.command("run")
def comfy_run( def comfy_run(
+36
View File
@@ -506,6 +506,42 @@ COMFYUI_DEFAULT_SCHEDULER = "normal"
# Model Family Defaults (Quality Tags, Negative Prompts, etc.) # Model Family Defaults (Quality Tags, Negative Prompts, etc.)
# ============================================================================ # ============================================================================
# Rating tags per model family — maps (family, rating) to the tag to inject
# Families not listed here have no rating tag system (prompt-driven only)
RATING_TAGS: dict[str, dict[str, str]] = {
"pony": {
"safe": "rating_safe",
"questionable": "rating_questionable",
"explicit": "rating_explicit",
},
"illustrious": {
"safe": "rating:safe",
"questionable": "rating:questionable",
"explicit": "rating:explicit",
},
}
# NoobAI uses same tags as Illustrious
RATING_TAGS["noobai"] = RATING_TAGS["illustrious"]
def get_rating_tag(family: str | None, rating: str) -> str | None:
"""Get the rating tag for a model family and rating level.
Args:
family: Model family key (e.g. "pony", "illustrious") or None
rating: One of "safe", "questionable", "explicit"
Returns:
Rating tag string to inject into prompt, or None if family has no rating system
"""
if not family:
return None
tags = RATING_TAGS.get(family)
if not tags:
return None
return tags.get(rating)
MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"pony": { "pony": {
"quality_prefix": "score_9, score_8_up, score_7_up", "quality_prefix": "score_9, score_8_up, score_7_up",