File size: 1,682 Bytes
9913c1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
from pathlib import Path
from typing import Optional


def _resolve_default_cache_path() -> Path:
    """Pick the best-available cache directory."""
    candidates = (
        os.environ.get("MEDRAX_MODEL_DIR"),
        os.environ.get("HF_HOME"),
        os.environ.get("TRANSFORMERS_CACHE"),
    )
    for candidate in candidates:
        if candidate:
            return Path(candidate).expanduser()
    return Path.home() / ".cache" / "medrax_hf"


def _ensure_subdirs(base: Path) -> None:
    """Create the common Hugging Face cache subdirectories."""
    for subdir in ("datasets", "hub"):
        (base / subdir).mkdir(parents=True, exist_ok=True)


_DEFAULT_CACHE_PATH = _resolve_default_cache_path()
_DEFAULT_CACHE_PATH.mkdir(parents=True, exist_ok=True)
_ensure_subdirs(_DEFAULT_CACHE_PATH)

# Expose a string form for convenience
DEFAULT_CACHE_DIR = str(_DEFAULT_CACHE_PATH)


def _ensure_hf_env_vars(base: Path) -> None:
    """Populate the typical HF cache env vars if the user did not set them."""
    os.environ.setdefault("MEDRAX_MODEL_DIR", str(base))
    os.environ.setdefault("HF_HOME", str(base))
    os.environ.setdefault("TRANSFORMERS_CACHE", str(base))
    os.environ.setdefault("HF_DATASETS_CACHE", str(base / "datasets"))
    os.environ.setdefault("HF_HUB_CACHE", str(base / "hub"))


_ensure_hf_env_vars(_DEFAULT_CACHE_PATH)


def resolve_cache_dir(cache_dir: Optional[str] = None) -> str:
    """Expand, create, and return a usable cache directory."""
    if cache_dir:
        path = Path(cache_dir).expanduser()
        path.mkdir(parents=True, exist_ok=True)
        _ensure_subdirs(path)
        return str(path)
    return DEFAULT_CACHE_DIR