style: clear all 23 pre-existing ruff lint errors
Master had been failing CI lint since before this PR. Knock out the backlog so the parallel-queue PR can ship green and future PRs don't inherit the red baseline. Changes by category: - UP042 (9): Migrate `class Foo(str, Enum)` to `class Foo(StrEnum)` in tensors/config.py (7 enums) and tensors/server/search_routes.py (2). Requires Python 3.11+, already enforced by `requires-python = ">=3.12"`. - PLR2004 (3): Extract magic comparison values in comfyui_api_routes.py to module-level constants (_DEFAULT_STEPS, _DEFAULT_CFG, _PROMPT_LOG_TRUNCATE). - PLW0108 (2): Inline `lambda: StubDB()` -> `StubDB` in test_server.py. - PLR0915 (3): Add explicit `# noqa: PLR0915` to typer command bodies that are intentionally long (template, templates_extract, _wait_for_completion_ws). - PLR1714 (1): `file_path.name == model or file_path.stem == model` -> `model in (file_path.name, file_path.stem)` in cli.py:3047. - SIM113 (1): Use `enumerate(as_completed(futures), start=1)` for the completion counter in style-sweep parallel loop. - RUF059 (1): Prefix unused tuple-unpacked vars with `_` in _run_one. - SIM105 (1): `try: ws.close() except Exception: pass` -> `contextlib.suppress(Exception)` in comfyui.py. - PLC0415 (1): Add missing `# noqa: PLC0415` to the second of two function-scoped tensors.config imports. No behavior changes. All 374 tests still pass.
This commit is contained in:
+7
-9
@@ -1485,7 +1485,7 @@ def _run_generation( # noqa: PLR0915
|
|||||||
|
|
||||||
# ---- Resolve preset defaults for None params (both remote and local need these) ----
|
# ---- Resolve preset defaults for None params (both remote and local need these) ----
|
||||||
from tensors.config import resolve_orientation # noqa: PLC0415
|
from tensors.config import resolve_orientation # noqa: PLC0415
|
||||||
from tensors.config import resolve_remote as do_resolve_remote
|
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
|
||||||
|
|
||||||
# Use already-detected family_defaults from DB lookup above (not filename guessing)
|
# Use already-detected family_defaults from DB lookup above (not filename guessing)
|
||||||
if family_defaults:
|
if family_defaults:
|
||||||
@@ -2087,7 +2087,7 @@ def style_sweep( # noqa: PLR0915
|
|||||||
|
|
||||||
def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]:
|
def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]:
|
||||||
"""Run a single style. Returns the result dict (success or error captured)."""
|
"""Run a single style. Returns the result dict (success or error captured)."""
|
||||||
idx, entry_in, res, opath = task
|
_idx, _entry_in, res, opath = task
|
||||||
composed = res["prompt"]
|
composed = res["prompt"]
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
@@ -2139,11 +2139,9 @@ def style_sweep( # noqa: PLR0915
|
|||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=parallel_queue) as pool:
|
with ThreadPoolExecutor(max_workers=parallel_queue) as pool:
|
||||||
futures = {pool.submit(_run_one, task): task for task in pending_tasks}
|
futures = {pool.submit(_run_one, task): task for task in pending_tasks}
|
||||||
completed = 0
|
for completed, fut in enumerate(as_completed(futures), start=1):
|
||||||
for fut in as_completed(futures):
|
|
||||||
completed += 1
|
|
||||||
task = futures[fut]
|
task = futures[fut]
|
||||||
idx, _entry, _res, _out_path = task
|
idx, _entry, _res, _out_path = task # idx used in log message below
|
||||||
try:
|
try:
|
||||||
res = fut.result()
|
res = fut.result()
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@@ -2226,7 +2224,7 @@ def _print_styles_list(styles_origin: str, entries: list[dict[str, str]]) -> Non
|
|||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def template(
|
def template( # noqa: PLR0915
|
||||||
model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")],
|
model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")],
|
||||||
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
||||||
@@ -3044,7 +3042,7 @@ def scene_extract(
|
|||||||
target_file = None
|
target_file = None
|
||||||
for f in files:
|
for f in files:
|
||||||
file_path = Path(f["file_path"])
|
file_path = Path(f["file_path"])
|
||||||
if file_path.name == model or file_path.stem == model:
|
if model in (file_path.name, file_path.stem):
|
||||||
target_file = f
|
target_file = f
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -3172,7 +3170,7 @@ app.add_typer(templates_app)
|
|||||||
|
|
||||||
|
|
||||||
@templates_app.command("extract")
|
@templates_app.command("extract")
|
||||||
def templates_extract(
|
def templates_extract( # noqa: PLR0915
|
||||||
model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")],
|
model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")],
|
||||||
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait",
|
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait",
|
||||||
no_overrides: Annotated[
|
no_overrides: Annotated[
|
||||||
|
|||||||
+3
-4
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
@@ -388,7 +389,7 @@ def queue_prompt(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _wait_for_completion_ws(
|
def _wait_for_completion_ws( # noqa: PLR0915
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
url: str,
|
url: str,
|
||||||
client_id: str,
|
client_id: str,
|
||||||
@@ -494,10 +495,8 @@ def _wait_for_completion_ws(
|
|||||||
break
|
break
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
ws.close()
|
ws.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Fetch final outputs from history to ensure we have everything
|
# Fetch final outputs from history to ensure we have everything
|
||||||
try:
|
try:
|
||||||
|
|||||||
+8
-8
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import tomllib
|
import tomllib
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class Provider(str, Enum):
|
class Provider(StrEnum):
|
||||||
"""Model search providers."""
|
"""Model search providers."""
|
||||||
|
|
||||||
civitai = "civitai"
|
civitai = "civitai"
|
||||||
@@ -73,7 +73,7 @@ class Provider(str, Enum):
|
|||||||
all = "all"
|
all = "all"
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(StrEnum):
|
||||||
"""CivitAI model types."""
|
"""CivitAI model types."""
|
||||||
|
|
||||||
checkpoint = "checkpoint"
|
checkpoint = "checkpoint"
|
||||||
@@ -110,7 +110,7 @@ class ModelType(str, Enum):
|
|||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(str, Enum):
|
class BaseModel(StrEnum):
|
||||||
"""Common base models."""
|
"""Common base models."""
|
||||||
|
|
||||||
# Stable Diffusion 1.x
|
# Stable Diffusion 1.x
|
||||||
@@ -166,7 +166,7 @@ class BaseModel(str, Enum):
|
|||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
class SortOrder(str, Enum):
|
class SortOrder(StrEnum):
|
||||||
"""Sort options for search."""
|
"""Sort options for search."""
|
||||||
|
|
||||||
downloads = "downloads"
|
downloads = "downloads"
|
||||||
@@ -183,7 +183,7 @@ class SortOrder(str, Enum):
|
|||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
class Period(str, Enum):
|
class Period(StrEnum):
|
||||||
"""Time period for sorting/filtering."""
|
"""Time period for sorting/filtering."""
|
||||||
|
|
||||||
all = "all"
|
all = "all"
|
||||||
@@ -204,7 +204,7 @@ class Period(str, Enum):
|
|||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
class NsfwLevel(str, Enum):
|
class NsfwLevel(StrEnum):
|
||||||
"""NSFW content filter level."""
|
"""NSFW content filter level."""
|
||||||
|
|
||||||
none = "none"
|
none = "none"
|
||||||
@@ -219,7 +219,7 @@ class NsfwLevel(str, Enum):
|
|||||||
return self.value.capitalize() if self.value != "none" else "None"
|
return self.value.capitalize() if self.value != "none" else "None"
|
||||||
|
|
||||||
|
|
||||||
class CommercialUse(str, Enum):
|
class CommercialUse(StrEnum):
|
||||||
"""Commercial use permissions."""
|
"""Commercial use permissions."""
|
||||||
|
|
||||||
none = "none"
|
none = "none"
|
||||||
|
|||||||
@@ -27,6 +27,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"])
|
router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"])
|
||||||
|
|
||||||
|
# Schema default sentinels — see GenerateRequest defaults. These let us detect
|
||||||
|
# "user accepted default" vs "user explicitly chose this value matching default"
|
||||||
|
# is intentionally not distinguished; both paths apply family overrides.
|
||||||
|
_DEFAULT_STEPS = 20
|
||||||
|
_DEFAULT_CFG = 7.0
|
||||||
|
# Logging truncation threshold for long prompts in info-level output.
|
||||||
|
_PROMPT_LOG_TRUNCATE = 100
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Request/Response Models
|
# Request/Response Models
|
||||||
@@ -262,9 +270,9 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
|||||||
sampler = family_defaults["sampler"]
|
sampler = family_defaults["sampler"]
|
||||||
if request.scheduler == "normal": # Default value in schema
|
if request.scheduler == "normal": # Default value in schema
|
||||||
scheduler = family_defaults["scheduler"]
|
scheduler = family_defaults["scheduler"]
|
||||||
if request.steps == 20: # Default value in schema
|
if request.steps == _DEFAULT_STEPS: # Default value in schema
|
||||||
steps = family_defaults["steps"]
|
steps = family_defaults["steps"]
|
||||||
if request.cfg == 7.0: # Default value in schema
|
if request.cfg == _DEFAULT_CFG: # Default value in schema
|
||||||
cfg = family_defaults["cfg"]
|
cfg = family_defaults["cfg"]
|
||||||
# Only override VAE if user explicitly specified one;
|
# Only override VAE if user explicitly specified one;
|
||||||
# otherwise use checkpoint's built-in VAE (vae stays None)
|
# otherwise use checkpoint's built-in VAE (vae stays None)
|
||||||
@@ -290,7 +298,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
|||||||
sampler,
|
sampler,
|
||||||
scheduler,
|
scheduler,
|
||||||
lora_info,
|
lora_info,
|
||||||
request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt,
|
request.prompt[:_PROMPT_LOG_TRUNCATE] + "..." if len(request.prompt) > _PROMPT_LOG_TRUNCATE else request.prompt,
|
||||||
)
|
)
|
||||||
if request.negative_prompt:
|
if request.negative_prompt:
|
||||||
logger.debug("Negative prompt: %r", request.negative_prompt[:100])
|
logger.debug("Negative prompt: %r", request.negative_prompt[:100])
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, Query
|
||||||
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter(prefix="/api/search", tags=["Search"])
|
router = APIRouter(prefix="/api/search", tags=["Search"])
|
||||||
|
|
||||||
|
|
||||||
class Provider(str, Enum):
|
class Provider(StrEnum):
|
||||||
"""Search provider options."""
|
"""Search provider options."""
|
||||||
|
|
||||||
civitai = "civitai"
|
civitai = "civitai"
|
||||||
@@ -45,7 +45,7 @@ class Provider(str, Enum):
|
|||||||
all = "all"
|
all = "all"
|
||||||
|
|
||||||
|
|
||||||
class SortOrder(str, Enum):
|
class SortOrder(StrEnum):
|
||||||
"""Sort order options."""
|
"""Sort order options."""
|
||||||
|
|
||||||
downloads = "downloads"
|
downloads = "downloads"
|
||||||
|
|||||||
@@ -1237,7 +1237,7 @@ class TestDownloadBackgroundTasks:
|
|||||||
captured["api_key"] = api_key
|
captured["api_key"] = api_key
|
||||||
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
|
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
|
||||||
|
|
||||||
monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB())
|
monkeypatch.setattr(download_routes_module, "Database", StubDB)
|
||||||
return captured
|
return captured
|
||||||
|
|
||||||
def test_do_download_success(self, monkeypatch, tmp_path) -> None:
|
def test_do_download_success(self, monkeypatch, tmp_path) -> None:
|
||||||
@@ -1397,7 +1397,7 @@ class TestDownloadBackgroundTasks:
|
|||||||
def register_downloaded_file(self, *args, **kwargs):
|
def register_downloaded_file(self, *args, **kwargs):
|
||||||
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
|
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
|
||||||
|
|
||||||
monkeypatch.setattr(download_routes, "Database", lambda: FailingDB())
|
monkeypatch.setattr(download_routes, "Database", FailingDB)
|
||||||
|
|
||||||
dest_path = tmp_path / "model.safetensors"
|
dest_path = tmp_path / "model.safetensors"
|
||||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||||
|
|||||||
Reference in New Issue
Block a user