fix null params when no model family detected in tsr generate
Co-Authored-By: marauder-os <marauder@saiden.dev>
This commit is contained in:
+363
-213
@@ -21,6 +21,12 @@ from tensors.api import (
|
||||
search_civitai,
|
||||
)
|
||||
from tensors.config import (
|
||||
COMFYUI_DEFAULT_CFG,
|
||||
COMFYUI_DEFAULT_HEIGHT,
|
||||
COMFYUI_DEFAULT_SAMPLER,
|
||||
COMFYUI_DEFAULT_SCHEDULER,
|
||||
COMFYUI_DEFAULT_STEPS,
|
||||
COMFYUI_DEFAULT_WIDTH,
|
||||
CONFIG_FILE,
|
||||
MODEL_FAMILY_DEFAULTS,
|
||||
BaseModel,
|
||||
@@ -787,6 +793,10 @@ def generate( # noqa: PLR0915
|
||||
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
||||
negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "",
|
||||
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1,
|
||||
rating: Annotated[str | None, typer.Option("--rating", "-R", help="Content rating: safe, questionable, explicit (Pony/Illustrious)")] = None,
|
||||
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False,
|
||||
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False,
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None,
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
@@ -794,15 +804,19 @@ def generate( # noqa: PLR0915
|
||||
) -> None:
|
||||
"""Generate an image using text-to-image.
|
||||
|
||||
Auto-detects optimal sampler, scheduler, CFG, resolution, and VAE from the checkpoint
|
||||
model family. All auto-detected values can be overridden with explicit flags.
|
||||
|
||||
Calls ComfyUI directly when local, or the remote tensors API when --remote is given.
|
||||
Accepts --input with a JSON object whose keys match CLI option names. CLI flags override JSON values.
|
||||
|
||||
Examples:
|
||||
tsr generate "a cat on a windowsill"
|
||||
tsr generate "portrait photo" -m "flux1-dev-fp8.safetensors" --steps 30
|
||||
tsr generate "cyberpunk city" -o output.png
|
||||
tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait
|
||||
tsr generate "cyberpunk city" -o output.png --count 4
|
||||
tsr generate "landscape" --remote junkpile
|
||||
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors", "steps": 30}'
|
||||
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}'
|
||||
tsr generate "raw prompt" --no-quality --no-negative
|
||||
"""
|
||||
import random as rng # noqa: PLC0415
|
||||
|
||||
@@ -877,16 +891,133 @@ def generate( # noqa: PLR0915
|
||||
output = Path(mapped["output"])
|
||||
if "remote" in mapped and "remote" not in explicit:
|
||||
remote = mapped["remote"]
|
||||
if "count" in mapped and "count" not in explicit:
|
||||
count = int(mapped["count"])
|
||||
if "orientation" in mapped and "orientation" not in explicit:
|
||||
orientation = mapped["orientation"]
|
||||
if "no_quality" in mapped and "no_quality" not in explicit:
|
||||
no_quality = bool(mapped["no_quality"])
|
||||
if "no_negative" in mapped and "no_negative" not in explicit:
|
||||
no_negative = bool(mapped["no_negative"])
|
||||
if "rating" in mapped and "rating" not in explicit:
|
||||
rating = mapped["rating"]
|
||||
|
||||
if not prompt:
|
||||
console.print("[red]Prompt is required (as argument or in --input JSON)[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
|
||||
# ---- Detect model family and enhance prompt/negative ----
|
||||
family_defaults: dict[str, Any] = {}
|
||||
model_family: str | None = None
|
||||
base_model_str: str | None = None
|
||||
if model:
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
base_model_str = db.get_base_model_by_filename(model)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model_family = detect_model_family(model, base_model_str)
|
||||
if model_family:
|
||||
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
|
||||
if not json_output:
|
||||
console.print(f"[dim]Detected model family: {model_family}[/dim]")
|
||||
|
||||
# Build enhanced prompt with quality prefix and LoRA trigger words
|
||||
prompt_parts: list[str] = []
|
||||
|
||||
# Add LoRA trigger words if using LoRA
|
||||
if lora:
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
trigger_words = db.get_trigger_words_by_filename(lora)
|
||||
if trigger_words:
|
||||
prompt_parts.extend(trigger_words)
|
||||
if not json_output:
|
||||
console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Add quality prefix based on model family
|
||||
if not no_quality and family_defaults.get("quality_prefix"):
|
||||
prompt_parts.append(family_defaults["quality_prefix"])
|
||||
|
||||
# Add rating tag based on model family (Pony/Illustrious)
|
||||
if rating:
|
||||
from tensors.config import get_rating_tag # noqa: PLC0415
|
||||
|
||||
rating_tag = get_rating_tag(model_family, rating.lower())
|
||||
if rating_tag:
|
||||
prompt_parts.append(rating_tag)
|
||||
if not json_output:
|
||||
console.print(f"[dim]Rating tag: {rating_tag}[/dim]")
|
||||
elif not json_output:
|
||||
console.print(f"[dim]Rating '{rating}' not applicable for {model_family or 'unknown'} family[/dim]")
|
||||
|
||||
# Add user prompt
|
||||
prompt_parts.append(prompt)
|
||||
enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt
|
||||
|
||||
# Build enhanced negative prompt
|
||||
enhanced_negative = negative
|
||||
if not no_negative and family_defaults.get("negative_prompt"):
|
||||
family_negative = family_defaults["negative_prompt"]
|
||||
enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative
|
||||
|
||||
if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative):
|
||||
if enhanced_prompt != prompt:
|
||||
truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004
|
||||
console.print(f"[dim]Enhanced prompt: {truncated}[/dim]")
|
||||
if enhanced_negative != negative:
|
||||
truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004
|
||||
console.print(f"[dim]Enhanced negative: {truncated}[/dim]")
|
||||
|
||||
# ---- Resolve preset defaults for None params (both remote and local need these) ----
|
||||
from tensors.config import resolve_orientation, resolve_remote as do_resolve_remote # noqa: PLC0415
|
||||
|
||||
# Use already-detected family_defaults from DB lookup above (not filename guessing)
|
||||
if family_defaults:
|
||||
res_w, res_h = resolve_orientation(model_family, orientation)
|
||||
if width is None:
|
||||
width = res_w
|
||||
if height is None:
|
||||
height = res_h
|
||||
if steps is None:
|
||||
steps = family_defaults.get("steps", 20)
|
||||
if cfg is None:
|
||||
cfg = family_defaults.get("cfg", 7.0)
|
||||
if sampler is None:
|
||||
sampler = family_defaults.get("sampler", "euler")
|
||||
if scheduler is None:
|
||||
scheduler = family_defaults.get("scheduler", "normal")
|
||||
if vae is None:
|
||||
vae = family_defaults.get("vae")
|
||||
|
||||
# Fallback to global defaults when no model family was detected
|
||||
if width is None:
|
||||
width = COMFYUI_DEFAULT_WIDTH
|
||||
if height is None:
|
||||
height = COMFYUI_DEFAULT_HEIGHT
|
||||
if steps is None:
|
||||
steps = COMFYUI_DEFAULT_STEPS
|
||||
if cfg is None:
|
||||
cfg = COMFYUI_DEFAULT_CFG
|
||||
if sampler is None:
|
||||
sampler = COMFYUI_DEFAULT_SAMPLER
|
||||
if scheduler is None:
|
||||
scheduler = COMFYUI_DEFAULT_SCHEDULER
|
||||
|
||||
# ---- Determine base seed ----
|
||||
base_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1)
|
||||
|
||||
# Resolve remote (explicit flag, or default from config)
|
||||
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
|
||||
|
||||
all_results: list[dict[str, Any]] = []
|
||||
all_saved: list[Path] = []
|
||||
|
||||
if remote_url:
|
||||
# ---- Remote mode: HTTP call to tensors server ----
|
||||
if not json_output:
|
||||
@@ -894,14 +1025,14 @@ def generate( # noqa: PLR0915
|
||||
|
||||
result = remote_generate(
|
||||
remote or remote_url,
|
||||
prompt,
|
||||
negative_prompt=negative,
|
||||
enhanced_prompt,
|
||||
negative_prompt=enhanced_negative,
|
||||
model=model,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
seed=seed,
|
||||
seed=base_seed,
|
||||
sampler=sampler,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
@@ -948,69 +1079,188 @@ def generate( # noqa: PLR0915
|
||||
# ---- Local mode: direct library call ----
|
||||
from tensors.comfyui import generate_image, get_image # noqa: PLC0415
|
||||
|
||||
actual_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1)
|
||||
|
||||
result_local = generate_image(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative,
|
||||
prompt=enhanced_prompt,
|
||||
negative_prompt=enhanced_negative,
|
||||
model=model,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
seed=actual_seed,
|
||||
seed=base_seed,
|
||||
sampler=sampler,
|
||||
scheduler=scheduler,
|
||||
console=console if not json_output else None,
|
||||
lora_name=lora,
|
||||
lora_strength=lora_strength,
|
||||
batch_size=count,
|
||||
vae=vae,
|
||||
orientation=orientation,
|
||||
)
|
||||
|
||||
if not result_local:
|
||||
if json_output:
|
||||
console.print_json(data={"success": False, "errors": {"generation": "Failed to generate"}})
|
||||
all_results.append({"success": False, "index": 0, "errors": {"generation": "Failed to generate"}})
|
||||
else:
|
||||
console.print("[red]Generation failed[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not result_local.success:
|
||||
raise typer.Exit(1)
|
||||
elif not result_local.success:
|
||||
if json_output:
|
||||
console.print_json(data={"success": False, "errors": result_local.node_errors})
|
||||
all_results.append({"success": False, "index": 0, "errors": result_local.node_errors})
|
||||
else:
|
||||
console.print("[red]Generation failed[/red]")
|
||||
for node_id, errors in result_local.node_errors.items():
|
||||
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
|
||||
raise typer.Exit(1)
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
# Save all output images
|
||||
for i, img_path in enumerate(result_local.images):
|
||||
saved_path: Path | None = None
|
||||
if output:
|
||||
img_data = get_image(str(img_path))
|
||||
if img_data:
|
||||
save_path = (
|
||||
output
|
||||
if count == 1
|
||||
else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||
)
|
||||
save_path.write_bytes(img_data)
|
||||
saved_path = save_path
|
||||
all_saved.append(save_path)
|
||||
if not json_output:
|
||||
console.print(f"[green]Saved:[/green] {save_path}")
|
||||
elif not json_output:
|
||||
console.print(f"[yellow]Could not download image: {img_path}[/yellow]")
|
||||
|
||||
# Save images
|
||||
saved_paths: list[Path] = []
|
||||
for i, img_path in enumerate(result_local.images):
|
||||
if output:
|
||||
img_data = get_image(str(img_path))
|
||||
if img_data:
|
||||
save_path = (
|
||||
output if len(result_local.images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||
)
|
||||
save_path.write_bytes(img_data)
|
||||
saved_paths.append(save_path)
|
||||
if not json_output:
|
||||
console.print(f"[green]Saved:[/green] {save_path}")
|
||||
all_results.append(
|
||||
{
|
||||
"success": True,
|
||||
"index": i,
|
||||
"prompt_id": result_local.prompt_id,
|
||||
"image": str(img_path),
|
||||
"saved": str(saved_path) if saved_path else None,
|
||||
}
|
||||
)
|
||||
|
||||
if json_output:
|
||||
console.print_json(
|
||||
data={
|
||||
"success": True,
|
||||
"prompt_id": result_local.prompt_id,
|
||||
"images": [str(p) for p in result_local.images],
|
||||
"saved": [str(p) for p in saved_paths],
|
||||
}
|
||||
)
|
||||
return
|
||||
if json_output:
|
||||
console.print_json(
|
||||
data={
|
||||
"success": all(r.get("success", False) for r in all_results),
|
||||
"count": len(all_results),
|
||||
"results": all_results,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
console.print("[bold green]Generation complete![/bold green]")
|
||||
console.print(f"[dim]Prompt ID: {result_local.prompt_id}[/dim]")
|
||||
console.print("[bold green]Generation complete![/bold green]")
|
||||
if count > 1:
|
||||
successful = sum(1 for r in all_results if r.get("success", False))
|
||||
console.print(f"[dim]Generated {successful}/{count} images[/dim]")
|
||||
if all_saved:
|
||||
console.print(f"[dim]Saved to: {all_saved[0].parent}/[/dim]")
|
||||
elif all_results and all_results[0].get("prompt_id"):
|
||||
console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Template Dump
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def template(
|
||||
model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")],
|
||||
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
||||
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "square",
|
||||
rating: Annotated[str | None, typer.Option("--rating", "-R", help="Content rating: safe, questionable, explicit")] = None,
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save template to file")] = None,
|
||||
) -> None:
|
||||
"""Dump a JSON generation template with resolved defaults for a model.
|
||||
|
||||
Outputs a ready-to-use JSON object with all parameters auto-resolved from the
|
||||
checkpoint family. Pipe to 'tsr generate --input' or save to a file for reuse.
|
||||
|
||||
Examples:
|
||||
tsr template -m ponyDiffusionV6XL_v6StartWithThisOne.safetensors
|
||||
tsr template -m beautifulRealistic_v7.safetensors -O portrait
|
||||
tsr template -m waiIllustriousSDXL_v160.safetensors -l "Elvira iIlluLoRA.safetensors"
|
||||
tsr template -m ponyRealism_V22.safetensors -o pony_preset.json
|
||||
tsr generate --input "$(tsr template -m ponyRealism_V22.safetensors)" "a portrait"
|
||||
"""
|
||||
from tensors.config import get_model_generation_defaults, resolve_orientation # noqa: PLC0415
|
||||
|
||||
# Look up base_model from DB for accurate family detection
|
||||
base_model_str: str | None = None
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
base_model_str = db.get_base_model_by_filename(model)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
family = detect_model_family(model, base_model_str)
|
||||
defaults = get_model_generation_defaults(model, base_model_str)
|
||||
res_w, res_h = resolve_orientation(family, orientation)
|
||||
|
||||
# Build template
|
||||
tpl: dict[str, Any] = {
|
||||
"prompt": "",
|
||||
"negative_prompt": defaults.get("negative_prompt", ""),
|
||||
"model": model,
|
||||
"width": res_w,
|
||||
"height": res_h,
|
||||
"steps": defaults.get("steps"),
|
||||
"cfg": defaults.get("cfg"),
|
||||
"sampler": defaults.get("sampler"),
|
||||
"scheduler": defaults.get("scheduler"),
|
||||
"vae": defaults.get("vae"),
|
||||
"orientation": orientation,
|
||||
"seed": -1,
|
||||
"count": 1,
|
||||
}
|
||||
|
||||
# Add quality prefix if the family has one
|
||||
quality_prefix = defaults.get("quality_prefix", "")
|
||||
if quality_prefix:
|
||||
tpl["quality_prefix"] = quality_prefix
|
||||
|
||||
# Add rating tag if specified
|
||||
if rating:
|
||||
from tensors.config import get_rating_tag # noqa: PLC0415
|
||||
|
||||
rating_tag = get_rating_tag(family, rating.lower())
|
||||
if rating_tag:
|
||||
tpl["rating"] = rating
|
||||
tpl["rating_tag"] = rating_tag
|
||||
|
||||
# Add LoRA info
|
||||
if lora:
|
||||
tpl["lora"] = lora
|
||||
tpl["lora_strength"] = lora_strength
|
||||
|
||||
# Look up trigger words
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
trigger_words = db.get_trigger_words_by_filename(lora)
|
||||
if trigger_words:
|
||||
tpl["lora_triggers"] = trigger_words
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Add metadata (not used by generate, but informational)
|
||||
tpl["_family"] = family or "unknown"
|
||||
if base_model_str:
|
||||
tpl["_base_model"] = base_model_str
|
||||
|
||||
json_str = json.dumps(tpl, indent=2)
|
||||
|
||||
if output:
|
||||
output.write_text(json_str + "\n")
|
||||
console.print(f"[green]Saved template:[/green] {output}")
|
||||
else:
|
||||
console.print(json_str)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -1231,22 +1481,49 @@ def db_cache(
|
||||
|
||||
@db_app.command("list")
|
||||
def db_list(
|
||||
model_type: Annotated[str | None, typer.Option("-t", "--type", help="Filter by model type (Checkpoint, LORA, VAE, etc.)")] = None,
|
||||
base: Annotated[str | None, typer.Option("-b", "--base", help="Filter by base model (Pony, Illustrious, SDXL 1.0, SD 1.5, etc.)")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List local files with CivitAI info."""
|
||||
"""List local files with CivitAI info.
|
||||
|
||||
Examples:
|
||||
tsr db list # All local files
|
||||
tsr db list -t Checkpoint # Only checkpoints
|
||||
tsr db list -t LORA # Only LoRAs
|
||||
tsr db list -t Checkpoint -b Pony # Pony checkpoints only
|
||||
tsr db list -b "SDXL 1.0" # All SDXL 1.0 models
|
||||
"""
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
files = db.list_local_files()
|
||||
|
||||
# Apply filters (case-insensitive substring match)
|
||||
if model_type:
|
||||
mt_lower = model_type.lower()
|
||||
files = [f for f in files if (f.get("model_type") or "").lower() == mt_lower]
|
||||
if base:
|
||||
base_lower = base.lower()
|
||||
files = [f for f in files if base_lower in (f.get("base_model") or "").lower()]
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=files)
|
||||
return
|
||||
|
||||
if not files:
|
||||
console.print("[yellow]No files in database. Run 'tsr db scan' first.[/yellow]")
|
||||
console.print("[yellow]No files found. Try 'tsr db scan' or adjust filters.[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Local Files", show_header=True, header_style="bold magenta")
|
||||
title = "Local Files"
|
||||
if model_type or base:
|
||||
parts = []
|
||||
if model_type:
|
||||
parts.append(model_type)
|
||||
if base:
|
||||
parts.append(base)
|
||||
title = f"Local Files ({', '.join(parts)})"
|
||||
|
||||
table = Table(title=title, show_header=True, header_style="bold magenta")
|
||||
table.add_column("Path", style="cyan", max_width=50)
|
||||
table.add_column("Model", style="green")
|
||||
table.add_column("Version", style="white")
|
||||
@@ -1257,9 +1534,9 @@ def db_list(
|
||||
path = Path(f["file_path"]).name
|
||||
model = f.get("model_name") or "[dim]unlinked[/dim]"
|
||||
version = f.get("version_name") or ""
|
||||
model_type = f.get("model_type") or ""
|
||||
base = f.get("base_model") or ""
|
||||
table.add_row(path, model, version, model_type, base)
|
||||
ft = f.get("model_type") or ""
|
||||
base_model = f.get("base_model") or ""
|
||||
table.add_row(path, model, version, ft, base_model)
|
||||
|
||||
console.print(table)
|
||||
|
||||
@@ -1653,183 +1930,56 @@ def comfy_history(
|
||||
console.print(table)
|
||||
|
||||
|
||||
@comfy_app.command("generate")
|
||||
def comfy_generate( # noqa: PLR0915
|
||||
@comfy_app.command("generate", deprecated=True)
|
||||
def comfy_generate(
|
||||
prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
|
||||
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
|
||||
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
|
||||
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
|
||||
width: Annotated[int | None, typer.Option("-W", "--width", help="Image width (auto from checkpoint)")] = None,
|
||||
height: Annotated[int | None, typer.Option("-H", "--height", help="Image height (auto from checkpoint)")] = None,
|
||||
steps: Annotated[int | None, typer.Option("--steps", help="Sampling steps (auto from checkpoint)")] = None,
|
||||
cfg: Annotated[float | None, typer.Option("--cfg", help="CFG scale (auto from checkpoint)")] = None,
|
||||
seed: Annotated[int, typer.Option("--seed", "-s", help="Random seed (-1 for random)")] = -1,
|
||||
sampler: Annotated[str | None, typer.Option("--sampler", help="Sampler name (auto from checkpoint)")] = None,
|
||||
scheduler: Annotated[str | None, typer.Option("--scheduler", help="Scheduler name (auto from checkpoint)")] = None,
|
||||
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "square",
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None,
|
||||
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1,
|
||||
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 1.0,
|
||||
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False,
|
||||
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
|
||||
width: Annotated[int | None, typer.Option("-W", "--width")] = None,
|
||||
height: Annotated[int | None, typer.Option("-H", "--height")] = None,
|
||||
steps: Annotated[int | None, typer.Option("--steps")] = None,
|
||||
cfg: Annotated[float | None, typer.Option("--cfg")] = None,
|
||||
seed: Annotated[int, typer.Option("--seed", "-s")] = -1,
|
||||
sampler: Annotated[str | None, typer.Option("--sampler")] = None,
|
||||
scheduler: Annotated[str | None, typer.Option("--scheduler")] = None,
|
||||
orientation: Annotated[str, typer.Option("-O", "--orientation")] = "square",
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output")] = None,
|
||||
count: Annotated[int, typer.Option("-c", "--count")] = 1,
|
||||
lora: Annotated[str | None, typer.Option("-l", "--lora")] = None,
|
||||
lora_strength: Annotated[float, typer.Option("--lora-strength")] = 0.8,
|
||||
no_quality: Annotated[bool, typer.Option("--no-quality")] = False,
|
||||
no_negative: Annotated[bool, typer.Option("--no-negative")] = False,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j")] = False,
|
||||
) -> None:
|
||||
"""Generate an image with a simple text-to-image workflow.
|
||||
|
||||
Examples:
|
||||
tsr comfy generate "a cat sitting on a windowsill"
|
||||
tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30
|
||||
tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768
|
||||
tsr comfy generate "cyberpunk city" --count 4 -o batch.png
|
||||
tsr comfy generate "girl" --lora spumcostyle.safetensors --lora-strength 0.8
|
||||
tsr comfy generate "raw prompt" --no-quality --no-negative
|
||||
"""
|
||||
import random # noqa: PLC0415
|
||||
|
||||
from tensors.comfyui import generate_image, get_image # noqa: PLC0415
|
||||
|
||||
all_results: list[dict[str, Any]] = []
|
||||
all_saved: list[Path] = []
|
||||
|
||||
# Determine base seed for batch
|
||||
base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
||||
|
||||
# Detect model family and apply defaults
|
||||
family_defaults: dict[str, Any] = {}
|
||||
model_family: str | None = None
|
||||
if model:
|
||||
# Try to get base_model from database
|
||||
base_model_str: str | None = None
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
base_model_str = db.get_base_model_by_filename(model)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model_family = detect_model_family(model, base_model_str)
|
||||
if model_family:
|
||||
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
|
||||
if not json_output:
|
||||
console.print(f"[dim]Detected model family: {model_family}[/dim]")
|
||||
|
||||
# Build enhanced prompt with quality prefix and LoRA trigger words
|
||||
enhanced_prompt = prompt
|
||||
prompt_parts: list[str] = []
|
||||
|
||||
# Add LoRA trigger words if using LoRA
|
||||
if lora:
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
trigger_words = db.get_trigger_words_by_filename(lora)
|
||||
if trigger_words:
|
||||
prompt_parts.extend(trigger_words)
|
||||
if not json_output:
|
||||
console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Add quality prefix based on model family
|
||||
if not no_quality and family_defaults.get("quality_prefix"):
|
||||
prompt_parts.append(family_defaults["quality_prefix"])
|
||||
|
||||
# Add user prompt
|
||||
prompt_parts.append(prompt)
|
||||
enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt
|
||||
|
||||
# Build enhanced negative prompt
|
||||
enhanced_negative = negative
|
||||
if not no_negative and family_defaults.get("negative_prompt"):
|
||||
family_negative = family_defaults["negative_prompt"]
|
||||
enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative
|
||||
|
||||
if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative):
|
||||
if enhanced_prompt != prompt:
|
||||
truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004
|
||||
console.print(f"[dim]Enhanced prompt: {truncated}[/dim]")
|
||||
if enhanced_negative != negative:
|
||||
truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004
|
||||
console.print(f"[dim]Enhanced negative: {truncated}[/dim]")
|
||||
|
||||
# Use native ComfyUI batching - single workflow generates all images
|
||||
result = generate_image(
|
||||
prompt=enhanced_prompt,
|
||||
url=url,
|
||||
negative_prompt=enhanced_negative,
|
||||
"""[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command."""
|
||||
console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]")
|
||||
# Delegate to the unified generate command via context invocation
|
||||
ctx = typer.Context(generate)
|
||||
generate(
|
||||
ctx=ctx,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
seed=base_seed,
|
||||
seed=seed,
|
||||
sampler=sampler,
|
||||
scheduler=scheduler,
|
||||
console=console if not json_output else None,
|
||||
lora_name=lora,
|
||||
lora_strength=lora_strength,
|
||||
batch_size=count,
|
||||
vae=None,
|
||||
orientation=orientation,
|
||||
lora=lora,
|
||||
lora_strength=lora_strength,
|
||||
negative=negative,
|
||||
count=count,
|
||||
no_quality=no_quality,
|
||||
no_negative=no_negative,
|
||||
output=output,
|
||||
remote=None,
|
||||
json_output=json_output,
|
||||
json_input=None,
|
||||
)
|
||||
|
||||
if not result:
|
||||
if json_output:
|
||||
all_results.append({"success": False, "index": 0, "errors": {"generation": "Failed to generate"}})
|
||||
else:
|
||||
console.print("[red]Generation failed[/red]")
|
||||
elif not result.success:
|
||||
if json_output:
|
||||
all_results.append({"success": False, "index": 0, "errors": result.node_errors})
|
||||
else:
|
||||
console.print("[red]Generation failed[/red]")
|
||||
for node_id, errors in result.node_errors.items():
|
||||
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
|
||||
else:
|
||||
# Save all output images
|
||||
for i, img_path in enumerate(result.images):
|
||||
saved_path: Path | None = None
|
||||
if output:
|
||||
img_data = get_image(str(img_path), url=url)
|
||||
if img_data:
|
||||
save_path = output if count == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||
save_path.write_bytes(img_data)
|
||||
saved_path = save_path
|
||||
all_saved.append(save_path)
|
||||
if not json_output:
|
||||
console.print(f"[green]Saved:[/green] {save_path}")
|
||||
elif not json_output:
|
||||
console.print(f"[yellow]Could not download image: {img_path}[/yellow]")
|
||||
|
||||
all_results.append(
|
||||
{
|
||||
"success": True,
|
||||
"index": i,
|
||||
"prompt_id": result.prompt_id,
|
||||
"image": str(img_path),
|
||||
"saved": str(saved_path) if saved_path else None,
|
||||
}
|
||||
)
|
||||
|
||||
if json_output:
|
||||
console.print_json(
|
||||
data={
|
||||
"success": all(r.get("success", False) for r in all_results),
|
||||
"count": len(all_results),
|
||||
"results": all_results,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
console.print("\n[bold green]Generation complete![/bold green]")
|
||||
if count > 1:
|
||||
successful = sum(1 for r in all_results if r.get("success", False))
|
||||
console.print(f"[dim]Generated {successful}/{count} images[/dim]")
|
||||
if all_saved:
|
||||
console.print(f"[dim]Saved to: {all_saved[0].parent}/[/dim]")
|
||||
elif all_results and all_results[0].get("prompt_id"):
|
||||
console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]")
|
||||
|
||||
|
||||
@comfy_app.command("run")
|
||||
def comfy_run(
|
||||
|
||||
@@ -506,6 +506,42 @@ COMFYUI_DEFAULT_SCHEDULER = "normal"
|
||||
# Model Family Defaults (Quality Tags, Negative Prompts, etc.)
|
||||
# ============================================================================
|
||||
|
||||
# Rating tags per model family — maps (family, rating) to the tag to inject
|
||||
# Families not listed here have no rating tag system (prompt-driven only)
|
||||
RATING_TAGS: dict[str, dict[str, str]] = {
|
||||
"pony": {
|
||||
"safe": "rating_safe",
|
||||
"questionable": "rating_questionable",
|
||||
"explicit": "rating_explicit",
|
||||
},
|
||||
"illustrious": {
|
||||
"safe": "rating:safe",
|
||||
"questionable": "rating:questionable",
|
||||
"explicit": "rating:explicit",
|
||||
},
|
||||
}
|
||||
# NoobAI uses same tags as Illustrious
|
||||
RATING_TAGS["noobai"] = RATING_TAGS["illustrious"]
|
||||
|
||||
|
||||
def get_rating_tag(family: str | None, rating: str) -> str | None:
|
||||
"""Get the rating tag for a model family and rating level.
|
||||
|
||||
Args:
|
||||
family: Model family key (e.g. "pony", "illustrious") or None
|
||||
rating: One of "safe", "questionable", "explicit"
|
||||
|
||||
Returns:
|
||||
Rating tag string to inject into prompt, or None if family has no rating system
|
||||
"""
|
||||
if not family:
|
||||
return None
|
||||
tags = RATING_TAGS.get(family)
|
||||
if not tags:
|
||||
return None
|
||||
return tags.get(rating)
|
||||
|
||||
|
||||
MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"pony": {
|
||||
"quality_prefix": "score_9, score_8_up, score_7_up",
|
||||
|
||||
Reference in New Issue
Block a user