Add LoRA support and model family quality defaults to comfy generate
This commit is contained in:
+72
-8
@@ -21,6 +21,7 @@ from tensors.api import (
|
||||
)
|
||||
from tensors.config import (
|
||||
CONFIG_FILE,
|
||||
MODEL_FAMILY_DEFAULTS,
|
||||
BaseModel,
|
||||
CommercialUse,
|
||||
ModelType,
|
||||
@@ -28,6 +29,7 @@ from tensors.config import (
|
||||
Period,
|
||||
Provider,
|
||||
SortOrder,
|
||||
detect_model_family,
|
||||
get_default_output_path,
|
||||
get_model_paths,
|
||||
load_api_key,
|
||||
@@ -1118,7 +1120,7 @@ def comfy_history(
|
||||
|
||||
|
||||
@comfy_app.command("generate")
|
||||
def comfy_generate(
|
||||
def comfy_generate( # noqa: PLR0915
|
||||
prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
|
||||
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
|
||||
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
|
||||
@@ -1132,6 +1134,10 @@ def comfy_generate(
|
||||
scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal",
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None,
|
||||
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1,
|
||||
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 1.0,
|
||||
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False,
|
||||
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Generate an image with a simple text-to-image workflow.
|
||||
@@ -1141,6 +1147,8 @@ def comfy_generate(
|
||||
tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30
|
||||
tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768
|
||||
tsr comfy generate "cyberpunk city" --count 4 -o batch.png
|
||||
tsr comfy generate "girl" --lora spumcostyle.safetensors --lora-strength 0.8
|
||||
tsr comfy generate "raw prompt" --no-quality --no-negative
|
||||
"""
|
||||
import random # noqa: PLC0415
|
||||
|
||||
@@ -1152,6 +1160,64 @@ def comfy_generate(
|
||||
# Determine base seed for batch
|
||||
base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
||||
|
||||
# Detect model family and apply defaults
|
||||
family_defaults: dict[str, Any] = {}
|
||||
model_family: str | None = None
|
||||
if model:
|
||||
# Try to get base_model from database
|
||||
base_model_str: str | None = None
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
base_model_str = db.get_base_model_by_filename(model)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model_family = detect_model_family(model, base_model_str)
|
||||
if model_family:
|
||||
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
|
||||
if not json_output:
|
||||
console.print(f"[dim]Detected model family: {model_family}[/dim]")
|
||||
|
||||
# Build enhanced prompt with quality prefix and LoRA trigger words
|
||||
enhanced_prompt = prompt
|
||||
prompt_parts: list[str] = []
|
||||
|
||||
# Add LoRA trigger words if using LoRA
|
||||
if lora:
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
trigger_words = db.get_trigger_words_by_filename(lora)
|
||||
if trigger_words:
|
||||
prompt_parts.extend(trigger_words)
|
||||
if not json_output:
|
||||
console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Add quality prefix based on model family
|
||||
if not no_quality and family_defaults.get("quality_prefix"):
|
||||
prompt_parts.append(family_defaults["quality_prefix"])
|
||||
|
||||
# Add user prompt
|
||||
prompt_parts.append(prompt)
|
||||
enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt
|
||||
|
||||
# Build enhanced negative prompt
|
||||
enhanced_negative = negative
|
||||
if not no_negative and family_defaults.get("negative_prompt"):
|
||||
family_negative = family_defaults["negative_prompt"]
|
||||
enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative
|
||||
|
||||
if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative):
|
||||
if enhanced_prompt != prompt:
|
||||
truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004
|
||||
console.print(f"[dim]Enhanced prompt: {truncated}[/dim]")
|
||||
if enhanced_negative != negative:
|
||||
truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004
|
||||
console.print(f"[dim]Enhanced negative: {truncated}[/dim]")
|
||||
|
||||
for i in range(count):
|
||||
current_seed = base_seed + i if seed >= 0 else -1 # Increment seed or use random each time
|
||||
|
||||
@@ -1159,9 +1225,9 @@ def comfy_generate(
|
||||
console.print(f"\n[cyan]Generating image {i + 1}/{count}...[/cyan]")
|
||||
|
||||
result = generate_image(
|
||||
prompt=prompt,
|
||||
prompt=enhanced_prompt,
|
||||
url=url,
|
||||
negative_prompt=negative,
|
||||
negative_prompt=enhanced_negative,
|
||||
model=model,
|
||||
width=width,
|
||||
height=height,
|
||||
@@ -1171,6 +1237,8 @@ def comfy_generate(
|
||||
sampler=sampler,
|
||||
scheduler=scheduler,
|
||||
console=console if not json_output else None,
|
||||
lora_name=lora,
|
||||
lora_strength=lora_strength,
|
||||
)
|
||||
|
||||
if not result:
|
||||
@@ -1195,11 +1263,7 @@ def comfy_generate(
|
||||
img_path = result.images[0]
|
||||
img_data = get_image(str(img_path), url=url)
|
||||
if img_data:
|
||||
if count == 1:
|
||||
save_path = output
|
||||
else:
|
||||
# Add index suffix for batch: output.png -> output_001.png
|
||||
save_path = output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||
save_path = output if count == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||
save_path.write_bytes(img_data)
|
||||
saved_path = save_path
|
||||
all_saved.append(save_path)
|
||||
|
||||
@@ -526,6 +526,18 @@ def run_workflow(
|
||||
# Simple Text-to-Image Generation
|
||||
# ============================================================================
|
||||
|
||||
# LoRA loader node template (inserted between checkpoint and sampler)
|
||||
LORA_LOADER_NODE: dict[str, Any] = {
|
||||
"class_type": "LoraLoader",
|
||||
"inputs": {
|
||||
"lora_name": "",
|
||||
"strength_model": 1.0,
|
||||
"strength_clip": 1.0,
|
||||
"model": ["4", 0], # From checkpoint
|
||||
"clip": ["4", 1], # From checkpoint
|
||||
},
|
||||
}
|
||||
|
||||
# Default SDXL/Flux compatible workflow template
|
||||
# This is a minimal text-to-image workflow that works with most models
|
||||
DEFAULT_WORKFLOW_TEMPLATE: dict[str, Any] = {
|
||||
@@ -582,6 +594,8 @@ def _build_workflow(
|
||||
seed: int = -1,
|
||||
sampler: str = "euler",
|
||||
scheduler: str = "normal",
|
||||
lora_name: str | None = None,
|
||||
lora_strength: float = 1.0,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a text-to-image workflow from parameters.
|
||||
|
||||
@@ -596,6 +610,8 @@ def _build_workflow(
|
||||
seed: Random seed (-1 for random)
|
||||
sampler: Sampler name
|
||||
scheduler: Scheduler name
|
||||
lora_name: LoRA model filename (optional)
|
||||
lora_strength: LoRA strength (default 1.0)
|
||||
|
||||
Returns:
|
||||
ComfyUI workflow dict
|
||||
@@ -624,6 +640,25 @@ def _build_workflow(
|
||||
workflow["6"]["inputs"]["text"] = prompt
|
||||
workflow["7"]["inputs"]["text"] = negative_prompt
|
||||
|
||||
# Inject LoRA loader if specified
|
||||
if lora_name:
|
||||
# Add LoRA loader node (node 10)
|
||||
lora_node = copy.deepcopy(LORA_LOADER_NODE)
|
||||
lora_node["inputs"]["lora_name"] = lora_name
|
||||
lora_node["inputs"]["strength_model"] = lora_strength
|
||||
lora_node["inputs"]["strength_clip"] = lora_strength
|
||||
# LoRA takes model/clip from checkpoint (node 4)
|
||||
lora_node["inputs"]["model"] = ["4", 0]
|
||||
lora_node["inputs"]["clip"] = ["4", 1]
|
||||
workflow["10"] = lora_node
|
||||
|
||||
# Reroute KSampler model input from checkpoint (4) to LoRA (10)
|
||||
workflow["3"]["inputs"]["model"] = ["10", 0]
|
||||
|
||||
# Reroute CLIP text encoders from checkpoint (4) to LoRA (10)
|
||||
workflow["6"]["inputs"]["clip"] = ["10", 1]
|
||||
workflow["7"]["inputs"]["clip"] = ["10", 1]
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
@@ -642,6 +677,8 @@ def generate_image(
|
||||
console: Console | None = None,
|
||||
on_progress: ProgressCallback | None = None,
|
||||
timeout: float = 600.0,
|
||||
lora_name: str | None = None,
|
||||
lora_strength: float = 1.0,
|
||||
) -> GenerationResult | None:
|
||||
"""Generate an image using a simple text-to-image workflow.
|
||||
|
||||
@@ -660,6 +697,8 @@ def generate_image(
|
||||
console: Rich console for progress output
|
||||
on_progress: Optional callback for progress updates
|
||||
timeout: Maximum wait time in seconds
|
||||
lora_name: LoRA model filename (optional)
|
||||
lora_strength: LoRA strength (default 1.0)
|
||||
|
||||
Returns:
|
||||
GenerationResult with image paths, or None if generation failed
|
||||
@@ -690,6 +729,8 @@ def generate_image(
|
||||
seed=seed,
|
||||
sampler=sampler,
|
||||
scheduler=scheduler,
|
||||
lora_name=lora_name,
|
||||
lora_strength=lora_strength,
|
||||
)
|
||||
|
||||
# Run workflow
|
||||
|
||||
@@ -502,6 +502,93 @@ COMFYUI_DEFAULT_CFG = 7.0
|
||||
COMFYUI_DEFAULT_SAMPLER = "euler"
|
||||
COMFYUI_DEFAULT_SCHEDULER = "normal"
|
||||
|
||||
# ============================================================================
|
||||
# Model Family Defaults (Quality Tags, Negative Prompts, etc.)
|
||||
# ============================================================================
|
||||
|
||||
MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"pony": {
|
||||
"quality_prefix": "score_9, score_8_up, score_7_up",
|
||||
"negative_prompt": "score_5, score_4, ugly, deformed, blurry, bad anatomy, bad hands, missing fingers",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 7.0,
|
||||
"clip_skip": 2,
|
||||
},
|
||||
"illustrious": {
|
||||
"quality_prefix": "masterpiece, best quality, highres",
|
||||
"negative_prompt": "worst quality, bad quality, low quality, lowres, bad anatomy, bad hands, jpeg artifacts, watermark",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 6.0,
|
||||
},
|
||||
"sdxl": {
|
||||
"quality_prefix": "",
|
||||
"negative_prompt": "ugly, deformed, bad anatomy, bad hands, extra fingers, missing fingers, blurry, watermark",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 7.0,
|
||||
},
|
||||
"sd15": {
|
||||
"quality_prefix": "masterpiece, best quality",
|
||||
"negative_prompt": (
|
||||
"(worst quality:2), (low quality:2), bad anatomy, bad hands, extra fingers, "
|
||||
"missing fingers, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, watermark"
|
||||
),
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"cfg": 7.0,
|
||||
},
|
||||
"flux": {
|
||||
"quality_prefix": "",
|
||||
"negative_prompt": "", # Flux doesn't use negative prompts effectively
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 3.5,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def detect_model_family(model_name: str, base_model: str | None = None) -> str | None: # noqa: PLR0911
|
||||
"""Detect model family from filename or CivitAI base_model field.
|
||||
|
||||
Args:
|
||||
model_name: Filename of the model (e.g., "ponyDiffusionV6XL.safetensors")
|
||||
base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0")
|
||||
|
||||
Returns:
|
||||
Model family key (pony, illustrious, sdxl, sd15, flux) or None if unknown
|
||||
"""
|
||||
name_lower = model_name.lower()
|
||||
base_lower = (base_model or "").lower()
|
||||
|
||||
# Check base_model field first (most reliable from CivitAI)
|
||||
if base_lower:
|
||||
if "pony" in base_lower:
|
||||
return "pony"
|
||||
if "illustrious" in base_lower:
|
||||
return "illustrious"
|
||||
if "flux" in base_lower:
|
||||
return "flux"
|
||||
if "sd 1.5" in base_lower or "sd 1.4" in base_lower:
|
||||
return "sd15"
|
||||
if "sdxl" in base_lower:
|
||||
return "sdxl"
|
||||
|
||||
# Fall back to filename heuristics
|
||||
if "pony" in name_lower:
|
||||
return "pony"
|
||||
if "illustrious" in name_lower or "noob" in name_lower:
|
||||
return "illustrious"
|
||||
if "flux" in name_lower:
|
||||
return "flux"
|
||||
if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]):
|
||||
return "sd15"
|
||||
if any(x in name_lower for x in ["sdxl", "xl_"]):
|
||||
return "sdxl"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_comfyui_url() -> str:
|
||||
"""Get the ComfyUI server URL.
|
||||
|
||||
@@ -755,6 +755,54 @@ class Database:
|
||||
).all()
|
||||
return [w.word for w in words]
|
||||
|
||||
def get_trigger_words_by_filename(self, filename: str) -> list[str]:
|
||||
"""Get trigger words for a LoRA by matching filename in version_files.
|
||||
|
||||
Args:
|
||||
filename: The filename to search for (e.g., "spumcostyle.safetensors")
|
||||
|
||||
Returns:
|
||||
List of trigger/trained words from CivitAI metadata
|
||||
"""
|
||||
with self.session() as session:
|
||||
# Find version file by filename match
|
||||
vf = session.exec(select(VersionFile).where(VersionFile.name == filename)).first()
|
||||
if not vf:
|
||||
# Try partial match (without extension)
|
||||
base_name = filename.rsplit(".", 1)[0] if "." in filename else filename
|
||||
vf = session.exec(select(VersionFile).where(col(VersionFile.name).contains(base_name))).first()
|
||||
|
||||
if not vf or not vf.version_id:
|
||||
return []
|
||||
|
||||
words = session.exec(
|
||||
select(TrainedWord).where(TrainedWord.version_id == vf.version_id).order_by(col(TrainedWord.position))
|
||||
).all()
|
||||
return [w.word for w in words]
|
||||
|
||||
def get_base_model_by_filename(self, filename: str) -> str | None:
|
||||
"""Get base_model for a checkpoint/LoRA by filename lookup.
|
||||
|
||||
Args:
|
||||
filename: The filename to search for
|
||||
|
||||
Returns:
|
||||
Base model string (e.g., "Pony", "SDXL 1.0") or None
|
||||
"""
|
||||
with self.session() as session:
|
||||
# Find version file by filename match
|
||||
vf = session.exec(select(VersionFile).where(VersionFile.name == filename)).first()
|
||||
if not vf:
|
||||
# Try partial match (without extension)
|
||||
base_name = filename.rsplit(".", 1)[0] if "." in filename else filename
|
||||
vf = session.exec(select(VersionFile).where(col(VersionFile.name).contains(base_name))).first()
|
||||
|
||||
if not vf or not vf.version_id:
|
||||
return None
|
||||
|
||||
mv = session.get(ModelVersion, vf.version_id)
|
||||
return mv.base_model if mv else None
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
Reference in New Issue
Block a user