Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions photomap/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ class Config(BaseModel):
"placed in. None means Uncategorized."
),
)
encoder_idle_timeout_seconds: float = Field(
default=30.0,
description=(
"Seconds of search inactivity before the cached encoder is moved "
"from GPU VRAM to system RAM. Reload from RAM on the next query is "
"fast (no disk/network fetch). Set to 0 to disable offloading."
),
)

@field_validator("encoder_idle_timeout_seconds")
@classmethod
def validate_encoder_idle_timeout_seconds(cls, v: float) -> float:
if v < 0:
raise ValueError("encoder_idle_timeout_seconds must be >= 0")
return float(v)

@field_validator("config_version")
@classmethod
Expand Down Expand Up @@ -180,6 +195,7 @@ def to_dict(self) -> dict[str, Any]:
"invokeai_username": self.invokeai_username,
"invokeai_password": self.invokeai_password,
"invokeai_board_id": self.invokeai_board_id,
"encoder_idle_timeout_seconds": self.encoder_idle_timeout_seconds,
}


Expand Down Expand Up @@ -291,6 +307,11 @@ def load_config(self) -> Config:
for key, album_data in config_data.get("albums", {}).items():
albums[key] = Album.from_dict(key, album_data)

extra: dict[str, Any] = {}
if "encoder_idle_timeout_seconds" in config_data:
extra["encoder_idle_timeout_seconds"] = config_data[
"encoder_idle_timeout_seconds"
]
self._config = Config(
config_version=config_data.get("config_version", "1.0.0"),
albums=albums,
Expand All @@ -299,6 +320,7 @@ def load_config(self) -> Config:
invokeai_username=config_data.get("invokeai_username"),
invokeai_password=config_data.get("invokeai_password"),
invokeai_board_id=config_data.get("invokeai_board_id"),
**extra,
)

except Exception as e:
Expand Down
248 changes: 206 additions & 42 deletions photomap/backend/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@

from __future__ import annotations

import logging
import math
import threading
import time
from abc import ABC, abstractmethod

import numpy as np
import torch
from PIL import Image

logger = logging.getLogger(__name__)

# Default encoder for *new* albums. OpenCLIP-DFN ViT-L-14 is the best
# general-purpose pick across our three backends: noticeably stronger recall
# than legacy CLIP, with CLIP-style cosine semantics that work well on
Expand Down Expand Up @@ -98,6 +102,68 @@ def close(self) -> None: # noqa: B027
Default no-op so subclasses without weights to free aren't forced to override.
"""

# --- Offload / reload --------------------------------------------------
# Encoders cached for repeated search use can sit idle for a long time
# while still pinning multiple GB of VRAM. ``offload()`` moves the model
# weights to host RAM so other GPU workloads can run; the next call to an
# ``encode_*`` method transparently reloads them via ``_ensure_on_device``.
# Reload from RAM is sub-second — much cheaper than the initial Hub fetch
# and tensor allocation that ``build_encoder`` pays.

def _device_lock(self) -> threading.RLock:
"""Per-instance reentrant lock guarding offload/reload + encode.

Lazily created so subclasses that don't call ``super().__init__()``
(all of the bundled ones) still get a working lock.
"""
lock = self.__dict__.get("_offload_lock")
if lock is None:
lock = threading.RLock()
self.__dict__["_offload_lock"] = lock
return lock

@property
def is_offloaded(self) -> bool:
return bool(self.__dict__.get("_offloaded", False))

def offload(self) -> None:
"""Move model weights from VRAM to host RAM.

No-op when already offloaded, when running on CPU (nothing to free),
or when the subclass holds no ``_model`` attribute. Safe to call
concurrently with encode calls — the encode wrappers take the same
reentrant lock and will not be interrupted.
"""
if not self.device.startswith("cuda"):
return
with self._device_lock():
if self.is_offloaded:
return
model = getattr(self, "_model", None)
if model is None:
return
try:
model.to("cpu")
except Exception:
logger.exception("Failed to offload encoder %s", self.model_id)
return
self.__dict__["_offloaded"] = True
_free_cuda(self.device)
logger.info("Offloaded encoder %s from %s to cpu", self.model_id, self.device)

def _ensure_on_device(self) -> None:
"""If offloaded, move model weights back to ``self.device`` for the next encode."""
if not self.is_offloaded:
return
with self._device_lock():
if not self.is_offloaded:
return
model = getattr(self, "_model", None)
if model is not None:
model.to(self.device)
self.__dict__["_offloaded"] = False
logger.info("Reloaded encoder %s onto %s", self.model_id, self.device)


class OpenAIClipEncoder(ImageTextEncoder):
"""Original OpenAI CLIP via the ``clip`` package — preserves legacy behavior."""
Expand Down Expand Up @@ -126,17 +192,21 @@ def __init__(

@torch.no_grad()
def encode_images(self, images: list[Image.Image]) -> np.ndarray:
batch = torch.stack(
[self._preprocess(img.convert("RGB")) for img in images]
).to(self.device)
feats = self._model.encode_image(batch)
return _normalize(feats)
with self._device_lock():
self._ensure_on_device()
batch = torch.stack(
[self._preprocess(img.convert("RGB")) for img in images]
).to(self.device)
feats = self._model.encode_image(batch)
return _normalize(feats)

@torch.no_grad()
def encode_text(self, texts: list[str]) -> np.ndarray:
tokens = self._clip.tokenize(texts, truncate=True).to(self.device)
feats = self._model.encode_text(tokens)
return _normalize(feats)
with self._device_lock():
self._ensure_on_device()
tokens = self._clip.tokenize(texts, truncate=True).to(self.device)
feats = self._model.encode_text(tokens)
return _normalize(feats)

def close(self) -> None:
if hasattr(self, "_model"):
Expand Down Expand Up @@ -174,17 +244,21 @@ def __init__(

@torch.no_grad()
def encode_images(self, images: list[Image.Image]) -> np.ndarray:
batch = torch.stack(
[self._preprocess(img.convert("RGB")) for img in images]
).to(self.device)
feats = self._model.encode_image(batch)
return _normalize(feats)
with self._device_lock():
self._ensure_on_device()
batch = torch.stack(
[self._preprocess(img.convert("RGB")) for img in images]
).to(self.device)
feats = self._model.encode_image(batch)
return _normalize(feats)

@torch.no_grad()
def encode_text(self, texts: list[str]) -> np.ndarray:
tokens = self._tokenizer(texts).to(self.device)
feats = self._model.encode_text(tokens)
return _normalize(feats)
with self._device_lock():
self._ensure_on_device()
tokens = self._tokenizer(texts).to(self.device)
feats = self._model.encode_text(tokens)
return _normalize(feats)

def close(self) -> None:
for attr in ("_model", "_preprocess", "_tokenizer"):
Expand Down Expand Up @@ -236,37 +310,41 @@ def __init__(

@torch.no_grad()
def encode_images(self, images: list[Image.Image]) -> np.ndarray:
inputs = self._processor(
images=[img.convert("RGB") for img in images], return_tensors="pt"
).to(self.device)
feats = self._model.get_image_features(**inputs)
return _normalize(_unwrap_pooled(feats))
with self._device_lock():
self._ensure_on_device()
inputs = self._processor(
images=[img.convert("RGB") for img in images], return_tensors="pt"
).to(self.device)
feats = self._model.get_image_features(**inputs)
return _normalize(_unwrap_pooled(feats))

@torch.no_grad()
def encode_text(self, texts: list[str]) -> np.ndarray:
if self.use_ensembling:
# Prompt ensembling: encode each input wrapped in every template,
# L2-normalize each per-template embedding so longer phrasings
# can't dominate via larger magnitudes, then mean-pool across
# templates and re-normalize. Standard zero-shot CLIP/SigLIP
# practice.
n_templates = len(SIGLIP_PROMPT_TEMPLATES)
expanded = [
tpl.format(t) for t in texts for tpl in SIGLIP_PROMPT_TEMPLATES
]
with self._device_lock():
self._ensure_on_device()
if self.use_ensembling:
# Prompt ensembling: encode each input wrapped in every template,
# L2-normalize each per-template embedding so longer phrasings
# can't dominate via larger magnitudes, then mean-pool across
# templates and re-normalize. Standard zero-shot CLIP/SigLIP
# practice.
n_templates = len(SIGLIP_PROMPT_TEMPLATES)
expanded = [
tpl.format(t) for t in texts for tpl in SIGLIP_PROMPT_TEMPLATES
]
inputs = self._processor(
text=expanded, padding="max_length", truncation=True, return_tensors="pt"
).to(self.device)
feats = _unwrap_pooled(self._model.get_text_features(**inputs))
feats = feats / feats.norm(dim=-1, keepdim=True)
feats = feats.view(len(texts), n_templates, -1).mean(dim=1)
return _normalize(feats)

inputs = self._processor(
text=expanded, padding="max_length", truncation=True, return_tensors="pt"
text=texts, padding="max_length", truncation=True, return_tensors="pt"
).to(self.device)
feats = _unwrap_pooled(self._model.get_text_features(**inputs))
feats = feats / feats.norm(dim=-1, keepdim=True)
feats = feats.view(len(texts), n_templates, -1).mean(dim=1)
return _normalize(feats)

inputs = self._processor(
text=texts, padding="max_length", truncation=True, return_tensors="pt"
).to(self.device)
feats = self._model.get_text_features(**inputs)
return _normalize(_unwrap_pooled(feats))
feats = self._model.get_text_features(**inputs)
return _normalize(_unwrap_pooled(feats))

def calibrate_similarity(self, cosines: np.ndarray) -> np.ndarray:
"""Apply SigLIP's learned sigmoid calibration.
Expand Down Expand Up @@ -371,6 +449,7 @@ def build_encoder(
# leave eviction to ``clear_encoder_cache`` since the working set is small —
# typically one encoder per album in active use.
_search_encoder_cache: dict[tuple[str, str | None], ImageTextEncoder] = {}
_search_encoder_last_access: dict[tuple[str, str | None], float] = {}
_search_encoder_lock = threading.Lock()


Expand All @@ -386,6 +465,10 @@ def get_cached_encoder(
calls return the same instance. The caller MUST NOT call ``encoder.close()``
on the result — eviction is the cache's responsibility via
:func:`clear_encoder_cache`.

Each call refreshes the entry's idle timestamp so the background watcher
started by :func:`start_idle_watcher` won't offload an encoder that's
actively being queried.
"""
resolved_spec = spec or DEFAULT_ENCODER_SPEC
key = (resolved_spec, cache_dir)
Expand All @@ -394,6 +477,7 @@ def get_cached_encoder(
if encoder is None:
encoder = build_encoder(resolved_spec, cache_dir=cache_dir, device=device)
_search_encoder_cache[key] = encoder
_search_encoder_last_access[key] = time.monotonic()
return encoder


Expand All @@ -403,3 +487,83 @@ def clear_encoder_cache() -> None:
for encoder in _search_encoder_cache.values():
encoder.close()
_search_encoder_cache.clear()
_search_encoder_last_access.clear()


# --- Idle watcher ----------------------------------------------------------
# A single daemon thread monitors ``_search_encoder_last_access`` and offloads
# any cached encoder that's gone untouched for ``timeout`` seconds. The watcher
# is opt-in (started from the FastAPI lifespan) so CLI tools and tests aren't
# burdened with a background thread they don't need.

_idle_watcher_thread: threading.Thread | None = None
_idle_watcher_stop = threading.Event()


def _idle_watcher_loop(timeout_seconds: float, poll_interval: float) -> None:
while not _idle_watcher_stop.is_set():
# Event.wait returns True if set during the wait — use it to exit
# promptly on shutdown rather than sleeping out the full interval.
if _idle_watcher_stop.wait(poll_interval):
return
now = time.monotonic()
# Snapshot under the lock; release it before calling .offload() so a
# slow GPU-CPU transfer can't block new search queries from grabbing
# the cache lock.
with _search_encoder_lock:
stale = [
(key, _search_encoder_cache[key])
for key, ts in _search_encoder_last_access.items()
if key in _search_encoder_cache and now - ts >= timeout_seconds
]
for _key, encoder in stale:
if encoder.is_offloaded:
continue
try:
encoder.offload()
except Exception:
logger.exception("Idle watcher failed to offload encoder")


def start_idle_watcher(timeout_seconds: float, poll_interval: float | None = None) -> None:
"""Start the background thread that offloads idle search encoders.

``timeout_seconds`` is the inactivity threshold. ``0`` disables the
watcher entirely. ``poll_interval`` defaults to ``min(timeout/2, 5.0)`` so
a stale encoder is detected within roughly one half-life of the timeout
without busy-waking on tight schedules.

Idempotent: a second call replaces the running watcher with one that
honours the new timeout.
"""
global _idle_watcher_thread
stop_idle_watcher()
if timeout_seconds <= 0:
return
interval = poll_interval if poll_interval is not None else min(timeout_seconds / 2, 5.0)
interval = max(interval, 0.1)
_idle_watcher_stop.clear()
thread = threading.Thread(
target=_idle_watcher_loop,
args=(timeout_seconds, interval),
name="encoder-idle-watcher",
daemon=True,
)
thread.start()
_idle_watcher_thread = thread
logger.info(
"Encoder idle watcher started (timeout=%.1fs, poll=%.2fs)",
timeout_seconds,
interval,
)


def stop_idle_watcher() -> None:
"""Signal the idle watcher to stop and wait for it to exit."""
global _idle_watcher_thread
thread = _idle_watcher_thread
if thread is None:
return
_idle_watcher_stop.set()
thread.join(timeout=5.0)
_idle_watcher_thread = None
Loading
Loading