Source code for stressnet.models._registry

"""Registry of pre-trained StressNET model weights hosted on Google Drive.

Each entry maps a model name to its download metadata.  To register a new
model, add an entry to ``_REGISTRY`` following the same structure.

Fields
------
gdrive_id : str
    The Google Drive file ID extracted from the sharing URL
    ``https://drive.google.com/file/d/<gdrive_id>/view``.
filename : str
    Local filename used when saving to the cache directory.
sha256 : str | None
    Hex-encoded SHA-256 digest of the file.  Set to ``None`` to skip
    integrity checking (not recommended for published weights).
description : str
    Short human-readable description shown by ``list_models()``.
points_per_edge : int
    Value of ``points_per_edge`` expected by ``build_model`` for this
    checkpoint.
"""

from __future__ import annotations

from dataclasses import dataclass


@dataclass(frozen=True)
class ModelEntry:
    gdrive_id: str
    filename: str
    description: str
    points_per_edge: int
    sha256: str


_REGISTRY: dict[str, ModelEntry] = {
    'STRESSNET-PRETRAINED-A': ModelEntry(
        gdrive_id='1Uk5FGMctLNePXhHX_dfWZq-jOUW5OMIJ',
        filename='StressNET_weights_pretrained_A.h5',
        sha256='fd133b6e7fa7569db0ac9e79fc7d10ff2a91f7a915d694499cb4e686460a2502',
        description='StressNET pretrained model (replica A)',
        points_per_edge=9,
    ),
    'STRESSNET-PRETRAINED-B': ModelEntry(
        gdrive_id='1n_ZMn4QFInO6DS5MEeuOEo4dk9pyU2iI',
        filename='StressNET_weights_pretrained_B.h5',
        sha256='abd78a6fc9401d057fd8319a4c0bd57775cf7e401f89706d4ca993cae379cc2c',
        description='StressNET pretrained model (replica B)',
        points_per_edge=9,
    ),
    'STRESSNET-PRETRAINED-C': ModelEntry(
        gdrive_id='1SvNh4FJGBDyOIA1JYrFXzHfgVDFRqmv-',
        filename='StressNET_weights_pretrained_C.h5',
        sha256='20a8202b4722c2eca193374517afa873a06ab2658897a72e5e71c79f39ffed47',
        description='StressNET pretrained model (replica C)',
        points_per_edge=9,
    ),
    'STRESSNET-PRETRAINED-D': ModelEntry(
        gdrive_id='16Fo4uctlJhQoh12fapp5hifgvcx9cJEK',
        filename='StressNET_weights_pretrained_D.h5',
        sha256='399c6abad1c25ca55dea51397e99de3bc95a1624fce831f43f74ae973b19aca3',
        description='StressNET pretrained model (replica D)',
        points_per_edge=9,
    ),
    'STRESSNET-PRETRAINED-E': ModelEntry(
        gdrive_id='1JDzz4K7A7TqwZlsLObEMcvsWkTUfMpn5',
        filename='StressNET_weights_pretrained_E.h5',
        sha256='0018a81a55e88a47c0537526543d54526ea6bff7734b64237329e571e7641d3a',
        description='StressNET pretrained model (replica E)',
        points_per_edge=9,
    ),
    'STRESSNET-FINETUNED-MYOSIN-88-A': ModelEntry(
        gdrive_id='1PEiscpCaeLdVjR6qE2IsVVEzFFUf0PkI',
        filename='StressNET_weights_finetuned_myosin_88_A.h5',
        sha256='b44255964a80d9e211e7703dc8ff4a53f99ec084ddb90df0cb97cb81a8779bd8',
        description='StressNET model (replica A) finetuned on 88 myosin samples',
        points_per_edge=9,
    ),
}


def get_entry(name: str
              ) -> ModelEntry:
    """Return the registry entry for *name*, raising ``KeyError`` if absent."""
    name = name.upper()
    if name not in _REGISTRY:
        available = ', '.join(_REGISTRY) or '(none registered yet)'
        raise KeyError(f'Unknown model {name!r}. Available models: {available}')
    return _REGISTRY[name]


[docs] def list_models() -> list[str]: """Return the names of all registered models.""" return list(_REGISTRY.keys())