Files
tensors/tensors/fragments.py
T
2026-05-18 00:20:10 +00:00

175 lines
6.3 KiB
Python

"""Generic prompt-fragment library.
A *fragment* is a named, ordered list of comma-style prompt elements
(e.g. ``["blond hair", "broad chin"]``) stored as a flat YAML list on disk.
Different *kinds* of fragments (characters, scenes, …) each get their own
subdirectory under ``~/.local/share/tensors/<kind>/`` and their own
``FragmentLibrary`` instance.
Files are written as JSON-encoded YAML scalars (``- "value"``) so they round-trip
through both ``json`` and any YAML parser. Hand-edited single-quoted or bare
scalars are accepted on read.
"""
from __future__ import annotations
import json
import re
from pathlib import Path # noqa: TC003 # used in runtime return annotations exposed to typer
from tensors.config import DATA_DIR
# Restrict fragment names to a safe subset so they can't escape the storage dir
# via path traversal and so file listings stay tidy.
_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
# Minimum length for a quoted YAML scalar: opening + closing quote.
_MIN_QUOTED_SCALAR_LEN = 2
class FragmentLibrary:
"""A named collection of prompt-fragment YAML files of a single *kind*.
Each ``FragmentLibrary`` is rooted at ``<DATA_DIR>/<kind>/`` (overridable for
tests). Instance methods are stateless wrappers around that directory.
"""
def __init__(self, kind: str, base_dir: Path | None = None) -> None:
"""Create a library for ``kind`` (e.g. ``"characters"`` or ``"scenes"``).
``base_dir`` defaults to ``DATA_DIR / kind`` and is recomputed lazily so
tests can monkeypatch ``DATA_DIR`` *or* the ``base_dir`` attribute
directly without re-importing.
"""
if not kind or not _NAME_RE.match(kind):
raise ValueError(f"Invalid library kind {kind!r}")
self.kind = kind
self.base_dir = base_dir if base_dir is not None else DATA_DIR / kind
# ---------- internals ----------
@property
def _singular(self) -> str:
"""Human-readable singular form of ``kind`` used in error messages."""
return self.kind[:-1] if self.kind.endswith("s") else self.kind
def _validate_name(self, name: str) -> None:
if not name or not _NAME_RE.match(name):
raise ValueError(f"Invalid {self._singular} name {name!r}: only letters, digits, '.', '_', '-' allowed")
def path(self, name: str) -> Path:
"""Return the on-disk path for ``name`` (without ensuring it exists)."""
self._validate_name(name)
return self.base_dir / f"{name}.yml"
# ---------- CRUD ----------
def save(self, name: str, elements: list[str]) -> Path:
"""Persist ``elements`` to disk and return the file path.
Overwrites any existing file. Each element is written on its own line as
a JSON-encoded YAML scalar so embedded commas, quotes and unicode are
safe.
"""
self.base_dir.mkdir(parents=True, exist_ok=True)
path = self.path(name)
body = "\n".join(f"- {json.dumps(e, ensure_ascii=False)}" for e in elements)
path.write_text(body + "\n" if body else "")
return path
def load(self, name: str) -> list[str]:
"""Load a fragment. Raises ``FileNotFoundError`` if missing.
Accepts JSON-quoted scalars (``- "value"``), single-quoted YAML scalars
(``- 'value'``) and bare scalars (``- value``). Blank lines and ``#``
comments are ignored.
"""
path = self.path(name)
if not path.is_file():
raise FileNotFoundError(f"{self._singular.capitalize()} {name!r} not found at {path}")
elements: list[str] = []
for raw in path.read_text().splitlines():
line = raw.strip()
if not line or line.startswith("#"):
continue
if not line.startswith("-"):
# Skip non-list lines (e.g. a YAML document header users might
# add manually); we only consume flat list entries.
continue
item = line[1:].lstrip()
if not item:
continue
try:
value = json.loads(item)
if not isinstance(value, str):
value = str(value)
except json.JSONDecodeError:
value = (
item[1:-1].replace("''", "'") if len(item) >= _MIN_QUOTED_SCALAR_LEN and item[0] == item[-1] == "'" else item
)
elements.append(value)
return elements
def list(self) -> list[str]:
"""Return sorted fragment names. Empty list if the dir doesn't exist yet."""
if not self.base_dir.is_dir():
return []
return sorted(p.stem for p in self.base_dir.glob("*.yml") if p.is_file())
def delete(self, name: str) -> bool:
"""Delete a fragment. Returns True on success, False if it was missing."""
path = self.path(name)
if not path.is_file():
return False
path.unlink()
return True
# ---------- helpers ----------
def resolve(
self,
*,
name: str | None = None,
inline: str | None = None,
extra: list[str] | None = None,
) -> list[str]:
"""Merge a named fragment with an inline CSV string and optional extras.
Resolution order (first match wins per duplicate): named → inline → extra.
Result preserves order and drops duplicates and empty pieces.
"""
merged: list[str] = []
seen: set[str] = set()
def _push(items: list[str]) -> None:
for item in items:
if item and item not in seen:
seen.add(item)
merged.append(item)
if name:
_push(self.load(name))
if inline:
_push(parse_elements(inline))
if extra:
_push(extra)
return merged
def parse_elements(value: str) -> list[str]:
"""Split a comma-separated prompt fragment into clean, order-preserving elements.
Empty pieces and duplicates are dropped. Shared by ``tsr character|scene save``
and the ``--character-prompt`` / ``--scene-prompt`` CLI flags so the splitting
semantics stay identical across surfaces.
"""
parts = [p.strip() for p in value.split(",")]
seen: set[str] = set()
out: list[str] = []
for p in parts:
if p and p not in seen:
seen.add(p)
out.append(p)
return out