From 5588370a43ceb116b0b189cae0904fcea21d82c8 Mon Sep 17 00:00:00 2001 From: Adam Ladachowski Date: Tue, 3 Feb 2026 22:57:52 +0100 Subject: [PATCH] [Update] [2026-02-03 22:57:52] 10 files --- .coverage | Bin 61440 -> 69632 bytes pyproject.toml | 10 +- tensors.py | 1147 ----------------------------------------- tensors/__init__.py | 26 + tensors/api.py | 287 +++++++++++ tensors/cli.py | 386 ++++++++++++++ tensors/config.py | 166 ++++++ tensors/display.py | 324 ++++++++++++ tensors/safetensor.py | 92 ++++ tests/test_tensors.py | 30 +- 10 files changed, 1295 insertions(+), 1173 deletions(-) delete mode 100644 tensors.py create mode 100644 tensors/__init__.py create mode 100644 tensors/api.py create mode 100644 tensors/cli.py create mode 100644 tensors/config.py create mode 100644 tensors/display.py create mode 100644 tensors/safetensor.py diff --git a/.coverage b/.coverage index e98daa6c42b1176978935b1c704ca68eefb46931..1d6cb60c3e84aa7712ae042e0deb48a0e9ea6df3 100644 GIT binary patch literal 69632 zcmeI4d3+Vs-T%*-xifR-%mylAl;yI9kc}j)35tLqh&u`dkPyO!T*yXp!xA=UqSZ>> ziVK#2ihzK%Pi?K%)(x>Pt#(sdZR^IOT^_BiU9{?x=X>s%bGWvpKVH9oo~J(NzJRgJO57`v<4X9cy3iq^%;m1l~|D<$!WPIiMW)+&BEo~%B~Q9S|+tR!xz9#Qt72lY>niwc=Li}YidcVHPuv?YU4*SEWNmZ7Ibc;R_~Kf z#@D2>=TzHT2Y;N;oDD2Eb}G}Cz>eygTNksVU({OLR9ByBD@xYNGpuWd<6NELXWQD` zSej{0Hl&-fr`gn&&ZL{0lBw0Hy7o+J@jpGmB4(O+fqA)S1N&OYPG}+DGXAU_caClR zPqvI#k2hWd2iMdrY(2Ii`)JrWd(2Jkjf=qYW$leE?8|{I_y=d2kGJLVw#_uhCws!* z%|H1rkDc#lzD{8|`$qXz?M>+w?J4=?7bTBvS(H7Erskp~I|LmadBxRWu0QJGrZ zwxR*1PBpddndYpnfiJ(N4E}Mi>Pf8L{rlr{>-3}i;pvO*!#yH zH|#3;>>y!RQA0zrc}ez?VOuBpmjk0Y_nIjcX4JN_3oY3kuLwu_KRVjj!NM^l7d135 z`s`TCj+nL-Tp$n&#A-|0o9fu5(#FoAu6bpuwRUN$Bz~8Nh5zY+1Vg9HHhWL$f4re1 zhd@o4L>lIqR_`;;z>%FDM)9{5|FAwYgv2-bPr|5pt1(sE*4`=t$-fGcVvA!@Qd}eP zJ8(M1mf3sa8OJ-Fc)gmk7)WR(tlqtP;gDYuEH1IX#%E4~Eje}x{9PhA@xyPa&D1CV z@z~j7IPiSK>|NFCcn8i7p{9)eYhKku%O3Kzde;3$2>R@={+UB&OO74#jMSRqm9-5J zO3Q^x*fRWe~yxv)RBdKVVrusyzld2p&} zgPA~MtvJ_MYM&WdfA{9GL3R;DX_zAA4lZVRj)zS0Phb&|f9;F7bz-e}f7T$^qqoazHtt98eA@2b2TK0p);lKsoUFb3n%$wpjhI1#hC@J@`kxlmp5E z<$!WPIiMU+4k!nd1IhvAfO0@Npd9$Obimd0zJX{9DeHmtGxFjdfH5Uy6(wa9&Poe@ zje=hXfBv^TL3Pl|0p);lKslfsP!1>ulmp5E<$!WPIiMU+4oC-Fy)PEG0W_+gVaIm? z*zVMP%q_xazHtt98eA@2b2TK0p);lKslfsP!9a>cVMKcqvq1f z+u*+(D6OkcmtNf3yew6hX)Dd7n%bISq4?QWyClUK{U9ll^>VDEV zrvEYuUJROp?ZKG=1)*Nb0p);lKslfsP!1>ulmp5E<$!WPIiMW)H*jE--i%MA4E{|0 zZ^EF<$!WPIiMU+4k!nd1IhvAfO0@N@c-|C4($RSd;X7uPZ0c3FXe!8KslfsP!1>u zlmp5E<$!WPIiMU+4k!mchYskvXS3)3sK@8<ulmp5E z<$!WPIUpTi&;M2ZFYQpE98eA@2b2TK0p);lKslfsP!1>ulmp6v&!qzfoD~lK#?GxL z|M9szO?6Vr0p);lKslfsP!1>ulmp5E<$!WPIiMW)zw1Dcs91xu$H9BdgMk0|ziWp& zbLD_?KslfsP!1>ulmp5E<$!WPIiMU+4t#DL@I4dT8oY2R9*j1t-~a#Ie6#9Almp5E z<$!WPIiMU+4k!nd1IhvAfO0@Na9ju2@Be?%?yuc51|3m*>|B(N4{{?@)|B!#5f0uu&f1SV4&-iJ7 zfj`rq=AY@8`}uw!zo&0{AA9e6N4;NoKk}aT9`U~B-Q#WXuJ=gxAcyW`zbcc9zLCCP&IQI)zR@r>B#kf2F^ruhAdV1N0%fmv+$`X++y; zJ)K7{q!Z~#nosknV}EM@-hRt|$^M@Gu)WvbYF}?}v{%@5_T~1u_BgxH?rR4&CLfTa zNZO#`>l8qVl|yWRb(ZtQ!HYBV!m%4F<&uXFrPFZH20WY=1t~R=4!LSTxecqPBX`wBh10( zX{KXz8y^_&7{4@rWE?QQW!!7rVQezi8_h z`x4s{HzvYFTOyU1otU1eN)#vhCweAy{m=S4`m6d6^vCqC>pS#Y^e^bE^=0~e{X+e0 zy<8uxpQ;n>Z`yxqZ)z`V&uNcp4`{o!uV^=Dk+xD>rY+De)uw8d#~o1W4dsAxKslfs zP!9aNJ0KsGl;AS-`E=D&ha#SI@I!6sdSLLW7s53_mLXlL1a6Sri)L?X_ zn|)}6HSFL)8-+mcphFv^cpDvBFU3)GXq^;C(4nV=)aQt&mga~jmq_s>`d%tW>|ZR!9@JGQ#l5I&krcbpy|p=_YoQdo z(C(TX(X~K|ooLq;IihR66!)N=^Q5>Nb5^;H6w*H;Q$RzFT)x_+#{`i?OY8%IlQ7$tG}NQuiTB&N#+PAD%E zcvkrcf#b_d1y-F^BC)(!U}aU2z_I0p0>@MimsnmPaP*jA0!NnT3#=#~DzLnMh{Oql zC5|5?v1*{iu>&NINlF~qUt&c+iDi8SmW}QsvHT2yC1s}zEGozoSXgwLz~P0x1s0T? zDsgx(f%yfe2<)GKvcO*bPZH?&>M79mdkA#gKp=H|frNSjE#e9^Ek~eXQh~Z*+b|Si z2x~fF#D-#~tksb5D0&=f9lyokLLPKjVDc+2?dQH#k>1 zt^O_mZvT zMyT~yk>#X@%p~WMDya4ck>13yx~&hacc9+?k#)d&*t*ZU)7orZW39ECtvYMAHN%=< zRair<(<}=0{(I(|=1i617OOnf7;GjUtu+QiyKbD|FF z{qqwO6QdFZiM~+r8~R83yZT%DLH&8C`5(}C>38VY>!IGRr}ZoJOY|xFIK5aOp!d>k z?NjX!+EMLQ?L}?B_APC%woSWPyINbLHEOlm6h+@GXLe7vs&63-Ol)7YxTY3(oI{ZxTGT0N*G% zAKxH&@KAic;DP<{mjn+Oj5kRhh_4fz9Du(lxZenTt>h%WM)DE-1;KsB81QUuw!4|<+3N|ggQLtg+4T5z8uV;Y?LXfHBb=eqYYi2iIE7myD4|fQD zu>!9Vyl)QXb3LpXxUw6sil>_f9)K(HN~tW+?LwhHVhd-a`W5IlsScoD;Z~{kp#yk@ zR9$Ev&QAHTrWviM3pdBdVJvRNO@gb&;YPug<8XuEaf|VC$yIonULY90jVt1B!_e{k z_}ehL@x0iG1lrxb0MC_b7pQ8fcB0*QPLA4z`2-JZ3ACdd&ysK5+1-vWmufqx%cR z@ub-M#EKR?QEOR(zJ^7JR1Q#&$ej@PfU#O7Q#zxKi-E`FNaQh@@i$ z&zXzI2%g=DM@ycAM+u(sI36i@`V3qlc-nMaF8JJOxJ>ZWbMXklQ>Nll!IP)p62X&Z z<6_B^agpFr2XUd`icxsD;PMJwAh@g?4--714Cf0j9f5}mE-A%B1Q(Uy!Ga4X;X#s% z@Ib+X3h@BJ0|()x;QoDZf5H9w<9>qs_QQPz_Zf)$2+r$^&k%fS9zI=guTycJ;8S|x z(*&P%3hphq=Slcf!FEsFOE9tVDS{!Yo-7!m>Pdnjs`eBN*Ku~j2qP|`V?JTDwL~}e z#To|-u_yTPN!S&66nw@9IqDGlV~%ucg`z>K&m3<%Q1F@zdRg%540=iMsto$6;FTHl6T$5n z^y4_H>lyg?ABi_I9q5OW*Ps_AuSPFOUWI-jc_sS3o)TO+8|{}o9z7|!3Oyk>e+YV9@X#UXF~LJB(W8>{(RU>e zMc)zJcL;h!a9&^ZZNa_s(8Ge=-smC0j*Gq}m^kP`!6rf96l|F28-jHMJrK`ObRB&? zo}uX7=s)7CT*F(s(bwW{Tf^7kE$Du!T5u!!s#MLO?vttsH=}(ysuAs#su5p@_T;D* zbZ?GoM!R!V6WW!d8qiLu8gLW3Cr34+9a3EZw`=c~YCfpDq?!lnPO0XC+AdW!sBKcs z!40S@M_qxo=BWATD>-T&x+6!;MO$)IHM(7@3b?I&n^fhXZk4JGm!n&t{~s4?%h>b( z_k(wW*Moz>3&8=n3-I+|chChL0M`Z^;7-7@U|}#TI6s&ajD@=agM-t99)an9;{Oru z2)ypU?0?^X61o7s3U>vz_&4}h`5k_Xzr>&KUkZHyXZa)j;eOIT)pvc(`_OyOI|7{m zKlYyS9`(NA?eVsIw|LiiSHk^)W!^&Q1vuZEb3mpLl{fNFx-=YWU^Yk(L0Nq8m&`op$T|pPq+4MYmCLICy3r?lB{R#909I_AE z&)MI#zh>WU-(p{5ci4^gLVKn?6}ke7;BLW5wn085zay`cpOB}?!(<=XMs9+81uIFK z%qJI-$z&YdD;P}jh);CuL+f4Zko9xx1#7?c5Zo=e%evLN&e{n50cmT2HPf1AooSU@ z`BopRr)8QSLx;do^HuYQ<^l8D=Kbal^LFz(b3NQOs557o)1gD4)EsE`GKulA@t*OP z@v`x(@g3uS<1V;s@C9Ry(O}dVmqLfYSfkMBXY@1@*?R`BC4QVZka#Gu4|)V{N^~Yx zCen%diHj1G;GV&-#2E=cf%QM>ztLaUU(%n`AJ-q$_v+jA&HAg!Ym4o_1J!1@0R>r9G^DRl8feP1~ehskOts3^wC zBr4=M>L7`RbF3Ieq5_WP6(kzQv8{{i3V^SR7j#E$AN=L)SqMjJ|ybLv0r}@_2t;NABp;K>@$!=XK>8xOQO>`o|;FZ zJdV9iCDCadPw7RX-W*Rlg+!-v?0FK2dU3RSlIRqU#3s?n94$hklQ^0diF$H0OcM3r zs2f%k0NOaATanKe+KFxwdHn5zaAq#Y$Kli+j(fp#%5ifujch!Kjm%iUq>x5NEMQJZqeLuVN=PG^u)^MS3^PI+K|P-pFkzz+js?sIX@sP(Xt+@d zlZG3lFle}5iUbYUNuk@}S|%`?3DU5GzriMgG+Yx4mDxI7jx^P}OiSir=OhUr+qypM+Uv4CkG4VT6OW_>hV5(}90(J&PYnDfzaaV%iU zN5i^Uz>JTEi=@~`!&)i!(r}>^duUiA#l19KAjNJPULnOU8qSwuCk^LGaSsjW#sVgK zG^~yV%=2hCCl)Zxqv32Rw$N~v6t|P`au9H%x+6ov%Q&t9oXK%D;H4Z_0bas!CE&#z z+cS1}5krm5?nrncU!3V6;RP|SA>sKkt|sA(7*~<-yck!KaC(gGB%H>vdNv8q43ClS` z=n2a>Lg)!c0NPmBt+13WwDoQ~ECGV@4<>IkWI>+&b|K8&Xvo4m6EJn7Aq(_Oz|4(? zEYvdr6E_;NV9y1-!@@liFm0nD3;0~L(2#|EE?^rL^touFAq)FlG|-R*ekNeXMne|* zxqw}>;Lili*J#MXKNB!rqajNGOu%f7hAagz0h2WvvLwI+%++Yf(f|`MRihzG1YBG} zLzW7-m`_8N47iv_LzWJ>m`g*J5V)wOAxjBNz$A@^EGaMnb2J*Vw7>*R(P+pL0~7H3 znlxmofeGkHpdm{RT);alJum_N2sC5~f(d9xpdrf0^=tsPYN1>9s|mA0^|LZrvwc_hXLgoK||1BKzTyY z5cC(&h1`GOF3Qt^hM>EE@@&Aki}GZkA?PijJQpzTqC6F72s#TW&jcER#sbO{frg;3 zfbu+`A!sY0JPl|Fx(X=I0*t#TO9C(o5%d&LmIXNPp)3t>zL&B*z61%EUP+;u(;2t zoUot|R9Qt>#0MH%PFT1H8dFJFtY=hCSfB?QJ%+F-4>Yozun-SaQBGKV2P&^8EVwh8 zKv-mFG@cBMsfw_e&X$cO$(Y6v7R}kRk%WbEMiqp`aYki?1#zIV(S${CM&*QsZ=jMg z!eTd2Q2}9r8>p~|u&50*ypXVv4OCDeZjH7!Bn2 zA}l}yxjtdh83MPwkT16Vi)f?UjEF%YC;7Kqsd$6`Sk z#$~r+Nx)}w34t^WF5xg8c zA3Xm5=>GQtD9=nJTJE_5bC|9`PF(CO{C4yJ#hzoTzL zz5jjs1bq;G2Vgs#vU({8lmp5E<$!WPIiMU+4*c(MfOPUI09px2Cococd4RkQFz3!A zoxBh*=T?(WUJ01h^GPQ!1p(#fj=a|A5qWr11Tk96|7 zz$^shg@HM|m~`^Wz=T?&la~f2ln|Y~HZb#HU0xiRLkmbJuMW(i`J|JV2j<|Rq?6YN z=0I4N7YOEn!K9N{2LgvwS6X1@`nGg~EKvtN>Qn(9pA3+VTw(Rs0ec0U?Tj|Fu5(P&yMpxKW`=f(nh z{b)2*ikoOOB^J=>N27CM0gZlkG#Lc6T|u88i6$||(B?;?i5yq7kmziVEh|VgAx2o? zERM}EcFv3u)*a8Wv7JO!92e{*Q6bWE;xdyp}i^>Xei9irRAb>!)dO~E7dTDL7R{JWV;2o_O1OcsBLGfC(S}jFD!9ulK(%;!LhY>4%+y2uR-?JYQzR&E; z?3rh0p3Re)$%L{oH4Vvlbvjd*Xo&ZQ0w4s2#^V6M@XNz5`Jf09De!;t4L%YwC|-1) z<4p$Jm;&BZr>)n+o^J28eq#4$ldKxrofYE~>Y*G^4k!nd1K%45{4$#swrN9h3mXy@ zwaJD|qB5Bg-<{#(N0v?;8J}1>Y|O~Gm=^D9#_`dkN4zwiNnaGNPiErPHMPljO{%J< zGSN_znjUYMk(6^cW|LL?GYhBTrxu^FIKMon@|r4Ksv$WYS5cp-sY_($#b+hwbzc+lXqs3uc&%@^6E6H6O9rzC3s>Xd zQfX;xxFCONTsS}GRAXHQT0XO}uAaX+xP*A&hV+q^Jkqia>0o33<@e@0-pg=%4|_YM z<@_BLOEsoyW;Z4ez4_ha;ga3+n~_R)kMoVdJ)o0i6qYn6fO}0Z1Vb|f$A9=swN6Fzk2|tQz zYvbwa{2{}aj*B-3_vUXNGd-mli3~r`;^|;U9O-{`Z{vf-F~los(-nvBYxzAUo5TYI zvp~*P-I%K6hf0=jLuGnSGLx8|EDoL(vGCvCk#N`P#f-v{?T@tU9NGoSdmTd4Z)4EH zcI`-x<#(gt-3tDrhwVc0o&1g%l~2_r6WPX$43d8ql5&YKCFL;^JcFClgBpdU?T)lL z!F=Vt0*LuELJM2AB)&KzcrNj;=wTbdr-X;Vw>}8|_2Jhi8fL`*aqN6D4!ld$C>+@G zNCW3XDDTC;8%qnc{E#ox^X}JS=<_fAup#p);gHWt&g(HJQHyC(pQy>;A&VI-J^|u< z#;OF)oTx}QHstqTF7vx5|Lh6JD^?I{={?R>l%E|B?{m`GYz=;_@{g_T068_@Wkw>) zQz+XdUXf1MCKD;~*?>!&k;T`n$;R_z_{UGBA>lf*^Y}WZ&*_pMll*YuF+ULxI(d8{ z)YTi)#q8$?Cx0CgBDi3@3m%kNJVd{>PPx+HS)N#O{l)d@314enqDma@ z{7w1B8ZVlW{D;+vPooiuhGaucT{6yR$qy#KV_lTU#w#;Pd@}ef%?}J6oAK>YiPfBr z&&UiFTT|3U|7ZISe~Zg|@$bAP$M}zOetSX-yLBUeV{inEa4OH@PXcubxvhcJ4%@T7 z^<=n{jl-12ACbPfb8*AJ;USme6?h_wx4lAGC+7-&Kr)LQ93}C-0{o{Q$^qqoazHtt z98eA@2b2TK0p);lKslfsP!4?m9MFkIDDVH8_da-g@SA!l2b2TK0p);lKslfsP!1>u zlmp5E<$!WPIdD`uU~Bqup1driR{_!6wF#~P=vUmUcX4kMw@UMNfVacje^hRu8nkji zIiMU+4k!nd1IhvAfO0@Npd3&RC@KW7U9_1IhvAfO0@Npd3&RCM026#|Yp|Nk|3 zU;jIcs(Fulmp5E<$!WP zIiMU+4k!nX9tZIJKk1|o2XDW((|gZ*)!Xbn0*1F%i)4JKZ+FD}Gwd$-2>s)J$RciIMx><3nrDd82 z&3)#_=3C|#bDjB!d5?LUd7Zh;oNuPg3Ujh~ra9O=#q4ahHC&s4J$FpOY5&LUwPwa!(_Sg%tr($bjcg1drT^-BC=EP>k zrpCs{M#ToiienvOtzsrUNcYi?>05LQT}L0G_t4wub#xh>PgAslPNrwl!Soc`nYN`a z)rviSVcck3ZY(k~MzwLiG1eGn^fgX0jyH}mjObsZd!ipiw?|)y zJ{4UPy(@Z4^y+9XIwv|aIyE{zIx0FKS{&^VZ51^m2P69;A4lGbY>BLkJQBGla$DrO z$g;@%NGehhnH)JYGB|Qdq;sTg#EoeB0ez?buD(@&USF%P((ll3)GyZ;=^4FRKVKh< zyVC!+M=aMV%9utX#Z4@CB1bI}j})_9N5Mu!XSoi7qnu9=9En6(E-o3Tw->BO;w*Q( zWP|07lZ^8pD>=e)?Idd~*H*9=!Rc)z<2qaCS>L9wXSqVbTRFE9{EEJnuInhAc9&tD+;1^p$K^8|0=(_O(E^%q#q5xjvE=?FBH6#^TmRfa9$udr{Bx``I472f1c!J%%3ayQs!SIc`5Vf zNM6GHM#0s3IrAF?Cpl*YSLsRSXC&kF*@7$eD(2SeK_-fAb2pJeu3cA^r6f@U+_Ra{XD?~^wXF> zMKVr5S8$1r+cH`5Xy#9njN_ar8OJt3GLC1wf99Vp zxWA6m$4JKUoRw$ofVPYIXA0iW`3%9Ib3R@0K5akqM@z=(qXh5O_A!5?;7|GdBLwf! z_A-CCWSl-s@NPa&so-5Yj%{eb+5zSdk^DLH2TR8FpC9hItxZrv+oB1ODNj;w~ z@}HzhJ@ZBWBja?D|0G4y%oq8OjMGK_lRA=OzQ})MoG$X8)Dj%C$baN4^F{t6<9I~= zBjXrE{v+e-iu^~$*An@UjO!Qqk6g!mk^iKOv&etamy|JIC1eP|HvJgFY+I`1M@}xlQx`1{*z`0asL`8LLc6P517b&#P^wqe0Ue$V0uju#6-{|zQ#n( z!#3Qu>B5p%U@sFX51X3?G7<8y2~lJ`Y=q6MQp~#%HZhU#umN6S2{CnD(|9KG9iBxL z@ea?xvrMErJPpq<5$^C5tYaeE5uauv+TjUYOe8z3g{PPZcEl%`$aQ!e)-n<6@EAPK zM5@E1@E8-J4v)a2Ok_Gdh$|D34r|~MCK4U-K_&tnR>K-5@*Gyd1}5ShZh}=zq&eKs zRL?}1!}YL&i7bZ|a6J=I4l8g5k>qe4+`vSTBd%Z~$Ke{dDK?tBL%?4a8-+x)#o$vO z;e;$&OnH72vTzX{o`(esV?4QuHXp8`Jhus%zku@4CM0z$<%vy5ZHn@^CS+DEBZ<=IQfsIim>FCimG zQJ%Vl3>!jufS_$cMD&@IKNVh(ehbkdOoheULLb?`F9;bwK=}LK)64JRF8 z)yL{;9cvX>5%UXkm-)W=s=3L0(tN5=ik6CV+&Y{CD=qZnhjt*Se!M-mj7jJSM0sm*4T#FW3hWeH$D8v-BZ)H@$^kMK7UQT20TRW9Sg<_nl~KY8eNO zy~ZDm*Nqp9UmFh?zrc3?L*rs2ZB!VOjMI&QMzL{%(ZZnE?RQ7tk8X=TAALN!GWxUV zb3Bt7(4w!yC+lbEgY=$y2fd{p)4tO7XusE9 z)i!EtwfnT6Yb&&6+B~gRE7!)$+HV_tm&76(iINjSp_8~2Sa3=EP#EO+P(;abp@@)U zL!ps&QsBx-+feAFO(1yBC9S2vqDu;;z@|%DNr6?Dw3Gt7E@=^p^`yBJ*mlV=QefRB z&7{D-OA4gG!b?0Uu<;UC3aq@ukpep}u@4D-8wP@~3ft;dVus>n!a}iy#6s~Rp`qAB zj8JSK(NH``B2r-aC3+~<6HN-NzZg!z2(bSWkOB)ZG)aLC7!FE-6&U^|1$JQgS_&+| z@Kq?5!(T(O489D-rSL^4mcm~`u>=lCfmImxOMzV&{wxKSVfZ{0<*-i*ti!NZ3hcx1 znG{%v;ZrHF5yPHPRKxC2Bw<%5s^F7QBw%MKroo>=F%@=5fyEg97>aWESPHDh@R1bQ zjo}YcU^#{lLs0@BNP+bjejkbwcwY)E$nc&N*pT5}DX=2L@1($v4DU#RB^ln90$Vb? zB?Z=Gcrz3w@P-svl;QPIjE3!@7zM9|VkEp8iV?6a6s52=6ocWFPz;3MhGGD`9E$$1 zMGCCU@KPvB;5UKbr5Rq70$Ve@AO+TD*enJ1X4n*pU9eFKY|ik!6j+^MgA~}EVSOle z!E;hzdxmvVV10&X4+-r6JQIq~;ptH9gQr6ADf~JVd*I1X?1m>o@d>Pz0&6rp9*SM? zm=su~;ZZ5DNy8&jV3mf4rNAx?4@rS#8XlAa+cd0^0_!xa4g~Mh@PHIpsNsGouu;RR zP^^Ueq`*!ME2Y3v4fjfcts3q*B*<#`pHMsizY4|uaCazH!Cj%a4}K{Hc5Ap(3M|*~ z3n{Q&!yTchh1;dTehohlMJ@bH3T)VLTPSMbr&3_YhM$C@7H*XSTQ=Mhidwi?3hdc% zQz)|V<4|Pa#!$?L8$wYJ*M}kvD?*Wi>q1cn*GhqX8?KQ83pe~o3T)hPwG>#n;VLPx zbHkOP=m=MYq61tm1=eo(p%mD=;W8<(c*AljuzAC>P`GfZ6xh9CX(&8cA_cZ@$Vq|q z8~i}<{tcH{`41z)8~9?$Siu)b#tyzvGM4a*C1VR;5b}=shq$TrJjqzZ=Ss#VevxFX z;&UWp7jF#t{RYX{#$Q@xmCurlwR~pC zuhkslrk*n-V>6#F8LN4F0<1#Cejjs85lMP5oTSSk)&>#;!g|GM4p;lCiB%2>HSBlCiIs zNyfr{j$~}?<0NBcA1fI<``MDQw2wK&P4#C<#@c>n$Scl}jK%$QlkcB|H}}yTiAJ!x zk768ocOMzh!bKwjTCi|$^g!PimChm zWp7pgD+iPV$^qqoazHtt98eA@2b2TK0p);l;QQ+UzyJScHh^FM|0nNv>ioYt|F6#f zHxGV8rmOS+ny$|OYkIk^&i|L_>ioYt|F02s{$Jx)*{k#a>ioYt|DQhv$*i?U(E4)SC zY`pLPT<=WpH1A}*?!T4C+^_L|{}0^P+)eHi?kc?E|9W>h-sfNEPID)?qudg=J6`L5 z4Bp@Wh4YE?uJeks-g(ry2e0zK##!Rbacc0c{&9GP|EW$_r=8VbFhH@6J)OT4oGJ#(wM!Fulmp5E<$!YFd*ML- zLP(k diff --git a/pyproject.toml b/pyproject.toml index 3f94683..0b500e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] -packages = ["tensors.py"] +packages = ["tensors"] [dependency-groups] dev = [ @@ -33,7 +33,7 @@ dev = [ [tool.ruff] target-version = "py312" -line-length = 100 +line-length = 130 [tool.ruff.lint] select = [ @@ -51,11 +51,7 @@ select = [ "PL", # pylint "RUF", # ruff-specific ] -ignore = [ - "PLR0911", # too many return statements - "PLR0913", # too many arguments - "PLR2004", # magic value comparison -] +ignore = [] [tool.ruff.lint.isort] known-first-party = ["tensors"] diff --git a/tensors.py b/tensors.py deleted file mode 100644 index a7bf101..0000000 --- a/tensors.py +++ /dev/null @@ -1,1147 +0,0 @@ -#!/usr/bin/env python3 -""" -tsr: Read safetensor metadata, search and download CivitAI models. -""" - -from __future__ import annotations - -import hashlib -import json -import os -import re -import struct -import sys -import tomllib -from enum import Enum -from pathlib import Path -from typing import Annotated, Any - -import httpx -import typer -from rich.console import Console -from rich.progress import ( - BarColumn, - DownloadColumn, - Progress, - SpinnerColumn, - TaskProgressColumn, - TextColumn, - TimeRemainingColumn, - TransferSpeedColumn, -) -from rich.table import Table - -# ============================================================================ -# App and Console Setup -# ============================================================================ - -app = typer.Typer( - name="tsr", - help="Read safetensor metadata, search and download CivitAI models.", - no_args_is_help=True, -) -console = Console() - -# ============================================================================ -# Configuration -# ============================================================================ - -# XDG Base Directory spec -# Config: ~/.config/tensors/config.toml -# Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/ -CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors" -CONFIG_FILE = CONFIG_DIR / "config.toml" - -DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors" -MODELS_DIR = DATA_DIR / "models" -METADATA_DIR = DATA_DIR / "metadata" - -# Legacy config for migration -LEGACY_RC_FILE = Path.home() / ".sftrc" - -# Default download paths by model type -DEFAULT_PATHS: dict[str, Path] = { - "Checkpoint": MODELS_DIR / "checkpoints", - "LORA": MODELS_DIR / "loras", - "LoCon": MODELS_DIR / "loras", -} - -CIVITAI_API_BASE = "https://civitai.com/api/v1" -CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" - - -# ============================================================================ -# Enums for CLI -# ============================================================================ - - -class ModelType(str, Enum): - """CivitAI model types.""" - - checkpoint = "checkpoint" - lora = "lora" - embedding = "embedding" - vae = "vae" - controlnet = "controlnet" - locon = "locon" - - def to_api(self) -> str: - """Convert to CivitAI API value.""" - mapping = { - "checkpoint": "Checkpoint", - "lora": "LORA", - "embedding": "TextualInversion", - "vae": "VAE", - "controlnet": "Controlnet", - "locon": "LoCon", - } - return mapping[self.value] - - -class BaseModel(str, Enum): - """Common base models.""" - - sd15 = "sd15" - sdxl = "sdxl" - pony = "pony" - flux = "flux" - illustrious = "illustrious" - - def to_api(self) -> str: - """Convert to CivitAI API value.""" - mapping = { - "sd15": "SD 1.5", - "sdxl": "SDXL 1.0", - "pony": "Pony", - "flux": "Flux.1 D", - "illustrious": "Illustrious", - } - return mapping[self.value] - - -class SortOrder(str, Enum): - """Sort options for search.""" - - downloads = "downloads" - rating = "rating" - newest = "newest" - - def to_api(self) -> str: - """Convert to CivitAI API value.""" - mapping = { - "downloads": "Most Downloaded", - "rating": "Highest Rated", - "newest": "Newest", - } - return mapping[self.value] - - -# ============================================================================ -# Config Functions -# ============================================================================ - - -def load_config() -> dict[str, Any]: - """Load configuration from TOML config file.""" - if CONFIG_FILE.exists(): - with CONFIG_FILE.open("rb") as f: - return tomllib.load(f) - return {} - - -def save_config(config: dict[str, Any]) -> None: - """Save configuration to TOML config file.""" - CONFIG_DIR.mkdir(parents=True, exist_ok=True) - - lines: list[str] = [] - for key, value in config.items(): - if isinstance(value, dict): - lines.append(f"[{key}]") - for k, v in value.items(): - if isinstance(v, str): - lines.append(f'{k} = "{v}"') - else: - lines.append(f"{k} = {v}") - lines.append("") - elif isinstance(value, str): - lines.append(f'{key} = "{value}"') - else: - lines.append(f"{key} = {value}") - - CONFIG_FILE.write_text("\n".join(lines) + "\n") - - -def load_api_key() -> str | None: - """Load API key from config file or CIVITAI_API_KEY env var.""" - # Check environment variable first - env_key = os.environ.get("CIVITAI_API_KEY") - if env_key: - return env_key - - # Check TOML config file - config = load_config() - api_section = config.get("api", {}) - if isinstance(api_section, dict): - key = api_section.get("civitai_key") - if key: - return str(key) - - # Fall back to legacy RC file for migration - if LEGACY_RC_FILE.exists(): - content = LEGACY_RC_FILE.read_text().strip() - if content: - return content - return None - - -def get_default_output_path(model_type: str | None) -> Path | None: - """Get default output path based on model type.""" - if model_type and model_type in DEFAULT_PATHS: - return DEFAULT_PATHS[model_type] - return None - - -# ============================================================================ -# Safetensor Functions -# ============================================================================ - - -def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: - """Read metadata from a safetensor file header.""" - with file_path.open("rb") as f: - # First 8 bytes are the header size (little-endian u64) - header_size_bytes = f.read(8) - if len(header_size_bytes) < 8: - raise ValueError("Invalid safetensor file: too short") - - header_size = struct.unpack(" 100_000_000: # 100MB sanity check - raise ValueError(f"Invalid header size: {header_size}") - - header_bytes = f.read(header_size) - if len(header_bytes) < header_size: - raise ValueError("Invalid safetensor file: header truncated") - - header: dict[str, Any] = json.loads(header_bytes.decode("utf-8")) - - # Extract __metadata__ if present - metadata: dict[str, Any] = header.get("__metadata__", {}) - - # Count tensors (keys that aren't __metadata__) - tensor_count = sum(1 for k in header if k != "__metadata__") - - return { - "metadata": metadata, - "tensor_count": tensor_count, - "header_size": header_size, - } - - -def compute_sha256(file_path: Path) -> str: - """Compute SHA256 hash of a file with progress display.""" - file_size = file_path.stat().st_size - sha256 = hashlib.sha256() - chunk_size = 1024 * 1024 * 8 # 8MB chunks - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - DownloadColumn(), - TransferSpeedColumn(), - TimeRemainingColumn(), - console=console, - ) as progress: - task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size) - - with file_path.open("rb") as f: - while chunk := f.read(chunk_size): - sha256.update(chunk) - progress.update(task, advance=len(chunk)) - - return sha256.hexdigest().upper() - - -def get_base_name(file_path: Path) -> str: - """Get base filename without .safetensors extension.""" - name = file_path.name - for ext in (".safetensors", ".sft"): - if name.lower().endswith(ext): - return name[: -len(ext)] - return file_path.stem - - -# ============================================================================ -# CivitAI API Functions -# ============================================================================ - - -def _get_headers(api_key: str | None) -> dict[str, str]: - """Get headers for CivitAI API requests.""" - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - return headers - - -def fetch_civitai_model_version( - version_id: int, api_key: str | None = None -) -> dict[str, Any] | None: - """Fetch model version information from CivitAI by version ID.""" - url = f"{CIVITAI_API_BASE}/model-versions/{version_id}" - - try: - response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) - if response.status_code == 404: - return None - response.raise_for_status() - result: dict[str, Any] = response.json() - return result - except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") - return None - except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") - return None - - -def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, Any] | None: - """Fetch model information from CivitAI by model ID.""" - url = f"{CIVITAI_API_BASE}/models/{model_id}" - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - transient=True, - ) as progress: - progress.add_task("[cyan]Fetching model from CivitAI...", total=None) - - try: - response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) - if response.status_code == 404: - return None - response.raise_for_status() - result: dict[str, Any] = response.json() - return result - except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") - return None - except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") - return None - - -def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None: - """Fetch model information from CivitAI by SHA256 hash.""" - url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - transient=True, - ) as progress: - progress.add_task("[cyan]Fetching from CivitAI...", total=None) - - try: - response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) - if response.status_code == 404: - return None - response.raise_for_status() - result: dict[str, Any] = response.json() - return result - except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") - return None - except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") - return None - - -def search_civitai( - query: str | None = None, - model_type: ModelType | None = None, - base_model: BaseModel | None = None, - sort: SortOrder = SortOrder.downloads, - limit: int = 20, - api_key: str | None = None, -) -> dict[str, Any] | None: - """Search CivitAI models.""" - params: dict[str, Any] = { - "limit": min(limit, 100), - "nsfw": "true", - } - - # API quirk: query + filters don't work reliably together - # If we have filters, skip query and filter client-side - has_filters = model_type is not None or base_model is not None - - if query and not has_filters: - params["query"] = query - - if model_type: - params["types"] = model_type.to_api() - - if base_model: - params["baseModels"] = base_model.to_api() - - params["sort"] = sort.to_api() - - # Request more if we need client-side filtering - if query and has_filters: - params["limit"] = 100 - - url = f"{CIVITAI_API_BASE}/models" - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - transient=True, - ) as progress: - progress.add_task("[cyan]Searching CivitAI...", total=None) - - try: - response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0) - response.raise_for_status() - result: dict[str, Any] = response.json() - - # Client-side filtering when query + filters combined - if query and has_filters: - q_lower = query.lower() - result["items"] = [ - m for m in result.get("items", []) if q_lower in m.get("name", "").lower() - ][:limit] - - return result - except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") - return None - except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") - return None - - -def download_model( - version_id: int, - dest_path: Path, - api_key: str | None = None, - resume: bool = True, -) -> bool: - """Download a model from CivitAI by version ID with resume support.""" - url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}" - params: dict[str, str] = {} - if api_key: - params["token"] = api_key - - headers: dict[str, str] = {} - mode = "wb" - initial_size = 0 - - # Check for existing partial download - if resume and dest_path.exists(): - initial_size = dest_path.stat().st_size - headers["Range"] = f"bytes={initial_size}-" - mode = "ab" - console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]") - - try: - with httpx.stream( - "GET", - url, - params=params, - headers=headers, - follow_redirects=True, - timeout=httpx.Timeout(30.0, read=None), - ) as response: - if response.status_code == 416: - console.print("[green]File already fully downloaded.[/green]") - return True - - response.raise_for_status() - - content_length = response.headers.get("content-length") - total_size = int(content_length) + initial_size if content_length else 0 - - content_disp = response.headers.get("content-disposition", "") - if "filename=" in content_disp: - match = re.search(r'filename="?([^";\n]+)"?', content_disp) - if match and dest_path.is_dir(): - dest_path = dest_path / match.group(1) - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - DownloadColumn(), - TransferSpeedColumn(), - TimeRemainingColumn(), - console=console, - ) as progress: - task = progress.add_task( - f"[cyan]Downloading {dest_path.name}...", - total=total_size if total_size > 0 else None, - completed=initial_size, - ) - - with dest_path.open(mode) as f: - for chunk in response.iter_bytes(1024 * 1024): - f.write(chunk) - progress.update(task, advance=len(chunk)) - - console.print() - console.print(f"[magenta]Downloaded:[/magenta] [green]\"{dest_path}\"[/green]") - return True - - except httpx.HTTPStatusError as e: - console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]") - if e.response.status_code == 401: - console.print("[yellow]Hint: This model may require an API key.[/yellow]") - return False - except httpx.RequestError as e: - console.print(f"[red]Download error: {e}[/red]") - return False - - -# ============================================================================ -# Display Functions -# ============================================================================ - - -def _format_size(size_kb: float) -> str: - """Format size in KB to human-readable string.""" - if size_kb < 1024: - return f"{size_kb:.0f} KB" - if size_kb < 1024 * 1024: - return f"{size_kb / 1024:.1f} MB" - return f"{size_kb / 1024 / 1024:.2f} GB" - - -def _format_count(count: int) -> str: - """Format large numbers with K/M suffix.""" - if count < 1000: - return str(count) - if count < 1_000_000: - return f"{count / 1000:.1f}K" - return f"{count / 1_000_000:.1f}M" - - -def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None: - """Display file information table.""" - # Property column: 12 chars, Value fills remaining width - prop_width = 12 - - file_table = Table(title="File Information", show_header=True, header_style="bold magenta", expand=True) - file_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) - file_table.add_column("Value", style="green", no_wrap=True, overflow="ellipsis") - - file_table.add_row("File", str(file_path.name)) - file_table.add_row("Path", str(file_path.parent)) - file_table.add_row("Size", f"{file_path.stat().st_size / (1024**3):.2f} GB") - file_table.add_row("SHA256", sha256_hash) - file_table.add_row("Header Size", f"{local_metadata['header_size']:,} bytes") - file_table.add_row("Tensor Count", str(local_metadata["tensor_count"])) - - console.print() - console.print(file_table) - - -def _display_local_metadata(local_metadata: dict[str, Any], keys_filter: list[str] | None = None) -> None: - """Display local safetensor metadata table.""" - if not local_metadata["metadata"]: - console.print() - console.print("[yellow]No embedded metadata found in safetensor file.[/yellow]") - return - - metadata = local_metadata["metadata"] - - # If specific keys requested, show them in full - if keys_filter: - for key in keys_filter: - if key in metadata: - console.print(f"[cyan]{key}[/cyan]: {metadata[key]}") - else: - console.print(f"[yellow]{key}: not found[/yellow]") - return - - # Find the longest key to set column width - all_keys = list(metadata.keys()) - key_width = max(len(k) for k in all_keys) if all_keys else 20 - - # Value width: terminal minus key column and table borders (7 chars) - terminal_width = console.size.width - value_width = terminal_width - key_width - 7 - - meta_table = Table( - title="Safetensor Metadata", show_header=True, header_style="bold magenta", - ) - meta_table.add_column("Key", style="cyan", width=key_width, no_wrap=True) - meta_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") - - for key, value in sorted(metadata.items()): - meta_table.add_row(key, str(value)) - - console.print() - console.print(meta_table) - - -def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None: - """Display CivitAI model information table.""" - if not civitai_data: - console.print() - console.print("[yellow]Model not found on CivitAI.[/yellow]") - return - - # Property column: 14 chars, Value fills remaining width - prop_width = 14 - terminal_width = console.size.width - overhead = 7 # borders and separators for 2 columns - value_width = max(40, terminal_width - prop_width - overhead) - - civit_table = Table( - title="CivitAI Model Information", show_header=True, header_style="bold magenta" - ) - civit_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) - civit_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") - - civit_table.add_row("Model ID", str(civitai_data.get("modelId", "N/A"))) - civit_table.add_row("Version ID", str(civitai_data.get("id", "N/A"))) - civit_table.add_row("Version Name", str(civitai_data.get("name", "N/A"))) - civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A"))) - civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A"))) - - trained_words: list[str] = civitai_data.get("trainedWords", []) - if trained_words: - civit_table.add_row("Trigger Words", ", ".join(trained_words)) - - download_url = str(civitai_data.get("downloadUrl", "N/A")) - civit_table.add_row("Download URL", download_url) - - files: list[dict[str, Any]] = civitai_data.get("files", []) - for f in files: - if f.get("primary"): - civit_table.add_row("Primary File", str(f.get("name", "N/A"))) - civit_table.add_row("File Size", _format_size(f.get("sizeKB", 0))) - meta: dict[str, Any] = f.get("metadata", {}) - if meta: - civit_table.add_row("Format", str(meta.get("format", "N/A"))) - civit_table.add_row("Precision", str(meta.get("fp", "N/A"))) - civit_table.add_row("Size Type", str(meta.get("size", "N/A"))) - - console.print() - console.print(civit_table) - - model_id = civitai_data.get("modelId") - if model_id: - console.print() - console.print( - f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}" - ) - - -def _display_model_info(model_data: dict[str, Any]) -> None: - """Display full CivitAI model information.""" - # Property column: 10 chars, Value fills remaining width - prop_width = 10 - terminal_width = console.size.width - overhead = 7 # borders and separators for 2 columns - value_width = max(40, terminal_width - prop_width - overhead) - - model_table = Table(title="Model Information", show_header=True, header_style="bold magenta") - model_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) - model_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") - - model_table.add_row("ID", str(model_data.get("id", "N/A"))) - model_table.add_row("Name", str(model_data.get("name", "N/A"))) - model_table.add_row("Type", str(model_data.get("type", "N/A"))) - model_table.add_row("NSFW", str(model_data.get("nsfw", False))) - - creator = model_data.get("creator", {}) - if creator: - model_table.add_row("Creator", str(creator.get("username", "N/A"))) - - tags: list[str] = model_data.get("tags", []) - if tags: - model_table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else "")) - - stats: dict[str, Any] = model_data.get("stats", {}) - if stats: - model_table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}") - model_table.add_row("Likes", f"{stats.get('thumbsUpCount', 0):,}") - - mode = model_data.get("mode") - if mode: - model_table.add_row("Status", str(mode)) - - console.print() - console.print(model_table) - - versions: list[dict[str, Any]] = model_data.get("modelVersions", []) - if versions: - # Static column widths for version table - # ID: 7 chars, Base Model: 20 chars, Created: 10 chars, Size: 8 chars - id_width = 7 - base_width = 20 - created_width = 10 - size_width = 8 - - # Calculate dynamic widths for Name and Filename - terminal_width = console.size.width - fixed_width = id_width + base_width + created_width + size_width - overhead = 20 # borders and separators for 5 columns - remaining = max(40, terminal_width - fixed_width - overhead) - name_width = remaining // 3 - file_width = remaining - name_width - - ver_table = Table(title="Model Versions", show_header=True, header_style="bold magenta") - ver_table.add_column("ID", style="cyan", width=id_width, no_wrap=True) - ver_table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis") - ver_table.add_column("Base Model", style="yellow", width=base_width, no_wrap=True, overflow="ellipsis") - ver_table.add_column("Created", style="blue", width=created_width, no_wrap=True) - ver_table.add_column("Filename", style="white", width=file_width, no_wrap=True, overflow="ellipsis") - ver_table.add_column("Size", justify="right", width=size_width, no_wrap=True) - - for ver in versions: - files: list[dict[str, Any]] = ver.get("files", []) - primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) - filename = "N/A" - size = "N/A" - if primary_file: - filename = primary_file.get("name", "N/A") - size = _format_size(primary_file.get("sizeKB", 0)) - - created = str(ver.get("createdAt", "N/A"))[:10] - ver_table.add_row( - str(ver.get("id", "N/A")), - str(ver.get("name", "N/A")), - str(ver.get("baseModel", "N/A")), - created, - filename, - size, - ) - - console.print() - console.print(ver_table) - - model_id = model_data.get("id") - if model_id: - console.print() - console.print( - f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}" - ) - - -def _display_search_results(results: dict[str, Any]) -> None: - """Display search results in a table.""" - items = results.get("items", []) - if not items: - console.print("[yellow]No results found.[/yellow]") - return - - # Static column widths based on expected max values - # ID: 7 chars (max ~9,999,999) - # Type: 16 chars (longest: "TextualInversion") - # Base: 20 chars (e.g., "Flux.2 Klein 9B-base") - # Size: 8 chars (e.g., "11.08 GB") - # DLs: 6 chars (e.g., "999.9K") - # Likes: 6 chars (e.g., "999.9K") - id_width = 7 - type_width = 16 - base_width = 20 - size_width = 8 - dls_width = 6 - likes_width = 6 - - # Calculate name width: terminal width minus fixed columns and separators - # Table has 7 columns with separators: "│ col │ col │ ..." = 3 chars per col (space+pipe+space) - # Plus outer borders: "┃" on each side = 2 chars - # Total overhead: 2 (outer) + 7*3 (separators) = 23 chars - terminal_width = console.size.width - fixed_width = id_width + type_width + base_width + size_width + dls_width + likes_width - overhead = 23 # borders and separators - name_width = max(20, terminal_width - fixed_width - overhead) - - table = Table(show_header=True, header_style="bold magenta") - table.add_column("ID", style="cyan", justify="right", width=id_width, no_wrap=True) - table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis") - table.add_column("Type", style="yellow", width=type_width, no_wrap=True) - table.add_column("Base", style="blue", width=base_width, no_wrap=True, overflow="ellipsis") - table.add_column("Size", justify="right", width=size_width, no_wrap=True) - table.add_column("DLs", justify="right", width=dls_width, no_wrap=True) - table.add_column("Likes", justify="right", width=likes_width, no_wrap=True) - - for model in items: - model_id = str(model.get("id", "")) - name = model.get("name", "N/A") - model_type = model.get("type", "N/A") - - # Get latest version info - versions = model.get("modelVersions", []) - base_model = "N/A" - size = "N/A" - if versions: - latest = versions[0] - base_model = latest.get("baseModel", "N/A") - files = latest.get("files", []) - primary = next((f for f in files if f.get("primary")), files[0] if files else None) - if primary: - size = _format_size(primary.get("sizeKB", 0)) - - stats = model.get("stats", {}) - downloads = _format_count(stats.get("downloadCount", 0)) - likes = _format_count(stats.get("thumbsUpCount", 0)) - - table.add_row(model_id, name, model_type, base_model, size, downloads, likes) - - console.print() - console.print(table) - - metadata = results.get("metadata", {}) - total = metadata.get("totalItems", len(items)) - console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]") - console.print("[dim]Use 'tsr get ' to view details or 'tsr dl -m ' to download[/dim]") - - -# ============================================================================ -# CLI Commands -# ============================================================================ - - -@app.command() -def info( - file: Annotated[Path, typer.Argument(help="Path to the safetensor file")], - meta: Annotated[ - list[str] | None, typer.Option("--meta", "-m", help="Show specific metadata key(s) in full") - ] = None, - api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, - skip_civitai: Annotated[ - bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup") - ] = False, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, - save_to: Annotated[ - Path | None, typer.Option("--save-to", help="Save metadata to directory") - ] = None, -) -> None: - """Read safetensor metadata and fetch CivitAI info.""" - file_path = file.resolve() - - if not file_path.exists(): - console.print(f"[red]Error: File not found: {file_path}[/red]") - raise typer.Exit(1) - - if file_path.suffix.lower() not in (".safetensors", ".sft"): - console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]") - - try: - local_metadata = read_safetensor_metadata(file_path) - - # If just fetching specific metadata keys, skip everything else - if meta: - _display_local_metadata(local_metadata, keys_filter=meta) - return - - console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}") - sha256_hash = compute_sha256(file_path) - - civitai_data = None - if not skip_civitai: - key = api_key or load_api_key() - civitai_data = fetch_civitai_by_hash(sha256_hash, key) - - if json_output: - output = { - "file": str(file_path), - "sha256": sha256_hash, - "header_size": local_metadata["header_size"], - "tensor_count": local_metadata["tensor_count"], - "metadata": local_metadata["metadata"], - "civitai": civitai_data, - } - console.print_json(data=output) - else: - _display_file_info(file_path, local_metadata, sha256_hash) - _display_local_metadata(local_metadata) - _display_civitai_data(civitai_data) - - if save_to: - output_dir = save_to.resolve() - if not output_dir.exists() or not output_dir.is_dir(): - console.print(f"[red]Error: Invalid directory: {output_dir}[/red]") - raise typer.Exit(1) - - base_name = get_base_name(file_path) - json_path = output_dir / f"{base_name}.json" - sha_path = output_dir / f"{base_name}.sha256" - - output = { - "file": str(file_path), - "sha256": sha256_hash, - "header_size": local_metadata["header_size"], - "tensor_count": local_metadata["tensor_count"], - "metadata": local_metadata["metadata"], - "civitai": civitai_data, - } - json_path.write_text(json.dumps(output, indent=2)) - sha_path.write_text(f"{sha256_hash} {file_path.name}\n") - - console.print() - console.print(f"[green]Saved:[/green] {json_path}") - console.print(f"[green]Saved:[/green] {sha_path}") - - except ValueError as e: - console.print(f"[red]Error reading safetensor: {e}[/red]") - raise typer.Exit(1) from e - - -@app.command() -def search( - query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None, - model_type: Annotated[ - ModelType | None, typer.Option("-t", "--type", help="Model type filter") - ] = None, - base: Annotated[ - BaseModel | None, typer.Option("-b", "--base", help="Base model filter") - ] = None, - sort: Annotated[ - SortOrder, typer.Option("-s", "--sort", help="Sort order") - ] = SortOrder.downloads, - limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, - api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, -) -> None: - """Search CivitAI models.""" - key = api_key or load_api_key() - - results = search_civitai( - query=query, - model_type=model_type, - base_model=base, - sort=sort, - limit=limit, - api_key=key, - ) - - if not results: - console.print("[red]Search failed.[/red]") - raise typer.Exit(1) - - if json_output: - console.print_json(data=results) - else: - _display_search_results(results) - - -@app.command() -def get( - id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")], - version: Annotated[ - bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID") - ] = False, - api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, -) -> None: - """Fetch model information from CivitAI by model ID or version ID.""" - key = api_key or load_api_key() - - if version: - # Fetch by version ID - version_data = fetch_civitai_model_version(id_value, key) - if not version_data: - console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]") - raise typer.Exit(1) - - if json_output: - console.print_json(data=version_data) - else: - _display_civitai_data(version_data) - else: - # Fetch by model ID - model_data = fetch_civitai_model(id_value, key) - if not model_data: - console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]") - raise typer.Exit(1) - - if json_output: - console.print_json(data=model_data) - else: - _display_model_info(model_data) - - -def _resolve_version_id( - version_id: int | None, - hash_val: str | None, - model_id: int | None, - api_key: str | None, -) -> int | None: - """Resolve version ID from hash or model ID.""" - if version_id: - return version_id - - if hash_val: - console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]") - civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key) - if not civitai_data: - console.print("[red]Error: Model not found on CivitAI for this hash.[/red]") - return None - vid: int | None = civitai_data.get("id") - if vid: - console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}") - return vid - - if model_id: - console.print(f"[cyan]Looking up model {model_id}...[/cyan]") - model_data = fetch_civitai_model(model_id, api_key) - if not model_data: - console.print(f"[red]Error: Model {model_id} not found.[/red]") - return None - versions = model_data.get("modelVersions", []) - if not versions: - console.print("[red]Error: Model has no versions.[/red]") - return None - latest = versions[0] - latest_vid: int | None = latest.get("id") - if latest_vid: - name = latest.get("name", "N/A") - console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})") - return latest_vid - - return None - - -def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None: - """Prepare output directory for download.""" - if output is None: - output_dir = get_default_output_path(model_type_str) - if output_dir is None: - console.print( - f"[red]Error: No default path for type '{model_type_str}'. " - "Use --output to specify.[/red]" - ) - return None - console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]") - else: - output_dir = output.resolve() - - if not output_dir.exists(): - console.print(f"[cyan]Creating directory: {output_dir}[/cyan]") - output_dir.mkdir(parents=True, exist_ok=True) - elif not output_dir.is_dir(): - console.print(f"[red]Error: Not a directory: {output_dir}[/red]") - return None - - return output_dir - - -@app.command("dl") -def download( - version_id: Annotated[ - int | None, typer.Option("-v", "--version-id", help="Model version ID") - ] = None, - model_id: Annotated[ - int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)") - ] = None, - hash_val: Annotated[ - str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up") - ] = None, - output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, - no_resume: Annotated[ - bool, typer.Option("--no-resume", help="Don't resume partial downloads") - ] = False, - api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, -) -> None: - """Download a model from CivitAI.""" - key = api_key or load_api_key() - - resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) - if not resolved_version_id: - if not version_id and not hash_val and not model_id: - console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") - raise typer.Exit(1) - - console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") - version_info = fetch_civitai_model_version(resolved_version_id, key) - if not version_info: - console.print("[red]Error: Could not fetch model version info.[/red]") - raise typer.Exit(1) - - model_type_str: str | None = version_info.get("model", {}).get("type") - output_dir = _prepare_download_dir(output, model_type_str) - if not output_dir: - raise typer.Exit(1) - - files: list[dict[str, Any]] = version_info.get("files", []) - primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) - if not primary_file: - console.print("[red]Error: No files found for this version.[/red]") - raise typer.Exit(1) - - filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") - dest_path = output_dir / filename - - table = Table(title="Model Download", show_header=True, header_style="bold magenta") - table.add_column("Property", style="cyan") - table.add_column("Value", style="green") - table.add_row("Version", version_info.get("name", "N/A")) - table.add_row("Base Model", version_info.get("baseModel", "N/A")) - table.add_row("File", filename) - table.add_row("Size", _format_size(primary_file.get("sizeKB", 0))) - table.add_row("Destination", str(dest_path)) - console.print() - console.print(table) - console.print() - - success = download_model(resolved_version_id, dest_path, key, resume=not no_resume) - if not success: - raise typer.Exit(1) - - -@app.command() -def config( - show: Annotated[bool, typer.Option("--show", help="Show current config")] = False, - set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None, -) -> None: - """Manage configuration.""" - if set_key: - cfg = load_config() - if "api" not in cfg: - cfg["api"] = {} - cfg["api"]["civitai_key"] = set_key - save_config(cfg) - console.print(f"[green]API key saved to {CONFIG_FILE}[/green]") - return - - if show or (not set_key): - console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}") - console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}") - - key = load_api_key() - if key: - masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***" - console.print(f"[bold]API key:[/bold] {masked}") - else: - console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") - - console.print() - console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") - - -def main() -> int: - """Main entry point.""" - # Handle legacy invocation: tsr -> tsr info - if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): - arg = sys.argv[1] - if arg not in ("info", "search", "get", "dl", "download", "config") and ( - arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists() - ): - sys.argv = [sys.argv[0], "info", *sys.argv[1:]] - - app() - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tensors/__init__.py b/tensors/__init__.py new file mode 100644 index 0000000..4456cfe --- /dev/null +++ b/tensors/__init__.py @@ -0,0 +1,26 @@ +"""tsr: Read safetensor metadata, search and download CivitAI models.""" + +from tensors.cli import main +from tensors.config import ( + CONFIG_DIR, + CONFIG_FILE, + LEGACY_RC_FILE, + get_default_output_path, + load_api_key, + load_config, + save_config, +) +from tensors.safetensor import get_base_name, read_safetensor_metadata + +__all__ = [ + "CONFIG_DIR", + "CONFIG_FILE", + "LEGACY_RC_FILE", + "get_base_name", + "get_default_output_path", + "load_api_key", + "load_config", + "main", + "read_safetensor_metadata", + "save_config", +] diff --git a/tensors/api.py b/tensors/api.py new file mode 100644 index 0000000..06d72a0 --- /dev/null +++ b/tensors/api.py @@ -0,0 +1,287 @@ +"""CivitAI API functions.""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pathlib import Path + +import httpx +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeRemainingColumn, + TransferSpeedColumn, +) + +from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder + +if TYPE_CHECKING: + from rich.console import Console + + +def _get_headers(api_key: str | None) -> dict[str, str]: + """Get headers for CivitAI API requests.""" + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + +def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None: + """Fetch model version information from CivitAI by version ID.""" + url = f"{CIVITAI_API_BASE}/model-versions/{version_id}" + + try: + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) + if response.status_code == 404: + return None + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + +def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None: + """Fetch model information from CivitAI by model ID.""" + url = f"{CIVITAI_API_BASE}/models/{model_id}" + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Fetching model from CivitAI...", total=None) + + try: + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) + if response.status_code == 404: + return None + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + +def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console) -> dict[str, Any] | None: + """Fetch model information from CivitAI by SHA256 hash.""" + url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Fetching from CivitAI...", total=None) + + try: + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) + if response.status_code == 404: + return None + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + +def _build_search_params( + query: str | None, + model_type: ModelType | None, + base_model: BaseModel | None, + sort: SortOrder, + limit: int, +) -> tuple[dict[str, Any], bool]: + """Build search parameters and return (params, has_filters).""" + params: dict[str, Any] = { + "limit": min(limit, 100), + "nsfw": "true", + } + + # API quirk: query + filters don't work reliably together + has_filters = model_type is not None or base_model is not None + + if query and not has_filters: + params["query"] = query + + if model_type: + params["types"] = model_type.to_api() + + if base_model: + params["baseModels"] = base_model.to_api() + + params["sort"] = sort.to_api() + + # Request more if we need client-side filtering + if query and has_filters: + params["limit"] = 100 + + return params, has_filters + + +def _filter_results(result: dict[str, Any], query: str | None, has_filters: bool, limit: int) -> dict[str, Any]: + """Apply client-side filtering when query + filters combined.""" + if query and has_filters: + q_lower = query.lower() + result["items"] = [m for m in result.get("items", []) if q_lower in m.get("name", "").lower()][:limit] + return result + + +def search_civitai( + query: str | None, + model_type: ModelType | None, + base_model: BaseModel | None, + sort: SortOrder, + limit: int, + api_key: str | None, + console: Console, +) -> dict[str, Any] | None: + """Search CivitAI models.""" + params, has_filters = _build_search_params(query, model_type, base_model, sort, limit) + url = f"{CIVITAI_API_BASE}/models" + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Searching CivitAI...", total=None) + + try: + response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0) + response.raise_for_status() + result: dict[str, Any] = response.json() + return _filter_results(result, query, has_filters, limit) + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + +def _setup_resume(dest_path: Path, resume: bool, console: Console) -> tuple[dict[str, str], str, int]: + """Set up resume headers and mode for download.""" + headers: dict[str, str] = {} + mode = "wb" + initial_size = 0 + + if resume and dest_path.exists(): + initial_size = dest_path.stat().st_size + headers["Range"] = f"bytes={initial_size}-" + mode = "ab" + console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]") + + return headers, mode, initial_size + + +def _get_dest_from_response(response: httpx.Response, dest_path: Path) -> Path: + """Extract destination path from response headers if dest is directory.""" + content_disp = response.headers.get("content-disposition", "") + if "filename=" in content_disp: + match = re.search(r'filename="?([^";\n]+)"?', content_disp) + if match and dest_path.is_dir(): + return dest_path / match.group(1) + return dest_path + + +def _stream_download( + response: httpx.Response, + dest_path: Path, + mode: str, + initial_size: int, + console: Console, +) -> bool: + """Stream download content to file with progress.""" + content_length = response.headers.get("content-length") + total_size = int(content_length) + initial_size if content_length else 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + DownloadColumn(), + TransferSpeedColumn(), + TimeRemainingColumn(), + console=console, + ) as progress: + task = progress.add_task( + f"[cyan]Downloading {dest_path.name}...", + total=total_size if total_size > 0 else None, + completed=initial_size, + ) + + with dest_path.open(mode) as f: + for chunk in response.iter_bytes(1024 * 1024): + f.write(chunk) + progress.update(task, advance=len(chunk)) + + console.print() + console.print(f'[magenta]Downloaded:[/magenta] [green]"{dest_path}"[/green]') + return True + + +def download_model( + version_id: int, + dest_path: Path, + api_key: str | None, + console: Console, + resume: bool = True, +) -> bool: + """Download a model from CivitAI by version ID with resume support.""" + url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}" + params: dict[str, str] = {} + if api_key: + params["token"] = api_key + + headers, mode, initial_size = _setup_resume(dest_path, resume, console) + + try: + with httpx.stream( + "GET", + url, + params=params, + headers=headers, + follow_redirects=True, + timeout=httpx.Timeout(30.0, read=None), + ) as response: + if response.status_code == 416: + console.print("[green]File already fully downloaded.[/green]") + return True + + response.raise_for_status() + dest_path = _get_dest_from_response(response, dest_path) + return _stream_download(response, dest_path, mode, initial_size, console) + + except httpx.HTTPStatusError as e: + console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]") + if e.response.status_code == 401: + console.print("[yellow]Hint: This model may require an API key.[/yellow]") + return False + except httpx.RequestError as e: + console.print(f"[red]Download error: {e}[/red]") + return False diff --git a/tensors/cli.py b/tensors/cli.py new file mode 100644 index 0000000..d0f58f9 --- /dev/null +++ b/tensors/cli.py @@ -0,0 +1,386 @@ +"""CLI application and commands for tsr.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Annotated, Any + +import typer +from rich.console import Console +from rich.table import Table + +from tensors.api import ( + download_model, + fetch_civitai_by_hash, + fetch_civitai_model, + fetch_civitai_model_version, + search_civitai, +) +from tensors.config import ( + CONFIG_FILE, + BaseModel, + ModelType, + SortOrder, + get_default_output_path, + load_api_key, + load_config, + save_config, +) +from tensors.display import ( + _format_size, + display_civitai_data, + display_file_info, + display_local_metadata, + display_model_info, + display_search_results, +) +from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata + +app = typer.Typer( + name="tsr", + help="Read safetensor metadata, search and download CivitAI models.", + no_args_is_help=True, +) +console = Console() + + +@app.command() +def info( + file: Annotated[Path, typer.Argument(help="Path to the safetensor file")], + meta: Annotated[list[str] | None, typer.Option("--meta", "-m", help="Show specific metadata key(s) in full")] = None, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + skip_civitai: Annotated[bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup")] = False, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, + save_to: Annotated[Path | None, typer.Option("--save-to", help="Save metadata to directory")] = None, +) -> None: + """Read safetensor metadata and fetch CivitAI info.""" + file_path = file.resolve() + + if not file_path.exists(): + console.print(f"[red]Error: File not found: {file_path}[/red]") + raise typer.Exit(1) + + if file_path.suffix.lower() not in (".safetensors", ".sft"): + console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]") + + try: + local_metadata = read_safetensor_metadata(file_path) + + if meta: + display_local_metadata(local_metadata, console, keys_filter=meta) + return + + console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}") + sha256_hash = compute_sha256(file_path, console) + + civitai_data = None + if not skip_civitai: + key = api_key or load_api_key() + civitai_data = fetch_civitai_by_hash(sha256_hash, key, console) + + if json_output: + _output_info_json(file_path, sha256_hash, local_metadata, civitai_data) + else: + display_file_info(file_path, local_metadata, sha256_hash, console) + display_local_metadata(local_metadata, console) + display_civitai_data(civitai_data, console) + + if save_to: + _save_metadata(save_to, file_path, sha256_hash, local_metadata, civitai_data) + + except ValueError as e: + console.print(f"[red]Error reading safetensor: {e}[/red]") + raise typer.Exit(1) from e + + +def _output_info_json( + file_path: Path, + sha256_hash: str, + local_metadata: dict[str, Any], + civitai_data: dict[str, Any] | None, +) -> None: + """Output info command result as JSON.""" + output = { + "file": str(file_path), + "sha256": sha256_hash, + "header_size": local_metadata["header_size"], + "tensor_count": local_metadata["tensor_count"], + "metadata": local_metadata["metadata"], + "civitai": civitai_data, + } + console.print_json(data=output) + + +def _save_metadata( + save_to: Path, + file_path: Path, + sha256_hash: str, + local_metadata: dict[str, Any], + civitai_data: dict[str, Any] | None, +) -> None: + """Save metadata to directory.""" + output_dir = save_to.resolve() + if not output_dir.exists() or not output_dir.is_dir(): + console.print(f"[red]Error: Invalid directory: {output_dir}[/red]") + raise typer.Exit(1) + + base_name = get_base_name(file_path) + json_path = output_dir / f"{base_name}.json" + sha_path = output_dir / f"{base_name}.sha256" + + output = { + "file": str(file_path), + "sha256": sha256_hash, + "header_size": local_metadata["header_size"], + "tensor_count": local_metadata["tensor_count"], + "metadata": local_metadata["metadata"], + "civitai": civitai_data, + } + json_path.write_text(json.dumps(output, indent=2)) + sha_path.write_text(f"{sha256_hash} {file_path.name}\n") + + console.print() + console.print(f"[green]Saved:[/green] {json_path}") + console.print(f"[green]Saved:[/green] {sha_path}") + + +@app.command() +def search( + query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None, + model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter")] = None, + base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None, + sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads, + limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, +) -> None: + """Search CivitAI models.""" + key = api_key or load_api_key() + + results = search_civitai( + query=query, + model_type=model_type, + base_model=base, + sort=sort, + limit=limit, + api_key=key, + console=console, + ) + + if not results: + console.print("[red]Search failed.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=results) + else: + display_search_results(results, console) + + +@app.command() +def get( + id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")], + version: Annotated[bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID")] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Fetch model information from CivitAI by model ID or version ID.""" + key = api_key or load_api_key() + + if version: + version_data = fetch_civitai_model_version(id_value, key, console) + if not version_data: + console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=version_data) + else: + display_civitai_data(version_data, console) + else: + model_data = fetch_civitai_model(id_value, key, console) + if not model_data: + console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=model_data) + else: + display_model_info(model_data, console) + + +def _resolve_version_id( + version_id: int | None, + hash_val: str | None, + model_id: int | None, + api_key: str | None, +) -> int | None: + """Resolve version ID from hash or model ID.""" + if version_id: + return version_id + + if hash_val: + console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]") + civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console) + if not civitai_data: + console.print("[red]Error: Model not found on CivitAI for this hash.[/red]") + return None + vid: int | None = civitai_data.get("id") + if vid: + console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}") + return vid + + if model_id: + console.print(f"[cyan]Looking up model {model_id}...[/cyan]") + model_data = fetch_civitai_model(model_id, api_key, console) + if not model_data: + console.print(f"[red]Error: Model {model_id} not found.[/red]") + return None + versions = model_data.get("modelVersions", []) + if not versions: + console.print("[red]Error: Model has no versions.[/red]") + return None + latest = versions[0] + latest_vid: int | None = latest.get("id") + if latest_vid: + name = latest.get("name", "N/A") + console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})") + return latest_vid + + return None + + +def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None: + """Prepare output directory for download.""" + if output is None: + output_dir = get_default_output_path(model_type_str) + if output_dir is None: + console.print(f"[red]Error: No default path for type '{model_type_str}'. Use --output to specify.[/red]") + return None + console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]") + else: + output_dir = output.resolve() + + if not output_dir.exists(): + console.print(f"[cyan]Creating directory: {output_dir}[/cyan]") + output_dir.mkdir(parents=True, exist_ok=True) + elif not output_dir.is_dir(): + console.print(f"[red]Error: Not a directory: {output_dir}[/red]") + return None + + return output_dir + + +@app.command("dl") +def download( + version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None, + model_id: Annotated[int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)")] = None, + hash_val: Annotated[str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up")] = None, + output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, + no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, +) -> None: + """Download a model from CivitAI.""" + key = api_key or load_api_key() + + resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) + if not resolved_version_id: + if not version_id and not hash_val and not model_id: + console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") + raise typer.Exit(1) + + console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") + version_info = fetch_civitai_model_version(resolved_version_id, key, console) + if not version_info: + console.print("[red]Error: Could not fetch model version info.[/red]") + raise typer.Exit(1) + + model_type_str: str | None = version_info.get("model", {}).get("type") + output_dir = _prepare_download_dir(output, model_type_str) + if not output_dir: + raise typer.Exit(1) + + files: list[dict[str, Any]] = version_info.get("files", []) + primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) + if not primary_file: + console.print("[red]Error: No files found for this version.[/red]") + raise typer.Exit(1) + + filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") + dest_path = output_dir / filename + + _display_download_info(version_info, filename, primary_file, dest_path) + + success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume) + if not success: + raise typer.Exit(1) + + +def _display_download_info( + version_info: dict[str, Any], + filename: str, + primary_file: dict[str, Any], + dest_path: Path, +) -> None: + """Display download info table.""" + table = Table(title="Model Download", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + table.add_row("Version", version_info.get("name", "N/A")) + table.add_row("Base Model", version_info.get("baseModel", "N/A")) + table.add_row("File", filename) + table.add_row("Size", _format_size(primary_file.get("sizeKB", 0))) + table.add_row("Destination", str(dest_path)) + console.print() + console.print(table) + console.print() + + +@app.command() +def config( + show: Annotated[bool, typer.Option("--show", help="Show current config")] = False, + set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None, +) -> None: + """Manage configuration.""" + if set_key: + cfg = load_config() + if "api" not in cfg: + cfg["api"] = {} + cfg["api"]["civitai_key"] = set_key + save_config(cfg) + console.print(f"[green]API key saved to {CONFIG_FILE}[/green]") + return + + if show or (not set_key): + console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}") + console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}") + + key = load_api_key() + if key: + masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***" + console.print(f"[bold]API key:[/bold] {masked}") + else: + console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") + + console.print() + console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") + + +def main() -> int: + """Main entry point.""" + # Handle legacy invocation: tsr -> tsr info + if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): + arg = sys.argv[1] + if arg not in ("info", "search", "get", "dl", "download", "config") and ( + arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists() + ): + sys.argv = [sys.argv[0], "info", *sys.argv[1:]] + + app() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tensors/config.py b/tensors/config.py new file mode 100644 index 0000000..9dd14bd --- /dev/null +++ b/tensors/config.py @@ -0,0 +1,166 @@ +"""Configuration, constants, and enums for tsr CLI.""" + +from __future__ import annotations + +import os +import tomllib +from enum import Enum +from pathlib import Path +from typing import Any + +# ============================================================================ +# XDG Base Directory Configuration +# ============================================================================ + +# Config: ~/.config/tensors/config.toml +# Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/ +CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors" +CONFIG_FILE = CONFIG_DIR / "config.toml" + +DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors" +MODELS_DIR = DATA_DIR / "models" +METADATA_DIR = DATA_DIR / "metadata" + +# Legacy config for migration +LEGACY_RC_FILE = Path.home() / ".sftrc" + +# Default download paths by model type +DEFAULT_PATHS: dict[str, Path] = { + "Checkpoint": MODELS_DIR / "checkpoints", + "LORA": MODELS_DIR / "loras", + "LoCon": MODELS_DIR / "loras", +} + +CIVITAI_API_BASE = "https://civitai.com/api/v1" +CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" + + +# ============================================================================ +# Enums for CLI +# ============================================================================ + + +class ModelType(str, Enum): + """CivitAI model types.""" + + checkpoint = "checkpoint" + lora = "lora" + embedding = "embedding" + vae = "vae" + controlnet = "controlnet" + locon = "locon" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + mapping = { + "checkpoint": "Checkpoint", + "lora": "LORA", + "embedding": "TextualInversion", + "vae": "VAE", + "controlnet": "Controlnet", + "locon": "LoCon", + } + return mapping[self.value] + + +class BaseModel(str, Enum): + """Common base models.""" + + sd15 = "sd15" + sdxl = "sdxl" + pony = "pony" + flux = "flux" + illustrious = "illustrious" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + mapping = { + "sd15": "SD 1.5", + "sdxl": "SDXL 1.0", + "pony": "Pony", + "flux": "Flux.1 D", + "illustrious": "Illustrious", + } + return mapping[self.value] + + +class SortOrder(str, Enum): + """Sort options for search.""" + + downloads = "downloads" + rating = "rating" + newest = "newest" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + mapping = { + "downloads": "Most Downloaded", + "rating": "Highest Rated", + "newest": "Newest", + } + return mapping[self.value] + + +# ============================================================================ +# Config Functions +# ============================================================================ + + +def load_config() -> dict[str, Any]: + """Load configuration from TOML config file.""" + if CONFIG_FILE.exists(): + with CONFIG_FILE.open("rb") as f: + return tomllib.load(f) + return {} + + +def save_config(config: dict[str, Any]) -> None: + """Save configuration to TOML config file.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + lines: list[str] = [] + for key, value in config.items(): + if isinstance(value, dict): + lines.append(f"[{key}]") + for k, v in value.items(): + if isinstance(v, str): + lines.append(f'{k} = "{v}"') + else: + lines.append(f"{k} = {v}") + lines.append("") + elif isinstance(value, str): + lines.append(f'{key} = "{value}"') + else: + lines.append(f"{key} = {value}") + + CONFIG_FILE.write_text("\n".join(lines) + "\n") + + +def load_api_key() -> str | None: + """Load API key from config file or CIVITAI_API_KEY env var.""" + # Check environment variable first + env_key = os.environ.get("CIVITAI_API_KEY") + if env_key: + return env_key + + # Check TOML config file + config = load_config() + api_section = config.get("api", {}) + if isinstance(api_section, dict): + key = api_section.get("civitai_key") + if key: + return str(key) + + # Fall back to legacy RC file for migration + if LEGACY_RC_FILE.exists(): + content = LEGACY_RC_FILE.read_text().strip() + if content: + return content + return None + + +def get_default_output_path(model_type: str | None) -> Path | None: + """Get default output path based on model type.""" + if model_type and model_type in DEFAULT_PATHS: + return DEFAULT_PATHS[model_type] + return None diff --git a/tensors/display.py b/tensors/display.py new file mode 100644 index 0000000..9f5d64d --- /dev/null +++ b/tensors/display.py @@ -0,0 +1,324 @@ +"""Rich table display functions for tsr CLI.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pathlib import Path + +from rich.table import Table + +if TYPE_CHECKING: + from rich.console import Console + + +def _format_size(size_kb: float) -> str: + """Format size in KB to human-readable string.""" + if size_kb < 1024: + return f"{size_kb:.0f} KB" + if size_kb < 1024 * 1024: + return f"{size_kb / 1024:.1f} MB" + return f"{size_kb / 1024 / 1024:.2f} GB" + + +def _format_count(count: int) -> str: + """Format large numbers with K/M suffix.""" + if count < 1000: + return str(count) + if count < 1_000_000: + return f"{count / 1000:.1f}K" + return f"{count / 1_000_000:.1f}M" + + +def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None: + """Display file information table.""" + prop_width = 12 + + file_table = Table(title="File Information", show_header=True, header_style="bold magenta", expand=True) + file_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) + file_table.add_column("Value", style="green", no_wrap=True, overflow="ellipsis") + + file_table.add_row("File", str(file_path.name)) + file_table.add_row("Path", str(file_path.parent)) + file_table.add_row("Size", f"{file_path.stat().st_size / (1024**3):.2f} GB") + file_table.add_row("SHA256", sha256_hash) + file_table.add_row("Header Size", f"{local_metadata['header_size']:,} bytes") + file_table.add_row("Tensor Count", str(local_metadata["tensor_count"])) + + console.print() + console.print(file_table) + + +def display_local_metadata(local_metadata: dict[str, Any], console: Console, keys_filter: list[str] | None = None) -> None: + """Display local safetensor metadata table.""" + if not local_metadata["metadata"]: + console.print() + console.print("[yellow]No embedded metadata found in safetensor file.[/yellow]") + return + + metadata = local_metadata["metadata"] + + # If specific keys requested, show them in full + if keys_filter: + for key in keys_filter: + if key in metadata: + console.print(f"[cyan]{key}[/cyan]: {metadata[key]}") + else: + console.print(f"[yellow]{key}: not found[/yellow]") + return + + # Find the longest key to set column width + all_keys = list(metadata.keys()) + key_width = max(len(k) for k in all_keys) if all_keys else 20 + + # Value width: terminal minus key column and table borders (7 chars) + terminal_width = console.size.width + value_width = terminal_width - key_width - 7 + + meta_table = Table( + title="Safetensor Metadata", + show_header=True, + header_style="bold magenta", + ) + meta_table.add_column("Key", style="cyan", width=key_width, no_wrap=True) + meta_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") + + for key, value in sorted(metadata.items()): + meta_table.add_row(key, str(value)) + + console.print() + console.print(meta_table) + + +def _build_civitai_table(console: Console) -> tuple[Table, int]: + """Build CivitAI info table with proper column widths.""" + prop_width = 14 + terminal_width = console.size.width + overhead = 7 + value_width = max(40, terminal_width - prop_width - overhead) + + table = Table(title="CivitAI Model Information", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) + table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") + + return table, value_width + + +def display_civitai_data(civitai_data: dict[str, Any] | None, console: Console) -> None: + """Display CivitAI model information table.""" + if not civitai_data: + console.print() + console.print("[yellow]Model not found on CivitAI.[/yellow]") + return + + civit_table, _ = _build_civitai_table(console) + + civit_table.add_row("Model ID", str(civitai_data.get("modelId", "N/A"))) + civit_table.add_row("Version ID", str(civitai_data.get("id", "N/A"))) + civit_table.add_row("Version Name", str(civitai_data.get("name", "N/A"))) + civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A"))) + civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A"))) + + trained_words: list[str] = civitai_data.get("trainedWords", []) + if trained_words: + civit_table.add_row("Trigger Words", ", ".join(trained_words)) + + download_url = str(civitai_data.get("downloadUrl", "N/A")) + civit_table.add_row("Download URL", download_url) + + files: list[dict[str, Any]] = civitai_data.get("files", []) + for f in files: + if f.get("primary"): + civit_table.add_row("Primary File", str(f.get("name", "N/A"))) + civit_table.add_row("File Size", _format_size(f.get("sizeKB", 0))) + meta: dict[str, Any] = f.get("metadata", {}) + if meta: + civit_table.add_row("Format", str(meta.get("format", "N/A"))) + civit_table.add_row("Precision", str(meta.get("fp", "N/A"))) + civit_table.add_row("Size Type", str(meta.get("size", "N/A"))) + + console.print() + console.print(civit_table) + + model_id = civitai_data.get("modelId") + if model_id: + console.print() + console.print(f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}") + + +def _build_model_table(console: Console) -> Table: + """Build model info table with proper column widths.""" + prop_width = 10 + terminal_width = console.size.width + overhead = 7 + value_width = max(40, terminal_width - prop_width - overhead) + + table = Table(title="Model Information", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) + table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") + + return table + + +def _add_model_basic_info(table: Table, model_data: dict[str, Any]) -> None: + """Add basic model info rows to table.""" + table.add_row("ID", str(model_data.get("id", "N/A"))) + table.add_row("Name", str(model_data.get("name", "N/A"))) + table.add_row("Type", str(model_data.get("type", "N/A"))) + table.add_row("NSFW", str(model_data.get("nsfw", False))) + + creator = model_data.get("creator", {}) + if creator: + table.add_row("Creator", str(creator.get("username", "N/A"))) + + tags: list[str] = model_data.get("tags", []) + if tags: + table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else "")) + + stats: dict[str, Any] = model_data.get("stats", {}) + if stats: + table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}") + table.add_row("Likes", f"{stats.get('thumbsUpCount', 0):,}") + + mode = model_data.get("mode") + if mode: + table.add_row("Status", str(mode)) + + +def _build_versions_table(console: Console) -> Table: + """Build model versions table with proper column widths.""" + id_width = 7 + base_width = 20 + created_width = 10 + size_width = 8 + + terminal_width = console.size.width + fixed_width = id_width + base_width + created_width + size_width + overhead = 20 + remaining = max(40, terminal_width - fixed_width - overhead) + name_width = remaining // 3 + file_width = remaining - name_width + + table = Table(title="Model Versions", show_header=True, header_style="bold magenta") + table.add_column("ID", style="cyan", width=id_width, no_wrap=True) + table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis") + table.add_column("Base Model", style="yellow", width=base_width, no_wrap=True, overflow="ellipsis") + table.add_column("Created", style="blue", width=created_width, no_wrap=True) + table.add_column("Filename", style="white", width=file_width, no_wrap=True, overflow="ellipsis") + table.add_column("Size", justify="right", width=size_width, no_wrap=True) + + return table + + +def _add_version_rows(table: Table, versions: list[dict[str, Any]]) -> None: + """Add version rows to versions table.""" + for ver in versions: + files: list[dict[str, Any]] = ver.get("files", []) + primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) + filename = "N/A" + size = "N/A" + if primary_file: + filename = primary_file.get("name", "N/A") + size = _format_size(primary_file.get("sizeKB", 0)) + + created = str(ver.get("createdAt", "N/A"))[:10] + table.add_row( + str(ver.get("id", "N/A")), + str(ver.get("name", "N/A")), + str(ver.get("baseModel", "N/A")), + created, + filename, + size, + ) + + +def display_model_info(model_data: dict[str, Any], console: Console) -> None: + """Display full CivitAI model information.""" + model_table = _build_model_table(console) + _add_model_basic_info(model_table, model_data) + + console.print() + console.print(model_table) + + versions: list[dict[str, Any]] = model_data.get("modelVersions", []) + if versions: + ver_table = _build_versions_table(console) + _add_version_rows(ver_table, versions) + console.print() + console.print(ver_table) + + model_id = model_data.get("id") + if model_id: + console.print() + console.print(f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}") + + +def _build_search_table(console: Console) -> Table: + """Build search results table with proper column widths.""" + id_width = 7 + type_width = 16 + base_width = 20 + size_width = 8 + dls_width = 6 + likes_width = 6 + + terminal_width = console.size.width + fixed_width = id_width + type_width + base_width + size_width + dls_width + likes_width + overhead = 23 + name_width = max(20, terminal_width - fixed_width - overhead) + + table = Table(show_header=True, header_style="bold magenta") + table.add_column("ID", style="cyan", justify="right", width=id_width, no_wrap=True) + table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis") + table.add_column("Type", style="yellow", width=type_width, no_wrap=True) + table.add_column("Base", style="blue", width=base_width, no_wrap=True, overflow="ellipsis") + table.add_column("Size", justify="right", width=size_width, no_wrap=True) + table.add_column("DLs", justify="right", width=dls_width, no_wrap=True) + table.add_column("Likes", justify="right", width=likes_width, no_wrap=True) + + return table + + +def _add_search_rows(table: Table, items: list[dict[str, Any]]) -> None: + """Add search result rows to table.""" + for model in items: + model_id = str(model.get("id", "")) + name = model.get("name", "N/A") + model_type = model.get("type", "N/A") + + versions = model.get("modelVersions", []) + base_model = "N/A" + size = "N/A" + if versions: + latest = versions[0] + base_model = latest.get("baseModel", "N/A") + files = latest.get("files", []) + primary = next((f for f in files if f.get("primary")), files[0] if files else None) + if primary: + size = _format_size(primary.get("sizeKB", 0)) + + stats = model.get("stats", {}) + downloads = _format_count(stats.get("downloadCount", 0)) + likes = _format_count(stats.get("thumbsUpCount", 0)) + + table.add_row(model_id, name, model_type, base_model, size, downloads, likes) + + +def display_search_results(results: dict[str, Any], console: Console) -> None: + """Display search results in a table.""" + items = results.get("items", []) + if not items: + console.print("[yellow]No results found.[/yellow]") + return + + table = _build_search_table(console) + _add_search_rows(table, items) + + console.print() + console.print(table) + + metadata = results.get("metadata", {}) + total = metadata.get("totalItems", len(items)) + console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]") + console.print("[dim]Use 'tsr get ' to view details or 'tsr dl -m ' to download[/dim]") diff --git a/tensors/safetensor.py b/tensors/safetensor.py new file mode 100644 index 0000000..899ca38 --- /dev/null +++ b/tensors/safetensor.py @@ -0,0 +1,92 @@ +"""Safetensor file reading functions.""" + +from __future__ import annotations + +import hashlib +import json +import struct +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pathlib import Path + +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeRemainingColumn, + TransferSpeedColumn, +) + +if TYPE_CHECKING: + from rich.console import Console + + +def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: + """Read metadata from a safetensor file header.""" + with file_path.open("rb") as f: + # First 8 bytes are the header size (little-endian u64) + header_size_bytes = f.read(8) + if len(header_size_bytes) < 8: + raise ValueError("Invalid safetensor file: too short") + + header_size = struct.unpack(" 100_000_000: # 100MB sanity check + raise ValueError(f"Invalid header size: {header_size}") + + header_bytes = f.read(header_size) + if len(header_bytes) < header_size: + raise ValueError("Invalid safetensor file: header truncated") + + header: dict[str, Any] = json.loads(header_bytes.decode("utf-8")) + + # Extract __metadata__ if present + metadata: dict[str, Any] = header.get("__metadata__", {}) + + # Count tensors (keys that aren't __metadata__) + tensor_count = sum(1 for k in header if k != "__metadata__") + + return { + "metadata": metadata, + "tensor_count": tensor_count, + "header_size": header_size, + } + + +def compute_sha256(file_path: Path, console: Console) -> str: + """Compute SHA256 hash of a file with progress display.""" + file_size = file_path.stat().st_size + sha256 = hashlib.sha256() + chunk_size = 1024 * 1024 * 8 # 8MB chunks + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + DownloadColumn(), + TransferSpeedColumn(), + TimeRemainingColumn(), + console=console, + ) as progress: + task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size) + + with file_path.open("rb") as f: + while chunk := f.read(chunk_size): + sha256.update(chunk) + progress.update(task, advance=len(chunk)) + + return sha256.hexdigest().upper() + + +def get_base_name(file_path: Path) -> str: + """Get base filename without .safetensors extension.""" + name = file_path.name + for ext in (".safetensors", ".sft"): + if name.lower().endswith(ext): + return name[: -len(ext)] + return file_path.stem diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 9672d6c..134350d 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -7,13 +7,9 @@ from pathlib import Path import pytest -import tensors -from tensors import ( - get_base_name, - get_default_output_path, - load_api_key, - read_safetensor_metadata, -) +from tensors import config +from tensors.config import get_default_output_path, load_api_key +from tensors.safetensor import get_base_name, read_safetensor_metadata class TestReadSafetensorMetadata: @@ -111,28 +107,24 @@ class TestLoadApiKey: """Test that None is returned when no key is available.""" monkeypatch.delenv("CIVITAI_API_KEY", raising=False) # Point config and legacy files to nonexistent paths - monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") - monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent") + monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") + monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent") assert load_api_key() is None - def test_returns_key_from_config_file( - self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path - ) -> None: + def test_returns_key_from_config_file(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: """Test that key is loaded from TOML config file.""" monkeypatch.delenv("CIVITAI_API_KEY", raising=False) config_file = tmp_path / "config.toml" config_file.write_text('[api]\ncivitai_key = "key-from-config"\n') - monkeypatch.setattr(tensors, "CONFIG_FILE", config_file) - monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent") + monkeypatch.setattr(config, "CONFIG_FILE", config_file) + monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent") assert load_api_key() == "key-from-config" - def test_returns_key_from_legacy_file( - self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path - ) -> None: + def test_returns_key_from_legacy_file(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: """Test that key is loaded from legacy RC file when no config exists.""" monkeypatch.delenv("CIVITAI_API_KEY", raising=False) legacy_file = tmp_path / ".sftrc" legacy_file.write_text("legacy-key") - monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") - monkeypatch.setattr(tensors, "LEGACY_RC_FILE", legacy_file) + monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") + monkeypatch.setattr(config, "LEGACY_RC_FILE", legacy_file) assert load_api_key() == "legacy-key"