Add LoRA support and model family quality defaults to comfy generate
This commit is contained in:
+72
-8
@@ -21,6 +21,7 @@ from tensors.api import (
|
|||||||
)
|
)
|
||||||
from tensors.config import (
|
from tensors.config import (
|
||||||
CONFIG_FILE,
|
CONFIG_FILE,
|
||||||
|
MODEL_FAMILY_DEFAULTS,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
CommercialUse,
|
CommercialUse,
|
||||||
ModelType,
|
ModelType,
|
||||||
@@ -28,6 +29,7 @@ from tensors.config import (
|
|||||||
Period,
|
Period,
|
||||||
Provider,
|
Provider,
|
||||||
SortOrder,
|
SortOrder,
|
||||||
|
detect_model_family,
|
||||||
get_default_output_path,
|
get_default_output_path,
|
||||||
get_model_paths,
|
get_model_paths,
|
||||||
load_api_key,
|
load_api_key,
|
||||||
@@ -1118,7 +1120,7 @@ def comfy_history(
|
|||||||
|
|
||||||
|
|
||||||
@comfy_app.command("generate")
|
@comfy_app.command("generate")
|
||||||
def comfy_generate(
|
def comfy_generate( # noqa: PLR0915
|
||||||
prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
|
prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
|
||||||
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
|
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
|
||||||
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
|
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
|
||||||
@@ -1132,6 +1134,10 @@ def comfy_generate(
|
|||||||
scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal",
|
scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal",
|
||||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None,
|
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None,
|
||||||
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1,
|
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1,
|
||||||
|
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||||
|
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 1.0,
|
||||||
|
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False,
|
||||||
|
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False,
|
||||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an image with a simple text-to-image workflow.
|
"""Generate an image with a simple text-to-image workflow.
|
||||||
@@ -1141,6 +1147,8 @@ def comfy_generate(
|
|||||||
tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30
|
tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30
|
||||||
tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768
|
tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768
|
||||||
tsr comfy generate "cyberpunk city" --count 4 -o batch.png
|
tsr comfy generate "cyberpunk city" --count 4 -o batch.png
|
||||||
|
tsr comfy generate "girl" --lora spumcostyle.safetensors --lora-strength 0.8
|
||||||
|
tsr comfy generate "raw prompt" --no-quality --no-negative
|
||||||
"""
|
"""
|
||||||
import random # noqa: PLC0415
|
import random # noqa: PLC0415
|
||||||
|
|
||||||
@@ -1152,6 +1160,64 @@ def comfy_generate(
|
|||||||
# Determine base seed for batch
|
# Determine base seed for batch
|
||||||
base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
|
# Detect model family and apply defaults
|
||||||
|
family_defaults: dict[str, Any] = {}
|
||||||
|
model_family: str | None = None
|
||||||
|
if model:
|
||||||
|
# Try to get base_model from database
|
||||||
|
base_model_str: str | None = None
|
||||||
|
try:
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
base_model_str = db.get_base_model_by_filename(model)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_family = detect_model_family(model, base_model_str)
|
||||||
|
if model_family:
|
||||||
|
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
|
||||||
|
if not json_output:
|
||||||
|
console.print(f"[dim]Detected model family: {model_family}[/dim]")
|
||||||
|
|
||||||
|
# Build enhanced prompt with quality prefix and LoRA trigger words
|
||||||
|
enhanced_prompt = prompt
|
||||||
|
prompt_parts: list[str] = []
|
||||||
|
|
||||||
|
# Add LoRA trigger words if using LoRA
|
||||||
|
if lora:
|
||||||
|
try:
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
trigger_words = db.get_trigger_words_by_filename(lora)
|
||||||
|
if trigger_words:
|
||||||
|
prompt_parts.extend(trigger_words)
|
||||||
|
if not json_output:
|
||||||
|
console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add quality prefix based on model family
|
||||||
|
if not no_quality and family_defaults.get("quality_prefix"):
|
||||||
|
prompt_parts.append(family_defaults["quality_prefix"])
|
||||||
|
|
||||||
|
# Add user prompt
|
||||||
|
prompt_parts.append(prompt)
|
||||||
|
enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt
|
||||||
|
|
||||||
|
# Build enhanced negative prompt
|
||||||
|
enhanced_negative = negative
|
||||||
|
if not no_negative and family_defaults.get("negative_prompt"):
|
||||||
|
family_negative = family_defaults["negative_prompt"]
|
||||||
|
enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative
|
||||||
|
|
||||||
|
if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative):
|
||||||
|
if enhanced_prompt != prompt:
|
||||||
|
truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004
|
||||||
|
console.print(f"[dim]Enhanced prompt: {truncated}[/dim]")
|
||||||
|
if enhanced_negative != negative:
|
||||||
|
truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004
|
||||||
|
console.print(f"[dim]Enhanced negative: {truncated}[/dim]")
|
||||||
|
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
current_seed = base_seed + i if seed >= 0 else -1 # Increment seed or use random each time
|
current_seed = base_seed + i if seed >= 0 else -1 # Increment seed or use random each time
|
||||||
|
|
||||||
@@ -1159,9 +1225,9 @@ def comfy_generate(
|
|||||||
console.print(f"\n[cyan]Generating image {i + 1}/{count}...[/cyan]")
|
console.print(f"\n[cyan]Generating image {i + 1}/{count}...[/cyan]")
|
||||||
|
|
||||||
result = generate_image(
|
result = generate_image(
|
||||||
prompt=prompt,
|
prompt=enhanced_prompt,
|
||||||
url=url,
|
url=url,
|
||||||
negative_prompt=negative,
|
negative_prompt=enhanced_negative,
|
||||||
model=model,
|
model=model,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
@@ -1171,6 +1237,8 @@ def comfy_generate(
|
|||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
console=console if not json_output else None,
|
console=console if not json_output else None,
|
||||||
|
lora_name=lora,
|
||||||
|
lora_strength=lora_strength,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
@@ -1195,11 +1263,7 @@ def comfy_generate(
|
|||||||
img_path = result.images[0]
|
img_path = result.images[0]
|
||||||
img_data = get_image(str(img_path), url=url)
|
img_data = get_image(str(img_path), url=url)
|
||||||
if img_data:
|
if img_data:
|
||||||
if count == 1:
|
save_path = output if count == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||||
save_path = output
|
|
||||||
else:
|
|
||||||
# Add index suffix for batch: output.png -> output_001.png
|
|
||||||
save_path = output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
|
||||||
save_path.write_bytes(img_data)
|
save_path.write_bytes(img_data)
|
||||||
saved_path = save_path
|
saved_path = save_path
|
||||||
all_saved.append(save_path)
|
all_saved.append(save_path)
|
||||||
|
|||||||
@@ -526,6 +526,18 @@ def run_workflow(
|
|||||||
# Simple Text-to-Image Generation
|
# Simple Text-to-Image Generation
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
# LoRA loader node template (inserted between checkpoint and sampler)
|
||||||
|
LORA_LOADER_NODE: dict[str, Any] = {
|
||||||
|
"class_type": "LoraLoader",
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "",
|
||||||
|
"strength_model": 1.0,
|
||||||
|
"strength_clip": 1.0,
|
||||||
|
"model": ["4", 0], # From checkpoint
|
||||||
|
"clip": ["4", 1], # From checkpoint
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
# Default SDXL/Flux compatible workflow template
|
# Default SDXL/Flux compatible workflow template
|
||||||
# This is a minimal text-to-image workflow that works with most models
|
# This is a minimal text-to-image workflow that works with most models
|
||||||
DEFAULT_WORKFLOW_TEMPLATE: dict[str, Any] = {
|
DEFAULT_WORKFLOW_TEMPLATE: dict[str, Any] = {
|
||||||
@@ -582,6 +594,8 @@ def _build_workflow(
|
|||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
sampler: str = "euler",
|
sampler: str = "euler",
|
||||||
scheduler: str = "normal",
|
scheduler: str = "normal",
|
||||||
|
lora_name: str | None = None,
|
||||||
|
lora_strength: float = 1.0,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Build a text-to-image workflow from parameters.
|
"""Build a text-to-image workflow from parameters.
|
||||||
|
|
||||||
@@ -596,6 +610,8 @@ def _build_workflow(
|
|||||||
seed: Random seed (-1 for random)
|
seed: Random seed (-1 for random)
|
||||||
sampler: Sampler name
|
sampler: Sampler name
|
||||||
scheduler: Scheduler name
|
scheduler: Scheduler name
|
||||||
|
lora_name: LoRA model filename (optional)
|
||||||
|
lora_strength: LoRA strength (default 1.0)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ComfyUI workflow dict
|
ComfyUI workflow dict
|
||||||
@@ -624,6 +640,25 @@ def _build_workflow(
|
|||||||
workflow["6"]["inputs"]["text"] = prompt
|
workflow["6"]["inputs"]["text"] = prompt
|
||||||
workflow["7"]["inputs"]["text"] = negative_prompt
|
workflow["7"]["inputs"]["text"] = negative_prompt
|
||||||
|
|
||||||
|
# Inject LoRA loader if specified
|
||||||
|
if lora_name:
|
||||||
|
# Add LoRA loader node (node 10)
|
||||||
|
lora_node = copy.deepcopy(LORA_LOADER_NODE)
|
||||||
|
lora_node["inputs"]["lora_name"] = lora_name
|
||||||
|
lora_node["inputs"]["strength_model"] = lora_strength
|
||||||
|
lora_node["inputs"]["strength_clip"] = lora_strength
|
||||||
|
# LoRA takes model/clip from checkpoint (node 4)
|
||||||
|
lora_node["inputs"]["model"] = ["4", 0]
|
||||||
|
lora_node["inputs"]["clip"] = ["4", 1]
|
||||||
|
workflow["10"] = lora_node
|
||||||
|
|
||||||
|
# Reroute KSampler model input from checkpoint (4) to LoRA (10)
|
||||||
|
workflow["3"]["inputs"]["model"] = ["10", 0]
|
||||||
|
|
||||||
|
# Reroute CLIP text encoders from checkpoint (4) to LoRA (10)
|
||||||
|
workflow["6"]["inputs"]["clip"] = ["10", 1]
|
||||||
|
workflow["7"]["inputs"]["clip"] = ["10", 1]
|
||||||
|
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
|
|
||||||
@@ -642,6 +677,8 @@ def generate_image(
|
|||||||
console: Console | None = None,
|
console: Console | None = None,
|
||||||
on_progress: ProgressCallback | None = None,
|
on_progress: ProgressCallback | None = None,
|
||||||
timeout: float = 600.0,
|
timeout: float = 600.0,
|
||||||
|
lora_name: str | None = None,
|
||||||
|
lora_strength: float = 1.0,
|
||||||
) -> GenerationResult | None:
|
) -> GenerationResult | None:
|
||||||
"""Generate an image using a simple text-to-image workflow.
|
"""Generate an image using a simple text-to-image workflow.
|
||||||
|
|
||||||
@@ -660,6 +697,8 @@ def generate_image(
|
|||||||
console: Rich console for progress output
|
console: Rich console for progress output
|
||||||
on_progress: Optional callback for progress updates
|
on_progress: Optional callback for progress updates
|
||||||
timeout: Maximum wait time in seconds
|
timeout: Maximum wait time in seconds
|
||||||
|
lora_name: LoRA model filename (optional)
|
||||||
|
lora_strength: LoRA strength (default 1.0)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GenerationResult with image paths, or None if generation failed
|
GenerationResult with image paths, or None if generation failed
|
||||||
@@ -690,6 +729,8 @@ def generate_image(
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
|
lora_name=lora_name,
|
||||||
|
lora_strength=lora_strength,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run workflow
|
# Run workflow
|
||||||
|
|||||||
@@ -502,6 +502,93 @@ COMFYUI_DEFAULT_CFG = 7.0
|
|||||||
COMFYUI_DEFAULT_SAMPLER = "euler"
|
COMFYUI_DEFAULT_SAMPLER = "euler"
|
||||||
COMFYUI_DEFAULT_SCHEDULER = "normal"
|
COMFYUI_DEFAULT_SCHEDULER = "normal"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Model Family Defaults (Quality Tags, Negative Prompts, etc.)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||||
|
"pony": {
|
||||||
|
"quality_prefix": "score_9, score_8_up, score_7_up",
|
||||||
|
"negative_prompt": "score_5, score_4, ugly, deformed, blurry, bad anatomy, bad hands, missing fingers",
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"cfg": 7.0,
|
||||||
|
"clip_skip": 2,
|
||||||
|
},
|
||||||
|
"illustrious": {
|
||||||
|
"quality_prefix": "masterpiece, best quality, highres",
|
||||||
|
"negative_prompt": "worst quality, bad quality, low quality, lowres, bad anatomy, bad hands, jpeg artifacts, watermark",
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"cfg": 6.0,
|
||||||
|
},
|
||||||
|
"sdxl": {
|
||||||
|
"quality_prefix": "",
|
||||||
|
"negative_prompt": "ugly, deformed, bad anatomy, bad hands, extra fingers, missing fingers, blurry, watermark",
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"cfg": 7.0,
|
||||||
|
},
|
||||||
|
"sd15": {
|
||||||
|
"quality_prefix": "masterpiece, best quality",
|
||||||
|
"negative_prompt": (
|
||||||
|
"(worst quality:2), (low quality:2), bad anatomy, bad hands, extra fingers, "
|
||||||
|
"missing fingers, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, watermark"
|
||||||
|
),
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
"cfg": 7.0,
|
||||||
|
},
|
||||||
|
"flux": {
|
||||||
|
"quality_prefix": "",
|
||||||
|
"negative_prompt": "", # Flux doesn't use negative prompts effectively
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"cfg": 3.5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def detect_model_family(model_name: str, base_model: str | None = None) -> str | None: # noqa: PLR0911
|
||||||
|
"""Detect model family from filename or CivitAI base_model field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Filename of the model (e.g., "ponyDiffusionV6XL.safetensors")
|
||||||
|
base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model family key (pony, illustrious, sdxl, sd15, flux) or None if unknown
|
||||||
|
"""
|
||||||
|
name_lower = model_name.lower()
|
||||||
|
base_lower = (base_model or "").lower()
|
||||||
|
|
||||||
|
# Check base_model field first (most reliable from CivitAI)
|
||||||
|
if base_lower:
|
||||||
|
if "pony" in base_lower:
|
||||||
|
return "pony"
|
||||||
|
if "illustrious" in base_lower:
|
||||||
|
return "illustrious"
|
||||||
|
if "flux" in base_lower:
|
||||||
|
return "flux"
|
||||||
|
if "sd 1.5" in base_lower or "sd 1.4" in base_lower:
|
||||||
|
return "sd15"
|
||||||
|
if "sdxl" in base_lower:
|
||||||
|
return "sdxl"
|
||||||
|
|
||||||
|
# Fall back to filename heuristics
|
||||||
|
if "pony" in name_lower:
|
||||||
|
return "pony"
|
||||||
|
if "illustrious" in name_lower or "noob" in name_lower:
|
||||||
|
return "illustrious"
|
||||||
|
if "flux" in name_lower:
|
||||||
|
return "flux"
|
||||||
|
if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]):
|
||||||
|
return "sd15"
|
||||||
|
if any(x in name_lower for x in ["sdxl", "xl_"]):
|
||||||
|
return "sdxl"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_comfyui_url() -> str:
|
def get_comfyui_url() -> str:
|
||||||
"""Get the ComfyUI server URL.
|
"""Get the ComfyUI server URL.
|
||||||
|
|||||||
@@ -755,6 +755,54 @@ class Database:
|
|||||||
).all()
|
).all()
|
||||||
return [w.word for w in words]
|
return [w.word for w in words]
|
||||||
|
|
||||||
|
def get_trigger_words_by_filename(self, filename: str) -> list[str]:
|
||||||
|
"""Get trigger words for a LoRA by matching filename in version_files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: The filename to search for (e.g., "spumcostyle.safetensors")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of trigger/trained words from CivitAI metadata
|
||||||
|
"""
|
||||||
|
with self.session() as session:
|
||||||
|
# Find version file by filename match
|
||||||
|
vf = session.exec(select(VersionFile).where(VersionFile.name == filename)).first()
|
||||||
|
if not vf:
|
||||||
|
# Try partial match (without extension)
|
||||||
|
base_name = filename.rsplit(".", 1)[0] if "." in filename else filename
|
||||||
|
vf = session.exec(select(VersionFile).where(col(VersionFile.name).contains(base_name))).first()
|
||||||
|
|
||||||
|
if not vf or not vf.version_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
words = session.exec(
|
||||||
|
select(TrainedWord).where(TrainedWord.version_id == vf.version_id).order_by(col(TrainedWord.position))
|
||||||
|
).all()
|
||||||
|
return [w.word for w in words]
|
||||||
|
|
||||||
|
def get_base_model_by_filename(self, filename: str) -> str | None:
|
||||||
|
"""Get base_model for a checkpoint/LoRA by filename lookup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: The filename to search for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base model string (e.g., "Pony", "SDXL 1.0") or None
|
||||||
|
"""
|
||||||
|
with self.session() as session:
|
||||||
|
# Find version file by filename match
|
||||||
|
vf = session.exec(select(VersionFile).where(VersionFile.name == filename)).first()
|
||||||
|
if not vf:
|
||||||
|
# Try partial match (without extension)
|
||||||
|
base_name = filename.rsplit(".", 1)[0] if "." in filename else filename
|
||||||
|
vf = session.exec(select(VersionFile).where(col(VersionFile.name).contains(base_name))).first()
|
||||||
|
|
||||||
|
if not vf or not vf.version_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mv = session.get(ModelVersion, vf.version_id)
|
||||||
|
return mv.base_model if mv else None
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Statistics
|
# Statistics
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user