fix(generate): dispatch hybrid Flux models to Flux workflow
Models like gonzalomoXLFluxPony are architecturally Flux but CivitAI tags them as 'Pony', causing the SDXL workflow to be sent to ComfyUI which fails validation. The filename now overrides base_model when it contains 'flux'. Also adds: - Full Flux Dev/Schnell workflow template (ModelSamplingFlux, FluxGuidance, ConditioningZeroOut, EmptySD3LatentImage); KSampler cfg locked to 1.0, caller cfg routed to FluxGuidance - --family/-F flag to manually override family detection - queue_prompt now surfaces ComfyUI node_errors from 400 responses - Tests for Flux workflow builder (8 cases) and updated family defaults
This commit is contained in:
+19
-1
@@ -772,6 +772,18 @@ def generate( # noqa: PLR0915
|
||||
] = None,
|
||||
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False,
|
||||
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False,
|
||||
family: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"--family",
|
||||
"-F",
|
||||
help=(
|
||||
"Override detected model family "
|
||||
"(pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo, "
|
||||
"sd15, sd15_lcm, flux, flux_schnell, zimage)"
|
||||
),
|
||||
),
|
||||
] = None,
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None,
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
@@ -893,10 +905,16 @@ def generate( # noqa: PLR0915
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model_family = detect_model_family(model, base_model_str)
|
||||
detected_family = detect_model_family(model, base_model_str)
|
||||
model_family = family or detected_family
|
||||
if model_family:
|
||||
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
|
||||
if not json_output:
|
||||
if family and detected_family and family != detected_family:
|
||||
console.print(f"[dim]Model family: {model_family} (override; detected: {detected_family})[/dim]")
|
||||
elif family:
|
||||
console.print(f"[dim]Model family: {model_family} (override)[/dim]")
|
||||
else:
|
||||
console.print(f"[dim]Detected model family: {model_family}[/dim]")
|
||||
|
||||
# Build enhanced prompt with quality prefix and LoRA trigger words
|
||||
|
||||
+213
-4
@@ -375,6 +375,9 @@ def queue_prompt(
|
||||
error_detail = e.response.json()
|
||||
if "error" in error_detail:
|
||||
console.print(f" [yellow]{error_detail['error']}[/yellow]")
|
||||
if "node_errors" in error_detail:
|
||||
for node_id, errors in error_detail["node_errors"].items():
|
||||
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
@@ -661,6 +664,84 @@ LORA_LOADER_NODE: dict[str, Any] = {
|
||||
},
|
||||
}
|
||||
|
||||
# Flux.1 Dev / Schnell workflow template (CheckpointLoaderSimple-based).
|
||||
#
|
||||
# Differs from DEFAULT_WORKFLOW_TEMPLATE in three load-bearing ways:
|
||||
# 1. KSampler.cfg is HARDCODED to 1.0. Flux is guidance-distilled; raising
|
||||
# KSampler.cfg burns the image. Source: https://comfyanonymous.github.io/ComfyUI_examples/flux/
|
||||
# 2. The user-facing "cfg/guidance" dial is wired into the FluxGuidance node
|
||||
# (default 3.5), which feeds the model's distilled guidance embedding.
|
||||
# 3. Negative prompt is routed through ConditioningZeroOut — Flux ignores
|
||||
# classifier-free guidance, so negatives must be zero conditioning.
|
||||
# 4. ModelSamplingFlux applies the resolution-dependent shift schedule
|
||||
# (defaults max_shift=1.15, base_shift=0.5) which sharpens output at
|
||||
# non-1024² aspect ratios.
|
||||
# 5. EmptySD3LatentImage replaces EmptyLatentImage (Flux uses SD3-style latents).
|
||||
#
|
||||
# Use the all-in-one fp8 checkpoint (flux1-dev-fp8.safetensors) for the simplest
|
||||
# path; for the split-file release (UNETLoader + DualCLIPLoader + VAELoader),
|
||||
# see examples/flux1-dev/workflow.json.
|
||||
FLUX_WORKFLOW_TEMPLATE: dict[str, Any] = {
|
||||
"100": {
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {"ckpt_name": ""},
|
||||
},
|
||||
"120": {
|
||||
"class_type": "ModelSamplingFlux",
|
||||
"inputs": {
|
||||
"model": ["100", 0],
|
||||
"max_shift": 1.15,
|
||||
"base_shift": 0.5,
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
},
|
||||
},
|
||||
"130": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "", "clip": ["100", 1]},
|
||||
},
|
||||
"131": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "", "clip": ["100", 1]},
|
||||
},
|
||||
"132": {
|
||||
"class_type": "ConditioningZeroOut",
|
||||
"inputs": {"conditioning": ["131", 0]},
|
||||
},
|
||||
"140": {
|
||||
"class_type": "FluxGuidance",
|
||||
"inputs": {"conditioning": ["130", 0], "guidance": 3.5},
|
||||
},
|
||||
"150": {
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"inputs": {"width": 1024, "height": 1024, "batch_size": 1},
|
||||
},
|
||||
"160": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"seed": 0,
|
||||
"steps": 20,
|
||||
"cfg": 1.0,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"denoise": 1.0,
|
||||
"model": ["120", 0],
|
||||
"positive": ["140", 0],
|
||||
"negative": ["132", 0],
|
||||
"latent_image": ["150", 0],
|
||||
},
|
||||
},
|
||||
"170": {
|
||||
"class_type": "VAEDecode",
|
||||
"inputs": {"samples": ["160", 0], "vae": ["100", 2]},
|
||||
},
|
||||
"180": {
|
||||
"class_type": "SaveImage",
|
||||
"inputs": {"filename_prefix": "flux", "images": ["170", 0]},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Default SDXL/Illustrious/Pony compatible workflow template
|
||||
# Uses separate VAE loader for better quality with modern models
|
||||
DEFAULT_WORKFLOW_TEMPLATE: dict[str, Any] = {
|
||||
@@ -710,6 +791,102 @@ DEFAULT_WORKFLOW_TEMPLATE: dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
def _resolve_flux_guidance(
|
||||
guidance: float | None,
|
||||
cfg: float | None,
|
||||
defaults: dict[str, Any],
|
||||
) -> float:
|
||||
"""Resolve the FluxGuidance value with the precedence:
|
||||
|
||||
explicit ``guidance`` > caller's ``cfg`` (re-interpreted as guidance for Flux) >
|
||||
family preset's ``guidance`` > 3.5 (BFL recommendation).
|
||||
"""
|
||||
if guidance is not None:
|
||||
return float(guidance)
|
||||
if cfg is not None:
|
||||
return float(cfg)
|
||||
return float(defaults.get("guidance", 3.5))
|
||||
|
||||
|
||||
def _build_flux_workflow(
|
||||
prompt: str,
|
||||
model: str | None,
|
||||
seed: int,
|
||||
steps: int,
|
||||
sampler: str,
|
||||
scheduler: str,
|
||||
width: int,
|
||||
height: int,
|
||||
batch_size: int,
|
||||
lora_name: str | None,
|
||||
lora_strength: float,
|
||||
vae: str | None,
|
||||
guidance: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a Flux Dev/Schnell workflow.
|
||||
|
||||
KSampler.cfg is force-locked to 1.0; the caller's intended CFG/guidance is
|
||||
routed to the FluxGuidance node. ModelSamplingFlux is wired with width/height
|
||||
matching the latent so the noise-shift schedule is correct.
|
||||
"""
|
||||
workflow = copy.deepcopy(FLUX_WORKFLOW_TEMPLATE)
|
||||
|
||||
# Set seed (random if -1)
|
||||
actual_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
||||
|
||||
# Checkpoint
|
||||
if model:
|
||||
workflow["100"]["inputs"]["ckpt_name"] = model
|
||||
|
||||
# ModelSamplingFlux must match the latent dimensions
|
||||
workflow["120"]["inputs"]["width"] = width
|
||||
workflow["120"]["inputs"]["height"] = height
|
||||
|
||||
# Prompts (positive only — negative is zero'd via ConditioningZeroOut)
|
||||
workflow["130"]["inputs"]["text"] = prompt
|
||||
|
||||
# FluxGuidance carries the real prompt-adherence dial
|
||||
workflow["140"]["inputs"]["guidance"] = guidance
|
||||
|
||||
# Latent
|
||||
workflow["150"]["inputs"]["width"] = width
|
||||
workflow["150"]["inputs"]["height"] = height
|
||||
workflow["150"]["inputs"]["batch_size"] = batch_size
|
||||
|
||||
# KSampler — cfg stays 1.0
|
||||
workflow["160"]["inputs"]["seed"] = actual_seed
|
||||
workflow["160"]["inputs"]["steps"] = steps
|
||||
workflow["160"]["inputs"]["sampler_name"] = sampler
|
||||
workflow["160"]["inputs"]["scheduler"] = scheduler
|
||||
|
||||
# Optional external VAE — fall back to checkpoint's built-in if not provided
|
||||
if vae:
|
||||
workflow["171"] = {
|
||||
"class_type": "VAELoader",
|
||||
"inputs": {"vae_name": vae},
|
||||
}
|
||||
workflow["170"]["inputs"]["vae"] = ["171", 0]
|
||||
|
||||
# Optional LoRA injected between checkpoint and ModelSamplingFlux
|
||||
if lora_name:
|
||||
workflow["110"] = {
|
||||
"class_type": "LoraLoader",
|
||||
"inputs": {
|
||||
"model": ["100", 0],
|
||||
"clip": ["100", 1],
|
||||
"lora_name": lora_name,
|
||||
"strength_model": lora_strength,
|
||||
"strength_clip": lora_strength,
|
||||
},
|
||||
}
|
||||
# Reroute downstream consumers from checkpoint to LoRA outputs
|
||||
workflow["120"]["inputs"]["model"] = ["110", 0]
|
||||
workflow["130"]["inputs"]["clip"] = ["110", 1]
|
||||
workflow["131"]["inputs"]["clip"] = ["110", 1]
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
def _build_workflow(
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
@@ -726,20 +903,28 @@ def _build_workflow(
|
||||
batch_size: int = 1,
|
||||
vae: str | None = None,
|
||||
orientation: str = "square",
|
||||
guidance: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a text-to-image workflow from parameters.
|
||||
|
||||
Parameters set to None are auto-resolved from the checkpoint's family preset
|
||||
via config.get_model_generation_defaults(). User-provided values always win.
|
||||
|
||||
For Flux Dev/Schnell models, the workflow dispatches to FLUX_WORKFLOW_TEMPLATE
|
||||
which wires FluxGuidance + ConditioningZeroOut + ModelSamplingFlux around a
|
||||
KSampler locked to cfg=1.0 (Flux is guidance-distilled, real prompt-adherence
|
||||
lives on FluxGuidance). The ``guidance`` param maps to FluxGuidance; if not
|
||||
provided, falls back to ``cfg`` (treated as guidance for Flux), then preset.
|
||||
|
||||
Args:
|
||||
prompt: Positive prompt text
|
||||
negative_prompt: Negative prompt text
|
||||
negative_prompt: Negative prompt text (zeroed-out for Flux)
|
||||
model: Checkpoint filename (if None, uses first available)
|
||||
width: Image width (None = use preset for orientation)
|
||||
height: Image height (None = use preset for orientation)
|
||||
steps: Number of sampling steps (None = use preset)
|
||||
cfg: CFG scale (None = use preset)
|
||||
cfg: CFG scale (None = use preset). For Flux models, this is interpreted
|
||||
as the FluxGuidance value if ``guidance`` is not explicitly set.
|
||||
seed: Random seed (-1 for random)
|
||||
sampler: Sampler name (None = use preset)
|
||||
scheduler: Scheduler name (None = use preset)
|
||||
@@ -748,6 +933,7 @@ def _build_workflow(
|
||||
batch_size: Number of images to generate in one workflow (default 1)
|
||||
vae: VAE filename (None = use preset)
|
||||
orientation: Resolution orientation: "square", "portrait", or "landscape"
|
||||
guidance: FluxGuidance value (Flux only; default = preset 3.5)
|
||||
|
||||
Returns:
|
||||
ComfyUI workflow dict
|
||||
@@ -756,9 +942,10 @@ def _build_workflow(
|
||||
|
||||
# Get preset defaults for this checkpoint family
|
||||
defaults = get_model_generation_defaults(model or "") if model else get_model_generation_defaults("")
|
||||
family = defaults.get("family")
|
||||
|
||||
# Resolve orientation-based resolution
|
||||
res_w, res_h = resolve_orientation(defaults.get("family"), orientation)
|
||||
res_w, res_h = resolve_orientation(family, orientation)
|
||||
|
||||
# Merge: user overrides > preset defaults
|
||||
resolved_sampler = sampler if sampler is not None else defaults.get("sampler", "euler")
|
||||
@@ -769,6 +956,24 @@ def _build_workflow(
|
||||
resolved_height = height if height is not None else res_h
|
||||
resolved_vae = vae if vae is not None else defaults.get("vae")
|
||||
|
||||
# Dispatch to Flux-specific template when the family is flux/flux_schnell.
|
||||
if family in ("flux", "flux_schnell"):
|
||||
return _build_flux_workflow(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
seed=seed,
|
||||
steps=resolved_steps,
|
||||
sampler=resolved_sampler,
|
||||
scheduler=resolved_scheduler,
|
||||
width=resolved_width,
|
||||
height=resolved_height,
|
||||
batch_size=batch_size,
|
||||
lora_name=lora_name,
|
||||
lora_strength=lora_strength,
|
||||
vae=resolved_vae,
|
||||
guidance=_resolve_flux_guidance(guidance, cfg, defaults),
|
||||
)
|
||||
|
||||
workflow = copy.deepcopy(DEFAULT_WORKFLOW_TEMPLATE)
|
||||
|
||||
# Set seed (random if -1)
|
||||
@@ -846,11 +1051,14 @@ def generate_image(
|
||||
batch_size: int = 1,
|
||||
vae: str | None = None,
|
||||
orientation: str = "square",
|
||||
guidance: float | None = None,
|
||||
) -> GenerationResult | None:
|
||||
"""Generate an image using a simple text-to-image workflow.
|
||||
|
||||
Parameters set to None are auto-resolved from the checkpoint's family preset.
|
||||
User-provided values always override preset defaults.
|
||||
User-provided values always override preset defaults. For Flux Dev/Schnell
|
||||
checkpoints, ``guidance`` controls the FluxGuidance node (defaults to 3.5);
|
||||
KSampler cfg is locked to 1.0 by the Flux template.
|
||||
|
||||
Args:
|
||||
prompt: Positive prompt text
|
||||
@@ -907,6 +1115,7 @@ def generate_image(
|
||||
batch_size=batch_size,
|
||||
vae=vae,
|
||||
orientation=orientation,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
# Run workflow
|
||||
|
||||
+22
-6
@@ -645,7 +645,11 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"height": 1024,
|
||||
"portrait": (832, 1216),
|
||||
"landscape": (1216, 832),
|
||||
"cfg": 3.5,
|
||||
# Flux Dev is guidance-distilled: KSampler.cfg MUST be 1.0.
|
||||
# Real prompt-adherence dial lives on the FluxGuidance node (see "guidance" below).
|
||||
# Source: https://comfyanonymous.github.io/ComfyUI_examples/flux/
|
||||
"cfg": 1.0,
|
||||
"guidance": 3.5,
|
||||
"sampler": "euler",
|
||||
"scheduler": "simple",
|
||||
"steps": 20,
|
||||
@@ -658,7 +662,10 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"height": 1024,
|
||||
"portrait": (832, 1216),
|
||||
"landscape": (1216, 832),
|
||||
# Schnell is also distilled; FluxGuidance is typically left at 3.5 but
|
||||
# has minimal effect since the model is trained for 4 steps regardless.
|
||||
"cfg": 1.0,
|
||||
"guidance": 3.5,
|
||||
"sampler": "euler",
|
||||
"scheduler": "simple",
|
||||
"steps": 4,
|
||||
@@ -694,6 +701,14 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
||||
name_lower = model_name.lower()
|
||||
base_lower = (base_model or "").lower()
|
||||
|
||||
# Architecture override: filename containing "flux" wins over any base_model
|
||||
# field (handles hybrid models like "FluxPony" that CivitAI tags as "Pony"
|
||||
# but are architecturally Flux and need the Flux workflow).
|
||||
if "flux" in name_lower:
|
||||
if "schnell" in name_lower:
|
||||
return "flux_schnell"
|
||||
return "flux"
|
||||
|
||||
# Check base_model field first (most reliable from CivitAI)
|
||||
if base_lower:
|
||||
if "pony" in base_lower:
|
||||
@@ -722,15 +737,16 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
||||
return "sdxl"
|
||||
|
||||
# Fall back to filename heuristics (check specific variants first)
|
||||
if "pony" in name_lower:
|
||||
return "pony"
|
||||
if "illustrious" in name_lower or "noob" in name_lower:
|
||||
return "illustrious"
|
||||
# Flux variants
|
||||
# Flux variants take precedence — architecture wins over training dataset
|
||||
# (e.g. "FluxPony" hybrids are Flux models trained on Pony data, not SDXL/Pony)
|
||||
if "flux" in name_lower and "schnell" in name_lower:
|
||||
return "flux_schnell"
|
||||
if "flux" in name_lower:
|
||||
return "flux"
|
||||
if "pony" in name_lower:
|
||||
return "pony"
|
||||
if "illustrious" in name_lower or "noob" in name_lower:
|
||||
return "illustrious"
|
||||
# ZImageTurbo
|
||||
if "zimage" in name_lower:
|
||||
return "zimage"
|
||||
|
||||
+107
-2
@@ -359,14 +359,20 @@ class TestModelFamilyDetection:
|
||||
assert defaults["cfg"] == 6.5
|
||||
|
||||
def test_get_model_generation_defaults_flux(self) -> None:
|
||||
"""Test getting generation defaults for Flux models."""
|
||||
"""Test getting generation defaults for Flux models.
|
||||
|
||||
Flux Dev is guidance-distilled: KSampler.cfg MUST be 1.0; the real
|
||||
prompt-adherence dial is the FluxGuidance node's ``guidance`` value.
|
||||
See https://comfyanonymous.github.io/ComfyUI_examples/flux/
|
||||
"""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
|
||||
assert defaults["family"] == "flux"
|
||||
assert defaults["sampler"] == "euler"
|
||||
assert defaults["scheduler"] == "simple"
|
||||
assert defaults["cfg"] == 3.5
|
||||
assert defaults["cfg"] == 1.0
|
||||
assert defaults["guidance"] == 3.5
|
||||
|
||||
def test_get_model_generation_defaults_flux_schnell(self) -> None:
|
||||
"""Test getting generation defaults for Flux Schnell models."""
|
||||
@@ -376,6 +382,7 @@ class TestModelFamilyDetection:
|
||||
assert defaults["family"] == "flux_schnell"
|
||||
assert defaults["steps"] == 4
|
||||
assert defaults["cfg"] == 1.0
|
||||
assert defaults["guidance"] == 3.5
|
||||
|
||||
def test_detect_zimage(self) -> None:
|
||||
"""Test detecting ZImageTurbo family."""
|
||||
@@ -428,6 +435,104 @@ class TestModelFamilyDetection:
|
||||
assert defaults["scheduler"] == "karras"
|
||||
|
||||
|
||||
class TestFluxWorkflowBuilder:
|
||||
"""Tests for the Flux-specific branch of _build_workflow."""
|
||||
|
||||
def test_flux_dispatch_uses_flux_template(self) -> None:
|
||||
"""Building a workflow for a Flux model emits the Flux node graph."""
|
||||
from tensors.comfyui import _build_workflow
|
||||
|
||||
wf = _build_workflow(prompt="a cat", model="flux1-dev-fp8.safetensors")
|
||||
|
||||
# Flux template uses node IDs in the 100s; default SDXL template uses single digits.
|
||||
assert "100" in wf and wf["100"]["class_type"] == "CheckpointLoaderSimple"
|
||||
assert "120" in wf and wf["120"]["class_type"] == "ModelSamplingFlux"
|
||||
assert "140" in wf and wf["140"]["class_type"] == "FluxGuidance"
|
||||
assert "132" in wf and wf["132"]["class_type"] == "ConditioningZeroOut"
|
||||
assert "150" in wf and wf["150"]["class_type"] == "EmptySD3LatentImage"
|
||||
assert "3" not in wf # default SDXL KSampler ID must NOT be present
|
||||
|
||||
def test_flux_ksampler_cfg_locked_to_one(self) -> None:
|
||||
"""KSampler cfg MUST be 1.0 for Flux Dev — caller cfg must NOT leak through."""
|
||||
from tensors.comfyui import _build_workflow
|
||||
|
||||
wf = _build_workflow(prompt="a cat", model="flux1-dev-fp8.safetensors", cfg=7.5)
|
||||
assert wf["160"]["inputs"]["cfg"] == 1.0
|
||||
# The caller's cfg=7.5 should be re-routed to FluxGuidance
|
||||
assert wf["140"]["inputs"]["guidance"] == 7.5
|
||||
|
||||
def test_flux_explicit_guidance_wins_over_cfg(self) -> None:
|
||||
"""Explicit guidance overrides re-interpreted cfg."""
|
||||
from tensors.comfyui import _build_workflow
|
||||
|
||||
wf = _build_workflow(prompt="a cat", model="flux1-dev-fp8.safetensors", cfg=7.5, guidance=4.0)
|
||||
assert wf["140"]["inputs"]["guidance"] == 4.0
|
||||
|
||||
def test_flux_default_guidance_from_preset(self) -> None:
|
||||
"""No caller value -> preset guidance (3.5) wins."""
|
||||
from tensors.comfyui import _build_workflow
|
||||
|
||||
wf = _build_workflow(prompt="a cat", model="flux1-dev-fp8.safetensors")
|
||||
assert wf["140"]["inputs"]["guidance"] == 3.5
|
||||
|
||||
def test_flux_lora_injection(self) -> None:
|
||||
"""LoRA injects node 110 and reroutes ModelSamplingFlux + CLIPTextEncodes."""
|
||||
from tensors.comfyui import _build_workflow
|
||||
|
||||
wf = _build_workflow(
|
||||
prompt="a cat",
|
||||
model="flux1-dev-fp8.safetensors",
|
||||
lora_name="my_style.safetensors",
|
||||
lora_strength=0.7,
|
||||
)
|
||||
assert "110" in wf and wf["110"]["class_type"] == "LoraLoader"
|
||||
assert wf["110"]["inputs"]["lora_name"] == "my_style.safetensors"
|
||||
assert wf["110"]["inputs"]["strength_model"] == 0.7
|
||||
# Downstream consumers must read from the LoRA node
|
||||
assert wf["120"]["inputs"]["model"] == ["110", 0]
|
||||
assert wf["130"]["inputs"]["clip"] == ["110", 1]
|
||||
assert wf["131"]["inputs"]["clip"] == ["110", 1]
|
||||
|
||||
def test_flux_external_vae_swaps_decoder_input(self) -> None:
|
||||
"""Providing an external VAE adds node 171 (VAELoader) and rewires VAEDecode."""
|
||||
from tensors.comfyui import _build_workflow
|
||||
|
||||
wf = _build_workflow(
|
||||
prompt="a cat",
|
||||
model="flux1-dev-fp8.safetensors",
|
||||
vae="ae.safetensors",
|
||||
)
|
||||
assert "171" in wf and wf["171"]["class_type"] == "VAELoader"
|
||||
assert wf["171"]["inputs"]["vae_name"] == "ae.safetensors"
|
||||
assert wf["170"]["inputs"]["vae"] == ["171", 0]
|
||||
|
||||
def test_flux_model_sampling_dimensions_match_latent(self) -> None:
|
||||
"""ModelSamplingFlux width/height must equal the latent dimensions for correct shift."""
|
||||
from tensors.comfyui import _build_workflow
|
||||
|
||||
wf = _build_workflow(
|
||||
prompt="a cat",
|
||||
model="flux1-dev-fp8.safetensors",
|
||||
width=1216,
|
||||
height=832,
|
||||
)
|
||||
assert wf["120"]["inputs"]["width"] == 1216
|
||||
assert wf["120"]["inputs"]["height"] == 832
|
||||
assert wf["150"]["inputs"]["width"] == 1216
|
||||
assert wf["150"]["inputs"]["height"] == 832
|
||||
|
||||
def test_non_flux_model_uses_default_template(self) -> None:
|
||||
"""SDXL/Pony/etc. checkpoints continue to use the legacy template."""
|
||||
from tensors.comfyui import _build_workflow
|
||||
|
||||
wf = _build_workflow(prompt="a cat", model="ponyDiffusionV6XL.safetensors")
|
||||
# Default SDXL template has KSampler at node "3"
|
||||
assert "3" in wf and wf["3"]["class_type"] == "KSampler"
|
||||
# Flux-specific nodes must NOT be present
|
||||
assert "140" not in wf
|
||||
assert "120" not in wf
|
||||
|
||||
|
||||
class TestDisplayFormatters:
|
||||
"""Tests for display formatting functions."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user