From b0b5bca5f8a63947bde0e800ebf11a6f3a4471c7 Mon Sep 17 00:00:00 2001 From: aladac Date: Mon, 18 May 2026 23:39:04 +0200 Subject: [PATCH] style: clear all 23 pre-existing ruff lint errors Master had been failing CI lint since before this PR. Knock out the backlog so the parallel-queue PR can ship green and future PRs don't inherit the red baseline. Changes by category: - UP042 (9): Migrate `class Foo(str, Enum)` to `class Foo(StrEnum)` in tensors/config.py (7 enums) and tensors/server/search_routes.py (2). Requires Python 3.11+, already enforced by `requires-python = ">=3.12"`. - PLR2004 (3): Extract magic comparison values in comfyui_api_routes.py to module-level constants (_DEFAULT_STEPS, _DEFAULT_CFG, _PROMPT_LOG_TRUNCATE). - PLW0108 (2): Inline `lambda: StubDB()` -> `StubDB` in test_server.py. - PLR0915 (3): Add explicit `# noqa: PLR0915` to typer command bodies that are intentionally long (template, templates_extract, _wait_for_completion_ws). - PLR1714 (1): `file_path.name == model or file_path.stem == model` -> `model in (file_path.name, file_path.stem)` in cli.py:3047. - SIM113 (1): Use `enumerate(as_completed(futures), start=1)` for the completion counter in style-sweep parallel loop. - RUF059 (1): Prefix unused tuple-unpacked vars with `_` in _run_one. - SIM105 (1): `try: ws.close() except Exception: pass` -> `contextlib.suppress(Exception)` in comfyui.py. - PLC0415 (1): Add missing `# noqa: PLC0415` to the second of two function-scoped tensors.config imports. No behavior changes. All 374 tests still pass. --- tensors/cli.py | 16 +++++++--------- tensors/comfyui.py | 7 +++---- tensors/config.py | 16 ++++++++-------- tensors/server/comfyui_api_routes.py | 14 +++++++++++--- tensors/server/search_routes.py | 6 +++--- tests/test_server.py | 4 ++-- 6 files changed, 34 insertions(+), 29 deletions(-) diff --git a/tensors/cli.py b/tensors/cli.py index 570a88e..b29c673 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -1485,7 +1485,7 @@ def _run_generation( # noqa: PLR0915 # ---- Resolve preset defaults for None params (both remote and local need these) ---- from tensors.config import resolve_orientation # noqa: PLC0415 - from tensors.config import resolve_remote as do_resolve_remote + from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415 # Use already-detected family_defaults from DB lookup above (not filename guessing) if family_defaults: @@ -2087,7 +2087,7 @@ def style_sweep( # noqa: PLR0915 def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]: """Run a single style. Returns the result dict (success or error captured).""" - idx, entry_in, res, opath = task + _idx, _entry_in, res, opath = task composed = res["prompt"] start = time.perf_counter() try: @@ -2139,11 +2139,9 @@ def style_sweep( # noqa: PLR0915 with ThreadPoolExecutor(max_workers=parallel_queue) as pool: futures = {pool.submit(_run_one, task): task for task in pending_tasks} - completed = 0 - for fut in as_completed(futures): - completed += 1 + for completed, fut in enumerate(as_completed(futures), start=1): task = futures[fut] - idx, _entry, _res, _out_path = task + idx, _entry, _res, _out_path = task # idx used in log message below try: res = fut.result() except Exception as ex: @@ -2226,7 +2224,7 @@ def _print_styles_list(styles_origin: str, entries: list[dict[str, str]]) -> Non @app.command() -def template( +def template( # noqa: PLR0915 model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")], lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, @@ -3044,7 +3042,7 @@ def scene_extract( target_file = None for f in files: file_path = Path(f["file_path"]) - if file_path.name == model or file_path.stem == model: + if model in (file_path.name, file_path.stem): target_file = f break @@ -3172,7 +3170,7 @@ app.add_typer(templates_app) @templates_app.command("extract") -def templates_extract( +def templates_extract( # noqa: PLR0915 model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")], orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait", no_overrides: Annotated[ diff --git a/tensors/comfyui.py b/tensors/comfyui.py index 5758a4c..13cd69c 100644 --- a/tensors/comfyui.py +++ b/tensors/comfyui.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import copy import json import random @@ -388,7 +389,7 @@ def queue_prompt( return None -def _wait_for_completion_ws( +def _wait_for_completion_ws( # noqa: PLR0915 prompt_id: str, url: str, client_id: str, @@ -494,10 +495,8 @@ def _wait_for_completion_ws( break finally: - try: + with contextlib.suppress(Exception): ws.close() - except Exception: - pass # Fetch final outputs from history to ensure we have everything try: diff --git a/tensors/config.py b/tensors/config.py index 19f46f5..c4fe497 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -4,7 +4,7 @@ from __future__ import annotations import os import tomllib -from enum import Enum +from enum import StrEnum from pathlib import Path from typing import Any @@ -65,7 +65,7 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" # ============================================================================ -class Provider(str, Enum): +class Provider(StrEnum): """Model search providers.""" civitai = "civitai" @@ -73,7 +73,7 @@ class Provider(str, Enum): all = "all" -class ModelType(str, Enum): +class ModelType(StrEnum): """CivitAI model types.""" checkpoint = "checkpoint" @@ -110,7 +110,7 @@ class ModelType(str, Enum): return mapping[self.value] -class BaseModel(str, Enum): +class BaseModel(StrEnum): """Common base models.""" # Stable Diffusion 1.x @@ -166,7 +166,7 @@ class BaseModel(str, Enum): return mapping[self.value] -class SortOrder(str, Enum): +class SortOrder(StrEnum): """Sort options for search.""" downloads = "downloads" @@ -183,7 +183,7 @@ class SortOrder(str, Enum): return mapping[self.value] -class Period(str, Enum): +class Period(StrEnum): """Time period for sorting/filtering.""" all = "all" @@ -204,7 +204,7 @@ class Period(str, Enum): return mapping[self.value] -class NsfwLevel(str, Enum): +class NsfwLevel(StrEnum): """NSFW content filter level.""" none = "none" @@ -219,7 +219,7 @@ class NsfwLevel(str, Enum): return self.value.capitalize() if self.value != "none" else "None" -class CommercialUse(str, Enum): +class CommercialUse(StrEnum): """Commercial use permissions.""" none = "none" diff --git a/tensors/server/comfyui_api_routes.py b/tensors/server/comfyui_api_routes.py index 1ce44e2..9974330 100644 --- a/tensors/server/comfyui_api_routes.py +++ b/tensors/server/comfyui_api_routes.py @@ -27,6 +27,14 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"]) +# Schema default sentinels — see GenerateRequest defaults. These let us detect +# "user accepted default" vs "user explicitly chose this value matching default" +# is intentionally not distinguished; both paths apply family overrides. +_DEFAULT_STEPS = 20 +_DEFAULT_CFG = 7.0 +# Logging truncation threshold for long prompts in info-level output. +_PROMPT_LOG_TRUNCATE = 100 + # ============================================================================= # Request/Response Models @@ -262,9 +270,9 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]: sampler = family_defaults["sampler"] if request.scheduler == "normal": # Default value in schema scheduler = family_defaults["scheduler"] - if request.steps == 20: # Default value in schema + if request.steps == _DEFAULT_STEPS: # Default value in schema steps = family_defaults["steps"] - if request.cfg == 7.0: # Default value in schema + if request.cfg == _DEFAULT_CFG: # Default value in schema cfg = family_defaults["cfg"] # Only override VAE if user explicitly specified one; # otherwise use checkpoint's built-in VAE (vae stays None) @@ -290,7 +298,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]: sampler, scheduler, lora_info, - request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt, + request.prompt[:_PROMPT_LOG_TRUNCATE] + "..." if len(request.prompt) > _PROMPT_LOG_TRUNCATE else request.prompt, ) if request.negative_prompt: logger.debug("Negative prompt: %r", request.negative_prompt[:100]) diff --git a/tensors/server/search_routes.py b/tensors/server/search_routes.py index 839ca39..3ca58dd 100644 --- a/tensors/server/search_routes.py +++ b/tensors/server/search_routes.py @@ -4,7 +4,7 @@ from __future__ import annotations import contextlib import logging -from enum import Enum +from enum import StrEnum from typing import Annotated, Any from fastapi import APIRouter, Query @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/search", tags=["Search"]) -class Provider(str, Enum): +class Provider(StrEnum): """Search provider options.""" civitai = "civitai" @@ -45,7 +45,7 @@ class Provider(str, Enum): all = "all" -class SortOrder(str, Enum): +class SortOrder(StrEnum): """Sort order options.""" downloads = "downloads" diff --git a/tests/test_server.py b/tests/test_server.py index 339f7ed..babed70 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1237,7 +1237,7 @@ class TestDownloadBackgroundTasks: captured["api_key"] = api_key return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None} - monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB()) + monkeypatch.setattr(download_routes_module, "Database", StubDB) return captured def test_do_download_success(self, monkeypatch, tmp_path) -> None: @@ -1397,7 +1397,7 @@ class TestDownloadBackgroundTasks: def register_downloaded_file(self, *args, **kwargs): return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"} - monkeypatch.setattr(download_routes, "Database", lambda: FailingDB()) + monkeypatch.setattr(download_routes, "Database", FailingDB) dest_path = tmp_path / "model.safetensors" _do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})