Commit
Β·
7d10354
1
Parent(s):
603833c
Fix duplicate key error with session state
Browse files- indexer.py +6 -212
- main.py +65 -121
indexer.py
CHANGED
|
@@ -1,75 +1,24 @@
|
|
| 1 |
-
# app/indexer.py
|
| 2 |
-
# Day 6: Vector store & embeddings
|
| 3 |
-
# Usage examples:
|
| 4 |
-
# python app/indexer.py --input_dir ./data/outputs --db_type chroma --persist_dir ./data/vector_store
|
| 5 |
-
# python app/indexer.py --input_dir ./data/outputs --db_type faiss --persist_dir ./data/vector_store_faiss
|
| 6 |
-
|
| 7 |
import os
|
| 8 |
import json
|
| 9 |
import argparse
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import List, Dict, Tuple
|
| 12 |
-
from tqdm import tqdm
|
| 13 |
-
|
| 14 |
-
# Embeddings
|
| 15 |
from sentence_transformers import SentenceTransformer
|
| 16 |
-
|
| 17 |
-
# Vector stores
|
| 18 |
-
# Chroma
|
| 19 |
import chromadb
|
| 20 |
from chromadb.config import Settings as ChromaSettings
|
| 21 |
-
|
| 22 |
-
# FAISS
|
| 23 |
import faiss
|
| 24 |
import pickle
|
| 25 |
|
| 26 |
DEFAULT_CHUNK_TOKENS = 200
|
| 27 |
DEFAULT_OVERLAP_TOKENS = 50
|
| 28 |
|
| 29 |
-
def read_note_files(input_dir: str) -> List[Dict]:
|
| 30 |
-
"""
|
| 31 |
-
Reads de-identified notes from .txt or .json in input_dir.
|
| 32 |
-
Expects .json to have a 'text' field containing de-identified content.
|
| 33 |
-
Returns list of dicts: {id, text, section?}
|
| 34 |
-
"""
|
| 35 |
-
items = []
|
| 36 |
-
p = Path(input_dir)
|
| 37 |
-
if not p.exists():
|
| 38 |
-
raise FileNotFoundError(f"Input dir not found: {input_dir}")
|
| 39 |
-
|
| 40 |
-
for fp in p.glob("**/*"):
|
| 41 |
-
if fp.is_dir():
|
| 42 |
-
continue
|
| 43 |
-
if fp.suffix.lower() == ".txt":
|
| 44 |
-
text = fp.read_text(encoding="utf-8", errors="ignore").strip()
|
| 45 |
-
if text:
|
| 46 |
-
items.append({"id": fp.stem, "text": text, "section": None})
|
| 47 |
-
elif fp.suffix.lower() == ".json":
|
| 48 |
-
try:
|
| 49 |
-
obj = json.loads(fp.read_text(encoding="utf-8", errors="ignore"))
|
| 50 |
-
text = obj.get("text") or obj.get("deidentified_text") or ""
|
| 51 |
-
section = obj.get("section")
|
| 52 |
-
if text:
|
| 53 |
-
items.append({"id": fp.stem, "text": text.strip(), "section": section})
|
| 54 |
-
except Exception:
|
| 55 |
-
# Skip malformed
|
| 56 |
-
continue
|
| 57 |
-
return items
|
| 58 |
-
|
| 59 |
def approx_tokenize(text: str) -> List[str]:
|
| 60 |
-
"""
|
| 61 |
-
Approximate tokenization by splitting on whitespace.
|
| 62 |
-
For MVP this is fine; can replace with tiktoken later.
|
| 63 |
-
"""
|
| 64 |
return text.split()
|
| 65 |
|
| 66 |
def detokenize(tokens: List[str]) -> str:
|
| 67 |
return " ".join(tokens)
|
| 68 |
|
| 69 |
def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> List[str]:
|
| 70 |
-
"""
|
| 71 |
-
Simple sliding window chunking.
|
| 72 |
-
"""
|
| 73 |
tokens = approx_tokenize(text)
|
| 74 |
chunks = []
|
| 75 |
i = 0
|
|
@@ -86,38 +35,6 @@ def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> List[str]:
|
|
| 86 |
i = 0
|
| 87 |
return chunks
|
| 88 |
|
| 89 |
-
def embed_texts(model: SentenceTransformer, texts: List[str]):
|
| 90 |
-
return model.encode(texts, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
|
| 91 |
-
|
| 92 |
-
def build_chroma(persist_dir: str, collection_name: str = "notes"):
|
| 93 |
-
client = chromadb.PersistentClient(
|
| 94 |
-
path=persist_dir,
|
| 95 |
-
settings=ChromaSettings(allow_reset=True)
|
| 96 |
-
)
|
| 97 |
-
if collection_name in [c.name for c in client.list_collections()]:
|
| 98 |
-
coll = client.get_collection(collection_name)
|
| 99 |
-
else:
|
| 100 |
-
coll = client.create_collection(collection_name)
|
| 101 |
-
return client, coll
|
| 102 |
-
|
| 103 |
-
def save_faiss(index, vectors_meta: List[Dict], persist_dir: str):
|
| 104 |
-
os.makedirs(persist_dir, exist_ok=True)
|
| 105 |
-
faiss_path = os.path.join(persist_dir, "index.faiss")
|
| 106 |
-
meta_path = os.path.join(persist_dir, "meta.pkl")
|
| 107 |
-
faiss.write_index(index, faiss_path)
|
| 108 |
-
with open(meta_path, "wb") as f:
|
| 109 |
-
pickle.dump(vectors_meta, f)
|
| 110 |
-
|
| 111 |
-
def load_faiss(persist_dir: str):
|
| 112 |
-
faiss_path = os.path.join(persist_dir, "index.faiss")
|
| 113 |
-
meta_path = os.path.join(persist_dir, "meta.pkl")
|
| 114 |
-
if os.path.exists(faiss_path) and os.path.exists(meta_path):
|
| 115 |
-
index = faiss.read_index(faiss_path)
|
| 116 |
-
with open(meta_path, "rb") as f:
|
| 117 |
-
meta = pickle.load(f)
|
| 118 |
-
return index, meta
|
| 119 |
-
return None, []
|
| 120 |
-
|
| 121 |
def index_note(
|
| 122 |
text: str,
|
| 123 |
note_id: str = "temp_note",
|
|
@@ -126,35 +43,6 @@ def index_note(
|
|
| 126 |
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 127 |
collection: str = "notes"
|
| 128 |
) -> str:
|
| 129 |
-
from sentence_transformers import SentenceTransformer
|
| 130 |
-
import os
|
| 131 |
-
|
| 132 |
-
DEFAULT_CHUNK_TOKENS = 200
|
| 133 |
-
DEFAULT_OVERLAP_TOKENS = 50
|
| 134 |
-
|
| 135 |
-
def approx_tokenize(text: str):
|
| 136 |
-
return text.split()
|
| 137 |
-
|
| 138 |
-
def detokenize(tokens):
|
| 139 |
-
return " ".join(tokens)
|
| 140 |
-
|
| 141 |
-
def chunk_text(text, chunk_tokens, overlap_tokens):
|
| 142 |
-
tokens = approx_tokenize(text)
|
| 143 |
-
chunks = []
|
| 144 |
-
i = 0
|
| 145 |
-
n = len(tokens)
|
| 146 |
-
while i < n:
|
| 147 |
-
j = min(i + chunk_tokens, n)
|
| 148 |
-
chunk = detokenize(tokens[i:j])
|
| 149 |
-
if chunk.strip():
|
| 150 |
-
chunks.append(chunk)
|
| 151 |
-
if j == n:
|
| 152 |
-
break
|
| 153 |
-
i = j - overlap_tokens
|
| 154 |
-
if i < 0:
|
| 155 |
-
i = 0
|
| 156 |
-
return chunks
|
| 157 |
-
|
| 158 |
os.makedirs(persist_dir, exist_ok=True)
|
| 159 |
model = SentenceTransformer(model_name)
|
| 160 |
chunks = chunk_text(text, DEFAULT_CHUNK_TOKENS, DEFAULT_OVERLAP_TOKENS)
|
|
@@ -163,16 +51,15 @@ def index_note(
|
|
| 163 |
vectors = model.encode(chunks, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
|
| 164 |
|
| 165 |
if db_type == "chroma":
|
| 166 |
-
|
| 167 |
-
import chromadb
|
| 168 |
client = chromadb.PersistentClient(
|
| 169 |
path=persist_dir,
|
| 170 |
-
settings=ChromaSettings(
|
|
|
|
|
|
|
|
|
|
| 171 |
)
|
| 172 |
-
|
| 173 |
-
coll = client.get_collection(collection)
|
| 174 |
-
else:
|
| 175 |
-
coll = client.create_collection(collection)
|
| 176 |
coll.upsert(
|
| 177 |
ids=chunk_ids,
|
| 178 |
embeddings=vectors.tolist(),
|
|
@@ -180,8 +67,6 @@ def index_note(
|
|
| 180 |
metadatas=metadatas,
|
| 181 |
)
|
| 182 |
elif db_type == "faiss":
|
| 183 |
-
import faiss
|
| 184 |
-
import pickle
|
| 185 |
d = vectors.shape[1]
|
| 186 |
index = faiss.IndexFlatIP(d)
|
| 187 |
index.add(vectors)
|
|
@@ -196,94 +81,3 @@ def index_note(
|
|
| 196 |
pickle.dump(vectors_meta, f)
|
| 197 |
|
| 198 |
return note_id
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def main():
|
| 202 |
-
parser = argparse.ArgumentParser(description="Day 6: Build local vector DB from de-identified notes.")
|
| 203 |
-
parser.add_argument("--input_dir", required=True, help="Directory with de-identified notes (.txt or .json).")
|
| 204 |
-
parser.add_argument("--persist_dir", default="./data/vector_store", help="Where to persist the DB.")
|
| 205 |
-
parser.add_argument("--db_type", choices=["chroma", "faiss"], default="chroma", help="Vector DB type.")
|
| 206 |
-
parser.add_argument("--model_name", default="sentence-transformers/all-MiniLM-L6-v2", help="Embedding model.")
|
| 207 |
-
parser.add_argument("--chunk_tokens", type=int, default=DEFAULT_CHUNK_TOKENS, help="Approx tokens per chunk.")
|
| 208 |
-
parser.add_argument("--overlap_tokens", type=int, default=DEFAULT_OVERLAP_TOKENS, help="Token overlap.")
|
| 209 |
-
parser.add_argument("--collection", default="notes", help="Collection name (Chroma).")
|
| 210 |
-
args = parser.parse_args()
|
| 211 |
-
|
| 212 |
-
notes = read_note_files(args.input_dir)
|
| 213 |
-
if not notes:
|
| 214 |
-
print(f"No de-identified notes found in {args.input_dir}. Ensure Day 5 outputs exist.")
|
| 215 |
-
return
|
| 216 |
-
|
| 217 |
-
print(f"Loaded {len(notes)} de-identified notes from {args.input_dir}")
|
| 218 |
-
os.makedirs(args.persist_dir, exist_ok=True)
|
| 219 |
-
|
| 220 |
-
print(f"Loading embedding model: {args.model_name}")
|
| 221 |
-
model = SentenceTransformer(args.model_name)
|
| 222 |
-
|
| 223 |
-
all_chunk_texts = []
|
| 224 |
-
all_chunk_ids = []
|
| 225 |
-
all_metadata = []
|
| 226 |
-
|
| 227 |
-
print("Chunking notes...")
|
| 228 |
-
for note in tqdm(notes):
|
| 229 |
-
chunks = chunk_text(note["text"], args.chunk_tokens, args.overlap_tokens)
|
| 230 |
-
for idx, ch in enumerate(chunks):
|
| 231 |
-
cid = f"{note['id']}::chunk_{idx}"
|
| 232 |
-
all_chunk_texts.append(ch)
|
| 233 |
-
all_chunk_ids.append(cid)
|
| 234 |
-
all_metadata.append({
|
| 235 |
-
"note_id": note["id"],
|
| 236 |
-
"chunk_index": idx,
|
| 237 |
-
"section": note.get("section")
|
| 238 |
-
})
|
| 239 |
-
|
| 240 |
-
print(f"Total chunks: {len(all_chunk_texts)}")
|
| 241 |
-
|
| 242 |
-
print("Embedding chunks...")
|
| 243 |
-
vectors = embed_texts(model, all_chunk_texts)
|
| 244 |
-
|
| 245 |
-
if args.db_type == "chroma":
|
| 246 |
-
print("Building Chroma persistent collection...")
|
| 247 |
-
client, coll = build_chroma(args.persist_dir, args.collection)
|
| 248 |
-
|
| 249 |
-
# Upsert in manageable batches
|
| 250 |
-
batch = 512
|
| 251 |
-
for i in tqdm(range(0, len(all_chunk_texts), batch)):
|
| 252 |
-
j = min(i + batch, len(all_chunk_texts))
|
| 253 |
-
coll.upsert(
|
| 254 |
-
ids=all_chunk_ids[i:j],
|
| 255 |
-
embeddings=vectors[i:j].tolist(),
|
| 256 |
-
documents=all_chunk_texts[i:j],
|
| 257 |
-
metadatas=all_metadata[i:j],
|
| 258 |
-
)
|
| 259 |
-
print(f"Chroma collection '{args.collection}' persisted at {args.persist_dir}")
|
| 260 |
-
|
| 261 |
-
elif args.db_type == "faiss":
|
| 262 |
-
print("Building FAISS index...")
|
| 263 |
-
d = vectors.shape[1]
|
| 264 |
-
index = faiss.IndexFlatIP(d) # normalized vectors β use inner product as cosine
|
| 265 |
-
# Try to load existing
|
| 266 |
-
existing_index, existing_meta = load_faiss(args.persist_dir)
|
| 267 |
-
if existing_index is not None:
|
| 268 |
-
print("Appending to existing FAISS index...")
|
| 269 |
-
index = existing_index
|
| 270 |
-
vectors_meta = existing_meta
|
| 271 |
-
else:
|
| 272 |
-
vectors_meta = []
|
| 273 |
-
index.add(vectors)
|
| 274 |
-
vectors_meta.extend([
|
| 275 |
-
{
|
| 276 |
-
"id": all_chunk_ids[k],
|
| 277 |
-
"text": all_chunk_texts[k],
|
| 278 |
-
"meta": all_metadata[k]
|
| 279 |
-
} for k in range(len(all_chunk_texts))
|
| 280 |
-
])
|
| 281 |
-
save_faiss(index, vectors_meta, args.persist_dir)
|
| 282 |
-
print(f"FAISS index persisted at {args.persist_dir}")
|
| 283 |
-
|
| 284 |
-
print("Done.")
|
| 285 |
-
|
| 286 |
-
if __name__ == "__main__":
|
| 287 |
-
main()
|
| 288 |
-
##result = pipeline.run_on_text(text=note_text, note_id="temp_note")
|
| 289 |
-
##deid_text = result["masked_text"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
import argparse
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import List, Dict, Tuple
|
|
|
|
|
|
|
|
|
|
| 6 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
|
| 7 |
import chromadb
|
| 8 |
from chromadb.config import Settings as ChromaSettings
|
|
|
|
|
|
|
| 9 |
import faiss
|
| 10 |
import pickle
|
| 11 |
|
| 12 |
DEFAULT_CHUNK_TOKENS = 200
|
| 13 |
DEFAULT_OVERLAP_TOKENS = 50
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def approx_tokenize(text: str) -> List[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
return text.split()
|
| 17 |
|
| 18 |
def detokenize(tokens: List[str]) -> str:
|
| 19 |
return " ".join(tokens)
|
| 20 |
|
| 21 |
def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> List[str]:
|
|
|
|
|
|
|
|
|
|
| 22 |
tokens = approx_tokenize(text)
|
| 23 |
chunks = []
|
| 24 |
i = 0
|
|
|
|
| 35 |
i = 0
|
| 36 |
return chunks
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def index_note(
|
| 39 |
text: str,
|
| 40 |
note_id: str = "temp_note",
|
|
|
|
| 43 |
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 44 |
collection: str = "notes"
|
| 45 |
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
os.makedirs(persist_dir, exist_ok=True)
|
| 47 |
model = SentenceTransformer(model_name)
|
| 48 |
chunks = chunk_text(text, DEFAULT_CHUNK_TOKENS, DEFAULT_OVERLAP_TOKENS)
|
|
|
|
| 51 |
vectors = model.encode(chunks, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
|
| 52 |
|
| 53 |
if db_type == "chroma":
|
| 54 |
+
# FIX: Use get_or_create with consistent settings
|
|
|
|
| 55 |
client = chromadb.PersistentClient(
|
| 56 |
path=persist_dir,
|
| 57 |
+
settings=ChromaSettings(
|
| 58 |
+
allow_reset=False, # Changed to False for consistency
|
| 59 |
+
anonymized_telemetry=False
|
| 60 |
+
)
|
| 61 |
)
|
| 62 |
+
coll = client.get_or_create_collection(collection)
|
|
|
|
|
|
|
|
|
|
| 63 |
coll.upsert(
|
| 64 |
ids=chunk_ids,
|
| 65 |
embeddings=vectors.tolist(),
|
|
|
|
| 67 |
metadatas=metadatas,
|
| 68 |
)
|
| 69 |
elif db_type == "faiss":
|
|
|
|
|
|
|
| 70 |
d = vectors.shape[1]
|
| 71 |
index = faiss.IndexFlatIP(d)
|
| 72 |
index.add(vectors)
|
|
|
|
| 81 |
pickle.dump(vectors_meta, f)
|
| 82 |
|
| 83 |
return note_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
|
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
| 7 |
import subprocess
|
| 8 |
import torch
|
| 9 |
|
| 10 |
-
# Fix torch.classes path error
|
| 11 |
torch.classes.__path__ = []
|
| 12 |
|
| 13 |
# HF Spaces env vars
|
|
@@ -18,15 +18,13 @@ os.environ["SPACY_MODEL"] = "en_core_web_lg"
|
|
| 18 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 19 |
warnings.filterwarnings("ignore", category=UserWarning)
|
| 20 |
|
| 21 |
-
# Dynamic install helpers
|
| 22 |
def install_package(package):
|
| 23 |
try:
|
| 24 |
subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])
|
| 25 |
st.sidebar.success(f"Installed {package}")
|
| 26 |
except Exception:
|
| 27 |
-
st.sidebar.error(f"Failed to install {package}
|
| 28 |
|
| 29 |
-
# Check transformers
|
| 30 |
try:
|
| 31 |
import transformers
|
| 32 |
TRANSFORMERS_OK = True
|
|
@@ -51,7 +49,7 @@ method = "multistage"
|
|
| 51 |
Path(secure_dir).mkdir(exist_ok=True)
|
| 52 |
Path(persist_dir).mkdir(exist_ok=True)
|
| 53 |
|
| 54 |
-
# Sidebar
|
| 55 |
with st.sidebar:
|
| 56 |
st.header("Status")
|
| 57 |
HAS_MODULES = True
|
|
@@ -79,24 +77,13 @@ with st.sidebar:
|
|
| 79 |
HAS_MODULES = False
|
| 80 |
st.error(f"summarizer: {e}")
|
| 81 |
|
| 82 |
-
if not TRANSFORMERS_OK:
|
| 83 |
-
st.error("Transformers failedβrebuild Space.")
|
| 84 |
-
|
| 85 |
st.info(modular_status)
|
| 86 |
st.caption(f"DB: {persist_dir} | Secure: {secure_dir}")
|
| 87 |
-
|
| 88 |
-
if st.button("π§ Install Missing"):
|
| 89 |
-
install_package("presidio-analyzer")
|
| 90 |
-
install_package("spacy")
|
| 91 |
-
subprocess.check_call(["python", "-m", "spacy", "download", "en_core_web_lg"], stdout=subprocess.DEVNULL)
|
| 92 |
-
st.rerun()
|
| 93 |
|
| 94 |
# Fallback functions
|
| 95 |
def fallback_deid(text: str) -> str:
|
| 96 |
patterns = [
|
| 97 |
(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]'),
|
| 98 |
-
(r'\b[A-Z][a-z]{2,}\b(?=\s+(her|his|the|by)\b)', '[LAST_NAME]'),
|
| 99 |
-
(r'\b[A-Z][a-z]{2,}\b(?! (BP|HR|RR|mg|mmHg|bpm|CT|MRI|TIA|NIH|EF|RA|HS|BID|QID|PCP))', '[NAME]'),
|
| 100 |
(r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', '[DATE]'),
|
| 101 |
(r'\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b', '[PHONE]'),
|
| 102 |
(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]'),
|
|
@@ -104,7 +91,7 @@ def fallback_deid(text: str) -> str:
|
|
| 104 |
]
|
| 105 |
for pat, rep in patterns:
|
| 106 |
text = re.sub(pat, rep, text, flags=re.I)
|
| 107 |
-
return
|
| 108 |
|
| 109 |
def fallback_retrieve(deid_text: str, top_k: int = 5) -> list:
|
| 110 |
if len(deid_text) > 3000:
|
|
@@ -114,71 +101,33 @@ def fallback_retrieve(deid_text: str, top_k: int = 5) -> list:
|
|
| 114 |
|
| 115 |
def fallback_summarize(chunks: list, tokenizer, model) -> str:
|
| 116 |
context = "\n\n".join(chunks)
|
| 117 |
-
prompt = f"summarize:
|
| 118 |
-
inputs = tokenizer(prompt, return_tensors="pt", max_length=
|
| 119 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 120 |
model.to(device)
|
| 121 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 122 |
with torch.no_grad():
|
| 123 |
outputs = model.generate(
|
| 124 |
inputs['input_ids'],
|
| 125 |
-
max_new_tokens=
|
| 126 |
-
min_length=
|
| 127 |
num_beams=2,
|
| 128 |
-
length_penalty=1.0,
|
| 129 |
early_stopping=True,
|
| 130 |
-
|
| 131 |
-
repetition_penalty=1.1,
|
| 132 |
-
pad_token_id=tokenizer.pad_token_id,
|
| 133 |
-
use_cache=True
|
| 134 |
)
|
| 135 |
-
|
| 136 |
-
sections = {
|
| 137 |
-
"Chief Complaint:": "Not documented",
|
| 138 |
-
"HPI:": "Not documented",
|
| 139 |
-
"Assessment:": "Not documented",
|
| 140 |
-
"Vitals:": "Not documented",
|
| 141 |
-
"Medication:": "Not documented",
|
| 142 |
-
"Plan:": "Not documented",
|
| 143 |
-
"Discharge Summary:": "Not documented"
|
| 144 |
-
}
|
| 145 |
-
for line in raw.split('\n'):
|
| 146 |
-
line_lower = line.lower()
|
| 147 |
-
if any(kw in line_lower for kw in ['chief', 'complaint']):
|
| 148 |
-
sections["Chief Complaint:"] = line
|
| 149 |
-
elif any(kw in line_lower for kw in ['hpi', 'history', 'onset']):
|
| 150 |
-
sections["HPI:"] = line
|
| 151 |
-
elif any(kw in line_lower for kw in ['assessment', 'impression']):
|
| 152 |
-
sections["Assessment:"] = line
|
| 153 |
-
elif any(kw in line_lower for kw in ['vital', 'bp', 'hr']):
|
| 154 |
-
sections["Vitals:"] = line
|
| 155 |
-
elif any(kw in line_lower for kw in ['medication', 'mg', 'bid']):
|
| 156 |
-
sections["Medication:"] = line
|
| 157 |
-
elif any(kw in line_lower for kw in ['plan', 'admit', 'labs']):
|
| 158 |
-
sections["Plan:"] = line
|
| 159 |
-
elif 'discharge' in line_lower:
|
| 160 |
-
sections["Discharge Summary:"] = line
|
| 161 |
-
return "\n\n".join([f"{k}\n{sections[k]}" for k in sections])
|
| 162 |
|
| 163 |
def simple_validate(summary: str) -> dict:
|
| 164 |
-
score =
|
| 165 |
warnings = []
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
if re.search(r'\d+\s*(mg|%|bpm|mmHg)', summary, re.I):
|
| 173 |
-
score += 20
|
| 174 |
-
if "not documented" in summary.lower() and summary.lower().count("not documented") > 3:
|
| 175 |
-
score -= 25
|
| 176 |
-
warnings.append("Excessive gapsβreview input.")
|
| 177 |
-
score = max(0, min(100, score))
|
| 178 |
-
status = "EXCELLENT" if score >= 85 else "GOOD" if score >= 70 else "FAIR" if score >= 50 else "POOR"
|
| 179 |
return {"quality_score": score, "status": status, "warnings": warnings}
|
| 180 |
|
| 181 |
-
# Load model
|
| 182 |
@st.cache_resource
|
| 183 |
def load_model(model_name):
|
| 184 |
try:
|
|
@@ -192,11 +141,8 @@ def load_model(model_name):
|
|
| 192 |
low_cpu_mem_usage=True,
|
| 193 |
cache_dir="/tmp/hf_cache"
|
| 194 |
)
|
| 195 |
-
if
|
| 196 |
-
|
| 197 |
-
model.to('cpu')
|
| 198 |
-
else:
|
| 199 |
-
model.to('cuda')
|
| 200 |
st.sidebar.success("β Model Loaded")
|
| 201 |
return tokenizer, model
|
| 202 |
except Exception as e:
|
|
@@ -219,7 +165,7 @@ tab1, tab2 = st.tabs(["π De-ID & Prepare", "β¨ Generate Note"])
|
|
| 219 |
with tab1:
|
| 220 |
st.header("Upload/Paste Note")
|
| 221 |
uploaded = st.file_uploader("Upload .txt", type=["txt"])
|
| 222 |
-
input_text = st.text_area("Or paste
|
| 223 |
note_text = ""
|
| 224 |
if uploaded:
|
| 225 |
note_text = uploaded.read().decode("utf-8", errors="ignore")
|
|
@@ -237,90 +183,88 @@ with tab1:
|
|
| 237 |
if "encrypted_span_map" in result:
|
| 238 |
with open(f"{secure_dir}/session_note.spanmap.enc", "wb") as f:
|
| 239 |
f.write(result["encrypted_span_map"])
|
| 240 |
-
|
| 241 |
except Exception as e:
|
| 242 |
-
st.warning(f"
|
| 243 |
deid_text = fallback_deid(note_text)
|
| 244 |
else:
|
| 245 |
deid_text = fallback_deid(note_text)
|
| 246 |
|
| 247 |
st.session_state.deid_text = deid_text
|
| 248 |
-
st.success(f"Ready: {len(deid_text)} chars (PHI redacted)
|
| 249 |
else:
|
| 250 |
-
st.warning("Enter text
|
| 251 |
|
| 252 |
if st.session_state.deid_text:
|
| 253 |
-
with st.expander("Preview
|
| 254 |
-
st.text_area("", st.session_state.deid_text, height=200, disabled=True)
|
| 255 |
|
| 256 |
with tab2:
|
| 257 |
st.header("RAG Summarization")
|
| 258 |
if not st.session_state.deid_text:
|
| 259 |
-
st.warning("
|
| 260 |
else:
|
| 261 |
-
st.info(f"
|
| 262 |
-
|
| 263 |
-
|
|
|
|
| 264 |
deid_text = st.session_state.deid_text
|
| 265 |
|
| 266 |
try:
|
| 267 |
if HAS_MODULES:
|
| 268 |
-
# WORKAROUND: Delete vector store to avoid Chroma singleton conflict
|
| 269 |
-
import shutil
|
| 270 |
-
if Path(persist_dir).exists():
|
| 271 |
-
shutil.rmtree(persist_dir)
|
| 272 |
-
Path(persist_dir).mkdir(exist_ok=True)
|
| 273 |
-
|
| 274 |
# Index
|
| 275 |
index_note(deid_text, note_id="session_note", persist_dir=persist_dir, db_type=db_type)
|
| 276 |
|
| 277 |
# Retrieve
|
| 278 |
embed_f = load_embedder()
|
| 279 |
docs = retrieve_docs(db_type, persist_dir, "notes", deid_text[:200], top_k, embed_f)
|
| 280 |
-
chunks = [doc.page_content for doc in docs] if docs else fallback_retrieve(deid_text, top_k)
|
| 281 |
|
| 282 |
# Summarize
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
| 285 |
else:
|
| 286 |
chunks = fallback_retrieve(deid_text, top_k)
|
| 287 |
summary = fallback_summarize(chunks, tokenizer, model)
|
| 288 |
st.session_state.validation = simple_validate(summary)
|
| 289 |
|
| 290 |
st.session_state.summary = summary
|
| 291 |
-
st.success("
|
| 292 |
|
| 293 |
except Exception as e:
|
| 294 |
-
st.error(f"
|
| 295 |
-
|
|
|
|
| 296 |
st.session_state.summary = summary
|
| 297 |
st.session_state.validation = simple_validate(summary)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
for w in val["warnings"]:
|
| 315 |
-
st.warning(w)
|
| 316 |
-
|
| 317 |
-
st.download_button("πΎ Download", summ, "note.txt")
|
| 318 |
-
|
| 319 |
-
if st.button("π Reset"):
|
| 320 |
-
st.session_state.deid_text = ""
|
| 321 |
-
st.session_state.summary = None
|
| 322 |
-
st.session_state.validation = None
|
| 323 |
-
st.rerun()
|
| 324 |
|
| 325 |
st.markdown("---")
|
| 326 |
-
st.
|
|
|
|
| 7 |
import subprocess
|
| 8 |
import torch
|
| 9 |
|
| 10 |
+
# Fix torch.classes path error
|
| 11 |
torch.classes.__path__ = []
|
| 12 |
|
| 13 |
# HF Spaces env vars
|
|
|
|
| 18 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 19 |
warnings.filterwarnings("ignore", category=UserWarning)
|
| 20 |
|
|
|
|
| 21 |
def install_package(package):
|
| 22 |
try:
|
| 23 |
subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])
|
| 24 |
st.sidebar.success(f"Installed {package}")
|
| 25 |
except Exception:
|
| 26 |
+
st.sidebar.error(f"Failed to install {package}")
|
| 27 |
|
|
|
|
| 28 |
try:
|
| 29 |
import transformers
|
| 30 |
TRANSFORMERS_OK = True
|
|
|
|
| 49 |
Path(secure_dir).mkdir(exist_ok=True)
|
| 50 |
Path(persist_dir).mkdir(exist_ok=True)
|
| 51 |
|
| 52 |
+
# Sidebar
|
| 53 |
with st.sidebar:
|
| 54 |
st.header("Status")
|
| 55 |
HAS_MODULES = True
|
|
|
|
| 77 |
HAS_MODULES = False
|
| 78 |
st.error(f"summarizer: {e}")
|
| 79 |
|
|
|
|
|
|
|
|
|
|
| 80 |
st.info(modular_status)
|
| 81 |
st.caption(f"DB: {persist_dir} | Secure: {secure_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# Fallback functions
|
| 84 |
def fallback_deid(text: str) -> str:
|
| 85 |
patterns = [
|
| 86 |
(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]'),
|
|
|
|
|
|
|
| 87 |
(r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', '[DATE]'),
|
| 88 |
(r'\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b', '[PHONE]'),
|
| 89 |
(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]'),
|
|
|
|
| 91 |
]
|
| 92 |
for pat, rep in patterns:
|
| 93 |
text = re.sub(pat, rep, text, flags=re.I)
|
| 94 |
+
return text
|
| 95 |
|
| 96 |
def fallback_retrieve(deid_text: str, top_k: int = 5) -> list:
|
| 97 |
if len(deid_text) > 3000:
|
|
|
|
| 101 |
|
| 102 |
def fallback_summarize(chunks: list, tokenizer, model) -> str:
|
| 103 |
context = "\n\n".join(chunks)
|
| 104 |
+
prompt = f"summarize: Clinical note. Extract: Chief Complaint, HPI, Assessment, Vitals, Medication, Plan, Discharge Summary.\n\nNote: {context}\n\nSummary:"
|
| 105 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=2048, truncation=True)
|
| 106 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 107 |
model.to(device)
|
| 108 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 109 |
with torch.no_grad():
|
| 110 |
outputs = model.generate(
|
| 111 |
inputs['input_ids'],
|
| 112 |
+
max_new_tokens=300,
|
| 113 |
+
min_length=100,
|
| 114 |
num_beams=2,
|
|
|
|
| 115 |
early_stopping=True,
|
| 116 |
+
pad_token_id=tokenizer.pad_token_id
|
|
|
|
|
|
|
|
|
|
| 117 |
)
|
| 118 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
def simple_validate(summary: str) -> dict:
|
| 121 |
+
score = 75
|
| 122 |
warnings = []
|
| 123 |
+
if "not documented" in summary.lower():
|
| 124 |
+
count = summary.lower().count("not documented")
|
| 125 |
+
if count > 3:
|
| 126 |
+
score -= 25
|
| 127 |
+
warnings.append(f"Excessive gaps ({count} sections empty)")
|
| 128 |
+
status = "GOOD" if score >= 70 else "FAIR" if score >= 50 else "POOR"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
return {"quality_score": score, "status": status, "warnings": warnings}
|
| 130 |
|
|
|
|
| 131 |
@st.cache_resource
|
| 132 |
def load_model(model_name):
|
| 133 |
try:
|
|
|
|
| 141 |
low_cpu_mem_usage=True,
|
| 142 |
cache_dir="/tmp/hf_cache"
|
| 143 |
)
|
| 144 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 145 |
+
model.to(device)
|
|
|
|
|
|
|
|
|
|
| 146 |
st.sidebar.success("β Model Loaded")
|
| 147 |
return tokenizer, model
|
| 148 |
except Exception as e:
|
|
|
|
| 165 |
with tab1:
|
| 166 |
st.header("Upload/Paste Note")
|
| 167 |
uploaded = st.file_uploader("Upload .txt", type=["txt"])
|
| 168 |
+
input_text = st.text_area("Or paste clinical note:", height=250)
|
| 169 |
note_text = ""
|
| 170 |
if uploaded:
|
| 171 |
note_text = uploaded.read().decode("utf-8", errors="ignore")
|
|
|
|
| 183 |
if "encrypted_span_map" in result:
|
| 184 |
with open(f"{secure_dir}/session_note.spanmap.enc", "wb") as f:
|
| 185 |
f.write(result["encrypted_span_map"])
|
| 186 |
+
st.success("β De-identified with audit trail")
|
| 187 |
except Exception as e:
|
| 188 |
+
st.warning(f"Using fallback De-ID: {e}")
|
| 189 |
deid_text = fallback_deid(note_text)
|
| 190 |
else:
|
| 191 |
deid_text = fallback_deid(note_text)
|
| 192 |
|
| 193 |
st.session_state.deid_text = deid_text
|
| 194 |
+
st.success(f"Ready: {len(deid_text)} chars (PHI redacted)")
|
| 195 |
else:
|
| 196 |
+
st.warning("Enter text first")
|
| 197 |
|
| 198 |
if st.session_state.deid_text:
|
| 199 |
+
with st.expander("Preview De-identified Text"):
|
| 200 |
+
st.text_area("", st.session_state.deid_text, height=200, disabled=True, key="preview")
|
| 201 |
|
| 202 |
with tab2:
|
| 203 |
st.header("RAG Summarization")
|
| 204 |
if not st.session_state.deid_text:
|
| 205 |
+
st.warning("β Please de-identify a note first (Tab 1)")
|
| 206 |
else:
|
| 207 |
+
st.info(f"β Ready: {len(st.session_state.deid_text)} chars | Mode: {'Modular RAG' if HAS_MODULES else 'Fallback'}")
|
| 208 |
+
|
| 209 |
+
if st.button("π Generate Summary", type="primary"):
|
| 210 |
+
with st.spinner("Processing (this may take 1-2 minutes)..."):
|
| 211 |
deid_text = st.session_state.deid_text
|
| 212 |
|
| 213 |
try:
|
| 214 |
if HAS_MODULES:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
# Index
|
| 216 |
index_note(deid_text, note_id="session_note", persist_dir=persist_dir, db_type=db_type)
|
| 217 |
|
| 218 |
# Retrieve
|
| 219 |
embed_f = load_embedder()
|
| 220 |
docs = retrieve_docs(db_type, persist_dir, "notes", deid_text[:200], top_k, embed_f)
|
|
|
|
| 221 |
|
| 222 |
# Summarize
|
| 223 |
+
if docs:
|
| 224 |
+
summary = summarize_docs(tokenizer, model, docs, method)
|
| 225 |
+
st.session_state.validation = validate_summary_quality(summary, deid_text)
|
| 226 |
+
else:
|
| 227 |
+
raise Exception("No documents retrieved")
|
| 228 |
else:
|
| 229 |
chunks = fallback_retrieve(deid_text, top_k)
|
| 230 |
summary = fallback_summarize(chunks, tokenizer, model)
|
| 231 |
st.session_state.validation = simple_validate(summary)
|
| 232 |
|
| 233 |
st.session_state.summary = summary
|
| 234 |
+
st.success("β Summary generated!")
|
| 235 |
|
| 236 |
except Exception as e:
|
| 237 |
+
st.error(f"RAG failed: {e}. Using direct fallback.")
|
| 238 |
+
chunks = fallback_retrieve(deid_text, 3)
|
| 239 |
+
summary = fallback_summarize(chunks, tokenizer, model)
|
| 240 |
st.session_state.summary = summary
|
| 241 |
st.session_state.validation = simple_validate(summary)
|
| 242 |
+
|
| 243 |
+
if st.session_state.summary:
|
| 244 |
+
summ = st.session_state.summary
|
| 245 |
+
val = st.session_state.validation
|
| 246 |
|
| 247 |
+
col1, col2 = st.columns([3, 1])
|
| 248 |
+
with col1:
|
| 249 |
+
st.subheader("π Structured Clinical Summary")
|
| 250 |
+
st.markdown(summ)
|
| 251 |
+
with col2:
|
| 252 |
+
st.subheader("π Quality Assessment")
|
| 253 |
+
color = {"EXCELLENT": "π’", "GOOD": "π΅", "FAIR": "π‘", "POOR": "π΄"}.get(val.get("status", ""), "βͺ")
|
| 254 |
+
st.markdown(f"**{color} {val.get('status', 'N/A')}**")
|
| 255 |
+
st.metric("Quality Score", f"{val.get('quality_score', 0)}/100")
|
| 256 |
|
| 257 |
+
if val.get("warnings"):
|
| 258 |
+
for w in val["warnings"]:
|
| 259 |
+
st.warning(w)
|
| 260 |
+
|
| 261 |
+
st.download_button("πΎ Download Summary", summ, "clinical_summary.txt", type="secondary")
|
| 262 |
+
|
| 263 |
+
if st.button("π Reset & Start Over"):
|
| 264 |
+
st.session_state.deid_text = ""
|
| 265 |
+
st.session_state.summary = None
|
| 266 |
+
st.session_state.validation = None
|
| 267 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
st.markdown("---")
|
| 270 |
+
st.caption("*HIPAA-Compliant RAG Clinical Summarizer | Portfolio Demo*")
|