windamir123 commited on
Commit
9f64b79
·
verified ·
1 Parent(s): b2cafbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py CHANGED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ FastAPI app to inspect Hugging Face transformer model sizing:
4
+ - Total & trainable parameter counts
5
+ - Approximate memory usage in bytes / human-readable
6
+ - Saved model disk size
7
+ - Basic model config info
8
+
9
+ To run locally:
10
+ pip install fastapi "uvicorn[standard]" transformers torch
11
+ uvicorn app:app --reload
12
+
13
+ Endpoints:
14
+ / → simple HTML form
15
+ /inspect?model=bert-base-uncased → JSON sizing info
16
+ """
17
+
18
+ import os
19
+ import math
20
+ import tempfile
21
+ import shutil
22
+ from typing import Optional
23
+
24
+ from fastapi import FastAPI, Query, HTTPException
25
+ from fastapi.responses import HTMLResponse, JSONResponse
26
+ from pydantic import BaseModel
27
+ from transformers import AutoModel, AutoConfig, AutoTokenizer
28
+ from transformers.utils import logging as hf_logging
29
+ import torch
30
+
31
+ # quiet transformers logs
32
+ hf_logging.set_verbosity_error()
33
+
34
+ app = FastAPI(title="Hugging Face Transformer Sizing API")
35
+
36
+
37
+ # ---------- Helper Functions ----------
38
+
39
+ def humanize_bytes(n: int) -> str:
40
+ """Convert bytes → human-readable format."""
41
+ if n < 1024:
42
+ return f"{n} B"
43
+ units = ["B", "KB", "MB", "GB", "TB"]
44
+ i = int(math.floor(math.log(n, 1024)))
45
+ return f"{n / (1024 ** i):.2f} {units[i]}"
46
+
47
+
48
+ def model_parameter_counts(model: torch.nn.Module):
49
+ """Return parameter counts and approximate memory usage."""
50
+ total, trainable, bytes_total, bytes_trainable = 0, 0, 0, 0
51
+ for p in model.parameters():
52
+ n = p.numel()
53
+ b = p.element_size() * n
54
+ total += n
55
+ bytes_total += b
56
+ if p.requires_grad:
57
+ trainable += n
58
+ bytes_trainable += b
59
+ return {
60
+ "total_params": total,
61
+ "trainable_params": trainable,
62
+ "approx_bytes": bytes_total,
63
+ "trainable_bytes": bytes_trainable,
64
+ "approx_bytes_human": humanize_bytes(bytes_total),
65
+ "trainable_bytes_human": humanize_bytes(bytes_trainable),
66
+ }
67
+
68
+
69
+ def folder_size_bytes(path: str) -> int:
70
+ """Return folder size in bytes."""
71
+ total = 0
72
+ for root, _, files in os.walk(path):
73
+ for f in files:
74
+ fp = os.path.join(root, f)
75
+ try:
76
+ total += os.path.getsize(fp)
77
+ except OSError:
78
+ pass
79
+ return total
80
+
81
+
82
+ # ---------- Pydantic Model ----------
83
+
84
+ class InspectResult(BaseModel):
85
+ model_id: str
86
+ model_class: str
87
+ config: dict
88
+ sizing: dict
89
+ saved_size_bytes: Optional[int]
90
+ saved_size_human: Optional[str]
91
+ notes: Optional[str]
92
+
93
+
94
+ # ---------- Routes ----------
95
+
96
+ @app.get("/", response_class=HTMLResponse)
97
+ def index():
98
+ """Simple web UI."""
99
+ html = """
100
+ <html>
101
+ <head><title>Transformer Sizing</title></head>
102
+ <body style="font-family:Arial; max-width:700px; margin:40px auto;">
103
+ <h2>Hugging Face Transformer Sizing</h2>
104
+ <form action="/inspect" method="get">
105
+ <label>Enter Model ID (e.g. bert-base-uncased):</label><br>
106
+ <input type="text" name="model" value="bert-base-uncased" style="width:70%; padding:6px;">
107
+ <button type="submit" style="padding:6px;">Inspect</button>
108
+ </form>
109
+ <p>Examples: <code>bert-base-uncased</code>, <code>roberta-base</code>, <code>distilbert-base-uncased</code></p>
110
+ <hr>
111
+ <p>Results will appear in JSON format.</p>
112
+ </body>
113
+ </html>
114
+ """
115
+ return HTMLResponse(html)
116
+
117
+
118
+ @app.get("/inspect", response_model=InspectResult)
119
+ def inspect(
120
+ model: str = Query(..., description="Model ID, e.g. bert-base-uncased"),
121
+ save_to_disk: bool = Query(True, description="Save to disk temporarily to get size (default True)")
122
+ ):
123
+ """Inspect model parameters, memory, and size."""
124
+ if not model:
125
+ raise HTTPException(status_code=400, detail="Missing model name.")
126
+
127
+ # --- Load config ---
128
+ try:
129
+ config = AutoConfig.from_pretrained(model)
130
+ except Exception as e:
131
+ raise HTTPException(status_code=400, detail=f"Could not load config: {e}")
132
+
133
+ # --- Load model safely to CPU ---
134
+ try:
135
+ model_obj = AutoModel.from_pretrained(model, config=config, torch_dtype=torch.float32).to("cpu")
136
+ except Exception as e:
137
+ raise HTTPException(status_code=500, detail=f"Could not load model: {e}")
138
+
139
+ sizing = model_parameter_counts(model_obj)
140
+
141
+ # --- Compute disk size ---
142
+ saved_size_bytes = None
143
+ saved_size_human = None
144
+ notes = ""
145
+ if save_to_disk:
146
+ try:
147
+ tmp = tempfile.mkdtemp(prefix="hf_")
148
+ model_obj.save_pretrained(tmp)
149
+ try:
150
+ tok = AutoTokenizer.from_pretrained(model)
151
+ tok.save_pretrained(tmp)
152
+ except Exception:
153
+ notes = "Tokenizer not saved."
154
+ saved_size_bytes = folder_size_bytes(tmp)
155
+ saved_size_human = humanize_bytes(saved_size_bytes)
156
+ finally:
157
+ shutil.rmtree(tmp, ignore_errors=True)
158
+
159
+ # --- Build config summary ---
160
+ summary = {}
161
+ for k in ("hidden_size", "num_hidden_layers", "vocab_size", "num_attention_heads", "intermediate_size"):
162
+ if hasattr(config, k):
163
+ summary[k] = getattr(config, k)
164
+
165
+ # --- Result ---
166
+ result = {
167
+ "model_id": model,
168
+ "model_class": model_obj.__class__.__name__,
169
+ "config": summary,
170
+ "sizing": sizing,
171
+ "saved_size_bytes": saved_size_bytes,
172
+ "saved_size_human": saved_size_human,
173
+ "notes": notes or None
174
+ }
175
+
176
+ # cleanup
177
+ del model_obj
178
+ if torch.cuda.is_available():
179
+ torch.cuda.empty_cache()
180
+
181
+ return JSONResponse(result)
182
+
183
+
184
+ # ---------- Local Run ----------
185
+ if __name__ == "__main__":
186
+ import uvicorn
187
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
188
+