style: clear all 23 pre-existing ruff lint errors
Master had been failing CI lint since before this PR. Knock out the backlog so the parallel-queue PR can ship green and future PRs don't inherit the red baseline. Changes by category: - UP042 (9): Migrate `class Foo(str, Enum)` to `class Foo(StrEnum)` in tensors/config.py (7 enums) and tensors/server/search_routes.py (2). Requires Python 3.11+, already enforced by `requires-python = ">=3.12"`. - PLR2004 (3): Extract magic comparison values in comfyui_api_routes.py to module-level constants (_DEFAULT_STEPS, _DEFAULT_CFG, _PROMPT_LOG_TRUNCATE). - PLW0108 (2): Inline `lambda: StubDB()` -> `StubDB` in test_server.py. - PLR0915 (3): Add explicit `# noqa: PLR0915` to typer command bodies that are intentionally long (template, templates_extract, _wait_for_completion_ws). - PLR1714 (1): `file_path.name == model or file_path.stem == model` -> `model in (file_path.name, file_path.stem)` in cli.py:3047. - SIM113 (1): Use `enumerate(as_completed(futures), start=1)` for the completion counter in style-sweep parallel loop. - RUF059 (1): Prefix unused tuple-unpacked vars with `_` in _run_one. - SIM105 (1): `try: ws.close() except Exception: pass` -> `contextlib.suppress(Exception)` in comfyui.py. - PLC0415 (1): Add missing `# noqa: PLC0415` to the second of two function-scoped tensors.config imports. No behavior changes. All 374 tests still pass.
This commit is contained in:
+7
-9
@@ -1485,7 +1485,7 @@ def _run_generation( # noqa: PLR0915
|
||||
|
||||
# ---- Resolve preset defaults for None params (both remote and local need these) ----
|
||||
from tensors.config import resolve_orientation # noqa: PLC0415
|
||||
from tensors.config import resolve_remote as do_resolve_remote
|
||||
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
|
||||
|
||||
# Use already-detected family_defaults from DB lookup above (not filename guessing)
|
||||
if family_defaults:
|
||||
@@ -2087,7 +2087,7 @@ def style_sweep( # noqa: PLR0915
|
||||
|
||||
def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]:
|
||||
"""Run a single style. Returns the result dict (success or error captured)."""
|
||||
idx, entry_in, res, opath = task
|
||||
_idx, _entry_in, res, opath = task
|
||||
composed = res["prompt"]
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
@@ -2139,11 +2139,9 @@ def style_sweep( # noqa: PLR0915
|
||||
|
||||
with ThreadPoolExecutor(max_workers=parallel_queue) as pool:
|
||||
futures = {pool.submit(_run_one, task): task for task in pending_tasks}
|
||||
completed = 0
|
||||
for fut in as_completed(futures):
|
||||
completed += 1
|
||||
for completed, fut in enumerate(as_completed(futures), start=1):
|
||||
task = futures[fut]
|
||||
idx, _entry, _res, _out_path = task
|
||||
idx, _entry, _res, _out_path = task # idx used in log message below
|
||||
try:
|
||||
res = fut.result()
|
||||
except Exception as ex:
|
||||
@@ -2226,7 +2224,7 @@ def _print_styles_list(styles_origin: str, entries: list[dict[str, str]]) -> Non
|
||||
|
||||
|
||||
@app.command()
|
||||
def template(
|
||||
def template( # noqa: PLR0915
|
||||
model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")],
|
||||
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
||||
@@ -3044,7 +3042,7 @@ def scene_extract(
|
||||
target_file = None
|
||||
for f in files:
|
||||
file_path = Path(f["file_path"])
|
||||
if file_path.name == model or file_path.stem == model:
|
||||
if model in (file_path.name, file_path.stem):
|
||||
target_file = f
|
||||
break
|
||||
|
||||
@@ -3172,7 +3170,7 @@ app.add_typer(templates_app)
|
||||
|
||||
|
||||
@templates_app.command("extract")
|
||||
def templates_extract(
|
||||
def templates_extract( # noqa: PLR0915
|
||||
model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")],
|
||||
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait",
|
||||
no_overrides: Annotated[
|
||||
|
||||
+3
-4
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
@@ -388,7 +389,7 @@ def queue_prompt(
|
||||
return None
|
||||
|
||||
|
||||
def _wait_for_completion_ws(
|
||||
def _wait_for_completion_ws( # noqa: PLR0915
|
||||
prompt_id: str,
|
||||
url: str,
|
||||
client_id: str,
|
||||
@@ -494,10 +495,8 @@ def _wait_for_completion_ws(
|
||||
break
|
||||
|
||||
finally:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fetch final outputs from history to ensure we have everything
|
||||
try:
|
||||
|
||||
+8
-8
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tomllib
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -65,7 +65,7 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class Provider(str, Enum):
|
||||
class Provider(StrEnum):
|
||||
"""Model search providers."""
|
||||
|
||||
civitai = "civitai"
|
||||
@@ -73,7 +73,7 @@ class Provider(str, Enum):
|
||||
all = "all"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
class ModelType(StrEnum):
|
||||
"""CivitAI model types."""
|
||||
|
||||
checkpoint = "checkpoint"
|
||||
@@ -110,7 +110,7 @@ class ModelType(str, Enum):
|
||||
return mapping[self.value]
|
||||
|
||||
|
||||
class BaseModel(str, Enum):
|
||||
class BaseModel(StrEnum):
|
||||
"""Common base models."""
|
||||
|
||||
# Stable Diffusion 1.x
|
||||
@@ -166,7 +166,7 @@ class BaseModel(str, Enum):
|
||||
return mapping[self.value]
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
class SortOrder(StrEnum):
|
||||
"""Sort options for search."""
|
||||
|
||||
downloads = "downloads"
|
||||
@@ -183,7 +183,7 @@ class SortOrder(str, Enum):
|
||||
return mapping[self.value]
|
||||
|
||||
|
||||
class Period(str, Enum):
|
||||
class Period(StrEnum):
|
||||
"""Time period for sorting/filtering."""
|
||||
|
||||
all = "all"
|
||||
@@ -204,7 +204,7 @@ class Period(str, Enum):
|
||||
return mapping[self.value]
|
||||
|
||||
|
||||
class NsfwLevel(str, Enum):
|
||||
class NsfwLevel(StrEnum):
|
||||
"""NSFW content filter level."""
|
||||
|
||||
none = "none"
|
||||
@@ -219,7 +219,7 @@ class NsfwLevel(str, Enum):
|
||||
return self.value.capitalize() if self.value != "none" else "None"
|
||||
|
||||
|
||||
class CommercialUse(str, Enum):
|
||||
class CommercialUse(StrEnum):
|
||||
"""Commercial use permissions."""
|
||||
|
||||
none = "none"
|
||||
|
||||
@@ -27,6 +27,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"])
|
||||
|
||||
# Schema default sentinels — see GenerateRequest defaults. These let us detect
|
||||
# "user accepted default" vs "user explicitly chose this value matching default"
|
||||
# is intentionally not distinguished; both paths apply family overrides.
|
||||
_DEFAULT_STEPS = 20
|
||||
_DEFAULT_CFG = 7.0
|
||||
# Logging truncation threshold for long prompts in info-level output.
|
||||
_PROMPT_LOG_TRUNCATE = 100
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
@@ -262,9 +270,9 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
||||
sampler = family_defaults["sampler"]
|
||||
if request.scheduler == "normal": # Default value in schema
|
||||
scheduler = family_defaults["scheduler"]
|
||||
if request.steps == 20: # Default value in schema
|
||||
if request.steps == _DEFAULT_STEPS: # Default value in schema
|
||||
steps = family_defaults["steps"]
|
||||
if request.cfg == 7.0: # Default value in schema
|
||||
if request.cfg == _DEFAULT_CFG: # Default value in schema
|
||||
cfg = family_defaults["cfg"]
|
||||
# Only override VAE if user explicitly specified one;
|
||||
# otherwise use checkpoint's built-in VAE (vae stays None)
|
||||
@@ -290,7 +298,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
||||
sampler,
|
||||
scheduler,
|
||||
lora_info,
|
||||
request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt,
|
||||
request.prompt[:_PROMPT_LOG_TRUNCATE] + "..." if len(request.prompt) > _PROMPT_LOG_TRUNCATE else request.prompt,
|
||||
)
|
||||
if request.negative_prompt:
|
||||
logger.debug("Negative prompt: %r", request.negative_prompt[:100])
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/search", tags=["Search"])
|
||||
|
||||
|
||||
class Provider(str, Enum):
|
||||
class Provider(StrEnum):
|
||||
"""Search provider options."""
|
||||
|
||||
civitai = "civitai"
|
||||
@@ -45,7 +45,7 @@ class Provider(str, Enum):
|
||||
all = "all"
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
class SortOrder(StrEnum):
|
||||
"""Sort order options."""
|
||||
|
||||
downloads = "downloads"
|
||||
|
||||
@@ -1237,7 +1237,7 @@ class TestDownloadBackgroundTasks:
|
||||
captured["api_key"] = api_key
|
||||
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
|
||||
|
||||
monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB())
|
||||
monkeypatch.setattr(download_routes_module, "Database", StubDB)
|
||||
return captured
|
||||
|
||||
def test_do_download_success(self, monkeypatch, tmp_path) -> None:
|
||||
@@ -1397,7 +1397,7 @@ class TestDownloadBackgroundTasks:
|
||||
def register_downloaded_file(self, *args, **kwargs):
|
||||
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
|
||||
|
||||
monkeypatch.setattr(download_routes, "Database", lambda: FailingDB())
|
||||
monkeypatch.setattr(download_routes, "Database", FailingDB)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
Reference in New Issue
Block a user