Source code for stressnet.models

"""Pre-trained model registry and downloader for StressNET.

Quick start
-----------
>>> import stressnet
>>> stressnet.list_models()          # show available model names
>>> model = stressnet.load_model()   # download + build

Lower-level access
------------------
>>> from stressnet.models import get_model_path, clear_cache, list_models
>>> path = get_model_path(list_models()[0])  # returns local Path, downloads if needed
>>> clear_cache(list_models()[0])            # remove cached weights for that name
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

from ._downloader import clear_cache, download_model, get_model_path
from ._registry import list_models

if TYPE_CHECKING:
    import tensorflow as tf


DEFAULT_MODEL = 'STRESSNET-PRETRAINED-A'


[docs] def load_model( name: str | None = None, *, force_download: bool = False, device: str | None = None, **net_kwargs: Any, ) -> tf.keras.Model: """Download (if needed) and return a ready-to-use StressNET model. Parameters ---------- name: Model name as listed by :func:`list_models`. force_download: Re-download the weights even if they are already cached. device: TensorFlow device string (e.g. ``'/GPU:0'``). Defaults to GPU when available, otherwise CPU. **net_kwargs: Extra keyword arguments forwarded to :func:`stressnet.build_stressnet`. Returns ------- tf.keras.Model Compiled model with weights loaded, ready for inference. """ from ..net.stressnet import get_model from ._registry import get_entry if name is None: name = DEFAULT_MODEL entry = get_entry(name) weights_path: Path = download_model(name, force=force_download) return get_model( weights_path=str(weights_path), points_per_edge=entry.points_per_edge, device=device, **net_kwargs, )
__all__: list[str] = [ 'list_models', 'get_model_path', 'download_model', 'clear_cache', 'load_model', ]