From 7a3735adbad92712d4f64d505d856cc2bf618283 Mon Sep 17 00:00:00 2001 From: Adam Ladachowski Date: Tue, 3 Feb 2026 21:03:22 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20Commit=20message:=20Update=20202?= =?UTF-8?q?6-02-03=2021:03:22,=202=20files,=20806=20lines?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 📁 Files changed: 2 📝 Lines changed: 806 • .coverage • tensors.py --- .coverage | Bin 53248 -> 53248 bytes tensors.py | 804 +++++++++++++++++++++++++++++++---------------------- 2 files changed, 469 insertions(+), 335 deletions(-) diff --git a/.coverage b/.coverage index 1d9f032775d221623ed27a198cb375aaa9883980..78fc8c4bced27403f6ca750009c47b13d4c638e0 100644 GIT binary patch literal 53248 zcmeI43v^V~y~h80&dk}f&odB0fCw3=5RjKZMP34;gzyG=Q%N!j6G%)lI5Po!ds5}8 zR8gs-LIqza@hzgYYPkZw@YU9~6oIQ%K`pIXy%o{i|DJs|2dwnguDezjU3(S@zi-bz zXZHE#obSwMB2y-htIi~1RmoIcJQM2&IY0;ukHr9h;g^kH@<9>8QsCd@7(Nj)=ymxl z+no-UJ`>zhyPez9nrD4t-ewJ8)68nxgZ08E)I&L-98eA@2mWIm@FrU{zinI6cv&W1 zQJcu5;+2V%_->zb-pI43j*Lw`d&IbrF)iT)H%>0C$-H=Yq<@YQc!@Cxr*ekm|r?Tp~xK$=G4|lO3Rb3ZPT@ z#nmUJt-*%;rE%lzn(G_tD$w$phPnm(&A~0i3ulri+VVu(W|IEF{?qTxcf6Ou@gDbf zO3V2>Dz<82s{Hin|l7hTmpDjCKjhJti_K`W$}hg zGONq*%`fYN->j=zL}`9uAz6B1_C(?KlvPw`(no^L#PZE9= z)z-$6RoP31ZyghF4xY`&UNgO=8Sxaq&|*n{BChn`dbaV!;u>NVwaJR(&$a9slTP3T zf>|J^RW;OC@=GPnkD)TTFp-MSOZ4)e6|wN|K9TU$>BIE=k%cEZb&j3_WqppK;kDIi zenA0gWZBcmf4BU<=y9izd?&vnM&(m=iFmppC4=N2g{0gfNJ)8(_|M?s^rU+J*###$ z9DiO}A0MI~4b%L*JmQHfg69(diX3+kd?dI8zI7w`=a;`Ao|zx}`?d37TzFAL&mWw3 zqJ{G%l=b1?^{aA1*(ER2vhLSm=<_e_xFz$EV98?>7xi2iuf;T35U)<*C5ss>?f@}9 zac&$Z$19Qzne6$?W%l&s_nu(AVh4ei-s7xp+39$BPfDiK)p%Rwx2^ObIT|aPA5ZfX zN*Bc{lF8acyk6WJaEsAte9h`~EW3ukzbg$3_L07b?_=J=qU@UF&4tJORJ`cq^@UJZ z7q6e|->>*>SnL(EpI@B(bwr5ZhOr{NDARa}erum{r~b1%vE=@HEkIBBUhCp>#pTZ5 zlwye=f*RMOm$r%#;0T#lRdE(#nZ9MR05y$@1@y=p=0yEy(+Pr5`0Ex zs5qK#McO|)clcXe)`x%REjh|R%6Wx^=6CN-yaxXY7U5Ku#!muuae1u13>|l7ee22K zB%6pSjUSPYJ-N8y$9Tx4cmulmjQF16D|D@5)~b>61gWg0}we0G!#YPrqLM3_PlkdjQ-6?&l}v0jfnS2b2TK z0p);lKslfsP!1>ulmp5E<$!WPIdIegOKVT$e+CE{S^@L_D}ea@|KEfA`=e78CE7w8 zOz%vhKAlXZ@gG^>KmRB0Zooh4p&U>SCulmp+#0S$jcz!pFMKLYNN z@3X0zs~k`cC|lmp5E z<$!WPIiMU+4k!nd1IhvAfO0@N@UJ+)fBvuf|G(m#+D18`98eA@2b2TK0p);lKslfs zP!1>ulmjP^1NifQQmEYk?&t1@?mqX|?u+ho?i20jrC?b(xj2=2_=kldWQFuywlC z*=lFGR@nT~JY>FW{?6QGZZn@SH=B2uYs{7Am1cukW0sjy&C%viv!B`3Y;U$Sb>pyc z$au$i%XrCn#(2cI+gNL?G8&DAMvYNwoM#jpgN$BAC!@7tuy5F>?0vSIy~?(+E$lwF zj;&@(*u|`#Rj}!73>(V&vM#J0b66<)MfAhy-sqdr9nq(v4@5Ua*F;xDFOR08Rnghe z3DFVJ{?XH-9ilCxIz3De(Rb)u^dE-bHVxJ9;ScMdZWC-pHGg9g(LZ z4@5RZ)1s zzAlk`7i$E`8<=;6%#& z@K9|S^J*pI_yvN`(uOjxM(|)hzgqAh&hrHi;9Etxk;a7zuxizVZFMrJwmdT0mp z&K3L`=W_(_;ygm|E1}nzceZ34A1?Ui&@0RvCiq1@f2iOWLN7D#EXg=NMDR{NXRzQM z8m?`S&!N|uH&F5}<_(aH`#)3ii_GgU`32^kA$ceBekgf6^ZE(i9>O_&CFAS%5xgpd zb50jLGPH_$y#?omMl!EgmPvC6{p>0D5NT#!55ez|L(J*0PTq8)B)Ud%+Y!>x!S*kLW)%0#ZiEwGk}ScjY8 zHYQRX@fIdR9d5$SM5e=ya5EE;j(8Iji4HfwjZ6eOtcDwy$aA^gG z-5Pps7HX4}=Qbg=b(DuTAvGz=6Pu8VjdXYx;uVx{y$^)5@snaP>V?w6Pp*)Harc$26 zgiP8>c?c6SaT4VTOvr?Zl*cb2<0nv_y@ZS#PkHbXGIkv0sY}S{v6M$HAth5N&s&7i zl!q-L!%HYnT0({mr#xl}88VO-Wnu6T+9eBv22-A{gbW-;d9)JJe-P!lN=U!{l!q!I zefv?KsDyOyOL?3U(ya^SSxQLPZj=WpAw^v&Pfm{l!qiCwnKSB5@Oku z$0H%8MR_(7VwjW%BO%P7JQWFvGRh;75IstH9ugALDGx(}_b$qlkdUnu%2NT)XJPogB5q-@=F-S!6H5b4K`15~K7;51C|DgM} z`-Z#S{ki+FdyjiNw*Tes6>fujfg5*cxZ~Y(u>1FRJG%vL3zs@yJD)iRvHAbTdD;1; z^MupnY;u=V_*!X{Kyk z`Pla-S?5{E@|s z8`Cnu_^WZi*kinIJa23@HXA=PZZwt}7aMg(xiJNs{y?LLQE0S8^Z&*^Vtd&e?0NPC z_Waw~b?nD%A**KRvx#g3JA-v)1=#XQ^smtmqPwHJqR&Mijoux-C3;Qt3hela=*;NY z=+J1N=&8~Cs6oG`f1&SU!+()JO&_Ei=uLDPy@b}&IGsvIVZZN5V>FlQ`j`60`abt}a$lA!Ykt-u< zZ1%Gv<08W%eIuPCZ6ap)oA4*$KZbvcz5W;Bhr%1fYr@OJmxhz!itw~>Nq8`}`i|jN zVX7V0{-o{Kex>cup42vL>$TO|)!Jh10xeA{c3xshfu$E?!S?`LFEOOR+Dn)e*n3G-3M{^aN`cLn=z(Y=kw9!F;XrI6S|IKv zp+MY)5#_I7owk8MAl5^3AZ~|mq`(RcUk73x9FYP`FnlEiwqQ6c1=e8rQVQ(B@HZ*2 z2*dwKflU~`kOHePd>)86{525e@L3?r;L|{q!Y6^41AmbM3o-mz3T(u1ND8dP@TWiw zgpZ}bQVbtSfvp%mlmcrp9FzikF&qfQaQGk)!{Gfu42AasF&O?Bh(YjfAO^rYQeZuX z{|>}J*e?YZWY{MKHe}c<1y*EuI}o|BM+z*-@P|O;!fq+BCd2Qgz@7}hlLCt}{8kEV z%J7yHSe4;7QeanxUrT{y8GaRrTzFFotjq95AUeY9fr!CtfhdGm1JMC?1)>062}B!s zIS_g9QXpEvi-BkfFGzvK8FmID7k2o9S7+ER1$Jk6UJ5MF@Ly73dxmX+*a5$k0{b&O zCj}O0cvcE*(C|zkcEB&Bzzz*hOMxXCetuMhUWcawu?wCI#4E5h5HG?Lfp`J71Y##V z9*FJmm=sv1;n6_sfJdajIt>p?fqfbtk^&1gJm?GFsNn%Auu?;l6xgX@vlLjW;eIKw zRl|K!V6BEtQedx!djqiq?vVnUHQX%)R%^IR3hdT!XCPL>#z3rq4S`q=cLZV?{45Yl zVZ9XCuwk7PSh3-DDX?S1ZGk9;TcyC34Qm5Y47W&uJsWNgL^1qS3T)c2CJ@DNlN8vs z;l@A|!wpgtljX2F5aZ$cK#YUy0x=eT5{NOdDiEXL+CY@RHGvofE2Y5J4J)L;+6~Kn z!Fx9>lLCu3ER_PAH!L|SaN^a0I09Emf#n;1ECsf2Xp{o$H+WKD|As5G9}hx=H}ESY zV+FrF;DeV*#u9$1WNhJ=NX8m|v1IJw7fHqv&2s_VEiPV+_*7pv%BQ;eVo7U8`c#`dS5ne{hO@n=!HG7axGiMF-X~xU}K24u-rcYC+ z_xEYaoHHa%{h?2jw)XRB;-tPlO_E2#a(J4N~x)k`-I@Zpoyw+`fYL(Z{@t?r++ zy!%HaWBCtD#`dpC#`+)PCw3^rc>h=b|Nr-^u4=Y&KslfsP!1>ulmp5E<$!WPIiMU+ L4k!msMhE^E29YIL literal 53248 zcmeI44QyP+oxo>y9na4Dbg-R`<0gJi;@ZYvbrQ$ml!ic35(toBn|#M!uh;g*_HMHK z{FT@an*`Dn2!$5f&~lW1NvH&@-l4YY^&~_~rJkA!q3VunPo>pWS6U9#p15AGMY;dH znemH(^h$M7o9e%^H^2XTZ{E)Q-pv2)-|M(@=k~Op@&>ZG(S+}hnzIe?y=oFiB-ItxCiZJZj&`={mT4|bq((})2xvi`TX3hV5HCxy`FHJa{xUd7Hdk zcEUTH%6S9nk(8It^rw>vKb;x${GpUSdMuyn7xygK26ruSU~|x3P*1uaTKTC#IK|;y zdNh%n@`h7Wjllux)%?_?FE_wRQt82rIKrz_&Fj5fY9N(MWs<49>_uI=zuu@ndE8*} z(o%Xl6^tp7OTw4AU^K8H9LlxleN6%Ox_&p;W%nOU#cjnT38%`f{|n>}ZRh z^+wW}V3?VF+D~UQUTQLx9P?BC?>fLnAx%F(-ld~~bM=Y=^~#p&wJxy}Hh!0u>TuO~ z06Le+YOP^I@zT&Z=yPUlv=5XY8XG+%f?_BnQWsc1_3jm#w?CsTS^Jc%G6~3$Oyc3>Pd|G*&yzL zi{H}xfaZrdqf9aMW=b1v z<+LDIgvd+cd(nbHh%Mm~c-xKOkMI8Bgg@lHv+rUvbi8&^aePDhB?*N`Sus;C@C;GBuelUMU z7)(!b?+N=WPY}l1JsedZ91hESM>d~N!);aEw(=c%t5-Xe$crqLul4$}*^yKtBkv8+ zVp|?wGoAN>KD_g;v@twKeoCBUaJ)9?lfJo#G~We_POmS3(a}VvU)`_7ZCIWaa=%!d z;&o&afref!EXq7A(YKvbpH%G?8B3qP`7k&UXB|!S%jGUE$|Yx<7)rh4bn<^qAcFfKnr7P1*k+MH;Xo(dDlGwUE{l8m+8QzXQ@0vbRAXaEhM0W^RH&;S}h184vZpaC>+r5LE8QL?k8I}d+& zp(Qz#ZrPE`9!e$sd`o`7Z}C%^d^VSdKe8Y`|EKQ15ct9e8bAYR01co4G=K)s02)98 zXaEhM0W@&M8K{Zw;FH+9JWCNXaEhM0W^RH&;S}h184vZ zpaC>++yR+_?JK`qYz3xtT ztGm%%?>4!s-AeZ=m%GIIjq|$mGv`OnOU`$l=bW!O&p2Ol9&3+6Y>IrEF=qviwVDRauqm`QWDd8>JY*={zP)n>WL zjkkX*^~;Xxwc~8KXwR=r*<)*BPxwjZtCP2IX(^SNXs47x}mO z*ZFz=1b-M5;sXt!0W^RH&;S}h1Mks5tWdp(^9YTexX25hz(_O_J;4iA0Y~15MtEU` zYp?MWA zaLp@uA*Q*S7no*`7m77k@xmg_D|n$u^0K#HUkyr#cEHO%p_#usirzJ+0H%+#nC=2b;1v zwlRQAmVGFI;dAV|01lmF9}FNp%r;0k#5w{P%&=<%=$~WP1dyC#?E&;9+4=wyeXK2j z-hQ@DLV~Rg;J_SP6Ttpn)+*ru`#=Eu_Oq4%_U>cN0qoh!ngZCphcyP!J;xdX=-SP$ zme9@W1L*8xbph-cWVHcw>|pN?p#1`?31EF2TOB~#dbUc!1-3GPb?r>#QA*dfF_A?n zUDLrt3Z=BUfr9zbIg6Pc6JhBZv2O-kz;naG)x*3>bPFezPK!$h{Ew0boY zsglyFY9{g|rR7yjBuPrk%9+TJl$Mq;ksc{sT*^dlq_ku)6N!;hw}gqTNU7s8krFAj z9VYT2rIyV^GNjbBn8<{b8YUBIkWy|ikpn4>aV8QVrNuEOvLB_9VkS}_r57Vi1^T$+h=N9kFg%UpMHf}G(p)xGsvWQq9v|B2WyME(D&`(yV9?mxQE zxnFh9yPt>Z{{i=I_n3Rc9dr-4o$fZM{oC9Ix7v-nwi|VR>%0M#{|}uPo$ol`bpF

V{|VLp56$nJ&zoO0zid8ke%w559y7C0>+dpeF*llP&04e4 zbj_&o8{?P8kByg%?-<`Oo-saeJYt+Nrj4x82bKO-V}sFbR2yXmw&fz$2%_po-iR1}n(jI)K`noYDal-{hnY zp!y~!!XQoxI)M6{+!b6nxN`Bx(&L(;1fSImHTaBXD8i>TLlu5ZGnC;+HA5Xfr5OtG zNzG7+PlRlbhir|7Z2BP^`H*=|v8cz7XoiCPux6;pvznnK&uE64d{i?O}$bK2KB|nMbW-p(e&|KG|k+s=;-t|MU(zkMH5GF(lq&D zMdK4&)TyC(->f$J<2PzLdxP2y_4_8ZF`K(yQ6{@l(eSwsDLQoSIz{Q>4{AEJLD68Q zLs9?SwThB+*C^^swkt~Xtyk3B-=-?5bC^W3Wz9$BtS>?HAs!Xnk9aqPF#`HCNbR z)^vD^nj5MVH8rnL)Y!CKQNx<|DXMF%R8&*9OwsC^3Pshcmny2Njw>p!TB4||yj)Rf zS(&26rKO5W7GI^vEm^F{aZ420j;qMB9Yv;PD>6(=5jRXlF>WX-j&Vg1Ksjfj@>g_W zU9qA$_-U5hdoEq1=&T=AlV6nokzjp9xTyc}|NsBRy-~pqp#e022G9T+Km%w14WI!u NfCkV28u))V@V~yaXd(ar diff --git a/tensors.py b/tensors.py index 9c148c8..29402a6 100644 --- a/tensors.py +++ b/tensors.py @@ -1,11 +1,10 @@ #!/usr/bin/env python3 """ -sft-get: Read safetensor metadata and fetch CivitAI model information. +tsr: Read safetensor metadata, search and download CivitAI models. """ from __future__ import annotations -import argparse import hashlib import json import os @@ -13,10 +12,12 @@ import re import struct import sys import tomllib +from enum import Enum from pathlib import Path -from typing import Any +from typing import Annotated, Any import httpx +import typer from rich.console import Console from rich.progress import ( BarColumn, @@ -30,8 +31,21 @@ from rich.progress import ( ) from rich.table import Table +# ============================================================================ +# App and Console Setup +# ============================================================================ + +app = typer.Typer( + name="tsr", + help="Read safetensor metadata, search and download CivitAI models.", + no_args_is_help=True, +) console = Console() +# ============================================================================ +# Configuration +# ============================================================================ + # XDG Base Directory spec: ~/.config/tensors/config.toml CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors" CONFIG_FILE = CONFIG_DIR / "config.toml" @@ -46,6 +60,48 @@ DEFAULT_PATHS: dict[str, Path] = { "LoCon": Path.home() / ".xm" / "models" / "loras", } +CIVITAI_API_BASE = "https://civitai.com/api/v1" +CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" + + +# ============================================================================ +# Enums for CLI +# ============================================================================ + + +class ModelType(str, Enum): + """CivitAI model types.""" + + checkpoint = "Checkpoint" + lora = "LORA" + embedding = "TextualInversion" + vae = "VAE" + controlnet = "Controlnet" + locon = "LoCon" + + +class BaseModel(str, Enum): + """Common base models.""" + + sd15 = "SD 1.5" + sdxl = "SDXL 1.0" + pony = "Pony" + flux = "Flux.1 D" + illustrious = "Illustrious" + + +class SortOrder(str, Enum): + """Sort options for search.""" + + downloads = "Most Downloaded" + rating = "Highest Rated" + newest = "Newest" + + +# ============================================================================ +# Config Functions +# ============================================================================ + def load_config() -> dict[str, Any]: """Load configuration from TOML config file.""" @@ -107,8 +163,9 @@ def get_default_output_path(model_type: str | None) -> Path | None: return None -CIVITAI_API_BASE = "https://civitai.com/api/v1" -CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" +# ============================================================================ +# Safetensor Functions +# ============================================================================ def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: @@ -169,17 +226,36 @@ def compute_sha256(file_path: Path) -> str: return sha256.hexdigest().upper() +def get_base_name(file_path: Path) -> str: + """Get base filename without .safetensors extension.""" + name = file_path.name + for ext in (".safetensors", ".sft"): + if name.lower().endswith(ext): + return name[: -len(ext)] + return file_path.stem + + +# ============================================================================ +# CivitAI API Functions +# ============================================================================ + + +def _get_headers(api_key: str | None) -> dict[str, str]: + """Get headers for CivitAI API requests.""" + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + def fetch_civitai_model_version( version_id: int, api_key: str | None = None ) -> dict[str, Any] | None: """Fetch model version information from CivitAI by version ID.""" url = f"{CIVITAI_API_BASE}/model-versions/{version_id}" - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" try: - response = httpx.get(url, headers=headers, timeout=30.0) + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) if response.status_code == 404: return None response.raise_for_status() @@ -196,9 +272,6 @@ def fetch_civitai_model_version( def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, Any] | None: """Fetch model information from CivitAI by model ID.""" url = f"{CIVITAI_API_BASE}/models/{model_id}" - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" with Progress( SpinnerColumn(), @@ -209,7 +282,7 @@ def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, progress.add_task("[cyan]Fetching model from CivitAI...", total=None) try: - response = httpx.get(url, headers=headers, timeout=30.0) + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) if response.status_code == 404: return None response.raise_for_status() @@ -226,9 +299,6 @@ def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None: """Fetch model information from CivitAI by SHA256 hash.""" url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" with Progress( SpinnerColumn(), @@ -239,7 +309,7 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[ progress.add_task("[cyan]Fetching from CivitAI...", total=None) try: - response = httpx.get(url, headers=headers, timeout=30.0) + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) if response.status_code == 404: return None response.raise_for_status() @@ -253,16 +323,77 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[ return None +def search_civitai( + query: str | None = None, + model_type: ModelType | None = None, + base_model: BaseModel | None = None, + sort: SortOrder = SortOrder.downloads, + limit: int = 20, + api_key: str | None = None, +) -> dict[str, Any] | None: + """Search CivitAI models.""" + params: dict[str, Any] = { + "limit": min(limit, 100), + "nsfw": "true", + } + + # API quirk: query + filters don't work reliably together + # If we have filters, skip query and filter client-side + has_filters = model_type is not None or base_model is not None + + if query and not has_filters: + params["query"] = query + + if model_type: + params["types"] = model_type.value + + if base_model: + params["baseModels"] = base_model.value + + params["sort"] = sort.value + + # Request more if we need client-side filtering + if query and has_filters: + params["limit"] = 100 + + url = f"{CIVITAI_API_BASE}/models" + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Searching CivitAI...", total=None) + + try: + response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0) + response.raise_for_status() + result: dict[str, Any] = response.json() + + # Client-side filtering when query + filters combined + if query and has_filters: + q_lower = query.lower() + result["items"] = [ + m for m in result.get("items", []) if q_lower in m.get("name", "").lower() + ][:limit] + + return result + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + def download_model( version_id: int, dest_path: Path, api_key: str | None = None, resume: bool = True, ) -> bool: - """Download a model from CivitAI by version ID with resume support. - - Returns True on success, False on failure. - """ + """Download a model from CivitAI by version ID with resume support.""" url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}" params: dict[str, str] = {} if api_key: @@ -286,20 +417,17 @@ def download_model( params=params, headers=headers, follow_redirects=True, - timeout=httpx.Timeout(30.0, read=None), # No read timeout for large files + timeout=httpx.Timeout(30.0, read=None), ) as response: - # Handle 416 Range Not Satisfiable (file already complete) if response.status_code == 416: console.print("[green]File already fully downloaded.[/green]") return True response.raise_for_status() - # Get total size from Content-Length or Content-Range content_length = response.headers.get("content-length") total_size = int(content_length) + initial_size if content_length else 0 - # Get filename from Content-Disposition if available content_disp = response.headers.get("content-disposition", "") if "filename=" in content_disp: match = re.search(r'filename="?([^";\n]+)"?', content_disp) @@ -323,7 +451,7 @@ def download_model( ) with dest_path.open(mode) as f: - for chunk in response.iter_bytes(1024 * 1024): # 1MB chunks + for chunk in response.iter_bytes(1024 * 1024): f.write(chunk) progress.update(task, advance=len(chunk)) @@ -340,6 +468,29 @@ def download_model( return False +# ============================================================================ +# Display Functions +# ============================================================================ + + +def _format_size(size_kb: float) -> str: + """Format size in KB to human-readable string.""" + if size_kb < 1024: + return f"{size_kb:.0f} KB" + if size_kb < 1024 * 1024: + return f"{size_kb / 1024:.1f} MB" + return f"{size_kb / 1024 / 1024:.2f} GB" + + +def _format_count(count: int) -> str: + """Format large numbers with K/M suffix.""" + if count < 1000: + return str(count) + if count < 1_000_000: + return f"{count / 1000:.1f}K" + return f"{count / 1_000_000:.1f}M" + + def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None: """Display file information table.""" file_table = Table(title="File Information", show_header=True, header_style="bold magenta") @@ -398,24 +549,18 @@ def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None: civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A"))) civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A"))) - # Trained words trained_words: list[str] = civitai_data.get("trainedWords", []) if trained_words: civit_table.add_row("Trigger Words", ", ".join(trained_words)) - # Download URL download_url = str(civitai_data.get("downloadUrl", "N/A")) civit_table.add_row("Download URL", download_url) - # File info from CivitAI files: list[dict[str, Any]] = civitai_data.get("files", []) for f in files: if f.get("primary"): civit_table.add_row("Primary File", str(f.get("name", "N/A"))) - civit_table.add_row( - "File Size (CivitAI)", - f"{f.get('sizeKB', 0) / 1024:.2f} MB", - ) + civit_table.add_row("File Size (CivitAI)", _format_size(f.get("sizeKB", 0))) meta: dict[str, Any] = f.get("metadata", {}) if meta: civit_table.add_row("Format", str(meta.get("format", "N/A"))) @@ -425,7 +570,6 @@ def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None: console.print() console.print(civit_table) - # Model page link model_id = civitai_data.get("modelId") if model_id: console.print() @@ -436,7 +580,6 @@ def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None: def _display_model_info(model_data: dict[str, Any]) -> None: """Display full CivitAI model information.""" - # Main model info table model_table = Table(title="Model Information", show_header=True, header_style="bold magenta") model_table.add_column("Property", style="cyan") model_table.add_column("Value", style="green", max_width=80) @@ -446,17 +589,14 @@ def _display_model_info(model_data: dict[str, Any]) -> None: model_table.add_row("Type", str(model_data.get("type", "N/A"))) model_table.add_row("NSFW", str(model_data.get("nsfw", False))) - # Creator info creator = model_data.get("creator", {}) if creator: model_table.add_row("Creator", str(creator.get("username", "N/A"))) - # Tags tags: list[str] = model_data.get("tags", []) if tags: model_table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else "")) - # Stats stats: dict[str, Any] = model_data.get("stats", {}) if stats: model_table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}") @@ -465,7 +605,6 @@ def _display_model_info(model_data: dict[str, Any]) -> None: "Rating", f"{stats.get('rating', 0):.1f} ({stats.get('ratingCount', 0)} ratings)" ) - # Mode (archived/taken down) mode = model_data.get("mode") if mode: model_table.add_row("Status", str(mode)) @@ -473,7 +612,6 @@ def _display_model_info(model_data: dict[str, Any]) -> None: console.print() console.print(model_table) - # Versions table versions: list[dict[str, Any]] = model_data.get("modelVersions", []) if versions: ver_table = Table(title="Model Versions", show_header=True, header_style="bold magenta") @@ -488,15 +626,12 @@ def _display_model_info(model_data: dict[str, Any]) -> None: primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) file_info = "" if primary_file: - size_kb = primary_file.get("sizeKB", 0) - size_str = ( - f"{size_kb / 1024:.0f} MB" - if size_kb < 1024 * 1024 - else f"{size_kb / 1024 / 1024:.1f} GB" + file_info = ( + f"{primary_file.get('name', 'N/A')} " + f"({_format_size(primary_file.get('sizeKB', 0))})" ) - file_info = f"{primary_file.get('name', 'N/A')} ({size_str})" - created = str(ver.get("createdAt", "N/A"))[:10] # Just date portion + created = str(ver.get("createdAt", "N/A"))[:10] ver_table.add_row( str(ver.get("id", "N/A")), str(ver.get("name", "N/A")), @@ -508,7 +643,6 @@ def _display_model_info(model_data: dict[str, Any]) -> None: console.print() console.print(ver_table) - # Model page link model_id = model_data.get("id") if model_id: console.print() @@ -517,82 +651,94 @@ def _display_model_info(model_data: dict[str, Any]) -> None: ) -def display_results( - file_path: Path, - local_metadata: dict[str, Any], - sha256_hash: str, - civitai_data: dict[str, Any] | None, +def _display_search_results(results: dict[str, Any]) -> None: + """Display search results in a table.""" + items = results.get("items", []) + if not items: + console.print("[yellow]No results found.[/yellow]") + return + + table = Table(show_header=True, header_style="bold magenta") + table.add_column("ID", style="cyan", justify="right") + table.add_column("Name", style="green", max_width=40) + table.add_column("Type", style="yellow") + table.add_column("Base", style="blue") + table.add_column("Size", justify="right") + table.add_column("DLs", justify="right") + table.add_column("Rating", justify="right") + + for model in items: + model_id = str(model.get("id", "")) + name = model.get("name", "N/A") + if len(name) > 40: + name = name[:37] + "..." + model_type = model.get("type", "N/A") + + # Get latest version info + versions = model.get("modelVersions", []) + base_model = "N/A" + size = "N/A" + if versions: + latest = versions[0] + base_model = latest.get("baseModel", "N/A") + files = latest.get("files", []) + primary = next((f for f in files if f.get("primary")), files[0] if files else None) + if primary: + size = _format_size(primary.get("sizeKB", 0)) + + stats = model.get("stats", {}) + downloads = _format_count(stats.get("downloadCount", 0)) + rating = f"{stats.get('rating', 0):.1f}" + + table.add_row(model_id, name, model_type, base_model, size, downloads, rating) + + console.print() + console.print(table) + + metadata = results.get("metadata", {}) + total = metadata.get("totalItems", len(items)) + console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]") + console.print("[dim]Use 'tsr get ' to view details or 'tsr dl -m ' to download[/dim]") + + +# ============================================================================ +# CLI Commands +# ============================================================================ + + +@app.command() +def info( + file: Annotated[Path, typer.Argument(help="Path to the safetensor file")], + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + skip_civitai: Annotated[ + bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup") + ] = False, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, + save_to: Annotated[ + Path | None, typer.Option("--save-to", help="Save metadata to directory") + ] = None, ) -> None: - """Display results in rich tables.""" - _display_file_info(file_path, local_metadata, sha256_hash) - _display_local_metadata(local_metadata) - _display_civitai_data(civitai_data) - - -def get_base_name(file_path: Path) -> str: - """Get base filename without .safetensors extension.""" - name = file_path.name - for ext in (".safetensors", ".sft"): - if name.lower().endswith(ext): - return name[: -len(ext)] - return file_path.stem - - -def save_metadata( - file_path: Path, - sha256_hash: str, - local_metadata: dict[str, Any], - civitai_data: dict[str, Any] | None, - output_dir: Path, -) -> tuple[Path, Path]: - """Save metadata JSON and SHA256 hash to the specified output directory.""" - base_name = get_base_name(file_path) - - # Save JSON metadata - json_path = output_dir / f"{base_name}-xm.json" - output = { - "file": str(file_path), - "sha256": sha256_hash, - "header_size": local_metadata["header_size"], - "tensor_count": local_metadata["tensor_count"], - "metadata": local_metadata["metadata"], - "civitai": civitai_data, - } - json_path.write_text(json.dumps(output, indent=2)) - - # Save SHA256 hash - sha_path = output_dir / f"{base_name}-xm.sha256" - sha_path.write_text(f"{sha256_hash} {file_path.name}\n") - - return json_path, sha_path - - -def cmd_info(args: argparse.Namespace) -> int: - """Handle the info subcommand (default behavior).""" - file_path: Path = args.file.resolve() + """Read safetensor metadata and fetch CivitAI info.""" + file_path = file.resolve() if not file_path.exists(): console.print(f"[red]Error: File not found: {file_path}[/red]") - return 1 + raise typer.Exit(1) if file_path.suffix.lower() not in (".safetensors", ".sft"): console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]") try: - # Read local metadata console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}") local_metadata = read_safetensor_metadata(file_path) - - # Compute SHA256 sha256_hash = compute_sha256(file_path) - # Fetch from CivitAI civitai_data = None - if not args.skip_civitai: - api_key = args.api_key or load_api_key() - civitai_data = fetch_civitai_by_hash(sha256_hash, api_key) + if not skip_civitai: + key = api_key or load_api_key() + civitai_data = fetch_civitai_by_hash(sha256_hash, key) - if args.json_output: + if json_output: output = { "file": str(file_path), "sha256": sha256_hash, @@ -603,296 +749,284 @@ def cmd_info(args: argparse.Namespace) -> int: } console.print_json(data=output) else: - display_results(file_path, local_metadata, sha256_hash, civitai_data) + _display_file_info(file_path, local_metadata, sha256_hash) + _display_local_metadata(local_metadata) + _display_civitai_data(civitai_data) + + if save_to: + output_dir = save_to.resolve() + if not output_dir.exists() or not output_dir.is_dir(): + console.print(f"[red]Error: Invalid directory: {output_dir}[/red]") + raise typer.Exit(1) + + base_name = get_base_name(file_path) + json_path = output_dir / f"{base_name}-xm.json" + sha_path = output_dir / f"{base_name}-xm.sha256" + + output = { + "file": str(file_path), + "sha256": sha256_hash, + "header_size": local_metadata["header_size"], + "tensor_count": local_metadata["tensor_count"], + "metadata": local_metadata["metadata"], + "civitai": civitai_data, + } + json_path.write_text(json.dumps(output, indent=2)) + sha_path.write_text(f"{sha256_hash} {file_path.name}\n") - # Save files if requested - if args.save_to: - output_dir: Path = args.save_to.resolve() - if not output_dir.exists(): - console.print(f"[red]Error: Output directory not found: {output_dir}[/red]") - return 1 - if not output_dir.is_dir(): - console.print(f"[red]Error: Not a directory: {output_dir}[/red]") - return 1 - json_path, sha_path = save_metadata( - file_path, sha256_hash, local_metadata, civitai_data, output_dir - ) console.print() console.print(f"[green]Saved:[/green] {json_path}") console.print(f"[green]Saved:[/green] {sha_path}") - return 0 - except ValueError as e: console.print(f"[red]Error reading safetensor: {e}[/red]") - return 1 - except Exception as e: - console.print(f"[red]Unexpected error: {e}[/red]") - return 1 + raise typer.Exit(1) from e + + +@app.command() +def search( + query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None, + model_type: Annotated[ + ModelType | None, typer.Option("-t", "--type", help="Model type filter") + ] = None, + base: Annotated[ + BaseModel | None, typer.Option("-b", "--base", help="Base model filter") + ] = None, + sort: Annotated[ + SortOrder, typer.Option("-s", "--sort", help="Sort order") + ] = SortOrder.downloads, + limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, +) -> None: + """Search CivitAI models.""" + key = api_key or load_api_key() + + results = search_civitai( + query=query, + model_type=model_type, + base_model=base, + sort=sort, + limit=limit, + api_key=key, + ) + + if not results: + console.print("[red]Search failed.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=results) + else: + _display_search_results(results) + + +@app.command() +def get( + id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")], + version: Annotated[ + bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID") + ] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Fetch model information from CivitAI by model ID or version ID.""" + key = api_key or load_api_key() + + if version: + # Fetch by version ID + version_data = fetch_civitai_model_version(id_value, key) + if not version_data: + console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=version_data) + else: + _display_civitai_data(version_data) + else: + # Fetch by model ID + model_data = fetch_civitai_model(id_value, key) + if not model_data: + console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=model_data) + else: + _display_model_info(model_data) def _resolve_version_id( version_id: int | None, - sha256_hash: str | None, + hash_val: str | None, model_id: int | None, api_key: str | None, ) -> int | None: - """Resolve version ID from hash or model ID if needed.""" + """Resolve version ID from hash or model ID.""" if version_id: return version_id - if sha256_hash: - console.print(f"[cyan]Looking up model by hash: {sha256_hash[:16]}...[/cyan]") - civitai_data = fetch_civitai_by_hash(sha256_hash.upper(), api_key) + + if hash_val: + console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]") + civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key) if not civitai_data: console.print("[red]Error: Model not found on CivitAI for this hash.[/red]") return None - vid = civitai_data.get("id") + vid: int | None = civitai_data.get("id") if vid: - console.print(f"[green]Found model version:[/green] {civitai_data.get('name', 'N/A')}") - else: - console.print("[red]Error: Could not determine version ID from CivitAI response.[/red]") + console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}") return vid + if model_id: console.print(f"[cyan]Looking up model {model_id}...[/cyan]") model_data = fetch_civitai_model(model_id, api_key) if not model_data: - console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]") + console.print(f"[red]Error: Model {model_id} not found.[/red]") return None - versions: list[dict[str, Any]] = model_data.get("modelVersions", []) + versions = model_data.get("modelVersions", []) if not versions: console.print("[red]Error: Model has no versions.[/red]") return None - # First version is the latest latest = versions[0] - vid = latest.get("id") - if vid: - console.print( - f"[green]Found latest version:[/green] {latest.get('name', 'N/A')} (ID: {vid})" - ) - return vid + latest_vid: int | None = latest.get("id") + if latest_vid: + name = latest.get("name", "N/A") + console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})") + return latest_vid + return None -def cmd_download(args: argparse.Namespace) -> int: - """Handle the download subcommand.""" - api_key: str | None = args.api_key or load_api_key() - - # Resolve version ID from hash or model ID if needed - version_id = _resolve_version_id( - args.version_id, args.hash, getattr(args, "model_id", None), api_key - ) - if not version_id: - if not args.version_id and not args.hash and not getattr(args, "model_id", None): - console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") - return 1 - - # Fetch version info to get filename and model type - console.print(f"[cyan]Fetching model info for version {version_id}...[/cyan]") - version_info = fetch_civitai_model_version(version_id, api_key) - - if not version_info: - console.print("[red]Error: Could not fetch model version info.[/red]") - return 1 - - # Determine model type for default path - model_type: str | None = version_info.get("model", {}).get("type") - - # Determine output directory - if args.output is None: - # Use model type-based default - output_dir = get_default_output_path(model_type) +def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None: + """Prepare output directory for download.""" + if output is None: + output_dir = get_default_output_path(model_type_str) if output_dir is None: console.print( - f"[red]Error: No default path for model type '{model_type}'. " + f"[red]Error: No default path for type '{model_type_str}'. " "Use --output to specify.[/red]" ) - return 1 - console.print(f"[dim]Using default path for {model_type}: {output_dir}[/dim]") + return None + console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]") else: - output_dir = args.output.resolve() + output_dir = output.resolve() - # Create directory if it doesn't exist if not output_dir.exists(): console.print(f"[cyan]Creating directory: {output_dir}[/cyan]") output_dir.mkdir(parents=True, exist_ok=True) elif not output_dir.is_dir(): console.print(f"[red]Error: Not a directory: {output_dir}[/red]") - return 1 + return None + + return output_dir + + +@app.command("dl") +def download( + version_id: Annotated[ + int | None, typer.Option("-v", "--version-id", help="Model version ID") + ] = None, + model_id: Annotated[ + int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)") + ] = None, + hash_val: Annotated[ + str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up") + ] = None, + output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, + no_resume: Annotated[ + bool, typer.Option("--no-resume", help="Don't resume partial downloads") + ] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, +) -> None: + """Download a model from CivitAI.""" + key = api_key or load_api_key() + + resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) + if not resolved_version_id: + if not version_id and not hash_val and not model_id: + console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") + raise typer.Exit(1) + + console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") + version_info = fetch_civitai_model_version(resolved_version_id, key) + if not version_info: + console.print("[red]Error: Could not fetch model version info.[/red]") + raise typer.Exit(1) + + model_type_str: str | None = version_info.get("model", {}).get("type") + output_dir = _prepare_download_dir(output, model_type_str) + if not output_dir: + raise typer.Exit(1) - # Find primary file or first file files: list[dict[str, Any]] = version_info.get("files", []) primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) - if not primary_file: - console.print("[red]Error: No files found for this model version.[/red]") - return 1 + console.print("[red]Error: No files found for this version.[/red]") + raise typer.Exit(1) - filename = primary_file.get("name", f"model-{version_id}.safetensors") + filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") dest_path = output_dir / filename - # Display model info - model_table = Table(title="Model Download", show_header=True, header_style="bold magenta") - model_table.add_column("Property", style="cyan") - model_table.add_column("Value", style="green") - model_table.add_row("Version", version_info.get("name", "N/A")) - model_table.add_row("Base Model", version_info.get("baseModel", "N/A")) - model_table.add_row("File", filename) - model_table.add_row("Size", f"{primary_file.get('sizeKB', 0) / 1024:.2f} MB") - model_table.add_row("Destination", str(dest_path)) + table = Table(title="Model Download", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + table.add_row("Version", version_info.get("name", "N/A")) + table.add_row("Base Model", version_info.get("baseModel", "N/A")) + table.add_row("File", filename) + table.add_row("Size", _format_size(primary_file.get("sizeKB", 0))) + table.add_row("Destination", str(dest_path)) console.print() - console.print(model_table) + console.print(table) console.print() - # Download - success = download_model(version_id, dest_path, api_key, resume=not args.no_resume) - return 0 if success else 1 + success = download_model(resolved_version_id, dest_path, key, resume=not no_resume) + if not success: + raise typer.Exit(1) -def cmd_get(args: argparse.Namespace) -> int: - """Handle the get subcommand - fetch model info by ID.""" - model_id: int = args.model_id - api_key: str | None = args.api_key or load_api_key() +@app.command() +def config( + show: Annotated[bool, typer.Option("--show", help="Show current config")] = False, + set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None, +) -> None: + """Manage configuration.""" + if set_key: + cfg = load_config() + if "api" not in cfg: + cfg["api"] = {} + cfg["api"]["civitai_key"] = set_key + save_config(cfg) + console.print(f"[green]API key saved to {CONFIG_FILE}[/green]") + return - model_data = fetch_civitai_model(model_id, api_key) + if show or (not set_key): + console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}") + console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}") - if not model_data: - console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]") - return 1 + key = load_api_key() + if key: + masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***" + console.print(f"[bold]API key:[/bold] {masked}") + else: + console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") - if args.json_output: - console.print_json(data=model_data) - else: - _display_model_info(model_data) - - return 0 + console.print() + console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") def main() -> int: """Main entry point.""" - parser = argparse.ArgumentParser( - description="Read safetensor metadata and download CivitAI models.", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - subparsers = parser.add_subparsers(dest="command", help="Commands") + # Handle legacy invocation: tsr -> tsr info + if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): + arg = sys.argv[1] + if arg not in ("info", "search", "get", "dl", "download", "config") and ( + arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists() + ): + sys.argv = [sys.argv[0], "info", *sys.argv[1:]] - # Info command (default) - info_parser = subparsers.add_parser( - "info", - help="Read safetensor metadata and fetch CivitAI info (default)", - ) - info_parser.add_argument( - "file", - type=Path, - help="Path to the safetensor file", - ) - info_parser.add_argument( - "--api-key", - type=str, - default=None, - help="CivitAI API key for authenticated requests", - ) - info_parser.add_argument( - "--skip-civitai", - action="store_true", - help="Skip CivitAI API lookup", - ) - info_parser.add_argument( - "--json", - action="store_true", - dest="json_output", - help="Output results as JSON", - ) - info_parser.add_argument( - "--save-to", - type=Path, - metavar="DIR", - help="Save metadata JSON and SHA256 hash to the specified directory", - ) - info_parser.set_defaults(func=cmd_info) - - # Download command - dl_parser = subparsers.add_parser( - "download", - aliases=["dl"], - help="Download a model from CivitAI", - ) - dl_parser.add_argument( - "--version-id", - "-v", - type=int, - help="CivitAI model version ID to download", - ) - dl_parser.add_argument( - "--model-id", - "-m", - type=int, - help="CivitAI model ID (downloads latest version)", - ) - dl_parser.add_argument( - "--hash", - "-H", - type=str, - help="SHA256 hash to look up and download", - ) - dl_parser.add_argument( - "--api-key", - type=str, - default=None, - help="CivitAI API key for authenticated requests", - ) - dl_parser.add_argument( - "--output", - "-o", - type=Path, - default=None, - help="Output directory (default: type-based, e.g. ~/.xm/models/checkpoints for Checkpoint)", - ) - dl_parser.add_argument( - "--no-resume", - action="store_true", - help="Don't resume partial downloads, start fresh", - ) - dl_parser.set_defaults(func=cmd_download) - - # Get command - get_parser = subparsers.add_parser( - "get", - help="Fetch model information from CivitAI by model ID", - ) - get_parser.add_argument( - "model_id", - type=int, - help="CivitAI model ID", - ) - get_parser.add_argument( - "--api-key", - type=str, - default=None, - help="CivitAI API key for authenticated requests", - ) - get_parser.add_argument( - "--json", - action="store_true", - dest="json_output", - help="Output results as JSON", - ) - get_parser.set_defaults(func=cmd_get) - - # Parse and handle default command - args = parser.parse_args() - - # If no command specified and file argument given, assume 'info' command - if args.command is None: - # Check if there's a positional argument (file path) - if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): - # Re-parse with 'info' prepended - args = parser.parse_args(["info", *sys.argv[1:]]) - else: - parser.print_help() - return 0 - - result: int = args.func(args) - return result + app() + return 0 if __name__ == "__main__":