samarth09healthPM commited on
Commit
7d10354
Β·
1 Parent(s): 603833c

Fix duplicate key error with session state

Browse files
Files changed (2) hide show
  1. indexer.py +6 -212
  2. main.py +65 -121
indexer.py CHANGED
@@ -1,75 +1,24 @@
1
- # app/indexer.py
2
- # Day 6: Vector store & embeddings
3
- # Usage examples:
4
- # python app/indexer.py --input_dir ./data/outputs --db_type chroma --persist_dir ./data/vector_store
5
- # python app/indexer.py --input_dir ./data/outputs --db_type faiss --persist_dir ./data/vector_store_faiss
6
-
7
  import os
8
  import json
9
  import argparse
10
  from pathlib import Path
11
  from typing import List, Dict, Tuple
12
- from tqdm import tqdm
13
-
14
- # Embeddings
15
  from sentence_transformers import SentenceTransformer
16
-
17
- # Vector stores
18
- # Chroma
19
  import chromadb
20
  from chromadb.config import Settings as ChromaSettings
21
-
22
- # FAISS
23
  import faiss
24
  import pickle
25
 
26
  DEFAULT_CHUNK_TOKENS = 200
27
  DEFAULT_OVERLAP_TOKENS = 50
28
 
29
- def read_note_files(input_dir: str) -> List[Dict]:
30
- """
31
- Reads de-identified notes from .txt or .json in input_dir.
32
- Expects .json to have a 'text' field containing de-identified content.
33
- Returns list of dicts: {id, text, section?}
34
- """
35
- items = []
36
- p = Path(input_dir)
37
- if not p.exists():
38
- raise FileNotFoundError(f"Input dir not found: {input_dir}")
39
-
40
- for fp in p.glob("**/*"):
41
- if fp.is_dir():
42
- continue
43
- if fp.suffix.lower() == ".txt":
44
- text = fp.read_text(encoding="utf-8", errors="ignore").strip()
45
- if text:
46
- items.append({"id": fp.stem, "text": text, "section": None})
47
- elif fp.suffix.lower() == ".json":
48
- try:
49
- obj = json.loads(fp.read_text(encoding="utf-8", errors="ignore"))
50
- text = obj.get("text") or obj.get("deidentified_text") or ""
51
- section = obj.get("section")
52
- if text:
53
- items.append({"id": fp.stem, "text": text.strip(), "section": section})
54
- except Exception:
55
- # Skip malformed
56
- continue
57
- return items
58
-
59
  def approx_tokenize(text: str) -> List[str]:
60
- """
61
- Approximate tokenization by splitting on whitespace.
62
- For MVP this is fine; can replace with tiktoken later.
63
- """
64
  return text.split()
65
 
66
  def detokenize(tokens: List[str]) -> str:
67
  return " ".join(tokens)
68
 
69
  def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> List[str]:
70
- """
71
- Simple sliding window chunking.
72
- """
73
  tokens = approx_tokenize(text)
74
  chunks = []
75
  i = 0
@@ -86,38 +35,6 @@ def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> List[str]:
86
  i = 0
87
  return chunks
88
 
89
- def embed_texts(model: SentenceTransformer, texts: List[str]):
90
- return model.encode(texts, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
91
-
92
- def build_chroma(persist_dir: str, collection_name: str = "notes"):
93
- client = chromadb.PersistentClient(
94
- path=persist_dir,
95
- settings=ChromaSettings(allow_reset=True)
96
- )
97
- if collection_name in [c.name for c in client.list_collections()]:
98
- coll = client.get_collection(collection_name)
99
- else:
100
- coll = client.create_collection(collection_name)
101
- return client, coll
102
-
103
- def save_faiss(index, vectors_meta: List[Dict], persist_dir: str):
104
- os.makedirs(persist_dir, exist_ok=True)
105
- faiss_path = os.path.join(persist_dir, "index.faiss")
106
- meta_path = os.path.join(persist_dir, "meta.pkl")
107
- faiss.write_index(index, faiss_path)
108
- with open(meta_path, "wb") as f:
109
- pickle.dump(vectors_meta, f)
110
-
111
- def load_faiss(persist_dir: str):
112
- faiss_path = os.path.join(persist_dir, "index.faiss")
113
- meta_path = os.path.join(persist_dir, "meta.pkl")
114
- if os.path.exists(faiss_path) and os.path.exists(meta_path):
115
- index = faiss.read_index(faiss_path)
116
- with open(meta_path, "rb") as f:
117
- meta = pickle.load(f)
118
- return index, meta
119
- return None, []
120
-
121
  def index_note(
122
  text: str,
123
  note_id: str = "temp_note",
@@ -126,35 +43,6 @@ def index_note(
126
  model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
127
  collection: str = "notes"
128
  ) -> str:
129
- from sentence_transformers import SentenceTransformer
130
- import os
131
-
132
- DEFAULT_CHUNK_TOKENS = 200
133
- DEFAULT_OVERLAP_TOKENS = 50
134
-
135
- def approx_tokenize(text: str):
136
- return text.split()
137
-
138
- def detokenize(tokens):
139
- return " ".join(tokens)
140
-
141
- def chunk_text(text, chunk_tokens, overlap_tokens):
142
- tokens = approx_tokenize(text)
143
- chunks = []
144
- i = 0
145
- n = len(tokens)
146
- while i < n:
147
- j = min(i + chunk_tokens, n)
148
- chunk = detokenize(tokens[i:j])
149
- if chunk.strip():
150
- chunks.append(chunk)
151
- if j == n:
152
- break
153
- i = j - overlap_tokens
154
- if i < 0:
155
- i = 0
156
- return chunks
157
-
158
  os.makedirs(persist_dir, exist_ok=True)
159
  model = SentenceTransformer(model_name)
160
  chunks = chunk_text(text, DEFAULT_CHUNK_TOKENS, DEFAULT_OVERLAP_TOKENS)
@@ -163,16 +51,15 @@ def index_note(
163
  vectors = model.encode(chunks, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
164
 
165
  if db_type == "chroma":
166
- from chromadb.config import Settings as ChromaSettings
167
- import chromadb
168
  client = chromadb.PersistentClient(
169
  path=persist_dir,
170
- settings=ChromaSettings(allow_reset=True)
 
 
 
171
  )
172
- if collection in [c.name for c in client.list_collections()]:
173
- coll = client.get_collection(collection)
174
- else:
175
- coll = client.create_collection(collection)
176
  coll.upsert(
177
  ids=chunk_ids,
178
  embeddings=vectors.tolist(),
@@ -180,8 +67,6 @@ def index_note(
180
  metadatas=metadatas,
181
  )
182
  elif db_type == "faiss":
183
- import faiss
184
- import pickle
185
  d = vectors.shape[1]
186
  index = faiss.IndexFlatIP(d)
187
  index.add(vectors)
@@ -196,94 +81,3 @@ def index_note(
196
  pickle.dump(vectors_meta, f)
197
 
198
  return note_id
199
-
200
-
201
- def main():
202
- parser = argparse.ArgumentParser(description="Day 6: Build local vector DB from de-identified notes.")
203
- parser.add_argument("--input_dir", required=True, help="Directory with de-identified notes (.txt or .json).")
204
- parser.add_argument("--persist_dir", default="./data/vector_store", help="Where to persist the DB.")
205
- parser.add_argument("--db_type", choices=["chroma", "faiss"], default="chroma", help="Vector DB type.")
206
- parser.add_argument("--model_name", default="sentence-transformers/all-MiniLM-L6-v2", help="Embedding model.")
207
- parser.add_argument("--chunk_tokens", type=int, default=DEFAULT_CHUNK_TOKENS, help="Approx tokens per chunk.")
208
- parser.add_argument("--overlap_tokens", type=int, default=DEFAULT_OVERLAP_TOKENS, help="Token overlap.")
209
- parser.add_argument("--collection", default="notes", help="Collection name (Chroma).")
210
- args = parser.parse_args()
211
-
212
- notes = read_note_files(args.input_dir)
213
- if not notes:
214
- print(f"No de-identified notes found in {args.input_dir}. Ensure Day 5 outputs exist.")
215
- return
216
-
217
- print(f"Loaded {len(notes)} de-identified notes from {args.input_dir}")
218
- os.makedirs(args.persist_dir, exist_ok=True)
219
-
220
- print(f"Loading embedding model: {args.model_name}")
221
- model = SentenceTransformer(args.model_name)
222
-
223
- all_chunk_texts = []
224
- all_chunk_ids = []
225
- all_metadata = []
226
-
227
- print("Chunking notes...")
228
- for note in tqdm(notes):
229
- chunks = chunk_text(note["text"], args.chunk_tokens, args.overlap_tokens)
230
- for idx, ch in enumerate(chunks):
231
- cid = f"{note['id']}::chunk_{idx}"
232
- all_chunk_texts.append(ch)
233
- all_chunk_ids.append(cid)
234
- all_metadata.append({
235
- "note_id": note["id"],
236
- "chunk_index": idx,
237
- "section": note.get("section")
238
- })
239
-
240
- print(f"Total chunks: {len(all_chunk_texts)}")
241
-
242
- print("Embedding chunks...")
243
- vectors = embed_texts(model, all_chunk_texts)
244
-
245
- if args.db_type == "chroma":
246
- print("Building Chroma persistent collection...")
247
- client, coll = build_chroma(args.persist_dir, args.collection)
248
-
249
- # Upsert in manageable batches
250
- batch = 512
251
- for i in tqdm(range(0, len(all_chunk_texts), batch)):
252
- j = min(i + batch, len(all_chunk_texts))
253
- coll.upsert(
254
- ids=all_chunk_ids[i:j],
255
- embeddings=vectors[i:j].tolist(),
256
- documents=all_chunk_texts[i:j],
257
- metadatas=all_metadata[i:j],
258
- )
259
- print(f"Chroma collection '{args.collection}' persisted at {args.persist_dir}")
260
-
261
- elif args.db_type == "faiss":
262
- print("Building FAISS index...")
263
- d = vectors.shape[1]
264
- index = faiss.IndexFlatIP(d) # normalized vectors β†’ use inner product as cosine
265
- # Try to load existing
266
- existing_index, existing_meta = load_faiss(args.persist_dir)
267
- if existing_index is not None:
268
- print("Appending to existing FAISS index...")
269
- index = existing_index
270
- vectors_meta = existing_meta
271
- else:
272
- vectors_meta = []
273
- index.add(vectors)
274
- vectors_meta.extend([
275
- {
276
- "id": all_chunk_ids[k],
277
- "text": all_chunk_texts[k],
278
- "meta": all_metadata[k]
279
- } for k in range(len(all_chunk_texts))
280
- ])
281
- save_faiss(index, vectors_meta, args.persist_dir)
282
- print(f"FAISS index persisted at {args.persist_dir}")
283
-
284
- print("Done.")
285
-
286
- if __name__ == "__main__":
287
- main()
288
- ##result = pipeline.run_on_text(text=note_text, note_id="temp_note")
289
- ##deid_text = result["masked_text"]
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import argparse
4
  from pathlib import Path
5
  from typing import List, Dict, Tuple
 
 
 
6
  from sentence_transformers import SentenceTransformer
 
 
 
7
  import chromadb
8
  from chromadb.config import Settings as ChromaSettings
 
 
9
  import faiss
10
  import pickle
11
 
12
  DEFAULT_CHUNK_TOKENS = 200
13
  DEFAULT_OVERLAP_TOKENS = 50
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def approx_tokenize(text: str) -> List[str]:
 
 
 
 
16
  return text.split()
17
 
18
  def detokenize(tokens: List[str]) -> str:
19
  return " ".join(tokens)
20
 
21
  def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> List[str]:
 
 
 
22
  tokens = approx_tokenize(text)
23
  chunks = []
24
  i = 0
 
35
  i = 0
36
  return chunks
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def index_note(
39
  text: str,
40
  note_id: str = "temp_note",
 
43
  model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
44
  collection: str = "notes"
45
  ) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  os.makedirs(persist_dir, exist_ok=True)
47
  model = SentenceTransformer(model_name)
48
  chunks = chunk_text(text, DEFAULT_CHUNK_TOKENS, DEFAULT_OVERLAP_TOKENS)
 
51
  vectors = model.encode(chunks, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
52
 
53
  if db_type == "chroma":
54
+ # FIX: Use get_or_create with consistent settings
 
55
  client = chromadb.PersistentClient(
56
  path=persist_dir,
57
+ settings=ChromaSettings(
58
+ allow_reset=False, # Changed to False for consistency
59
+ anonymized_telemetry=False
60
+ )
61
  )
62
+ coll = client.get_or_create_collection(collection)
 
 
 
63
  coll.upsert(
64
  ids=chunk_ids,
65
  embeddings=vectors.tolist(),
 
67
  metadatas=metadatas,
68
  )
69
  elif db_type == "faiss":
 
 
70
  d = vectors.shape[1]
71
  index = faiss.IndexFlatIP(d)
72
  index.add(vectors)
 
81
  pickle.dump(vectors_meta, f)
82
 
83
  return note_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
7
  import subprocess
8
  import torch
9
 
10
- # Fix torch.classes path error for Streamlit compatibility
11
  torch.classes.__path__ = []
12
 
13
  # HF Spaces env vars
@@ -18,15 +18,13 @@ os.environ["SPACY_MODEL"] = "en_core_web_lg"
18
  warnings.filterwarnings("ignore", category=DeprecationWarning)
19
  warnings.filterwarnings("ignore", category=UserWarning)
20
 
21
- # Dynamic install helpers
22
  def install_package(package):
23
  try:
24
  subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])
25
  st.sidebar.success(f"Installed {package}")
26
  except Exception:
27
- st.sidebar.error(f"Failed to install {package}. Use requirements.txt.")
28
 
29
- # Check transformers
30
  try:
31
  import transformers
32
  TRANSFORMERS_OK = True
@@ -51,7 +49,7 @@ method = "multistage"
51
  Path(secure_dir).mkdir(exist_ok=True)
52
  Path(persist_dir).mkdir(exist_ok=True)
53
 
54
- # Sidebar for status
55
  with st.sidebar:
56
  st.header("Status")
57
  HAS_MODULES = True
@@ -79,24 +77,13 @@ with st.sidebar:
79
  HAS_MODULES = False
80
  st.error(f"summarizer: {e}")
81
 
82
- if not TRANSFORMERS_OK:
83
- st.error("Transformers failedβ€”rebuild Space.")
84
-
85
  st.info(modular_status)
86
  st.caption(f"DB: {persist_dir} | Secure: {secure_dir}")
87
-
88
- if st.button("πŸ”§ Install Missing"):
89
- install_package("presidio-analyzer")
90
- install_package("spacy")
91
- subprocess.check_call(["python", "-m", "spacy", "download", "en_core_web_lg"], stdout=subprocess.DEVNULL)
92
- st.rerun()
93
 
94
  # Fallback functions
95
  def fallback_deid(text: str) -> str:
96
  patterns = [
97
  (r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]'),
98
- (r'\b[A-Z][a-z]{2,}\b(?=\s+(her|his|the|by)\b)', '[LAST_NAME]'),
99
- (r'\b[A-Z][a-z]{2,}\b(?! (BP|HR|RR|mg|mmHg|bpm|CT|MRI|TIA|NIH|EF|RA|HS|BID|QID|PCP))', '[NAME]'),
100
  (r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', '[DATE]'),
101
  (r'\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b', '[PHONE]'),
102
  (r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]'),
@@ -104,7 +91,7 @@ def fallback_deid(text: str) -> str:
104
  ]
105
  for pat, rep in patterns:
106
  text = re.sub(pat, rep, text, flags=re.I)
107
- return re.sub(r'\b[A-Z][a-z]{2,}\b(?! (mg|daily|nightly|BID|QID|PCP|RA|HS|ED|PMH))', '[NAME]', text)
108
 
109
  def fallback_retrieve(deid_text: str, top_k: int = 5) -> list:
110
  if len(deid_text) > 3000:
@@ -114,71 +101,33 @@ def fallback_retrieve(deid_text: str, top_k: int = 5) -> list:
114
 
115
  def fallback_summarize(chunks: list, tokenizer, model) -> str:
116
  context = "\n\n".join(chunks)
117
- prompt = f"summarize: Structured clinical note from context. Sections: Chief Complaint | HPI | Assessment | Vitals | Medication | Plan | Discharge Summary\n\nContext: {context}\n\nOutput only structured sections."
118
- inputs = tokenizer(prompt, return_tensors="pt", max_length=4096, truncation=True)
119
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
120
  model.to(device)
121
  inputs = {k: v.to(device) for k, v in inputs.items()}
122
  with torch.no_grad():
123
  outputs = model.generate(
124
  inputs['input_ids'],
125
- max_new_tokens=400,
126
- min_length=150,
127
  num_beams=2,
128
- length_penalty=1.0,
129
  early_stopping=True,
130
- do_sample=False,
131
- repetition_penalty=1.1,
132
- pad_token_id=tokenizer.pad_token_id,
133
- use_cache=True
134
  )
135
- raw = tokenizer.decode(outputs[0], skip_special_tokens=True)
136
- sections = {
137
- "Chief Complaint:": "Not documented",
138
- "HPI:": "Not documented",
139
- "Assessment:": "Not documented",
140
- "Vitals:": "Not documented",
141
- "Medication:": "Not documented",
142
- "Plan:": "Not documented",
143
- "Discharge Summary:": "Not documented"
144
- }
145
- for line in raw.split('\n'):
146
- line_lower = line.lower()
147
- if any(kw in line_lower for kw in ['chief', 'complaint']):
148
- sections["Chief Complaint:"] = line
149
- elif any(kw in line_lower for kw in ['hpi', 'history', 'onset']):
150
- sections["HPI:"] = line
151
- elif any(kw in line_lower for kw in ['assessment', 'impression']):
152
- sections["Assessment:"] = line
153
- elif any(kw in line_lower for kw in ['vital', 'bp', 'hr']):
154
- sections["Vitals:"] = line
155
- elif any(kw in line_lower for kw in ['medication', 'mg', 'bid']):
156
- sections["Medication:"] = line
157
- elif any(kw in line_lower for kw in ['plan', 'admit', 'labs']):
158
- sections["Plan:"] = line
159
- elif 'discharge' in line_lower:
160
- sections["Discharge Summary:"] = line
161
- return "\n\n".join([f"{k}\n{sections[k]}" for k in sections])
162
 
163
  def simple_validate(summary: str) -> dict:
164
- score = 100
165
  warnings = []
166
- required = ["Chief Complaint", "HPI", "Assessment", "Vitals", "Medication", "Plan", "Discharge Summary"]
167
- present = [sec for sec in required if sec.lower() in summary.lower()]
168
- missing = [sec for sec in required if sec not in present]
169
- if missing:
170
- score -= len(missing) * 15
171
- warnings.append(f"Missing: {', '.join(missing)}")
172
- if re.search(r'\d+\s*(mg|%|bpm|mmHg)', summary, re.I):
173
- score += 20
174
- if "not documented" in summary.lower() and summary.lower().count("not documented") > 3:
175
- score -= 25
176
- warnings.append("Excessive gapsβ€”review input.")
177
- score = max(0, min(100, score))
178
- status = "EXCELLENT" if score >= 85 else "GOOD" if score >= 70 else "FAIR" if score >= 50 else "POOR"
179
  return {"quality_score": score, "status": status, "warnings": warnings}
180
 
181
- # Load model
182
  @st.cache_resource
183
  def load_model(model_name):
184
  try:
@@ -192,11 +141,8 @@ def load_model(model_name):
192
  low_cpu_mem_usage=True,
193
  cache_dir="/tmp/hf_cache"
194
  )
195
- if not torch.cuda.is_available():
196
- model.gradient_checkpointing_enable()
197
- model.to('cpu')
198
- else:
199
- model.to('cuda')
200
  st.sidebar.success("βœ“ Model Loaded")
201
  return tokenizer, model
202
  except Exception as e:
@@ -219,7 +165,7 @@ tab1, tab2 = st.tabs(["πŸ“ De-ID & Prepare", "✨ Generate Note"])
219
  with tab1:
220
  st.header("Upload/Paste Note")
221
  uploaded = st.file_uploader("Upload .txt", type=["txt"])
222
- input_text = st.text_area("Or paste (long OK):", height=250)
223
  note_text = ""
224
  if uploaded:
225
  note_text = uploaded.read().decode("utf-8", errors="ignore")
@@ -237,90 +183,88 @@ with tab1:
237
  if "encrypted_span_map" in result:
238
  with open(f"{secure_dir}/session_note.spanmap.enc", "wb") as f:
239
  f.write(result["encrypted_span_map"])
240
- st.success("De-ID + encrypted audit saved.")
241
  except Exception as e:
242
- st.warning(f"Modular De-ID failed ({e})β€”using fallback.")
243
  deid_text = fallback_deid(note_text)
244
  else:
245
  deid_text = fallback_deid(note_text)
246
 
247
  st.session_state.deid_text = deid_text
248
- st.success(f"Ready: {len(deid_text)} chars (PHI redacted).")
249
  else:
250
- st.warning("Enter text.")
251
 
252
  if st.session_state.deid_text:
253
- with st.expander("Preview (De-ID'd)"):
254
- st.text_area("", st.session_state.deid_text, height=200, disabled=True)
255
 
256
  with tab2:
257
  st.header("RAG Summarization")
258
  if not st.session_state.deid_text:
259
- st.warning("De-ID first.")
260
  else:
261
- st.info(f"Length: {len(st.session_state.deid_text)} chars | Mode: {'Modular' if HAS_MODULES else 'Fallback'}")
262
- if st.button("πŸš€ Generate", type="primary"):
263
- with st.spinner("Processing (Index/Retrieve/Summarize)..."):
 
264
  deid_text = st.session_state.deid_text
265
 
266
  try:
267
  if HAS_MODULES:
268
- # WORKAROUND: Delete vector store to avoid Chroma singleton conflict
269
- import shutil
270
- if Path(persist_dir).exists():
271
- shutil.rmtree(persist_dir)
272
- Path(persist_dir).mkdir(exist_ok=True)
273
-
274
  # Index
275
  index_note(deid_text, note_id="session_note", persist_dir=persist_dir, db_type=db_type)
276
 
277
  # Retrieve
278
  embed_f = load_embedder()
279
  docs = retrieve_docs(db_type, persist_dir, "notes", deid_text[:200], top_k, embed_f)
280
- chunks = [doc.page_content for doc in docs] if docs else fallback_retrieve(deid_text, top_k)
281
 
282
  # Summarize
283
- summary = summarize_docs(tokenizer, model, docs if docs else [], method)
284
- st.session_state.validation = validate_summary_quality(summary, deid_text)
 
 
 
285
  else:
286
  chunks = fallback_retrieve(deid_text, top_k)
287
  summary = fallback_summarize(chunks, tokenizer, model)
288
  st.session_state.validation = simple_validate(summary)
289
 
290
  st.session_state.summary = summary
291
- st.success("Generated!")
292
 
293
  except Exception as e:
294
- st.error(f"Generation failed: {e}. Using basic fallback.")
295
- summary = fallback_summarize(fallback_retrieve(deid_text, 3), tokenizer, model)
 
296
  st.session_state.summary = summary
297
  st.session_state.validation = simple_validate(summary)
 
 
 
 
298
 
299
- if st.session_state.summary:
300
- summ = st.session_state.summary
301
- val = st.session_state.validation
 
 
 
 
 
 
302
 
303
- col1, col2 = st.columns([3,1])
304
- with col1:
305
- st.subheader("Structured Note")
306
- st.markdown(summ)
307
- with col2:
308
- st.subheader("Assessment")
309
- color = {"EXCELLENT": "🟒", "GOOD": "πŸ”΅", "FAIR": "🟑", "POOR": "πŸ”΄"}.get(val.get("status", ""), "βšͺ")
310
- st.markdown(f"**{color} {val.get('status', 'N/A')}**")
311
- st.metric("Score", f"{val.get('quality_score', 0)}/100")
312
-
313
- if val.get("warnings"):
314
- for w in val["warnings"]:
315
- st.warning(w)
316
-
317
- st.download_button("πŸ’Ύ Download", summ, "note.txt")
318
-
319
- if st.button("πŸ”„ Reset"):
320
- st.session_state.deid_text = ""
321
- st.session_state.summary = None
322
- st.session_state.validation = None
323
- st.rerun()
324
 
325
  st.markdown("---")
326
- st.markdown("*Error-Resilient RAG Demo | Portfolio: HIPAA Audit-Ready.*")
 
7
  import subprocess
8
  import torch
9
 
10
+ # Fix torch.classes path error
11
  torch.classes.__path__ = []
12
 
13
  # HF Spaces env vars
 
18
  warnings.filterwarnings("ignore", category=DeprecationWarning)
19
  warnings.filterwarnings("ignore", category=UserWarning)
20
 
 
21
  def install_package(package):
22
  try:
23
  subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])
24
  st.sidebar.success(f"Installed {package}")
25
  except Exception:
26
+ st.sidebar.error(f"Failed to install {package}")
27
 
 
28
  try:
29
  import transformers
30
  TRANSFORMERS_OK = True
 
49
  Path(secure_dir).mkdir(exist_ok=True)
50
  Path(persist_dir).mkdir(exist_ok=True)
51
 
52
+ # Sidebar
53
  with st.sidebar:
54
  st.header("Status")
55
  HAS_MODULES = True
 
77
  HAS_MODULES = False
78
  st.error(f"summarizer: {e}")
79
 
 
 
 
80
  st.info(modular_status)
81
  st.caption(f"DB: {persist_dir} | Secure: {secure_dir}")
 
 
 
 
 
 
82
 
83
  # Fallback functions
84
  def fallback_deid(text: str) -> str:
85
  patterns = [
86
  (r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]'),
 
 
87
  (r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', '[DATE]'),
88
  (r'\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b', '[PHONE]'),
89
  (r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]'),
 
91
  ]
92
  for pat, rep in patterns:
93
  text = re.sub(pat, rep, text, flags=re.I)
94
+ return text
95
 
96
  def fallback_retrieve(deid_text: str, top_k: int = 5) -> list:
97
  if len(deid_text) > 3000:
 
101
 
102
  def fallback_summarize(chunks: list, tokenizer, model) -> str:
103
  context = "\n\n".join(chunks)
104
+ prompt = f"summarize: Clinical note. Extract: Chief Complaint, HPI, Assessment, Vitals, Medication, Plan, Discharge Summary.\n\nNote: {context}\n\nSummary:"
105
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=2048, truncation=True)
106
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
107
  model.to(device)
108
  inputs = {k: v.to(device) for k, v in inputs.items()}
109
  with torch.no_grad():
110
  outputs = model.generate(
111
  inputs['input_ids'],
112
+ max_new_tokens=300,
113
+ min_length=100,
114
  num_beams=2,
 
115
  early_stopping=True,
116
+ pad_token_id=tokenizer.pad_token_id
 
 
 
117
  )
118
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def simple_validate(summary: str) -> dict:
121
+ score = 75
122
  warnings = []
123
+ if "not documented" in summary.lower():
124
+ count = summary.lower().count("not documented")
125
+ if count > 3:
126
+ score -= 25
127
+ warnings.append(f"Excessive gaps ({count} sections empty)")
128
+ status = "GOOD" if score >= 70 else "FAIR" if score >= 50 else "POOR"
 
 
 
 
 
 
 
129
  return {"quality_score": score, "status": status, "warnings": warnings}
130
 
 
131
  @st.cache_resource
132
  def load_model(model_name):
133
  try:
 
141
  low_cpu_mem_usage=True,
142
  cache_dir="/tmp/hf_cache"
143
  )
144
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
145
+ model.to(device)
 
 
 
146
  st.sidebar.success("βœ“ Model Loaded")
147
  return tokenizer, model
148
  except Exception as e:
 
165
  with tab1:
166
  st.header("Upload/Paste Note")
167
  uploaded = st.file_uploader("Upload .txt", type=["txt"])
168
+ input_text = st.text_area("Or paste clinical note:", height=250)
169
  note_text = ""
170
  if uploaded:
171
  note_text = uploaded.read().decode("utf-8", errors="ignore")
 
183
  if "encrypted_span_map" in result:
184
  with open(f"{secure_dir}/session_note.spanmap.enc", "wb") as f:
185
  f.write(result["encrypted_span_map"])
186
+ st.success("βœ“ De-identified with audit trail")
187
  except Exception as e:
188
+ st.warning(f"Using fallback De-ID: {e}")
189
  deid_text = fallback_deid(note_text)
190
  else:
191
  deid_text = fallback_deid(note_text)
192
 
193
  st.session_state.deid_text = deid_text
194
+ st.success(f"Ready: {len(deid_text)} chars (PHI redacted)")
195
  else:
196
+ st.warning("Enter text first")
197
 
198
  if st.session_state.deid_text:
199
+ with st.expander("Preview De-identified Text"):
200
+ st.text_area("", st.session_state.deid_text, height=200, disabled=True, key="preview")
201
 
202
  with tab2:
203
  st.header("RAG Summarization")
204
  if not st.session_state.deid_text:
205
+ st.warning("⚠ Please de-identify a note first (Tab 1)")
206
  else:
207
+ st.info(f"βœ“ Ready: {len(st.session_state.deid_text)} chars | Mode: {'Modular RAG' if HAS_MODULES else 'Fallback'}")
208
+
209
+ if st.button("πŸš€ Generate Summary", type="primary"):
210
+ with st.spinner("Processing (this may take 1-2 minutes)..."):
211
  deid_text = st.session_state.deid_text
212
 
213
  try:
214
  if HAS_MODULES:
 
 
 
 
 
 
215
  # Index
216
  index_note(deid_text, note_id="session_note", persist_dir=persist_dir, db_type=db_type)
217
 
218
  # Retrieve
219
  embed_f = load_embedder()
220
  docs = retrieve_docs(db_type, persist_dir, "notes", deid_text[:200], top_k, embed_f)
 
221
 
222
  # Summarize
223
+ if docs:
224
+ summary = summarize_docs(tokenizer, model, docs, method)
225
+ st.session_state.validation = validate_summary_quality(summary, deid_text)
226
+ else:
227
+ raise Exception("No documents retrieved")
228
  else:
229
  chunks = fallback_retrieve(deid_text, top_k)
230
  summary = fallback_summarize(chunks, tokenizer, model)
231
  st.session_state.validation = simple_validate(summary)
232
 
233
  st.session_state.summary = summary
234
+ st.success("βœ“ Summary generated!")
235
 
236
  except Exception as e:
237
+ st.error(f"RAG failed: {e}. Using direct fallback.")
238
+ chunks = fallback_retrieve(deid_text, 3)
239
+ summary = fallback_summarize(chunks, tokenizer, model)
240
  st.session_state.summary = summary
241
  st.session_state.validation = simple_validate(summary)
242
+
243
+ if st.session_state.summary:
244
+ summ = st.session_state.summary
245
+ val = st.session_state.validation
246
 
247
+ col1, col2 = st.columns([3, 1])
248
+ with col1:
249
+ st.subheader("πŸ“‹ Structured Clinical Summary")
250
+ st.markdown(summ)
251
+ with col2:
252
+ st.subheader("πŸ“Š Quality Assessment")
253
+ color = {"EXCELLENT": "🟒", "GOOD": "πŸ”΅", "FAIR": "🟑", "POOR": "πŸ”΄"}.get(val.get("status", ""), "βšͺ")
254
+ st.markdown(f"**{color} {val.get('status', 'N/A')}**")
255
+ st.metric("Quality Score", f"{val.get('quality_score', 0)}/100")
256
 
257
+ if val.get("warnings"):
258
+ for w in val["warnings"]:
259
+ st.warning(w)
260
+
261
+ st.download_button("πŸ’Ύ Download Summary", summ, "clinical_summary.txt", type="secondary")
262
+
263
+ if st.button("πŸ”„ Reset & Start Over"):
264
+ st.session_state.deid_text = ""
265
+ st.session_state.summary = None
266
+ st.session_state.validation = None
267
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
268
 
269
  st.markdown("---")
270
+ st.caption("*HIPAA-Compliant RAG Clinical Summarizer | Portfolio Demo*")