Add LoRA support and model family quality defaults to comfy generate

This commit is contained in:
Adam Ladachowski
2026-02-19 04:30:32 +01:00
parent 91c5c1e0a7
commit 1ed40a3142
5 changed files with 248 additions and 8 deletions
BIN
View File
Binary file not shown.
+72 -8
View File
@@ -21,6 +21,7 @@ from tensors.api import (
) )
from tensors.config import ( from tensors.config import (
CONFIG_FILE, CONFIG_FILE,
MODEL_FAMILY_DEFAULTS,
BaseModel, BaseModel,
CommercialUse, CommercialUse,
ModelType, ModelType,
@@ -28,6 +29,7 @@ from tensors.config import (
Period, Period,
Provider, Provider,
SortOrder, SortOrder,
detect_model_family,
get_default_output_path, get_default_output_path,
get_model_paths, get_model_paths,
load_api_key, load_api_key,
@@ -1118,7 +1120,7 @@ def comfy_history(
@comfy_app.command("generate") @comfy_app.command("generate")
def comfy_generate( def comfy_generate( # noqa: PLR0915
prompt: Annotated[str, typer.Argument(help="Positive prompt text")], prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None, url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "", negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
@@ -1132,6 +1134,10 @@ def comfy_generate(
scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal", scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal",
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None, output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None,
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1, count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1,
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 1.0,
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False,
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None: ) -> None:
"""Generate an image with a simple text-to-image workflow. """Generate an image with a simple text-to-image workflow.
@@ -1141,6 +1147,8 @@ def comfy_generate(
tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30 tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30
tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768 tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768
tsr comfy generate "cyberpunk city" --count 4 -o batch.png tsr comfy generate "cyberpunk city" --count 4 -o batch.png
tsr comfy generate "girl" --lora spumcostyle.safetensors --lora-strength 0.8
tsr comfy generate "raw prompt" --no-quality --no-negative
""" """
import random # noqa: PLC0415 import random # noqa: PLC0415
@@ -1152,6 +1160,64 @@ def comfy_generate(
# Determine base seed for batch # Determine base seed for batch
base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
# Detect model family and apply defaults
family_defaults: dict[str, Any] = {}
model_family: str | None = None
if model:
# Try to get base_model from database
base_model_str: str | None = None
try:
with Database() as db:
db.init_schema()
base_model_str = db.get_base_model_by_filename(model)
except Exception:
pass
model_family = detect_model_family(model, base_model_str)
if model_family:
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
if not json_output:
console.print(f"[dim]Detected model family: {model_family}[/dim]")
# Build enhanced prompt with quality prefix and LoRA trigger words
enhanced_prompt = prompt
prompt_parts: list[str] = []
# Add LoRA trigger words if using LoRA
if lora:
try:
with Database() as db:
db.init_schema()
trigger_words = db.get_trigger_words_by_filename(lora)
if trigger_words:
prompt_parts.extend(trigger_words)
if not json_output:
console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]")
except Exception:
pass
# Add quality prefix based on model family
if not no_quality and family_defaults.get("quality_prefix"):
prompt_parts.append(family_defaults["quality_prefix"])
# Add user prompt
prompt_parts.append(prompt)
enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt
# Build enhanced negative prompt
enhanced_negative = negative
if not no_negative and family_defaults.get("negative_prompt"):
family_negative = family_defaults["negative_prompt"]
enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative
if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative):
if enhanced_prompt != prompt:
truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004
console.print(f"[dim]Enhanced prompt: {truncated}[/dim]")
if enhanced_negative != negative:
truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004
console.print(f"[dim]Enhanced negative: {truncated}[/dim]")
for i in range(count): for i in range(count):
current_seed = base_seed + i if seed >= 0 else -1 # Increment seed or use random each time current_seed = base_seed + i if seed >= 0 else -1 # Increment seed or use random each time
@@ -1159,9 +1225,9 @@ def comfy_generate(
console.print(f"\n[cyan]Generating image {i + 1}/{count}...[/cyan]") console.print(f"\n[cyan]Generating image {i + 1}/{count}...[/cyan]")
result = generate_image( result = generate_image(
prompt=prompt, prompt=enhanced_prompt,
url=url, url=url,
negative_prompt=negative, negative_prompt=enhanced_negative,
model=model, model=model,
width=width, width=width,
height=height, height=height,
@@ -1171,6 +1237,8 @@ def comfy_generate(
sampler=sampler, sampler=sampler,
scheduler=scheduler, scheduler=scheduler,
console=console if not json_output else None, console=console if not json_output else None,
lora_name=lora,
lora_strength=lora_strength,
) )
if not result: if not result:
@@ -1195,11 +1263,7 @@ def comfy_generate(
img_path = result.images[0] img_path = result.images[0]
img_data = get_image(str(img_path), url=url) img_data = get_image(str(img_path), url=url)
if img_data: if img_data:
if count == 1: save_path = output if count == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
save_path = output
else:
# Add index suffix for batch: output.png -> output_001.png
save_path = output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
save_path.write_bytes(img_data) save_path.write_bytes(img_data)
saved_path = save_path saved_path = save_path
all_saved.append(save_path) all_saved.append(save_path)
+41
View File
@@ -526,6 +526,18 @@ def run_workflow(
# Simple Text-to-Image Generation # Simple Text-to-Image Generation
# ============================================================================ # ============================================================================
# LoRA loader node template (inserted between checkpoint and sampler)
LORA_LOADER_NODE: dict[str, Any] = {
"class_type": "LoraLoader",
"inputs": {
"lora_name": "",
"strength_model": 1.0,
"strength_clip": 1.0,
"model": ["4", 0], # From checkpoint
"clip": ["4", 1], # From checkpoint
},
}
# Default SDXL/Flux compatible workflow template # Default SDXL/Flux compatible workflow template
# This is a minimal text-to-image workflow that works with most models # This is a minimal text-to-image workflow that works with most models
DEFAULT_WORKFLOW_TEMPLATE: dict[str, Any] = { DEFAULT_WORKFLOW_TEMPLATE: dict[str, Any] = {
@@ -582,6 +594,8 @@ def _build_workflow(
seed: int = -1, seed: int = -1,
sampler: str = "euler", sampler: str = "euler",
scheduler: str = "normal", scheduler: str = "normal",
lora_name: str | None = None,
lora_strength: float = 1.0,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Build a text-to-image workflow from parameters. """Build a text-to-image workflow from parameters.
@@ -596,6 +610,8 @@ def _build_workflow(
seed: Random seed (-1 for random) seed: Random seed (-1 for random)
sampler: Sampler name sampler: Sampler name
scheduler: Scheduler name scheduler: Scheduler name
lora_name: LoRA model filename (optional)
lora_strength: LoRA strength (default 1.0)
Returns: Returns:
ComfyUI workflow dict ComfyUI workflow dict
@@ -624,6 +640,25 @@ def _build_workflow(
workflow["6"]["inputs"]["text"] = prompt workflow["6"]["inputs"]["text"] = prompt
workflow["7"]["inputs"]["text"] = negative_prompt workflow["7"]["inputs"]["text"] = negative_prompt
# Inject LoRA loader if specified
if lora_name:
# Add LoRA loader node (node 10)
lora_node = copy.deepcopy(LORA_LOADER_NODE)
lora_node["inputs"]["lora_name"] = lora_name
lora_node["inputs"]["strength_model"] = lora_strength
lora_node["inputs"]["strength_clip"] = lora_strength
# LoRA takes model/clip from checkpoint (node 4)
lora_node["inputs"]["model"] = ["4", 0]
lora_node["inputs"]["clip"] = ["4", 1]
workflow["10"] = lora_node
# Reroute KSampler model input from checkpoint (4) to LoRA (10)
workflow["3"]["inputs"]["model"] = ["10", 0]
# Reroute CLIP text encoders from checkpoint (4) to LoRA (10)
workflow["6"]["inputs"]["clip"] = ["10", 1]
workflow["7"]["inputs"]["clip"] = ["10", 1]
return workflow return workflow
@@ -642,6 +677,8 @@ def generate_image(
console: Console | None = None, console: Console | None = None,
on_progress: ProgressCallback | None = None, on_progress: ProgressCallback | None = None,
timeout: float = 600.0, timeout: float = 600.0,
lora_name: str | None = None,
lora_strength: float = 1.0,
) -> GenerationResult | None: ) -> GenerationResult | None:
"""Generate an image using a simple text-to-image workflow. """Generate an image using a simple text-to-image workflow.
@@ -660,6 +697,8 @@ def generate_image(
console: Rich console for progress output console: Rich console for progress output
on_progress: Optional callback for progress updates on_progress: Optional callback for progress updates
timeout: Maximum wait time in seconds timeout: Maximum wait time in seconds
lora_name: LoRA model filename (optional)
lora_strength: LoRA strength (default 1.0)
Returns: Returns:
GenerationResult with image paths, or None if generation failed GenerationResult with image paths, or None if generation failed
@@ -690,6 +729,8 @@ def generate_image(
seed=seed, seed=seed,
sampler=sampler, sampler=sampler,
scheduler=scheduler, scheduler=scheduler,
lora_name=lora_name,
lora_strength=lora_strength,
) )
# Run workflow # Run workflow
+87
View File
@@ -502,6 +502,93 @@ COMFYUI_DEFAULT_CFG = 7.0
COMFYUI_DEFAULT_SAMPLER = "euler" COMFYUI_DEFAULT_SAMPLER = "euler"
COMFYUI_DEFAULT_SCHEDULER = "normal" COMFYUI_DEFAULT_SCHEDULER = "normal"
# ============================================================================
# Model Family Defaults (Quality Tags, Negative Prompts, etc.)
# ============================================================================
MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"pony": {
"quality_prefix": "score_9, score_8_up, score_7_up",
"negative_prompt": "score_5, score_4, ugly, deformed, blurry, bad anatomy, bad hands, missing fingers",
"width": 1024,
"height": 1024,
"cfg": 7.0,
"clip_skip": 2,
},
"illustrious": {
"quality_prefix": "masterpiece, best quality, highres",
"negative_prompt": "worst quality, bad quality, low quality, lowres, bad anatomy, bad hands, jpeg artifacts, watermark",
"width": 1024,
"height": 1024,
"cfg": 6.0,
},
"sdxl": {
"quality_prefix": "",
"negative_prompt": "ugly, deformed, bad anatomy, bad hands, extra fingers, missing fingers, blurry, watermark",
"width": 1024,
"height": 1024,
"cfg": 7.0,
},
"sd15": {
"quality_prefix": "masterpiece, best quality",
"negative_prompt": (
"(worst quality:2), (low quality:2), bad anatomy, bad hands, extra fingers, "
"missing fingers, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, watermark"
),
"width": 512,
"height": 512,
"cfg": 7.0,
},
"flux": {
"quality_prefix": "",
"negative_prompt": "", # Flux doesn't use negative prompts effectively
"width": 1024,
"height": 1024,
"cfg": 3.5,
},
}
def detect_model_family(model_name: str, base_model: str | None = None) -> str | None: # noqa: PLR0911
"""Detect model family from filename or CivitAI base_model field.
Args:
model_name: Filename of the model (e.g., "ponyDiffusionV6XL.safetensors")
base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0")
Returns:
Model family key (pony, illustrious, sdxl, sd15, flux) or None if unknown
"""
name_lower = model_name.lower()
base_lower = (base_model or "").lower()
# Check base_model field first (most reliable from CivitAI)
if base_lower:
if "pony" in base_lower:
return "pony"
if "illustrious" in base_lower:
return "illustrious"
if "flux" in base_lower:
return "flux"
if "sd 1.5" in base_lower or "sd 1.4" in base_lower:
return "sd15"
if "sdxl" in base_lower:
return "sdxl"
# Fall back to filename heuristics
if "pony" in name_lower:
return "pony"
if "illustrious" in name_lower or "noob" in name_lower:
return "illustrious"
if "flux" in name_lower:
return "flux"
if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]):
return "sd15"
if any(x in name_lower for x in ["sdxl", "xl_"]):
return "sdxl"
return None
def get_comfyui_url() -> str: def get_comfyui_url() -> str:
"""Get the ComfyUI server URL. """Get the ComfyUI server URL.
+48
View File
@@ -755,6 +755,54 @@ class Database:
).all() ).all()
return [w.word for w in words] return [w.word for w in words]
def get_trigger_words_by_filename(self, filename: str) -> list[str]:
"""Get trigger words for a LoRA by matching filename in version_files.
Args:
filename: The filename to search for (e.g., "spumcostyle.safetensors")
Returns:
List of trigger/trained words from CivitAI metadata
"""
with self.session() as session:
# Find version file by filename match
vf = session.exec(select(VersionFile).where(VersionFile.name == filename)).first()
if not vf:
# Try partial match (without extension)
base_name = filename.rsplit(".", 1)[0] if "." in filename else filename
vf = session.exec(select(VersionFile).where(col(VersionFile.name).contains(base_name))).first()
if not vf or not vf.version_id:
return []
words = session.exec(
select(TrainedWord).where(TrainedWord.version_id == vf.version_id).order_by(col(TrainedWord.position))
).all()
return [w.word for w in words]
def get_base_model_by_filename(self, filename: str) -> str | None:
"""Get base_model for a checkpoint/LoRA by filename lookup.
Args:
filename: The filename to search for
Returns:
Base model string (e.g., "Pony", "SDXL 1.0") or None
"""
with self.session() as session:
# Find version file by filename match
vf = session.exec(select(VersionFile).where(VersionFile.name == filename)).first()
if not vf:
# Try partial match (without extension)
base_name = filename.rsplit(".", 1)[0] if "." in filename else filename
vf = session.exec(select(VersionFile).where(col(VersionFile.name).contains(base_name))).first()
if not vf or not vf.version_id:
return None
mv = session.get(ModelVersion, vf.version_id)
return mv.base_model if mv else None
# ========================================================================= # =========================================================================
# Statistics # Statistics
# ========================================================================= # =========================================================================