diff --git a/.coverage b/.coverage index 448d21c..2639881 100644 Binary files a/.coverage and b/.coverage differ diff --git a/tensors/cli.py b/tensors/cli.py index 5f390f4..ce03f8f 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -21,6 +21,7 @@ from tensors.api import ( ) from tensors.config import ( CONFIG_FILE, + MODEL_FAMILY_DEFAULTS, BaseModel, CommercialUse, ModelType, @@ -28,6 +29,7 @@ from tensors.config import ( Period, Provider, SortOrder, + detect_model_family, get_default_output_path, get_model_paths, load_api_key, @@ -1118,7 +1120,7 @@ def comfy_history( @comfy_app.command("generate") -def comfy_generate( +def comfy_generate( # noqa: PLR0915 prompt: Annotated[str, typer.Argument(help="Positive prompt text")], url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None, negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "", @@ -1132,6 +1134,10 @@ def comfy_generate( scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal", output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None, count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1, + lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, + lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 1.0, + no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False, + no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Generate an image with a simple text-to-image workflow. @@ -1141,6 +1147,8 @@ def comfy_generate( tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30 tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768 tsr comfy generate "cyberpunk city" --count 4 -o batch.png + tsr comfy generate "girl" --lora spumcostyle.safetensors --lora-strength 0.8 + tsr comfy generate "raw prompt" --no-quality --no-negative """ import random # noqa: PLC0415 @@ -1152,6 +1160,64 @@ def comfy_generate( # Determine base seed for batch base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) + # Detect model family and apply defaults + family_defaults: dict[str, Any] = {} + model_family: str | None = None + if model: + # Try to get base_model from database + base_model_str: str | None = None + try: + with Database() as db: + db.init_schema() + base_model_str = db.get_base_model_by_filename(model) + except Exception: + pass + + model_family = detect_model_family(model, base_model_str) + if model_family: + family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {}) + if not json_output: + console.print(f"[dim]Detected model family: {model_family}[/dim]") + + # Build enhanced prompt with quality prefix and LoRA trigger words + enhanced_prompt = prompt + prompt_parts: list[str] = [] + + # Add LoRA trigger words if using LoRA + if lora: + try: + with Database() as db: + db.init_schema() + trigger_words = db.get_trigger_words_by_filename(lora) + if trigger_words: + prompt_parts.extend(trigger_words) + if not json_output: + console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]") + except Exception: + pass + + # Add quality prefix based on model family + if not no_quality and family_defaults.get("quality_prefix"): + prompt_parts.append(family_defaults["quality_prefix"]) + + # Add user prompt + prompt_parts.append(prompt) + enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt + + # Build enhanced negative prompt + enhanced_negative = negative + if not no_negative and family_defaults.get("negative_prompt"): + family_negative = family_defaults["negative_prompt"] + enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative + + if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative): + if enhanced_prompt != prompt: + truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004 + console.print(f"[dim]Enhanced prompt: {truncated}[/dim]") + if enhanced_negative != negative: + truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004 + console.print(f"[dim]Enhanced negative: {truncated}[/dim]") + for i in range(count): current_seed = base_seed + i if seed >= 0 else -1 # Increment seed or use random each time @@ -1159,9 +1225,9 @@ def comfy_generate( console.print(f"\n[cyan]Generating image {i + 1}/{count}...[/cyan]") result = generate_image( - prompt=prompt, + prompt=enhanced_prompt, url=url, - negative_prompt=negative, + negative_prompt=enhanced_negative, model=model, width=width, height=height, @@ -1171,6 +1237,8 @@ def comfy_generate( sampler=sampler, scheduler=scheduler, console=console if not json_output else None, + lora_name=lora, + lora_strength=lora_strength, ) if not result: @@ -1195,11 +1263,7 @@ def comfy_generate( img_path = result.images[0] img_data = get_image(str(img_path), url=url) if img_data: - if count == 1: - save_path = output - else: - # Add index suffix for batch: output.png -> output_001.png - save_path = output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}" + save_path = output if count == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}" save_path.write_bytes(img_data) saved_path = save_path all_saved.append(save_path) diff --git a/tensors/comfyui.py b/tensors/comfyui.py index 960f345..c51357b 100644 --- a/tensors/comfyui.py +++ b/tensors/comfyui.py @@ -526,6 +526,18 @@ def run_workflow( # Simple Text-to-Image Generation # ============================================================================ +# LoRA loader node template (inserted between checkpoint and sampler) +LORA_LOADER_NODE: dict[str, Any] = { + "class_type": "LoraLoader", + "inputs": { + "lora_name": "", + "strength_model": 1.0, + "strength_clip": 1.0, + "model": ["4", 0], # From checkpoint + "clip": ["4", 1], # From checkpoint + }, +} + # Default SDXL/Flux compatible workflow template # This is a minimal text-to-image workflow that works with most models DEFAULT_WORKFLOW_TEMPLATE: dict[str, Any] = { @@ -582,6 +594,8 @@ def _build_workflow( seed: int = -1, sampler: str = "euler", scheduler: str = "normal", + lora_name: str | None = None, + lora_strength: float = 1.0, ) -> dict[str, Any]: """Build a text-to-image workflow from parameters. @@ -596,6 +610,8 @@ def _build_workflow( seed: Random seed (-1 for random) sampler: Sampler name scheduler: Scheduler name + lora_name: LoRA model filename (optional) + lora_strength: LoRA strength (default 1.0) Returns: ComfyUI workflow dict @@ -624,6 +640,25 @@ def _build_workflow( workflow["6"]["inputs"]["text"] = prompt workflow["7"]["inputs"]["text"] = negative_prompt + # Inject LoRA loader if specified + if lora_name: + # Add LoRA loader node (node 10) + lora_node = copy.deepcopy(LORA_LOADER_NODE) + lora_node["inputs"]["lora_name"] = lora_name + lora_node["inputs"]["strength_model"] = lora_strength + lora_node["inputs"]["strength_clip"] = lora_strength + # LoRA takes model/clip from checkpoint (node 4) + lora_node["inputs"]["model"] = ["4", 0] + lora_node["inputs"]["clip"] = ["4", 1] + workflow["10"] = lora_node + + # Reroute KSampler model input from checkpoint (4) to LoRA (10) + workflow["3"]["inputs"]["model"] = ["10", 0] + + # Reroute CLIP text encoders from checkpoint (4) to LoRA (10) + workflow["6"]["inputs"]["clip"] = ["10", 1] + workflow["7"]["inputs"]["clip"] = ["10", 1] + return workflow @@ -642,6 +677,8 @@ def generate_image( console: Console | None = None, on_progress: ProgressCallback | None = None, timeout: float = 600.0, + lora_name: str | None = None, + lora_strength: float = 1.0, ) -> GenerationResult | None: """Generate an image using a simple text-to-image workflow. @@ -660,6 +697,8 @@ def generate_image( console: Rich console for progress output on_progress: Optional callback for progress updates timeout: Maximum wait time in seconds + lora_name: LoRA model filename (optional) + lora_strength: LoRA strength (default 1.0) Returns: GenerationResult with image paths, or None if generation failed @@ -690,6 +729,8 @@ def generate_image( seed=seed, sampler=sampler, scheduler=scheduler, + lora_name=lora_name, + lora_strength=lora_strength, ) # Run workflow diff --git a/tensors/config.py b/tensors/config.py index 2cf9fe9..14ed4d4 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -502,6 +502,93 @@ COMFYUI_DEFAULT_CFG = 7.0 COMFYUI_DEFAULT_SAMPLER = "euler" COMFYUI_DEFAULT_SCHEDULER = "normal" +# ============================================================================ +# Model Family Defaults (Quality Tags, Negative Prompts, etc.) +# ============================================================================ + +MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { + "pony": { + "quality_prefix": "score_9, score_8_up, score_7_up", + "negative_prompt": "score_5, score_4, ugly, deformed, blurry, bad anatomy, bad hands, missing fingers", + "width": 1024, + "height": 1024, + "cfg": 7.0, + "clip_skip": 2, + }, + "illustrious": { + "quality_prefix": "masterpiece, best quality, highres", + "negative_prompt": "worst quality, bad quality, low quality, lowres, bad anatomy, bad hands, jpeg artifacts, watermark", + "width": 1024, + "height": 1024, + "cfg": 6.0, + }, + "sdxl": { + "quality_prefix": "", + "negative_prompt": "ugly, deformed, bad anatomy, bad hands, extra fingers, missing fingers, blurry, watermark", + "width": 1024, + "height": 1024, + "cfg": 7.0, + }, + "sd15": { + "quality_prefix": "masterpiece, best quality", + "negative_prompt": ( + "(worst quality:2), (low quality:2), bad anatomy, bad hands, extra fingers, " + "missing fingers, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, watermark" + ), + "width": 512, + "height": 512, + "cfg": 7.0, + }, + "flux": { + "quality_prefix": "", + "negative_prompt": "", # Flux doesn't use negative prompts effectively + "width": 1024, + "height": 1024, + "cfg": 3.5, + }, +} + + +def detect_model_family(model_name: str, base_model: str | None = None) -> str | None: # noqa: PLR0911 + """Detect model family from filename or CivitAI base_model field. + + Args: + model_name: Filename of the model (e.g., "ponyDiffusionV6XL.safetensors") + base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0") + + Returns: + Model family key (pony, illustrious, sdxl, sd15, flux) or None if unknown + """ + name_lower = model_name.lower() + base_lower = (base_model or "").lower() + + # Check base_model field first (most reliable from CivitAI) + if base_lower: + if "pony" in base_lower: + return "pony" + if "illustrious" in base_lower: + return "illustrious" + if "flux" in base_lower: + return "flux" + if "sd 1.5" in base_lower or "sd 1.4" in base_lower: + return "sd15" + if "sdxl" in base_lower: + return "sdxl" + + # Fall back to filename heuristics + if "pony" in name_lower: + return "pony" + if "illustrious" in name_lower or "noob" in name_lower: + return "illustrious" + if "flux" in name_lower: + return "flux" + if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]): + return "sd15" + if any(x in name_lower for x in ["sdxl", "xl_"]): + return "sdxl" + + return None + def get_comfyui_url() -> str: """Get the ComfyUI server URL. diff --git a/tensors/db.py b/tensors/db.py index d2e5e56..3ac4cce 100644 --- a/tensors/db.py +++ b/tensors/db.py @@ -755,6 +755,54 @@ class Database: ).all() return [w.word for w in words] + def get_trigger_words_by_filename(self, filename: str) -> list[str]: + """Get trigger words for a LoRA by matching filename in version_files. + + Args: + filename: The filename to search for (e.g., "spumcostyle.safetensors") + + Returns: + List of trigger/trained words from CivitAI metadata + """ + with self.session() as session: + # Find version file by filename match + vf = session.exec(select(VersionFile).where(VersionFile.name == filename)).first() + if not vf: + # Try partial match (without extension) + base_name = filename.rsplit(".", 1)[0] if "." in filename else filename + vf = session.exec(select(VersionFile).where(col(VersionFile.name).contains(base_name))).first() + + if not vf or not vf.version_id: + return [] + + words = session.exec( + select(TrainedWord).where(TrainedWord.version_id == vf.version_id).order_by(col(TrainedWord.position)) + ).all() + return [w.word for w in words] + + def get_base_model_by_filename(self, filename: str) -> str | None: + """Get base_model for a checkpoint/LoRA by filename lookup. + + Args: + filename: The filename to search for + + Returns: + Base model string (e.g., "Pony", "SDXL 1.0") or None + """ + with self.session() as session: + # Find version file by filename match + vf = session.exec(select(VersionFile).where(VersionFile.name == filename)).first() + if not vf: + # Try partial match (without extension) + base_name = filename.rsplit(".", 1)[0] if "." in filename else filename + vf = session.exec(select(VersionFile).where(col(VersionFile.name).contains(base_name))).first() + + if not vf or not vf.version_id: + return None + + mv = session.get(ModelVersion, vf.version_id) + return mv.base_model if mv else None + # ========================================================================= # Statistics # =========================================================================