Source code for stressnet.models._downloader

"""Download and cache pre-trained StressNET weights from Google Drive."""

from __future__ import annotations

import hashlib
import logging
from pathlib import Path

import gdown
from platformdirs import user_cache_dir

from ._registry import ModelEntry, get_entry

logger = logging.getLogger(__name__)

_CACHE_DIR = Path(user_cache_dir('stressnet', appauthor=False)) / 'models'


def _cache_dir() -> Path:
    """Return (and create if needed) the local model cache directory."""
    _CACHE_DIR.mkdir(parents=True, exist_ok=True)
    return _CACHE_DIR


def _sha256(path: Path, chunk: int = 1 << 20) -> str:
    h = hashlib.sha256()
    with path.open('rb') as fh:
        while data := fh.read(chunk):
            h.update(data)
    return h.hexdigest()


def _verify(path: Path, expected: str) -> None:
    actual = _sha256(path)
    if actual != expected:
        path.unlink(missing_ok=True)
        raise RuntimeError(
            f'SHA-256 mismatch for {path.name}.\n'
            f'  expected : {expected}\n'
            f'  got      : {actual}\n'
            'The file has been removed. Re-run to download again.'
        )


def _gdrive_url(file_id: str) -> str:
    return f'https://drive.google.com/uc?id={file_id}'


def download_model(name: str, *, force: bool = False) -> Path:
    """Download model weights by *name* and return the local cache path.

    The file is only downloaded when it is absent or *force* is ``True``.
    After downloading, the SHA-256 checksum is verified when one is present
    in the registry entry.

    Parameters
    ----------
    name:
        Model name as registered in ``_registry._REGISTRY``.
    force:
        Re-download even if the file already exists in the cache.

    Returns
    -------
    Path
        Absolute path to the cached weights file.
    """
    entry: ModelEntry = get_entry(name)
    dest = _cache_dir() / entry.filename

    if dest.exists() and not force:
        logger.debug('Model %r already cached at %s', name, dest)
        if entry.sha256:
            _verify(dest, entry.sha256)
        return dest

    logger.info('Downloading model %r from Google Drive…', name)
    url = _gdrive_url(entry.gdrive_id)
    gdown.download(url, str(dest), quiet=False)

    if not dest.exists():
        raise RuntimeError(
            f'Download of model {name!r} failed: file not found at {dest}.'
        )

    if entry.sha256:
        logger.info('Verifying checksum for %r…', name)
        _verify(dest, entry.sha256)

    logger.info('Model %r saved to %s', name, dest)
    return dest


[docs] def get_model_path(name: str) -> Path: """Return the cached path for *name*, downloading it if necessary.""" return download_model(name)
[docs] def clear_cache(name: str | None = None) -> None: """Delete cached weights. Parameters ---------- name: Model name to remove. Pass ``None`` to clear the entire cache. """ if name is not None: entry = get_entry(name) target = _cache_dir() / entry.filename if target.exists(): target.unlink() logger.info('Removed cached model %r (%s)', name, target) else: import shutil shutil.rmtree(_CACHE_DIR, ignore_errors=True) logger.info('Cleared entire StressNET model cache at %s', _CACHE_DIR)