fix null params when no model family detected in tsr generate
Co-Authored-By: marauder-os <marauder@saiden.dev>
This commit is contained in:
+346
-196
@@ -21,6 +21,12 @@ from tensors.api import (
|
|||||||
search_civitai,
|
search_civitai,
|
||||||
)
|
)
|
||||||
from tensors.config import (
|
from tensors.config import (
|
||||||
|
COMFYUI_DEFAULT_CFG,
|
||||||
|
COMFYUI_DEFAULT_HEIGHT,
|
||||||
|
COMFYUI_DEFAULT_SAMPLER,
|
||||||
|
COMFYUI_DEFAULT_SCHEDULER,
|
||||||
|
COMFYUI_DEFAULT_STEPS,
|
||||||
|
COMFYUI_DEFAULT_WIDTH,
|
||||||
CONFIG_FILE,
|
CONFIG_FILE,
|
||||||
MODEL_FAMILY_DEFAULTS,
|
MODEL_FAMILY_DEFAULTS,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -787,6 +793,10 @@ def generate( # noqa: PLR0915
|
|||||||
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
||||||
negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "",
|
negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "",
|
||||||
|
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1,
|
||||||
|
rating: Annotated[str | None, typer.Option("--rating", "-R", help="Content rating: safe, questionable, explicit (Pony/Illustrious)")] = None,
|
||||||
|
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False,
|
||||||
|
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False,
|
||||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None,
|
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None,
|
||||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
@@ -794,15 +804,19 @@ def generate( # noqa: PLR0915
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an image using text-to-image.
|
"""Generate an image using text-to-image.
|
||||||
|
|
||||||
|
Auto-detects optimal sampler, scheduler, CFG, resolution, and VAE from the checkpoint
|
||||||
|
model family. All auto-detected values can be overridden with explicit flags.
|
||||||
|
|
||||||
Calls ComfyUI directly when local, or the remote tensors API when --remote is given.
|
Calls ComfyUI directly when local, or the remote tensors API when --remote is given.
|
||||||
Accepts --input with a JSON object whose keys match CLI option names. CLI flags override JSON values.
|
Accepts --input with a JSON object whose keys match CLI option names. CLI flags override JSON values.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
tsr generate "a cat on a windowsill"
|
tsr generate "a cat on a windowsill"
|
||||||
tsr generate "portrait photo" -m "flux1-dev-fp8.safetensors" --steps 30
|
tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait
|
||||||
tsr generate "cyberpunk city" -o output.png
|
tsr generate "cyberpunk city" -o output.png --count 4
|
||||||
tsr generate "landscape" --remote junkpile
|
tsr generate "landscape" --remote junkpile
|
||||||
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors", "steps": 30}'
|
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}'
|
||||||
|
tsr generate "raw prompt" --no-quality --no-negative
|
||||||
"""
|
"""
|
||||||
import random as rng # noqa: PLC0415
|
import random as rng # noqa: PLC0415
|
||||||
|
|
||||||
@@ -877,16 +891,133 @@ def generate( # noqa: PLR0915
|
|||||||
output = Path(mapped["output"])
|
output = Path(mapped["output"])
|
||||||
if "remote" in mapped and "remote" not in explicit:
|
if "remote" in mapped and "remote" not in explicit:
|
||||||
remote = mapped["remote"]
|
remote = mapped["remote"]
|
||||||
|
if "count" in mapped and "count" not in explicit:
|
||||||
|
count = int(mapped["count"])
|
||||||
|
if "orientation" in mapped and "orientation" not in explicit:
|
||||||
|
orientation = mapped["orientation"]
|
||||||
|
if "no_quality" in mapped and "no_quality" not in explicit:
|
||||||
|
no_quality = bool(mapped["no_quality"])
|
||||||
|
if "no_negative" in mapped and "no_negative" not in explicit:
|
||||||
|
no_negative = bool(mapped["no_negative"])
|
||||||
|
if "rating" in mapped and "rating" not in explicit:
|
||||||
|
rating = mapped["rating"]
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
console.print("[red]Prompt is required (as argument or in --input JSON)[/red]")
|
console.print("[red]Prompt is required (as argument or in --input JSON)[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
|
# ---- Detect model family and enhance prompt/negative ----
|
||||||
|
family_defaults: dict[str, Any] = {}
|
||||||
|
model_family: str | None = None
|
||||||
|
base_model_str: str | None = None
|
||||||
|
if model:
|
||||||
|
try:
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
base_model_str = db.get_base_model_by_filename(model)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_family = detect_model_family(model, base_model_str)
|
||||||
|
if model_family:
|
||||||
|
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
|
||||||
|
if not json_output:
|
||||||
|
console.print(f"[dim]Detected model family: {model_family}[/dim]")
|
||||||
|
|
||||||
|
# Build enhanced prompt with quality prefix and LoRA trigger words
|
||||||
|
prompt_parts: list[str] = []
|
||||||
|
|
||||||
|
# Add LoRA trigger words if using LoRA
|
||||||
|
if lora:
|
||||||
|
try:
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
trigger_words = db.get_trigger_words_by_filename(lora)
|
||||||
|
if trigger_words:
|
||||||
|
prompt_parts.extend(trigger_words)
|
||||||
|
if not json_output:
|
||||||
|
console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add quality prefix based on model family
|
||||||
|
if not no_quality and family_defaults.get("quality_prefix"):
|
||||||
|
prompt_parts.append(family_defaults["quality_prefix"])
|
||||||
|
|
||||||
|
# Add rating tag based on model family (Pony/Illustrious)
|
||||||
|
if rating:
|
||||||
|
from tensors.config import get_rating_tag # noqa: PLC0415
|
||||||
|
|
||||||
|
rating_tag = get_rating_tag(model_family, rating.lower())
|
||||||
|
if rating_tag:
|
||||||
|
prompt_parts.append(rating_tag)
|
||||||
|
if not json_output:
|
||||||
|
console.print(f"[dim]Rating tag: {rating_tag}[/dim]")
|
||||||
|
elif not json_output:
|
||||||
|
console.print(f"[dim]Rating '{rating}' not applicable for {model_family or 'unknown'} family[/dim]")
|
||||||
|
|
||||||
|
# Add user prompt
|
||||||
|
prompt_parts.append(prompt)
|
||||||
|
enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt
|
||||||
|
|
||||||
|
# Build enhanced negative prompt
|
||||||
|
enhanced_negative = negative
|
||||||
|
if not no_negative and family_defaults.get("negative_prompt"):
|
||||||
|
family_negative = family_defaults["negative_prompt"]
|
||||||
|
enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative
|
||||||
|
|
||||||
|
if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative):
|
||||||
|
if enhanced_prompt != prompt:
|
||||||
|
truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004
|
||||||
|
console.print(f"[dim]Enhanced prompt: {truncated}[/dim]")
|
||||||
|
if enhanced_negative != negative:
|
||||||
|
truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004
|
||||||
|
console.print(f"[dim]Enhanced negative: {truncated}[/dim]")
|
||||||
|
|
||||||
|
# ---- Resolve preset defaults for None params (both remote and local need these) ----
|
||||||
|
from tensors.config import resolve_orientation, resolve_remote as do_resolve_remote # noqa: PLC0415
|
||||||
|
|
||||||
|
# Use already-detected family_defaults from DB lookup above (not filename guessing)
|
||||||
|
if family_defaults:
|
||||||
|
res_w, res_h = resolve_orientation(model_family, orientation)
|
||||||
|
if width is None:
|
||||||
|
width = res_w
|
||||||
|
if height is None:
|
||||||
|
height = res_h
|
||||||
|
if steps is None:
|
||||||
|
steps = family_defaults.get("steps", 20)
|
||||||
|
if cfg is None:
|
||||||
|
cfg = family_defaults.get("cfg", 7.0)
|
||||||
|
if sampler is None:
|
||||||
|
sampler = family_defaults.get("sampler", "euler")
|
||||||
|
if scheduler is None:
|
||||||
|
scheduler = family_defaults.get("scheduler", "normal")
|
||||||
|
if vae is None:
|
||||||
|
vae = family_defaults.get("vae")
|
||||||
|
|
||||||
|
# Fallback to global defaults when no model family was detected
|
||||||
|
if width is None:
|
||||||
|
width = COMFYUI_DEFAULT_WIDTH
|
||||||
|
if height is None:
|
||||||
|
height = COMFYUI_DEFAULT_HEIGHT
|
||||||
|
if steps is None:
|
||||||
|
steps = COMFYUI_DEFAULT_STEPS
|
||||||
|
if cfg is None:
|
||||||
|
cfg = COMFYUI_DEFAULT_CFG
|
||||||
|
if sampler is None:
|
||||||
|
sampler = COMFYUI_DEFAULT_SAMPLER
|
||||||
|
if scheduler is None:
|
||||||
|
scheduler = COMFYUI_DEFAULT_SCHEDULER
|
||||||
|
|
||||||
|
# ---- Determine base seed ----
|
||||||
|
base_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
# Resolve remote (explicit flag, or default from config)
|
# Resolve remote (explicit flag, or default from config)
|
||||||
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
|
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
|
||||||
|
|
||||||
|
all_results: list[dict[str, Any]] = []
|
||||||
|
all_saved: list[Path] = []
|
||||||
|
|
||||||
if remote_url:
|
if remote_url:
|
||||||
# ---- Remote mode: HTTP call to tensors server ----
|
# ---- Remote mode: HTTP call to tensors server ----
|
||||||
if not json_output:
|
if not json_output:
|
||||||
@@ -894,14 +1025,14 @@ def generate( # noqa: PLR0915
|
|||||||
|
|
||||||
result = remote_generate(
|
result = remote_generate(
|
||||||
remote or remote_url,
|
remote or remote_url,
|
||||||
prompt,
|
enhanced_prompt,
|
||||||
negative_prompt=negative,
|
negative_prompt=enhanced_negative,
|
||||||
model=model,
|
model=model,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
seed=seed,
|
seed=base_seed,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
@@ -948,69 +1079,188 @@ def generate( # noqa: PLR0915
|
|||||||
# ---- Local mode: direct library call ----
|
# ---- Local mode: direct library call ----
|
||||||
from tensors.comfyui import generate_image, get_image # noqa: PLC0415
|
from tensors.comfyui import generate_image, get_image # noqa: PLC0415
|
||||||
|
|
||||||
actual_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1)
|
|
||||||
|
|
||||||
result_local = generate_image(
|
result_local = generate_image(
|
||||||
prompt=prompt,
|
prompt=enhanced_prompt,
|
||||||
negative_prompt=negative,
|
negative_prompt=enhanced_negative,
|
||||||
model=model,
|
model=model,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
seed=actual_seed,
|
seed=base_seed,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
console=console if not json_output else None,
|
console=console if not json_output else None,
|
||||||
lora_name=lora,
|
lora_name=lora,
|
||||||
lora_strength=lora_strength,
|
lora_strength=lora_strength,
|
||||||
|
batch_size=count,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
orientation=orientation,
|
orientation=orientation,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result_local:
|
if not result_local:
|
||||||
if json_output:
|
if json_output:
|
||||||
console.print_json(data={"success": False, "errors": {"generation": "Failed to generate"}})
|
all_results.append({"success": False, "index": 0, "errors": {"generation": "Failed to generate"}})
|
||||||
else:
|
else:
|
||||||
console.print("[red]Generation failed[/red]")
|
console.print("[red]Generation failed[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
elif not result_local.success:
|
||||||
if not result_local.success:
|
|
||||||
if json_output:
|
if json_output:
|
||||||
console.print_json(data={"success": False, "errors": result_local.node_errors})
|
all_results.append({"success": False, "index": 0, "errors": result_local.node_errors})
|
||||||
else:
|
else:
|
||||||
console.print("[red]Generation failed[/red]")
|
console.print("[red]Generation failed[/red]")
|
||||||
for node_id, errors in result_local.node_errors.items():
|
for node_id, errors in result_local.node_errors.items():
|
||||||
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
|
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
else:
|
||||||
# Save images
|
# Save all output images
|
||||||
saved_paths: list[Path] = []
|
|
||||||
for i, img_path in enumerate(result_local.images):
|
for i, img_path in enumerate(result_local.images):
|
||||||
|
saved_path: Path | None = None
|
||||||
if output:
|
if output:
|
||||||
img_data = get_image(str(img_path))
|
img_data = get_image(str(img_path))
|
||||||
if img_data:
|
if img_data:
|
||||||
save_path = (
|
save_path = (
|
||||||
output if len(result_local.images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
output
|
||||||
|
if count == 1
|
||||||
|
else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||||
)
|
)
|
||||||
save_path.write_bytes(img_data)
|
save_path.write_bytes(img_data)
|
||||||
saved_paths.append(save_path)
|
saved_path = save_path
|
||||||
|
all_saved.append(save_path)
|
||||||
if not json_output:
|
if not json_output:
|
||||||
console.print(f"[green]Saved:[/green] {save_path}")
|
console.print(f"[green]Saved:[/green] {save_path}")
|
||||||
|
elif not json_output:
|
||||||
|
console.print(f"[yellow]Could not download image: {img_path}[/yellow]")
|
||||||
|
|
||||||
|
all_results.append(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"index": i,
|
||||||
|
"prompt_id": result_local.prompt_id,
|
||||||
|
"image": str(img_path),
|
||||||
|
"saved": str(saved_path) if saved_path else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if json_output:
|
if json_output:
|
||||||
console.print_json(
|
console.print_json(
|
||||||
data={
|
data={
|
||||||
"success": True,
|
"success": all(r.get("success", False) for r in all_results),
|
||||||
"prompt_id": result_local.prompt_id,
|
"count": len(all_results),
|
||||||
"images": [str(p) for p in result_local.images],
|
"results": all_results,
|
||||||
"saved": [str(p) for p in saved_paths],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
console.print("[bold green]Generation complete![/bold green]")
|
console.print("[bold green]Generation complete![/bold green]")
|
||||||
console.print(f"[dim]Prompt ID: {result_local.prompt_id}[/dim]")
|
if count > 1:
|
||||||
|
successful = sum(1 for r in all_results if r.get("success", False))
|
||||||
|
console.print(f"[dim]Generated {successful}/{count} images[/dim]")
|
||||||
|
if all_saved:
|
||||||
|
console.print(f"[dim]Saved to: {all_saved[0].parent}/[/dim]")
|
||||||
|
elif all_results and all_results[0].get("prompt_id"):
|
||||||
|
console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Template Dump
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def template(
|
||||||
|
model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")],
|
||||||
|
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||||
|
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
||||||
|
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "square",
|
||||||
|
rating: Annotated[str | None, typer.Option("--rating", "-R", help="Content rating: safe, questionable, explicit")] = None,
|
||||||
|
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save template to file")] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Dump a JSON generation template with resolved defaults for a model.
|
||||||
|
|
||||||
|
Outputs a ready-to-use JSON object with all parameters auto-resolved from the
|
||||||
|
checkpoint family. Pipe to 'tsr generate --input' or save to a file for reuse.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
tsr template -m ponyDiffusionV6XL_v6StartWithThisOne.safetensors
|
||||||
|
tsr template -m beautifulRealistic_v7.safetensors -O portrait
|
||||||
|
tsr template -m waiIllustriousSDXL_v160.safetensors -l "Elvira iIlluLoRA.safetensors"
|
||||||
|
tsr template -m ponyRealism_V22.safetensors -o pony_preset.json
|
||||||
|
tsr generate --input "$(tsr template -m ponyRealism_V22.safetensors)" "a portrait"
|
||||||
|
"""
|
||||||
|
from tensors.config import get_model_generation_defaults, resolve_orientation # noqa: PLC0415
|
||||||
|
|
||||||
|
# Look up base_model from DB for accurate family detection
|
||||||
|
base_model_str: str | None = None
|
||||||
|
try:
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
base_model_str = db.get_base_model_by_filename(model)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
family = detect_model_family(model, base_model_str)
|
||||||
|
defaults = get_model_generation_defaults(model, base_model_str)
|
||||||
|
res_w, res_h = resolve_orientation(family, orientation)
|
||||||
|
|
||||||
|
# Build template
|
||||||
|
tpl: dict[str, Any] = {
|
||||||
|
"prompt": "",
|
||||||
|
"negative_prompt": defaults.get("negative_prompt", ""),
|
||||||
|
"model": model,
|
||||||
|
"width": res_w,
|
||||||
|
"height": res_h,
|
||||||
|
"steps": defaults.get("steps"),
|
||||||
|
"cfg": defaults.get("cfg"),
|
||||||
|
"sampler": defaults.get("sampler"),
|
||||||
|
"scheduler": defaults.get("scheduler"),
|
||||||
|
"vae": defaults.get("vae"),
|
||||||
|
"orientation": orientation,
|
||||||
|
"seed": -1,
|
||||||
|
"count": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add quality prefix if the family has one
|
||||||
|
quality_prefix = defaults.get("quality_prefix", "")
|
||||||
|
if quality_prefix:
|
||||||
|
tpl["quality_prefix"] = quality_prefix
|
||||||
|
|
||||||
|
# Add rating tag if specified
|
||||||
|
if rating:
|
||||||
|
from tensors.config import get_rating_tag # noqa: PLC0415
|
||||||
|
|
||||||
|
rating_tag = get_rating_tag(family, rating.lower())
|
||||||
|
if rating_tag:
|
||||||
|
tpl["rating"] = rating
|
||||||
|
tpl["rating_tag"] = rating_tag
|
||||||
|
|
||||||
|
# Add LoRA info
|
||||||
|
if lora:
|
||||||
|
tpl["lora"] = lora
|
||||||
|
tpl["lora_strength"] = lora_strength
|
||||||
|
|
||||||
|
# Look up trigger words
|
||||||
|
try:
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
trigger_words = db.get_trigger_words_by_filename(lora)
|
||||||
|
if trigger_words:
|
||||||
|
tpl["lora_triggers"] = trigger_words
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add metadata (not used by generate, but informational)
|
||||||
|
tpl["_family"] = family or "unknown"
|
||||||
|
if base_model_str:
|
||||||
|
tpl["_base_model"] = base_model_str
|
||||||
|
|
||||||
|
json_str = json.dumps(tpl, indent=2)
|
||||||
|
|
||||||
|
if output:
|
||||||
|
output.write_text(json_str + "\n")
|
||||||
|
console.print(f"[green]Saved template:[/green] {output}")
|
||||||
|
else:
|
||||||
|
console.print(json_str)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -1231,22 +1481,49 @@ def db_cache(
|
|||||||
|
|
||||||
@db_app.command("list")
|
@db_app.command("list")
|
||||||
def db_list(
|
def db_list(
|
||||||
|
model_type: Annotated[str | None, typer.Option("-t", "--type", help="Filter by model type (Checkpoint, LORA, VAE, etc.)")] = None,
|
||||||
|
base: Annotated[str | None, typer.Option("-b", "--base", help="Filter by base model (Pony, Illustrious, SDXL 1.0, SD 1.5, etc.)")] = None,
|
||||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List local files with CivitAI info."""
|
"""List local files with CivitAI info.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
tsr db list # All local files
|
||||||
|
tsr db list -t Checkpoint # Only checkpoints
|
||||||
|
tsr db list -t LORA # Only LoRAs
|
||||||
|
tsr db list -t Checkpoint -b Pony # Pony checkpoints only
|
||||||
|
tsr db list -b "SDXL 1.0" # All SDXL 1.0 models
|
||||||
|
"""
|
||||||
with Database() as db:
|
with Database() as db:
|
||||||
db.init_schema()
|
db.init_schema()
|
||||||
files = db.list_local_files()
|
files = db.list_local_files()
|
||||||
|
|
||||||
|
# Apply filters (case-insensitive substring match)
|
||||||
|
if model_type:
|
||||||
|
mt_lower = model_type.lower()
|
||||||
|
files = [f for f in files if (f.get("model_type") or "").lower() == mt_lower]
|
||||||
|
if base:
|
||||||
|
base_lower = base.lower()
|
||||||
|
files = [f for f in files if base_lower in (f.get("base_model") or "").lower()]
|
||||||
|
|
||||||
if json_output:
|
if json_output:
|
||||||
console.print_json(data=files)
|
console.print_json(data=files)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not files:
|
if not files:
|
||||||
console.print("[yellow]No files in database. Run 'tsr db scan' first.[/yellow]")
|
console.print("[yellow]No files found. Try 'tsr db scan' or adjust filters.[/yellow]")
|
||||||
return
|
return
|
||||||
|
|
||||||
table = Table(title="Local Files", show_header=True, header_style="bold magenta")
|
title = "Local Files"
|
||||||
|
if model_type or base:
|
||||||
|
parts = []
|
||||||
|
if model_type:
|
||||||
|
parts.append(model_type)
|
||||||
|
if base:
|
||||||
|
parts.append(base)
|
||||||
|
title = f"Local Files ({', '.join(parts)})"
|
||||||
|
|
||||||
|
table = Table(title=title, show_header=True, header_style="bold magenta")
|
||||||
table.add_column("Path", style="cyan", max_width=50)
|
table.add_column("Path", style="cyan", max_width=50)
|
||||||
table.add_column("Model", style="green")
|
table.add_column("Model", style="green")
|
||||||
table.add_column("Version", style="white")
|
table.add_column("Version", style="white")
|
||||||
@@ -1257,9 +1534,9 @@ def db_list(
|
|||||||
path = Path(f["file_path"]).name
|
path = Path(f["file_path"]).name
|
||||||
model = f.get("model_name") or "[dim]unlinked[/dim]"
|
model = f.get("model_name") or "[dim]unlinked[/dim]"
|
||||||
version = f.get("version_name") or ""
|
version = f.get("version_name") or ""
|
||||||
model_type = f.get("model_type") or ""
|
ft = f.get("model_type") or ""
|
||||||
base = f.get("base_model") or ""
|
base_model = f.get("base_model") or ""
|
||||||
table.add_row(path, model, version, model_type, base)
|
table.add_row(path, model, version, ft, base_model)
|
||||||
|
|
||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
@@ -1653,183 +1930,56 @@ def comfy_history(
|
|||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
@comfy_app.command("generate")
|
@comfy_app.command("generate", deprecated=True)
|
||||||
def comfy_generate( # noqa: PLR0915
|
def comfy_generate(
|
||||||
prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
|
prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
|
||||||
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
|
|
||||||
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
|
|
||||||
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
|
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
|
||||||
width: Annotated[int | None, typer.Option("-W", "--width", help="Image width (auto from checkpoint)")] = None,
|
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
|
||||||
height: Annotated[int | None, typer.Option("-H", "--height", help="Image height (auto from checkpoint)")] = None,
|
width: Annotated[int | None, typer.Option("-W", "--width")] = None,
|
||||||
steps: Annotated[int | None, typer.Option("--steps", help="Sampling steps (auto from checkpoint)")] = None,
|
height: Annotated[int | None, typer.Option("-H", "--height")] = None,
|
||||||
cfg: Annotated[float | None, typer.Option("--cfg", help="CFG scale (auto from checkpoint)")] = None,
|
steps: Annotated[int | None, typer.Option("--steps")] = None,
|
||||||
seed: Annotated[int, typer.Option("--seed", "-s", help="Random seed (-1 for random)")] = -1,
|
cfg: Annotated[float | None, typer.Option("--cfg")] = None,
|
||||||
sampler: Annotated[str | None, typer.Option("--sampler", help="Sampler name (auto from checkpoint)")] = None,
|
seed: Annotated[int, typer.Option("--seed", "-s")] = -1,
|
||||||
scheduler: Annotated[str | None, typer.Option("--scheduler", help="Scheduler name (auto from checkpoint)")] = None,
|
sampler: Annotated[str | None, typer.Option("--sampler")] = None,
|
||||||
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "square",
|
scheduler: Annotated[str | None, typer.Option("--scheduler")] = None,
|
||||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None,
|
orientation: Annotated[str, typer.Option("-O", "--orientation")] = "square",
|
||||||
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1,
|
output: Annotated[Path | None, typer.Option("-o", "--output")] = None,
|
||||||
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
count: Annotated[int, typer.Option("-c", "--count")] = 1,
|
||||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 1.0,
|
lora: Annotated[str | None, typer.Option("-l", "--lora")] = None,
|
||||||
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False,
|
lora_strength: Annotated[float, typer.Option("--lora-strength")] = 0.8,
|
||||||
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False,
|
no_quality: Annotated[bool, typer.Option("--no-quality")] = False,
|
||||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
no_negative: Annotated[bool, typer.Option("--no-negative")] = False,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j")] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an image with a simple text-to-image workflow.
|
"""[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command."""
|
||||||
|
console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]")
|
||||||
Examples:
|
# Delegate to the unified generate command via context invocation
|
||||||
tsr comfy generate "a cat sitting on a windowsill"
|
ctx = typer.Context(generate)
|
||||||
tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30
|
generate(
|
||||||
tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768
|
ctx=ctx,
|
||||||
tsr comfy generate "cyberpunk city" --count 4 -o batch.png
|
prompt=prompt,
|
||||||
tsr comfy generate "girl" --lora spumcostyle.safetensors --lora-strength 0.8
|
|
||||||
tsr comfy generate "raw prompt" --no-quality --no-negative
|
|
||||||
"""
|
|
||||||
import random # noqa: PLC0415
|
|
||||||
|
|
||||||
from tensors.comfyui import generate_image, get_image # noqa: PLC0415
|
|
||||||
|
|
||||||
all_results: list[dict[str, Any]] = []
|
|
||||||
all_saved: list[Path] = []
|
|
||||||
|
|
||||||
# Determine base seed for batch
|
|
||||||
base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
|
||||||
|
|
||||||
# Detect model family and apply defaults
|
|
||||||
family_defaults: dict[str, Any] = {}
|
|
||||||
model_family: str | None = None
|
|
||||||
if model:
|
|
||||||
# Try to get base_model from database
|
|
||||||
base_model_str: str | None = None
|
|
||||||
try:
|
|
||||||
with Database() as db:
|
|
||||||
db.init_schema()
|
|
||||||
base_model_str = db.get_base_model_by_filename(model)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
model_family = detect_model_family(model, base_model_str)
|
|
||||||
if model_family:
|
|
||||||
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
|
|
||||||
if not json_output:
|
|
||||||
console.print(f"[dim]Detected model family: {model_family}[/dim]")
|
|
||||||
|
|
||||||
# Build enhanced prompt with quality prefix and LoRA trigger words
|
|
||||||
enhanced_prompt = prompt
|
|
||||||
prompt_parts: list[str] = []
|
|
||||||
|
|
||||||
# Add LoRA trigger words if using LoRA
|
|
||||||
if lora:
|
|
||||||
try:
|
|
||||||
with Database() as db:
|
|
||||||
db.init_schema()
|
|
||||||
trigger_words = db.get_trigger_words_by_filename(lora)
|
|
||||||
if trigger_words:
|
|
||||||
prompt_parts.extend(trigger_words)
|
|
||||||
if not json_output:
|
|
||||||
console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Add quality prefix based on model family
|
|
||||||
if not no_quality and family_defaults.get("quality_prefix"):
|
|
||||||
prompt_parts.append(family_defaults["quality_prefix"])
|
|
||||||
|
|
||||||
# Add user prompt
|
|
||||||
prompt_parts.append(prompt)
|
|
||||||
enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt
|
|
||||||
|
|
||||||
# Build enhanced negative prompt
|
|
||||||
enhanced_negative = negative
|
|
||||||
if not no_negative and family_defaults.get("negative_prompt"):
|
|
||||||
family_negative = family_defaults["negative_prompt"]
|
|
||||||
enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative
|
|
||||||
|
|
||||||
if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative):
|
|
||||||
if enhanced_prompt != prompt:
|
|
||||||
truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004
|
|
||||||
console.print(f"[dim]Enhanced prompt: {truncated}[/dim]")
|
|
||||||
if enhanced_negative != negative:
|
|
||||||
truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004
|
|
||||||
console.print(f"[dim]Enhanced negative: {truncated}[/dim]")
|
|
||||||
|
|
||||||
# Use native ComfyUI batching - single workflow generates all images
|
|
||||||
result = generate_image(
|
|
||||||
prompt=enhanced_prompt,
|
|
||||||
url=url,
|
|
||||||
negative_prompt=enhanced_negative,
|
|
||||||
model=model,
|
model=model,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
seed=base_seed,
|
seed=seed,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
console=console if not json_output else None,
|
vae=None,
|
||||||
lora_name=lora,
|
|
||||||
lora_strength=lora_strength,
|
|
||||||
batch_size=count,
|
|
||||||
orientation=orientation,
|
orientation=orientation,
|
||||||
|
lora=lora,
|
||||||
|
lora_strength=lora_strength,
|
||||||
|
negative=negative,
|
||||||
|
count=count,
|
||||||
|
no_quality=no_quality,
|
||||||
|
no_negative=no_negative,
|
||||||
|
output=output,
|
||||||
|
remote=None,
|
||||||
|
json_output=json_output,
|
||||||
|
json_input=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result:
|
|
||||||
if json_output:
|
|
||||||
all_results.append({"success": False, "index": 0, "errors": {"generation": "Failed to generate"}})
|
|
||||||
else:
|
|
||||||
console.print("[red]Generation failed[/red]")
|
|
||||||
elif not result.success:
|
|
||||||
if json_output:
|
|
||||||
all_results.append({"success": False, "index": 0, "errors": result.node_errors})
|
|
||||||
else:
|
|
||||||
console.print("[red]Generation failed[/red]")
|
|
||||||
for node_id, errors in result.node_errors.items():
|
|
||||||
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
|
|
||||||
else:
|
|
||||||
# Save all output images
|
|
||||||
for i, img_path in enumerate(result.images):
|
|
||||||
saved_path: Path | None = None
|
|
||||||
if output:
|
|
||||||
img_data = get_image(str(img_path), url=url)
|
|
||||||
if img_data:
|
|
||||||
save_path = output if count == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
|
||||||
save_path.write_bytes(img_data)
|
|
||||||
saved_path = save_path
|
|
||||||
all_saved.append(save_path)
|
|
||||||
if not json_output:
|
|
||||||
console.print(f"[green]Saved:[/green] {save_path}")
|
|
||||||
elif not json_output:
|
|
||||||
console.print(f"[yellow]Could not download image: {img_path}[/yellow]")
|
|
||||||
|
|
||||||
all_results.append(
|
|
||||||
{
|
|
||||||
"success": True,
|
|
||||||
"index": i,
|
|
||||||
"prompt_id": result.prompt_id,
|
|
||||||
"image": str(img_path),
|
|
||||||
"saved": str(saved_path) if saved_path else None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if json_output:
|
|
||||||
console.print_json(
|
|
||||||
data={
|
|
||||||
"success": all(r.get("success", False) for r in all_results),
|
|
||||||
"count": len(all_results),
|
|
||||||
"results": all_results,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
console.print("\n[bold green]Generation complete![/bold green]")
|
|
||||||
if count > 1:
|
|
||||||
successful = sum(1 for r in all_results if r.get("success", False))
|
|
||||||
console.print(f"[dim]Generated {successful}/{count} images[/dim]")
|
|
||||||
if all_saved:
|
|
||||||
console.print(f"[dim]Saved to: {all_saved[0].parent}/[/dim]")
|
|
||||||
elif all_results and all_results[0].get("prompt_id"):
|
|
||||||
console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]")
|
|
||||||
|
|
||||||
|
|
||||||
@comfy_app.command("run")
|
@comfy_app.command("run")
|
||||||
def comfy_run(
|
def comfy_run(
|
||||||
|
|||||||
@@ -506,6 +506,42 @@ COMFYUI_DEFAULT_SCHEDULER = "normal"
|
|||||||
# Model Family Defaults (Quality Tags, Negative Prompts, etc.)
|
# Model Family Defaults (Quality Tags, Negative Prompts, etc.)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
# Rating tags per model family — maps (family, rating) to the tag to inject
|
||||||
|
# Families not listed here have no rating tag system (prompt-driven only)
|
||||||
|
RATING_TAGS: dict[str, dict[str, str]] = {
|
||||||
|
"pony": {
|
||||||
|
"safe": "rating_safe",
|
||||||
|
"questionable": "rating_questionable",
|
||||||
|
"explicit": "rating_explicit",
|
||||||
|
},
|
||||||
|
"illustrious": {
|
||||||
|
"safe": "rating:safe",
|
||||||
|
"questionable": "rating:questionable",
|
||||||
|
"explicit": "rating:explicit",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# NoobAI uses same tags as Illustrious
|
||||||
|
RATING_TAGS["noobai"] = RATING_TAGS["illustrious"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_rating_tag(family: str | None, rating: str) -> str | None:
|
||||||
|
"""Get the rating tag for a model family and rating level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
family: Model family key (e.g. "pony", "illustrious") or None
|
||||||
|
rating: One of "safe", "questionable", "explicit"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rating tag string to inject into prompt, or None if family has no rating system
|
||||||
|
"""
|
||||||
|
if not family:
|
||||||
|
return None
|
||||||
|
tags = RATING_TAGS.get(family)
|
||||||
|
if not tags:
|
||||||
|
return None
|
||||||
|
return tags.get(rating)
|
||||||
|
|
||||||
|
|
||||||
MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||||
"pony": {
|
"pony": {
|
||||||
"quality_prefix": "score_9, score_8_up, score_7_up",
|
"quality_prefix": "score_9, score_8_up, score_7_up",
|
||||||
|
|||||||
Reference in New Issue
Block a user