diff --git a/tensors/cli.py b/tensors/cli.py index 52fb815..e0cb436 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -401,7 +401,7 @@ def config( def generate( prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")], host: Annotated[str, typer.Option(help="sd-server address.")] = "127.0.0.1", - port: Annotated[int, typer.Option(help="sd-server port.")] = 1234, + port: Annotated[int, typer.Option(help="sd-server port.")] = 8080, output: Annotated[str, typer.Option("-o", help="Output directory.")] = ".", negative_prompt: Annotated[str, typer.Option("-n", help="Negative prompt.")] = "", width: Annotated[int, typer.Option("-W", help="Image width.")] = 512, @@ -437,6 +437,58 @@ def generate( console.print(f"[green]Saved:[/green] {p}") +@app.command() +def status( + host: Annotated[str, typer.Option(help="Wrapper API host.")] = "127.0.0.1", + port: Annotated[int, typer.Option(help="Wrapper API port.")] = 8080, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Show sd-server wrapper status.""" + import httpx # noqa: PLC0415 + + url = f"http://{host}:{port}/status" + try: + resp = httpx.get(url, timeout=10) + resp.raise_for_status() + data = resp.json() + except httpx.HTTPError as e: + console.print(f"[red]Error: Could not reach wrapper at {url}: {e}[/red]") + raise typer.Exit(1) from e + + if json_output: + console.print_json(data=data) + return + + table = Table(title="Server Status", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + for key, value in data.items(): + table.add_row(key, str(value)) + console.print(table) + + +@app.command() +def reload( + model: Annotated[str, typer.Option(help="Path to model file for sd-server.")], + host: Annotated[str, typer.Option(help="Wrapper API host.")] = "127.0.0.1", + port: Annotated[int, typer.Option(help="Wrapper API port.")] = 8080, +) -> None: + """Reload sd-server with a new model.""" + import httpx # noqa: PLC0415 + + url = f"http://{host}:{port}/reload" + console.print(f"[cyan]Reloading model: {model}[/cyan]") + try: + resp = httpx.post(url, json={"model": model}, timeout=300) + resp.raise_for_status() + data = resp.json() + except httpx.HTTPError as e: + console.print(f"[red]Error: Reload failed at {url}: {e}[/red]") + raise typer.Exit(1) from e + + console.print(f"[green]{data.get('status', 'OK')}[/green]") + + @app.command() def serve( model: Annotated[str, typer.Option(help="Path to model file for sd-server.")], @@ -464,7 +516,7 @@ def main() -> int: # Handle legacy invocation: tsr -> tsr info if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): arg = sys.argv[1] - if arg not in ("info", "search", "get", "dl", "download", "config", "generate", "serve") and ( + if arg not in ("info", "search", "get", "dl", "download", "config", "generate", "serve", "status", "reload") and ( arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists() ): sys.argv = [sys.argv[0], "info", *sys.argv[1:]]