diff --git a/tensors/cli.py b/tensors/cli.py index 03ad951..5b754df 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -788,7 +788,7 @@ def generate( # noqa: PLR0915 help=( "Override detected model family " "(pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo, " - "sd15, sd15_lcm, flux, flux_schnell, flux_unet, zimage)" + "sd15, sd15_lcm, flux, flux_schnell, flux_unet, flux2_klein, zimage)" ), ), ] = None, diff --git a/tensors/comfyui.py b/tensors/comfyui.py index 9d33bb3..c564aaf 100644 --- a/tensors/comfyui.py +++ b/tensors/comfyui.py @@ -830,6 +830,83 @@ FLUX_UNET_WORKFLOW_TEMPLATE: dict[str, Any] = { } +# Flux.2 Klein 9B workflow template — different architecture from Flux.1: +# - Single Qwen3-8B text encoder via CLIPLoader(type=flux2), produces 12288-dim +# conditioning (3 stacked hidden layers) +# - EmptyFlux2LatentImage instead of EmptySD3LatentImage (different latent shape) +# - Custom-sampling pipeline (Flux2Scheduler + BasicGuider + RandomNoise + +# KSamplerSelect + SamplerCustomAdvanced) instead of plain KSampler +# - Dedicated VAE (flux2-vae.safetensors), not the Flux.1 ae.safetensors +# Verified end-to-end against ComfyUI on madcat with lust_v10.safetensors. +FLUX2_KLEIN_WORKFLOW_TEMPLATE: dict[str, Any] = { + "100": { + "class_type": "UNETLoader", + "inputs": {"unet_name": "", "weight_dtype": "default"}, + }, + "101": { + "class_type": "CLIPLoader", + "inputs": { + "clip_name": "qwen_3_8b_fp8mixed.safetensors", + "type": "flux2", + }, + }, + "102": { + "class_type": "VAELoader", + "inputs": {"vae_name": "flux2-vae.safetensors"}, + }, + "130": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "", "clip": ["101", 0]}, + }, + "131": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "", "clip": ["101", 0]}, + }, + "140": { + "class_type": "FluxGuidance", + "inputs": {"conditioning": ["130", 0], "guidance": 3.5}, + }, + "150": { + "class_type": "EmptyFlux2LatentImage", + "inputs": {"width": 1024, "height": 1024, "batch_size": 1}, + }, + "151": { + "class_type": "RandomNoise", + "inputs": {"noise_seed": 0}, + }, + "152": { + "class_type": "KSamplerSelect", + "inputs": {"sampler_name": "euler"}, + }, + "153": { + "class_type": "Flux2Scheduler", + "inputs": {"steps": 20, "width": 1024, "height": 1024}, + }, + "154": { + "class_type": "BasicGuider", + "inputs": {"model": ["100", 0], "conditioning": ["140", 0]}, + }, + "160": { + "class_type": "SamplerCustomAdvanced", + "inputs": { + "noise": ["151", 0], + "guider": ["154", 0], + "sampler": ["152", 0], + "sigmas": ["153", 0], + "latent_image": ["150", 0], + }, + }, + "170": { + "class_type": "VAEDecode", + "inputs": {"samples": ["160", 0], "vae": ["102", 0]}, + }, + "180": { + "class_type": "SaveImage", + "inputs": {"filename_prefix": "flux2", "images": ["170", 0]}, + }, +} + + # Default SDXL/Illustrious/Pony compatible workflow template # Uses separate VAE loader for better quality with modern models DEFAULT_WORKFLOW_TEMPLATE: dict[str, Any] = { @@ -1056,6 +1133,93 @@ def _build_flux_unet_workflow( return workflow +def _build_flux2_klein_workflow( + prompt: str, + model: str | None, + seed: int, + steps: int, + sampler: str, + width: int, + height: int, + batch_size: int, + lora_name: str | None, + lora_strength: float, + vae: str | None, + guidance: float, + clip_encoder: str, + clip_type: str, +) -> dict[str, Any]: + """Build a Flux.2 Klein 9B workflow (single Qwen3 encoder, custom sampling). + + The graph differs from Flux.1 in three ways: + 1. Single-encoder ``CLIPLoader`` (type=flux2) instead of ``DualCLIPLoader``. + 2. ``EmptyFlux2LatentImage`` for the Flux2-specific latent shape. + 3. Custom-sampling pipeline: ``Flux2Scheduler`` produces SIGMAS, fed into + ``SamplerCustomAdvanced`` along with ``BasicGuider``/``RandomNoise``/ + ``KSamplerSelect``. There is no standalone ``KSampler`` node, so the + caller-provided ``scheduler`` is ignored (Flux2Scheduler is the only + supported sigma source). + """ + workflow = copy.deepcopy(FLUX2_KLEIN_WORKFLOW_TEMPLATE) + + actual_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) + + # UNet checkpoint + if model: + workflow["100"]["inputs"]["unet_name"] = model + + # Text encoder (Qwen3-8B variant — fp8 default, callers can override via + # family preset's clip_encoder field if a different quantization is desired). + workflow["101"]["inputs"]["clip_name"] = clip_encoder + workflow["101"]["inputs"]["type"] = clip_type + + # External VAE — fall back to flux2-vae.safetensors from the template if unset + if vae: + workflow["102"]["inputs"]["vae_name"] = vae + + # Positive prompt (negative is unused for Flux — guidance is distilled) + workflow["130"]["inputs"]["text"] = prompt + + # FluxGuidance carries the real prompt-adherence dial + workflow["140"]["inputs"]["guidance"] = guidance + + # Latent dimensions + workflow["150"]["inputs"]["width"] = width + workflow["150"]["inputs"]["height"] = height + workflow["150"]["inputs"]["batch_size"] = batch_size + + # Noise seed (separate node in custom-sampling pipeline) + workflow["151"]["inputs"]["noise_seed"] = actual_seed + + # Sampler selection + workflow["152"]["inputs"]["sampler_name"] = sampler + + # Flux2Scheduler — must receive matching width/height for correct sigma schedule + workflow["153"]["inputs"]["steps"] = steps + workflow["153"]["inputs"]["width"] = width + workflow["153"]["inputs"]["height"] = height + + # Optional LoRA: injected between UNet/CLIP loaders and BasicGuider/ + # CLIPTextEncode consumers. Mirrors the flux_unet wiring pattern. + if lora_name: + workflow["110"] = { + "class_type": "LoraLoader", + "inputs": { + "model": ["100", 0], + "clip": ["101", 0], + "lora_name": lora_name, + "strength_model": lora_strength, + "strength_clip": lora_strength, + }, + } + # Re-route consumers from raw loaders (100/101) to LoRA outputs (110). + workflow["154"]["inputs"]["model"] = ["110", 0] + workflow["130"]["inputs"]["clip"] = ["110", 1] + workflow["131"]["inputs"]["clip"] = ["110", 1] + + return workflow + + def _build_workflow( prompt: str, negative_prompt: str = "", @@ -1143,6 +1307,27 @@ def _build_workflow( guidance=_resolve_flux_guidance(guidance, cfg, defaults), ) + # Flux.2 Klein 9B: different architecture (single Qwen3 encoder, custom + # sampling pipeline, Flux2 latent format). Must dispatch BEFORE flux_unet + # since Klein checkpoints also set external_clip=True. + if family == "flux2_klein": + return _build_flux2_klein_workflow( + prompt=prompt, + model=model, + seed=seed, + steps=resolved_steps, + sampler=resolved_sampler, + width=resolved_width, + height=resolved_height, + batch_size=batch_size, + lora_name=lora_name, + lora_strength=lora_strength, + vae=resolved_vae, + guidance=_resolve_flux_guidance(guidance, cfg, defaults), + clip_encoder=defaults.get("clip_encoder", "qwen_3_8b_fp8mixed.safetensors"), + clip_type=defaults.get("clip_type", "flux2"), + ) + # UNet-only Flux checkpoints (no baked-in CLIP/T5/VAE): use the split-loader # variant. Triggered by family="flux_unet" — also covers any family whose # preset opts in via external_clip=True. diff --git a/tensors/config.py b/tensors/config.py index f7fb876..21607b2 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -692,6 +692,29 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "clip_l": "clip_l.safetensors", "clip_t5": "t5xxl_fp16.safetensors", }, + # Flux.2 Klein 9B — newer Black Forest Labs release. Different architecture + # from Flux.1: single Qwen3-8B text encoder (12288-dim conditioning, 3 stacked + # hidden layers), Flux2 latent format, custom Flux2Scheduler, dedicated VAE + # (flux2-vae.safetensors). Workflow uses CLIPLoader (type=flux2) instead of + # DualCLIPLoader, and the custom-sampling pipeline (SamplerCustomAdvanced + + # BasicGuider + Flux2Scheduler) instead of plain KSampler. + "flux2_klein": { + "quality_prefix": "", + "negative_prompt": "", + "width": 1024, + "height": 1024, + "portrait": (832, 1216), + "landscape": (1216, 832), + "cfg": 1.0, + "guidance": 3.5, + "sampler": "euler", + "scheduler": "simple", # unused — Flux2Scheduler provides sigmas + "steps": 20, + "vae": "flux2-vae.safetensors", + "external_clip": True, + "clip_encoder": "qwen_3_8b_fp8mixed.safetensors", + "clip_type": "flux2", + }, "zimage": { "quality_prefix": "", "negative_prompt": "", @@ -730,6 +753,28 @@ def _is_flux_unet_only(name_lower: str) -> bool: return any(p in name_lower for p in FLUX_UNET_ONLY_PATTERNS) +# Flux.2 Klein 9B filename substrings (case-insensitive). These checkpoints are +# UNet-only AND require the Flux.2 architecture (Qwen3-8B encoder, Flux2 +# scheduler). Detection is primarily via base_model field ("Flux.2 Klein"); the +# filename patterns are a fallback for checkpoints with missing/wrong DB +# metadata. Filename match wins over base_model. +FLUX2_KLEIN_PATTERNS: tuple[str, ...] = ( + "lust_", # lust_v10.safetensors + "moodydesire", # moodyDesireMix_v20PRO.safetensors +) + + +def _is_flux2_klein(name_lower: str, base_lower: str) -> bool: + """True if the model is a Flux.2 Klein 9B checkpoint. + + Detects via base_model field ("flux.2 klein", "flux2 klein") first, + then filename pattern fallback. + """ + if "flux.2 klein" in base_lower or "flux2 klein" in base_lower: + return True + return any(p in name_lower for p in FLUX2_KLEIN_PATTERNS) + + def detect_model_family(model_name: str, base_model: str | None = None) -> str | None: # noqa: PLR0911 """Detect model family from filename or CivitAI base_model field. @@ -739,11 +784,19 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str | Returns: Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo, - sd15, sd15_lcm, flux, flux_schnell, flux_unet, zimage) or None if unknown + sd15, sd15_lcm, flux, flux_schnell, flux_unet, flux2_klein, zimage) + or None if unknown """ name_lower = model_name.lower() base_lower = (base_model or "").lower() + # Flux.2 Klein 9B override: must run BEFORE flux_unet (Klein patterns like + # "lust_" and "moodydesire" also appear in FLUX_UNET_ONLY_PATTERNS) AND + # before the generic flux check. Detection prefers base_model field but + # falls back to filename pattern for checkpoints with missing metadata. + if _is_flux2_klein(name_lower, base_lower): + return "flux2_klein" + # UNet-only Flux override: must run BEFORE the generic flux check below, # since some patterns ("cyberrealisticflux", "getphatflux", "fcfluxpony") # also contain the substring "flux". Filename wins over base_model diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 3042970..876d0d6 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -321,11 +321,16 @@ class TestModelFamilyDetection: assert detect_model_family("model.safetensors", "Flux.1 S schnell") == "flux_schnell" def test_detect_flux_unet_lust(self) -> None: - """lust_*.safetensors → flux_unet (no 'flux' in name, custom pattern).""" + """lust_*.safetensors → flux2_klein (Klein detection wins via filename pattern). + + Originally classified as flux_unet, but lust_v10 is actually Flux.2 Klein 9B + (per CivitAI base_model). Klein detection runs before flux_unet, so the + lust_ pattern in FLUX2_KLEIN_PATTERNS takes precedence. + """ from tensors.config import detect_model_family - assert detect_model_family("lust_v10.safetensors") == "flux_unet" - assert detect_model_family("LUST_v10.safetensors") == "flux_unet" + assert detect_model_family("lust_v10.safetensors") == "flux2_klein" + assert detect_model_family("LUST_v10.safetensors") == "flux2_klein" def test_detect_flux_unet_cyberrealistic(self) -> None: """cyberrealisticFlux_*.safetensors → flux_unet (intercepts generic 'flux' match).""" @@ -340,10 +345,14 @@ class TestModelFamilyDetection: assert detect_model_family("getphatFLUXReality_v11Softcore.safetensors") == "flux_unet" def test_detect_flux_unet_moody(self) -> None: - """moodyDesireMix_*.safetensors → flux_unet (no 'flux' in name).""" + """moodyDesireMix_*.safetensors → flux2_klein (Klein, not Flux.1 D). + + Originally classified as flux_unet, but moodyDesireMix is Flux.2 Klein + 9B per CivitAI. Klein detection wins via the moodydesire filename pattern. + """ from tensors.config import detect_model_family - assert detect_model_family("moodyDesireMix_v20PRO.safetensors") == "flux_unet" + assert detect_model_family("moodyDesireMix_v20PRO.safetensors") == "flux2_klein" def test_detect_flux_unet_fcfluxpony(self) -> None: """fcFluxPony*.safetensors → flux_unet (intercepts flux + fluxpony).""" @@ -358,11 +367,15 @@ class TestModelFamilyDetection: """Filename UNet-only pattern wins over a (likely wrong) CivitAI base_model tag.""" from tensors.config import detect_model_family - # Even if CivitAI claims "SDXL 1.0", the filename pattern wins. - assert detect_model_family("lust_v10.safetensors", "SDXL 1.0") == "flux_unet" + # cyberrealisticFlux: filename pattern wins over wrong "Pony" tag → flux_unet. assert ( detect_model_family("cyberrealisticFlux_v25.safetensors", "Pony") == "flux_unet" ) + # getphat: filename pattern wins over wrong "SDXL 1.0" tag → flux_unet. + assert ( + detect_model_family("getphatFLUXReality_v11.safetensors", "SDXL 1.0") + == "flux_unet" + ) def test_flux_unet_family_defaults_has_external_clip(self) -> None: """flux_unet preset advertises external_clip + clip filenames.""" @@ -381,12 +394,106 @@ class TestModelFamilyDetection: """flux_unet model resolves to the flux_unet preset with external_clip set.""" from tensors.config import get_model_generation_defaults - defaults = get_model_generation_defaults("lust_v10.safetensors") + # getphat is genuinely Flux.1 D UNet-only (not Klein). + defaults = get_model_generation_defaults("getphatFLUXReality_v11.safetensors") assert defaults["family"] == "flux_unet" assert defaults["external_clip"] is True assert defaults["sampler"] == "euler" assert defaults["scheduler"] == "simple" + # ---- Flux.2 Klein 9B detection + workflow ---- + + def test_detect_flux2_klein_from_base_model(self) -> None: + """base_model='Flux.2 Klein 9B-base' → flux2_klein.""" + from tensors.config import detect_model_family + + assert detect_model_family("anything.safetensors", "Flux.2 Klein 9B-base") == "flux2_klein" + assert detect_model_family("anything.safetensors", "Flux.2 Klein 9B") == "flux2_klein" + # Compact variant ("flux2 klein" without the dot) — also accepted. + assert detect_model_family("anything.safetensors", "flux2 Klein") == "flux2_klein" + + def test_detect_flux2_klein_from_filename(self) -> None: + """Filename fallback: lust_ and moodydesire → flux2_klein even without DB metadata.""" + from tensors.config import detect_model_family + + assert detect_model_family("lust_v10.safetensors") == "flux2_klein" + assert detect_model_family("moodyDesireMix_v20PRO.safetensors") == "flux2_klein" + + def test_detect_flux2_klein_overrides_flux_unet(self) -> None: + """Klein detection runs BEFORE flux_unet, so Klein patterns win.""" + from tensors.config import detect_model_family + + # lust_ matches both FLUX2_KLEIN_PATTERNS and FLUX_UNET_ONLY_PATTERNS. + # Klein check runs first → flux2_klein. + assert detect_model_family("lust_v10.safetensors") == "flux2_klein" + # Even with wrong base_model, Klein filename wins. + assert detect_model_family("lust_v10.safetensors", "SDXL 1.0") == "flux2_klein" + + def test_flux2_klein_family_defaults(self) -> None: + """flux2_klein preset has external_clip + Qwen3 encoder + Flux.2 VAE.""" + from tensors.config import MODEL_FAMILY_DEFAULTS + + defaults = MODEL_FAMILY_DEFAULTS["flux2_klein"] + assert defaults["external_clip"] is True + assert defaults["clip_encoder"] == "qwen_3_8b_fp8mixed.safetensors" + assert defaults["clip_type"] == "flux2" + assert defaults["vae"] == "flux2-vae.safetensors" + assert defaults["cfg"] == 1.0 + assert defaults["guidance"] == 3.5 + + def test_build_workflow_flux2_klein_uses_cliploader(self) -> None: + """Flux.2 Klein workflow uses CLIPLoader(type=flux2) + EmptyFlux2LatentImage + + custom-sampling pipeline (no plain KSampler, no DualCLIPLoader).""" + from tensors.comfyui import _build_workflow + + wf = _build_workflow(prompt="test", model="lust_v10.safetensors", seed=42) + + class_types = {node["class_type"] for node in wf.values()} + # Required Flux.2-specific nodes + assert "CLIPLoader" in class_types + assert "EmptyFlux2LatentImage" in class_types + assert "Flux2Scheduler" in class_types + assert "BasicGuider" in class_types + assert "SamplerCustomAdvanced" in class_types + assert "RandomNoise" in class_types + # Forbidden — these belong to Flux.1 / SDXL paths + assert "DualCLIPLoader" not in class_types + assert "KSampler" not in class_types + assert "EmptySD3LatentImage" not in class_types + assert "CheckpointLoaderSimple" not in class_types + assert "ModelSamplingFlux" not in class_types + # Verify CLIPLoader is configured correctly + clip_nodes = [n for n in wf.values() if n["class_type"] == "CLIPLoader"] + assert len(clip_nodes) == 1 + assert clip_nodes[0]["inputs"]["type"] == "flux2" + assert clip_nodes[0]["inputs"]["clip_name"] == "qwen_3_8b_fp8mixed.safetensors" + # VAE is the Flux.2 one, not Flux.1's ae.safetensors + vae_nodes = [n for n in wf.values() if n["class_type"] == "VAELoader"] + assert vae_nodes[0]["inputs"]["vae_name"] == "flux2-vae.safetensors" + + def test_build_workflow_flux2_klein_with_lora(self) -> None: + """LoRA injection inserts LoraLoader and reroutes BasicGuider + text encoders.""" + from tensors.comfyui import _build_workflow + + wf = _build_workflow( + prompt="test", + model="lust_v10.safetensors", + seed=42, + lora_name="some_flux_lora.safetensors", + lora_strength=0.8, + ) + + # LoRA node added at "110" + assert "110" in wf + assert wf["110"]["class_type"] == "LoraLoader" + assert wf["110"]["inputs"]["lora_name"] == "some_flux_lora.safetensors" + assert wf["110"]["inputs"]["strength_model"] == 0.8 + # BasicGuider (model consumer) now wired to LoRA output + assert wf["154"]["inputs"]["model"] == ["110", 0] + # Both text encoders re-routed to LoRA clip output + assert wf["130"]["inputs"]["clip"] == ["110", 1] + assert wf["131"]["inputs"]["clip"] == ["110", 1] + def test_detect_sdxl_variants(self) -> None: """Test detecting SDXL family variants.""" from tensors.config import detect_model_family @@ -607,7 +714,9 @@ class TestFluxUnetWorkflowBuilder: """flux_unet checkpoints emit UNETLoader + DualCLIPLoader + VAELoader and NO CheckpointLoaderSimple.""" from tensors.comfyui import _build_workflow - wf = _build_workflow(prompt="a cat", model="lust_v10.safetensors") + # getphat is genuinely Flux.1 D UNet-only. lust_v10 used to live here + # but is actually Flux.2 Klein — see TestFamilyDetection. + wf = _build_workflow(prompt="a cat", model="getphatFLUXReality_v11.safetensors") # Three split loaders at the canonical IDs assert wf["100"]["class_type"] == "UNETLoader" @@ -619,7 +728,7 @@ class TestFluxUnetWorkflowBuilder: assert node["class_type"] != "CheckpointLoaderSimple" # UNet filename plumbed through - assert wf["100"]["inputs"]["unet_name"] == "lust_v10.safetensors" + assert wf["100"]["inputs"]["unet_name"] == "getphatFLUXReality_v11.safetensors" # DualCLIPLoader configured for flux with both encoders clip_inputs = wf["101"]["inputs"] @@ -642,7 +751,7 @@ class TestFluxUnetWorkflowBuilder: from tensors.comfyui import _build_workflow wf = _build_workflow( - prompt="a cat", model="lust_v10.safetensors", cfg=7.5 + prompt="a cat", model="getphatFLUXReality_v11.safetensors", cfg=7.5 ) assert wf["160"]["inputs"]["cfg"] == 1.0 # The caller's cfg=7.5 should re-route to FluxGuidance (same precedence as plain flux) @@ -654,7 +763,7 @@ class TestFluxUnetWorkflowBuilder: wf = _build_workflow( prompt="a cat", - model="lust_v10.safetensors", + model="getphatFLUXReality_v11.safetensors", lora_name="my_style.safetensors", lora_strength=0.6, ) @@ -675,7 +784,7 @@ class TestFluxUnetWorkflowBuilder: wf = _build_workflow( prompt="a cat", - model="lust_v10.safetensors", + model="getphatFLUXReality_v11.safetensors", vae="other_vae.safetensors", ) assert wf["102"]["inputs"]["vae_name"] == "other_vae.safetensors" @@ -692,8 +801,11 @@ class TestFluxUnetWorkflowBuilder: """ from tensors.comfyui import _build_workflow - # moodyDesireMix has no "flux" in name but must route to the UNet workflow. - wf = _build_workflow(prompt="a cat", model="moodyDesireMix_v20PRO.safetensors") + # fcFluxPony is genuinely Flux.1 D (UNet-only) — moodyDesire was Klein. + wf = _build_workflow( + prompt="a cat", + model="fcFluxPonyPerfectBase_fcFluxPerfectBase.safetensors", + ) assert wf["100"]["class_type"] == "UNETLoader" assert wf["101"]["class_type"] == "DualCLIPLoader"