diff --git a/TODO.md b/TODO.md index 4192422..bc2926a 100644 --- a/TODO.md +++ b/TODO.md @@ -17,9 +17,9 @@ - [x] Step 3.4: Enhance `/api/generate` (gallery integration, full params) ## Phase 4: Client Mode for tsr CLI -- [ ] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper) -- [ ] Step 4.2: Add `[remotes]` config section + `--remote` flag support -- [ ] Step 4.3: Update CLI commands with `--remote` support (generate, images, models, dl, db) +- [x] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper) +- [x] Step 4.2: Add `[remotes]` config section + `--remote` flag support +- [x] Step 4.3: Update CLI commands with `--remote` support (generate, images, models, dl, db) ## Phase 5: Docker Deployment Automation (SKIPPED) - [x] Step 5.1: ~~Create `rocm-docker/docker-compose.yml`~~ (skipped) diff --git a/tensors/cli.py b/tensors/cli.py index 6489510..64acba4 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -19,15 +19,20 @@ from tensors.api import ( fetch_civitai_model_version, search_civitai, ) +from tensors.client import TsrClient, TsrClientError from tensors.config import ( CONFIG_FILE, BaseModel, ModelType, SortOrder, get_default_output_path, + get_remotes, load_api_key, load_config, + resolve_remote, save_config, + save_remote, + set_default_remote, ) from tensors.db import DB_PATH, Database from tensors.display import ( @@ -311,41 +316,74 @@ def download( output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False, ) -> None: - """Download a model from CivitAI.""" - key = api_key or load_api_key() + """Download a model from CivitAI (locally or to remote server).""" + # Check if remote is specified or configured + remote_url = resolve_remote(remote) - resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) - if not resolved_version_id: - if not version_id and not hash_val and not model_id: + if remote_url: + # Remote mode: use TsrClient API + if not version_id and not model_id and not hash_val: console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") - raise typer.Exit(1) + raise typer.Exit(1) - console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") - version_info = fetch_civitai_model_version(resolved_version_id, key, console) - if not version_info: - console.print("[red]Error: Could not fetch model version info.[/red]") - raise typer.Exit(1) + try: + with TsrClient(remote_url) as client: + console.print(f"[cyan]Starting download on {remote_url}...[/cyan]") + result = client.start_download( + version_id=version_id, + model_id=model_id, + hash_val=hash_val, + output_dir=str(output) if output else None, + ) + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e - model_type_str: str | None = version_info.get("model", {}).get("type") - output_dir = _prepare_download_dir(output, model_type_str) - if not output_dir: - raise typer.Exit(1) + if json_output: + console.print_json(data=result) + return - files: list[dict[str, Any]] = version_info.get("files", []) - primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) - if not primary_file: - console.print("[red]Error: No files found for this version.[/red]") - raise typer.Exit(1) + download_id = result.get("download_id") + console.print(f"[green]Download started:[/green] {download_id}") + console.print(f"[dim]Check status with: tsr images download-status {download_id} --remote {remote or 'default'}[/dim]") + else: + # Local mode: direct download + key = api_key or load_api_key() - filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") - dest_path = output_dir / filename + resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) + if not resolved_version_id: + if not version_id and not hash_val and not model_id: + console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") + raise typer.Exit(1) - _display_download_info(version_info, filename, primary_file, dest_path) + console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") + version_info = fetch_civitai_model_version(resolved_version_id, key, console) + if not version_info: + console.print("[red]Error: Could not fetch model version info.[/red]") + raise typer.Exit(1) - success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume) - if not success: - raise typer.Exit(1) + model_type_str: str | None = version_info.get("model", {}).get("type") + output_dir = _prepare_download_dir(output, model_type_str) + if not output_dir: + raise typer.Exit(1) + + files: list[dict[str, Any]] = version_info.get("files", []) + primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) + if not primary_file: + console.print("[red]Error: No files found for this version.[/red]") + raise typer.Exit(1) + + filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") + dest_path = output_dir / filename + + _display_download_info(version_info, filename, primary_file, dest_path) + + success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume) + if not success: + raise typer.Exit(1) def _display_download_info( @@ -401,9 +439,10 @@ def config( @app.command() def generate( prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")], - host: Annotated[str, typer.Option(help="sd-server address.")] = "127.0.0.1", - port: Annotated[int, typer.Option(help="sd-server port.")] = 8080, - output: Annotated[str, typer.Option("-o", help="Output directory.")] = ".", + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1", + port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080, + output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".", negative_prompt: Annotated[str, typer.Option("-n", help="Negative prompt.")] = "", width: Annotated[int, typer.Option("-W", help="Image width.")] = 512, height: Annotated[int, typer.Option("-H", help="Image height.")] = 512, @@ -413,48 +452,96 @@ def generate( sampler: Annotated[str, typer.Option(help="Sampler name.")] = "", scheduler: Annotated[str, typer.Option(help="Scheduler name.")] = "", batch_size: Annotated[int, typer.Option("-b", help="Number of images.")] = 1, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False, ) -> None: - """Generate images using a running sd-server.""" - from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415 + """Generate images using sd-server (local or remote).""" + # Check if remote is specified or configured + remote_url = resolve_remote(remote) - params = Txt2ImgParams( - prompt=prompt, - negative_prompt=negative_prompt, - width=width, - height=height, - steps=steps, - cfg_scale=cfg_scale, - seed=seed, - batch_size=batch_size, - sampler_name=sampler, - scheduler=scheduler, - ) + if remote_url: + # Remote mode: use TsrClient API + try: + with TsrClient(remote_url) as client: + console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]") + result = client.generate( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + steps=steps, + cfg_scale=cfg_scale, + seed=seed, + sampler_name=sampler, + scheduler=scheduler, + batch_size=batch_size, + ) + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e - with SDClient(host=host, port=port) as client: - console.print(f"[cyan]Generating {batch_size} image(s)...[/cyan]") - images = client.generate.txt2img(params) - paths = save_images(images, output) - for p in paths: - console.print(f"[green]Saved:[/green] {p}") + if json_output: + console.print_json(data=result) + return + + images = result.get("images", []) + for img in images: + console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}") + else: + # Local mode: direct sd-server connection + from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415 + + params = Txt2ImgParams( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + steps=steps, + cfg_scale=cfg_scale, + seed=seed, + batch_size=batch_size, + sampler_name=sampler, + scheduler=scheduler, + ) + + with SDClient(host=host, port=port) as client: + console.print(f"[cyan]Generating {batch_size} image(s)...[/cyan]") + images = client.generate.txt2img(params) + paths = save_images(images, output) + for p in paths: + console.print(f"[green]Saved:[/green] {p}") @app.command() def status( - host: Annotated[str, typer.Option(help="Wrapper API host.")] = "127.0.0.1", - port: Annotated[int, typer.Option(help="Wrapper API port.")] = 8080, + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1", + port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Show sd-server wrapper status.""" - import httpx # noqa: PLC0415 + # Check if remote is specified or configured + remote_url = resolve_remote(remote) - url = f"http://{host}:{port}/status" - try: - resp = httpx.get(url, timeout=10) - resp.raise_for_status() - data = resp.json() - except httpx.HTTPError as e: - console.print(f"[red]Error: Could not reach wrapper at {url}: {e}[/red]") - raise typer.Exit(1) from e + if remote_url: + # Remote mode: use TsrClient API + try: + with TsrClient(remote_url) as client: + data = client.status() + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + else: + # Local mode: direct HTTP call + import httpx # noqa: PLC0415 + + url = f"http://{host}:{port}/status" + try: + resp = httpx.get(url, timeout=10) + resp.raise_for_status() + data = resp.json() + except httpx.HTTPError as e: + console.print(f"[red]Error: Could not reach wrapper at {url}: {e}[/red]") + raise typer.Exit(1) from e if json_output: console.print_json(data=data) @@ -748,14 +835,360 @@ def db_stats( console.print(table) +# ============================================================================= +# Images Commands (Remote) +# ============================================================================= + +images_app = typer.Typer( + name="images", + help="Manage images in remote gallery.", + no_args_is_help=True, +) +app.add_typer(images_app, name="images") + + +def _get_client(remote: str | None) -> TsrClient: + """Get TsrClient for remote or raise error.""" + url = resolve_remote(remote) + if not url: + console.print("[red]Error: No remote specified. Use --remote or set default_remote in config.[/red]") + raise typer.Exit(1) + return TsrClient(url) + + +@images_app.command("list") +def images_list( + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 50, + offset: Annotated[int, typer.Option("--offset", help="Offset for pagination")] = 0, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """List images in remote gallery.""" + try: + with _get_client(remote) as client: + result = client.list_images(limit=limit, offset=offset) + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + images = result.get("images", []) + total = result.get("total", len(images)) + + if json_output: + console.print_json(data=result) + return + + if not images: + console.print("[yellow]No images in gallery.[/yellow]") + return + + table = Table(title=f"Gallery Images ({len(images)}/{total})", show_header=True, header_style="bold magenta") + table.add_column("ID", style="cyan") + table.add_column("Filename", style="green") + table.add_column("Size", style="white") + table.add_column("Created", style="dim") + + for img in images: + size = f"{img.get('width', '?')}x{img.get('height', '?')}" + created = img.get("created_at", "") + if isinstance(created, (int, float)): + from datetime import datetime # noqa: PLC0415 + + created = datetime.fromtimestamp(created).strftime("%Y-%m-%d %H:%M") + table.add_row(img.get("id", ""), img.get("filename", ""), size, str(created)) + + console.print(table) + + +@images_app.command("show") +def images_show( + image_id: Annotated[str, typer.Argument(help="Image ID to show")], + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Show image metadata.""" + try: + with _get_client(remote) as client: + meta = client.get_image_meta(image_id) + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + if json_output: + console.print_json(data=meta) + return + + table = Table(title=f"Image: {image_id}", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + + for key, value in meta.items(): + display_value = json.dumps(value, indent=2) if isinstance(value, dict) else str(value) + table.add_row(key, display_value) + + console.print(table) + + +@images_app.command("delete") +def images_delete( + image_id: Annotated[str, typer.Argument(help="Image ID to delete")], + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + force: Annotated[bool, typer.Option("-f", "--force", help="Skip confirmation")] = False, +) -> None: + """Delete an image from the gallery.""" + if not force: + confirm = typer.confirm(f"Delete image {image_id}?") + if not confirm: + console.print("[yellow]Cancelled.[/yellow]") + raise typer.Exit(0) + + try: + with _get_client(remote) as client: + client.delete_image(image_id) + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + console.print(f"[green]Deleted image: {image_id}[/green]") + + +@images_app.command("download") +def images_download( + image_id: Annotated[str, typer.Argument(help="Image ID to download")], + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file or directory")] = None, +) -> None: + """Download an image from the remote gallery.""" + try: + with _get_client(remote) as client: + content = client.download_image(image_id) + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + # Determine output path + if output is None: + dest = Path(f"{image_id}.png") + elif output.is_dir(): + dest = output / f"{image_id}.png" + else: + dest = output + + dest.write_bytes(content) + console.print(f"[green]Saved:[/green] {dest}") + + +# ============================================================================= +# Models Commands (Remote) +# ============================================================================= + +models_app = typer.Typer( + name="models", + help="Manage models on remote server.", + no_args_is_help=True, +) +app.add_typer(models_app, name="models") + + +@models_app.command("list") +def models_list( + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """List available models on remote server.""" + try: + with _get_client(remote) as client: + result = client.list_models() + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + if json_output: + console.print_json(data=result) + return + + models = result.get("models", []) + active = result.get("active", "") + + if not models: + console.print("[yellow]No models found.[/yellow]") + return + + table = Table(title="Available Models", show_header=True, header_style="bold magenta") + table.add_column("Status", style="dim", width=3) + table.add_column("Name", style="cyan") + table.add_column("Path", style="dim") + + for model in models: + path = model.get("path", "") + name = model.get("name", Path(path).stem if path else "") + is_active = active in {path, name} + status = "[green]✓[/green]" if is_active else "" + table.add_row(status, name, path) + + console.print(table) + + +@models_app.command("active") +def models_active( + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Show currently active model.""" + try: + with _get_client(remote) as client: + result = client.get_active_model() + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + if json_output: + console.print_json(data=result) + return + + model = result.get("model", "None") + console.print(f"[bold]Active model:[/bold] {model}") + + +@models_app.command("switch") +def models_switch( + model: Annotated[str, typer.Argument(help="Model path or name to switch to")], + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, +) -> None: + """Switch to a different model on the remote server.""" + console.print(f"[cyan]Switching to model: {model}[/cyan]") + try: + with _get_client(remote) as client: + result = client.switch_model(model) + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + console.print(f"[green]{result.get('status', 'OK')}[/green]") + + +@models_app.command("loras") +def models_loras( + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """List available LoRAs on remote server.""" + try: + with _get_client(remote) as client: + result = client.list_loras() + except TsrClientError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + if json_output: + console.print_json(data=result) + return + + loras = result.get("loras", []) + if not loras: + console.print("[yellow]No LoRAs found.[/yellow]") + return + + table = Table(title="Available LoRAs", show_header=True, header_style="bold magenta") + table.add_column("Name", style="cyan") + table.add_column("Path", style="dim") + + for lora in loras: + path = lora.get("path", "") + name = lora.get("name", Path(path).stem if path else "") + table.add_row(name, path) + + console.print(table) + + +# ============================================================================= +# Remote Configuration Commands +# ============================================================================= + +remote_app = typer.Typer( + name="remote", + help="Manage remote server configuration.", + no_args_is_help=True, +) +app.add_typer(remote_app, name="remote") + + +@remote_app.command("list") +def remote_list( + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """List configured remotes.""" + from tensors.config import get_default_remote # noqa: PLC0415 + + remotes = get_remotes() + default = get_default_remote() + + if json_output: + console.print_json(data={"remotes": remotes, "default": default}) + return + + if not remotes: + console.print("[yellow]No remotes configured.[/yellow]") + console.print("[dim]Add one with: tsr remote add NAME URL[/dim]") + return + + table = Table(title="Configured Remotes", show_header=True, header_style="bold magenta") + table.add_column("Default", style="dim", width=3) + table.add_column("Name", style="cyan") + table.add_column("URL", style="green") + + for name, url in remotes.items(): + is_default = name == default + status = "[green]✓[/green]" if is_default else "" + table.add_row(status, name, url) + + console.print(table) + + +@remote_app.command("add") +def remote_add( + name: Annotated[str, typer.Argument(help="Remote name")], + url: Annotated[str, typer.Argument(help="Remote URL (e.g., http://host:8080)")], +) -> None: + """Add a remote server.""" + save_remote(name, url) + console.print(f"[green]Added remote:[/green] {name} → {url}") + + +@remote_app.command("default") +def remote_default( + name: Annotated[str | None, typer.Argument(help="Remote name to set as default (omit to clear)")] = None, +) -> None: + """Set or clear the default remote.""" + set_default_remote(name) + if name: + console.print(f"[green]Default remote set to:[/green] {name}") + else: + console.print("[green]Default remote cleared.[/green]") + + def main() -> int: """Main entry point.""" # Handle legacy invocation: tsr -> tsr info + known_commands = ( + "info", + "search", + "get", + "dl", + "download", + "config", + "generate", + "serve", + "status", + "reload", + "db", + "images", + "models", + "remote", + ) if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): arg = sys.argv[1] - if arg not in ("info", "search", "get", "dl", "download", "config", "generate", "serve", "status", "reload", "db") and ( - arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists() - ): + if arg not in known_commands and (arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()): sys.argv = [sys.argv[0], "info", *sys.argv[1:]] app() diff --git a/tensors/client.py b/tensors/client.py new file mode 100644 index 0000000..ef2c0b5 --- /dev/null +++ b/tensors/client.py @@ -0,0 +1,292 @@ +"""HTTP client for remote tsr server API.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import httpx + +if TYPE_CHECKING: + from collections.abc import Iterator + + +class TsrClientError(Exception): + """Error from TsrClient operations.""" + + +class TsrClient: + """HTTP client wrapper for tsr server API. + + Usage: + with TsrClient("http://junkpile:8080") as client: + images = client.list_images() + result = client.generate("a cat") + """ + + def __init__(self, base_url: str, timeout: float = 300.0) -> None: + """Initialize client with server URL.""" + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self._client: httpx.Client | None = None + + def __enter__(self) -> TsrClient: + self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout) + return self + + def __exit__(self, *exc: object) -> None: + if self._client: + self._client.close() + self._client = None + + @property + def client(self) -> httpx.Client: + """Get the HTTP client, creating if needed.""" + if self._client is None: + self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout) + return self._client + + def _get(self, path: str, params: dict[str, Any] | None = None) -> Any: + """Make GET request.""" + try: + resp = self.client.get(path, params=params) + resp.raise_for_status() + return resp.json() + except httpx.HTTPStatusError as e: + raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e + except httpx.RequestError as e: + raise TsrClientError(f"Request failed: {e}") from e + + def _post(self, path: str, json: dict[str, Any] | None = None) -> Any: + """Make POST request.""" + try: + resp = self.client.post(path, json=json) + resp.raise_for_status() + return resp.json() + except httpx.HTTPStatusError as e: + raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e + except httpx.RequestError as e: + raise TsrClientError(f"Request failed: {e}") from e + + def _delete(self, path: str) -> Any: + """Make DELETE request.""" + try: + resp = self.client.delete(path) + resp.raise_for_status() + return resp.json() + except httpx.HTTPStatusError as e: + raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e + except httpx.RequestError as e: + raise TsrClientError(f"Request failed: {e}") from e + + # ========================================================================= + # Server Status + # ========================================================================= + + def status(self) -> dict[str, Any]: + """Get server status.""" + return dict(self._get("/status")) + + # ========================================================================= + # Gallery / Images + # ========================================================================= + + def list_images(self, limit: int = 50, offset: int = 0) -> dict[str, Any]: + """List images in gallery.""" + return dict(self._get("/api/images", params={"limit": limit, "offset": offset})) + + def get_image_meta(self, image_id: str) -> dict[str, Any]: + """Get metadata for an image.""" + return dict(self._get(f"/api/images/{image_id}/meta")) + + def delete_image(self, image_id: str) -> dict[str, Any]: + """Delete an image.""" + return dict(self._delete(f"/api/images/{image_id}")) + + def edit_image(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any]: + """Update image metadata.""" + return dict(self._post(f"/api/images/{image_id}/edit", json=updates)) + + def download_image(self, image_id: str) -> bytes: + """Download image file bytes.""" + try: + resp = self.client.get(f"/api/images/{image_id}") + resp.raise_for_status() + return resp.content + except httpx.HTTPStatusError as e: + raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e + except httpx.RequestError as e: + raise TsrClientError(f"Request failed: {e}") from e + + # ========================================================================= + # Models + # ========================================================================= + + def list_models(self) -> dict[str, Any]: + """List available models.""" + return dict(self._get("/api/models")) + + def get_active_model(self) -> dict[str, Any]: + """Get currently active model.""" + return dict(self._get("/api/models/active")) + + def switch_model(self, model_path: str) -> dict[str, Any]: + """Switch to a different model.""" + return dict(self._post("/api/models/switch", json={"model": model_path})) + + def list_loras(self) -> dict[str, Any]: + """List available LoRAs.""" + return dict(self._get("/api/models/loras")) + + def scan_models(self) -> dict[str, Any]: + """Scan model directories.""" + return dict(self._get("/api/models/scan")) + + # ========================================================================= + # Generation + # ========================================================================= + + def generate( + self, + prompt: str, + negative_prompt: str = "", + width: int = 512, + height: int = 512, + steps: int = 20, + cfg_scale: float = 7.0, + seed: int = -1, + sampler_name: str = "", + scheduler: str = "", + batch_size: int = 1, + save_to_gallery: bool = True, + return_base64: bool = False, + ) -> dict[str, Any]: + """Generate images.""" + body = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "width": width, + "height": height, + "steps": steps, + "cfg_scale": cfg_scale, + "seed": seed, + "sampler_name": sampler_name, + "scheduler": scheduler, + "batch_size": batch_size, + "save_to_gallery": save_to_gallery, + "return_base64": return_base64, + } + return dict(self._post("/api/generate", json=body)) + + def list_samplers(self) -> dict[str, Any]: + """List available samplers.""" + return dict(self._get("/api/samplers")) + + def list_schedulers(self) -> dict[str, Any]: + """List available schedulers.""" + return dict(self._get("/api/schedulers")) + + # ========================================================================= + # Download + # ========================================================================= + + def start_download( + self, + version_id: int | None = None, + model_id: int | None = None, + hash_val: str | None = None, + output_dir: str | None = None, + ) -> dict[str, Any]: + """Start a model download from CivitAI.""" + body: dict[str, Any] = {} + if version_id: + body["version_id"] = version_id + if model_id: + body["model_id"] = model_id + if hash_val: + body["hash"] = hash_val + if output_dir: + body["output_dir"] = output_dir + return dict(self._post("/api/download", json=body)) + + def get_download_status(self, download_id: str) -> dict[str, Any]: + """Get download status.""" + return dict(self._get(f"/api/download/status/{download_id}")) + + def list_downloads(self) -> dict[str, Any]: + """List active downloads.""" + return dict(self._get("/api/download/active")) + + # ========================================================================= + # Database + # ========================================================================= + + def db_list_files(self) -> list[dict[str, Any]]: + """List local files in database.""" + return list(self._get("/api/db/files")) + + def db_search_models( + self, + query: str | None = None, + model_type: str | None = None, + base_model: str | None = None, + limit: int = 20, + ) -> list[dict[str, Any]]: + """Search cached models.""" + params: dict[str, Any] = {"limit": limit} + if query: + params["query"] = query + if model_type: + params["type"] = model_type + if base_model: + params["base"] = base_model + return list(self._get("/api/db/models", params=params)) + + def db_get_model(self, civitai_id: int) -> dict[str, Any]: + """Get cached model by CivitAI ID.""" + return dict(self._get(f"/api/db/models/{civitai_id}")) + + def db_get_triggers(self, file_path: str | None = None, version_id: int | None = None) -> list[str]: + """Get trigger words.""" + if version_id: + return list(self._get(f"/api/db/triggers/{version_id}")) + if file_path: + return list(self._get("/api/db/triggers", params={"file_path": file_path})) + return [] + + def db_stats(self) -> dict[str, Any]: + """Get database statistics.""" + return dict(self._get("/api/db/stats")) + + def db_scan(self, directory: str) -> dict[str, Any]: + """Scan directory for safetensor files.""" + return dict(self._post("/api/db/scan", json={"directory": directory})) + + def db_link(self) -> dict[str, Any]: + """Link unlinked files to CivitAI.""" + return dict(self._post("/api/db/link")) + + def db_cache(self, model_id: int) -> dict[str, Any]: + """Cache CivitAI model data.""" + return dict(self._post("/api/db/cache", json={"model_id": model_id})) + + # ========================================================================= + # Streaming Downloads + # ========================================================================= + + def stream_image(self, image_id: str) -> Iterator[bytes]: + """Stream image download in chunks.""" + try: + with self.client.stream("GET", f"/api/images/{image_id}") as resp: + resp.raise_for_status() + yield from resp.iter_bytes(chunk_size=1024 * 64) + except httpx.HTTPStatusError as e: + raise TsrClientError(f"HTTP {e.response.status_code}") from e + except httpx.RequestError as e: + raise TsrClientError(f"Request failed: {e}") from e + + def save_image_to(self, image_id: str, dest: Path) -> Path: + """Download and save image to file.""" + content = self.download_image(image_id) + dest.write_bytes(content) + return dest diff --git a/tensors/config.py b/tensors/config.py index c4c530b..89c287b 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -165,3 +165,73 @@ def get_default_output_path(model_type: str | None) -> Path | None: if model_type and model_type in DEFAULT_PATHS: return DEFAULT_PATHS[model_type] return None + + +# ============================================================================ +# Remote Server Configuration +# ============================================================================ + + +def get_remotes() -> dict[str, str]: + """Get configured remote servers. + + Returns a dict mapping names to URLs, e.g., {"junkpile": "http://junkpile:8080"} + """ + config = load_config() + remotes = config.get("remotes", {}) + return dict(remotes) if isinstance(remotes, dict) else {} + + +def get_default_remote() -> str | None: + """Get the default remote name or URL.""" + config = load_config() + return config.get("default_remote") + + +def resolve_remote(remote: str | None) -> str | None: + """Resolve a remote name or URL to a full URL. + + Args: + remote: Remote name (from config), URL, or None + + Returns: + Full URL or None if no remote specified and no default + """ + if remote is None: + # Check for default remote + default = get_default_remote() + if default: + remote = default + else: + return None + + # Check if it's a URL (starts with http:// or https://) + if remote.startswith(("http://", "https://")): + return remote + + # Look up in configured remotes + remotes = get_remotes() + if remote in remotes: + return remotes[remote] + + # Treat as hostname with default port + return f"http://{remote}:8080" + + +def save_remote(name: str, url: str) -> None: + """Save a remote server configuration.""" + config = load_config() + if "remotes" not in config: + config["remotes"] = {} + config["remotes"][name] = url + save_config(config) + + +def set_default_remote(name: str | None) -> None: + """Set the default remote.""" + config = load_config() + if name is None: + config.pop("default_remote", None) + else: + config["default_remote"] = name + save_config(config)