samarth09healthPM commited on
Commit
f64b3f9
·
1 Parent(s): 7f09b7f

Add HIPAA RAG Clinical Summarizer (essential files only)

Browse files
Files changed (13) hide show
  1. .gitignore +46 -0
  2. audit.py +79 -0
  3. bcrypt_pw.py +2 -0
  4. deid_pipeline.py +266 -0
  5. indexer.py +289 -0
  6. main.py +494 -0
  7. notes.py +47 -0
  8. quick_check_chroma.py +20 -0
  9. rag_pipeline.py +117 -0
  10. retriever_context.py +7 -0
  11. run_pipeline.py +51 -0
  12. streamlit_config.yaml +17 -0
  13. summarizer.py +610 -0
.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Large data folders
2
+ synthea/
3
+ data/raw_synthea/
4
+ data/vector_store*/
5
+ synthea/output/fhir/
6
+ *.json
7
+ *.csv
8
+ *.zip
9
+ *.sqlite
10
+ __pycache__/
11
+ *.ipynb_checkpoints
12
+
13
+ # Ignore vector store database
14
+ app/data/vector_store/chroma.sqlite3
15
+ *.sqlite3
16
+ *.db
17
+ # Models and cache
18
+ models/
19
+ .cache/
20
+ transformers_cache/
21
+
22
+ # Virtual environment
23
+ new_env_rag/
24
+ venv/
25
+ env/
26
+
27
+ # Secrets
28
+
29
+
30
+ # Logs and outputs
31
+ logs/
32
+ *.log
33
+ *.jsonl
34
+
35
+ # Python cache
36
+ __pycache__/
37
+ *.pyc
38
+
39
+ # IDE
40
+ .vs/
41
+ .vscode/
42
+ .idea/
43
+
44
+ # OS files
45
+ .DS_Store
46
+ Thumbs.db
audit.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import hashlib
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ import uuid
6
+ import pytz
7
+
8
+ class AuditLogger:
9
+ def __init__(self, log_file_path="logs/app_audit.jsonl"):
10
+ self.log_file = Path(log_file_path)
11
+ self.log_file.parent.mkdir(parents=True, exist_ok=True)
12
+ # Create file if it doesn't exist
13
+ if not self.log_file.exists():
14
+ self.log_file.touch()
15
+
16
+ def _get_last_hash(self):
17
+ """Read the last log entry and return its hash"""
18
+ try:
19
+ with open(self.log_file, 'r') as f:
20
+ lines = f.readlines()
21
+ if lines:
22
+ last_entry = json.loads(lines[-1])
23
+ return last_entry.get('sha256_curr', '')
24
+ except:
25
+ pass
26
+ return '' # First entry has no previous hash
27
+
28
+ def _compute_hash(self, log_entry):
29
+ """Create a hash fingerprint of the log entry"""
30
+ # Convert the log entry to a string and hash it
31
+ entry_string = json.dumps(log_entry, sort_keys=True)
32
+ return hashlib.sha256(entry_string.encode()).hexdigest()
33
+
34
+ def log_action(self, user, action, resource, additional_info=None):
35
+ """
36
+ Main logging function - call this whenever a user does something
37
+
38
+ Args:
39
+ user: username (e.g., 'dr_smith')
40
+ action: what they did (e.g., 'UPLOAD_NOTE', 'GENERATE_SUMMARY', 'VIEW_LOGS')
41
+ resource: what they acted on (e.g., 'note_12345.txt', 'patient_record')
42
+ additional_info: any extra details (dictionary)
43
+ """
44
+ # Get the hash of the previous log entry
45
+ previous_hash = self._get_last_hash()
46
+
47
+ # Generate unique IDs for tracing
48
+ trace_id = str(uuid.uuid4())
49
+ span_id = str(uuid.uuid4())[:16] # Shorter ID for span
50
+
51
+ # India timezone
52
+ india = pytz.timezone('Asia/Kolkata')
53
+ local_time = datetime.now(india).isoformat()
54
+
55
+ # Create the new log entry
56
+ log_entry = {
57
+ "timestamp": local_time + "Z",
58
+ "user": user,
59
+ "action": action,
60
+ "resource": resource,
61
+ "sha256_prev": previous_hash,
62
+ "additional_info": additional_info or {},
63
+
64
+ # OpenTelemetry attributes
65
+ "otel_trace_id": trace_id,
66
+ "otel_span_id": span_id,
67
+ "otel_service_name": "clinical-rag-app",
68
+ "severity": "INFO" # Can be DEBUG, INFO, WARN, ERROR
69
+ }
70
+
71
+ # Compute hash of THIS entry
72
+ current_hash = self._compute_hash(log_entry)
73
+ log_entry["sha256_curr"] = current_hash
74
+
75
+ # Append to log file (append-only = cannot change old entries)
76
+ with open(self.log_file, 'a') as f:
77
+ f.write(json.dumps(log_entry) + '\n')
78
+
79
+ return log_entry
bcrypt_pw.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ import bcrypt
2
+ print(bcrypt.hashpw(b"mypassword", bcrypt.gensalt()).decode())
deid_pipeline.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from dataclasses import dataclass
4
+ from typing import List, Dict, Any, Tuple
5
+
6
+ # Presidio
7
+ from presidio_analyzer import AnalyzerEngine
8
+ from presidio_analyzer.recognizer_registry import RecognizerRegistry
9
+ from presidio_anonymizer import AnonymizerEngine
10
+ from presidio_analyzer import PatternRecognizer
11
+
12
+ # Define medical terms that should NOT be redacted
13
+ medical_terms_allowlist = [
14
+ "substernal", "exertional", "pressure-like", "diaphoresis",
15
+ "chest pain", "nausea", "radiation", "murmurs", "ischemia"
16
+ ]
17
+
18
+ # Configure analyzer to ignore these terms
19
+ analyzer_config = {
20
+ "nlp_engine_name": "spacy",
21
+ "models": [{"lang_code": "en", "model_name": "en_core_web_lg"}],
22
+ "allow_list": medical_terms_allowlist # Don't redact these
23
+ }
24
+
25
+ # NLP for optional section detection
26
+ import spacy
27
+
28
+ # If using medspacy, uncomment (preferred for clinical):
29
+ # import medspacy
30
+ # from medspacy.sectionizer import Sectionizer
31
+
32
+ # If not using medspacy, optional lightweight section tagging:
33
+ # We'll use regex on common headers as a fallback
34
+ import re
35
+
36
+ # Encryption
37
+ from cryptography.fernet import Fernet
38
+
39
+ @dataclass
40
+ class PHISpan:
41
+ entity_type: str
42
+ start: int
43
+ end: int
44
+ text: str
45
+ section: str
46
+
47
+ SECTION_HEADERS = [
48
+ # Common clinical sections; customize as needed
49
+ "HPI", "History of Present Illness",
50
+ "PMH", "Past Medical History",
51
+ "Medications", "Allergies",
52
+ "Assessment and Plan", "Assessment & Plan", "Assessment",
53
+ "Plan", "ROS", "Review of Systems",
54
+ "Physical Exam"
55
+ ]
56
+
57
+ SECTION_PATTERN = re.compile(
58
+ r"^(?P<header>(" + "|".join([re.escape(h) for h in SECTION_HEADERS]) + r"))\s*:\s*$",
59
+ re.IGNORECASE | re.MULTILINE
60
+ )
61
+
62
+ TAG_MAP = {
63
+ "PERSON": "[NAME]",
64
+ "PHONE_NUMBER": "[PHONE]",
65
+ "DATE_TIME": "[DATE]",
66
+ "DATE": "[DATE]",
67
+ "EMAIL_ADDRESS": "[EMAIL]",
68
+ "US_SSN": "[SSN]"
69
+ }
70
+
71
+ class DeidPipeline:
72
+ """
73
+ De-identification pipeline using Microsoft Presidio
74
+ """
75
+ def __init__(self, fernet_key_path="secure_store/fernet.key"):
76
+ """
77
+ Initialize de-identification pipeline with Presidio
78
+
79
+ Args:
80
+ fernet_key_path: Path to Fernet encryption key
81
+ """
82
+ import os
83
+ from cryptography.fernet import Fernet
84
+
85
+ # Initialize encryption
86
+ try:
87
+ if os.path.exists(fernet_key_path):
88
+ # Load existing key from file
89
+ with open(fernet_key_path, "rb") as f:
90
+ key = f.read()
91
+ else:
92
+ # Generate new key for this session
93
+ key = Fernet.generate_key()
94
+ # Try to save it (might fail on read-only filesystems)
95
+ try:
96
+ os.makedirs(os.path.dirname(fernet_key_path), exist_ok=True)
97
+ with open(fernet_key_path, "wb") as f:
98
+ f.write(key)
99
+ except (PermissionError, OSError):
100
+ # Cloud filesystem is read-only, just use the generated key
101
+ pass
102
+
103
+ self.fernet = Fernet(key)
104
+
105
+ except Exception as e:
106
+ # Emergency fallback: Generate temporary key
107
+ print(f"Warning: Could not load encryption key, generating temporary key: {e}")
108
+ key = Fernet.generate_key()
109
+ self.fernet = Fernet(key)
110
+
111
+ # Initialize Presidio components
112
+ self.analyzer = AnalyzerEngine()
113
+ self.anonymizer = AnonymizerEngine()
114
+
115
+ # Load spaCy model
116
+ try:
117
+ self.nlp = spacy.load("en_core_web_lg")
118
+ except OSError:
119
+ print("Downloading spaCy model...")
120
+ import subprocess
121
+ subprocess.run(["python", "-m", "spacy", "download", "en_core_web_lg"])
122
+ self.nlp = spacy.load("en_core_web_lg")
123
+
124
+ def _detect_sections(self, text: str) -> List[Tuple[str, int, int]]:
125
+ """
126
+ Lightweight section finder:
127
+ Return list of (section_title, start_idx, end_idx_of_section_block)
128
+ """
129
+ # Find headers by regex, map their start positions
130
+ headers = []
131
+ for m in SECTION_PATTERN.finditer(text):
132
+ headers.append((m.group("header"), m.start()))
133
+ # Add end sentinel
134
+ headers.append(("[END]", len(text)))
135
+
136
+ sections = []
137
+ for i in range(len(headers) - 1):
138
+ title, start_pos = headers[i]
139
+ next_title, next_pos = headers[i+1]
140
+ sections.append((title.strip(), start_pos, next_pos))
141
+ if not sections:
142
+ # Single default section if none found
143
+ sections = [("DOCUMENT", 0, len(text))]
144
+ return sections
145
+
146
+ def _find_section_for_span(self, sections, start_idx) -> str:
147
+ for title, s, e in sections:
148
+ if s <= start_idx < e:
149
+ return title
150
+ return "DOCUMENT"
151
+
152
+ def analyze(self, text: str) -> List[Dict[str, Any]]:
153
+ # Detect entities
154
+ results = self.analyzer.analyze(text=text, language="en")
155
+ # Convert to dict for consistency
156
+ detections = []
157
+ for r in results:
158
+ detections.append({
159
+ "entity_type": r.entity_type,
160
+ "start": r.start,
161
+ "end": r.end,
162
+ "score": r.score
163
+ })
164
+ return detections
165
+
166
+ def mask(self, text: str, detections: List[Dict[str, Any]]) -> Tuple[str, List[PHISpan]]:
167
+ """
168
+ Replace spans with tags safely (right-to-left to maintain indices).
169
+ """
170
+ # Determine sections for context
171
+ sections = self._detect_sections(text)
172
+
173
+ # Build PHI span records
174
+ spans: List[PHISpan] = []
175
+ for d in detections:
176
+ entity = d["entity_type"]
177
+ start = d["start"]
178
+ end = d["end"]
179
+ original = text[start:end]
180
+ section = self._find_section_for_span(sections, start)
181
+ spans.append(PHISpan(entity_type=entity, start=start, end=end, text=original, section=section))
182
+
183
+ # Replace from the end to avoid index shifting
184
+ masked = text
185
+ for d in sorted(detections, key=lambda x: x["start"], reverse=True):
186
+ entity = d["entity_type"]
187
+ start = d["start"]
188
+ end = d["end"]
189
+ tag = TAG_MAP.get(entity, f"[{entity}]")
190
+ masked = masked[:start] + tag + masked[end:]
191
+
192
+ return masked, spans
193
+
194
+ def encrypt_span_map(self, spans: List[PHISpan], meta: Dict[str, Any]) -> bytes:
195
+ payload = {
196
+ "meta": meta,
197
+ "spans": [s.__dict__ for s in spans]
198
+ }
199
+ blob = json.dumps(payload).encode("utf-8")
200
+ token = self.fernet.encrypt(blob)
201
+ return token
202
+
203
+ def run_on_text(self, text: str, note_id: str) -> Dict[str, Any]:
204
+ detections = self.analyze(text)
205
+ masked, spans = self.mask(text, detections)
206
+
207
+ # Encrypt span map
208
+ token = self.encrypt_span_map(
209
+ spans=spans,
210
+ meta={"note_id": note_id}
211
+ )
212
+
213
+ return {
214
+ "masked_text": masked,
215
+ "encrypted_span_map": token
216
+ }
217
+
218
+ def _read_text_with_fallback(path: str) -> str:
219
+ # 1) Try UTF-8 (preferred for cross-platform)
220
+ try:
221
+ with open(path, "r", encoding="utf-8") as f:
222
+ return f.read()
223
+ except UnicodeDecodeError:
224
+ pass
225
+ # 2) Try Windows-1252 (common for Notepad/docx copy-paste on Windows)
226
+ try:
227
+ with open(path, "r", encoding="cp1252") as f:
228
+ return f.read()
229
+ except UnicodeDecodeError:
230
+ pass
231
+ # 3) Last resort: decode with replacement to avoid crashing; preserves structure
232
+ with open(path, "r", encoding="utf-8", errors="replace") as f:
233
+ return f.read()
234
+
235
+ def run_file(input_path: str, outputs_dir: str = "data/outputs", secure_dir: str = "secure_store"):
236
+ os.makedirs(outputs_dir, exist_ok=True)
237
+ os.makedirs(secure_dir, exist_ok=True)
238
+
239
+ note_id = os.path.splitext(os.path.basename(input_path))[0]
240
+ text = _read_text_with_fallback(input_path)
241
+
242
+ pipeline = DeidPipeline()
243
+ result = pipeline.run_on_text(text, note_id=note_id)
244
+
245
+ # Save masked text normalized to UTF-8
246
+ out_txt = os.path.join(outputs_dir, f"{note_id}.deid.txt")
247
+ with open(out_txt, "w", encoding="utf-8", newline="\n") as f:
248
+ f.write(result["masked_text"])
249
+
250
+ # Save encrypted span map (binary)
251
+ out_bin = os.path.join(secure_dir, f"{note_id}.spanmap.enc")
252
+ with open(out_bin, "wb") as f:
253
+ f.write(result["encrypted_span_map"])
254
+
255
+ print(f"De-identified text -> {out_txt}")
256
+ print(f"Encrypted span map -> {out_bin}")
257
+
258
+
259
+ if __name__ == "__main__":
260
+ import argparse
261
+ parser = argparse.ArgumentParser(description="De-identify a clinical note and save encrypted span map.")
262
+ parser.add_argument("--input", required=True, help="Path to input .txt note")
263
+ parser.add_argument("--outputs_dir", default="data/outputs", help="Output folder for masked text")
264
+ parser.add_argument("--secure_dir", default="secure_store", help="Folder for encrypted span maps")
265
+ args = parser.parse_args()
266
+ run_file(args.input, args.outputs_dir, args.secure_dir)
indexer.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/indexer.py
2
+ # Day 6: Vector store & embeddings
3
+ # Usage examples:
4
+ # python app/indexer.py --input_dir ./data/outputs --db_type chroma --persist_dir ./data/vector_store
5
+ # python app/indexer.py --input_dir ./data/outputs --db_type faiss --persist_dir ./data/vector_store_faiss
6
+
7
+ import os
8
+ import json
9
+ import argparse
10
+ from pathlib import Path
11
+ from typing import List, Dict, Tuple
12
+ from tqdm import tqdm
13
+
14
+ # Embeddings
15
+ from sentence_transformers import SentenceTransformer
16
+
17
+ # Vector stores
18
+ # Chroma
19
+ import chromadb
20
+ from chromadb.config import Settings as ChromaSettings
21
+
22
+ # FAISS
23
+ import faiss
24
+ import pickle
25
+
26
+ DEFAULT_CHUNK_TOKENS = 200
27
+ DEFAULT_OVERLAP_TOKENS = 50
28
+
29
+ def read_note_files(input_dir: str) -> List[Dict]:
30
+ """
31
+ Reads de-identified notes from .txt or .json in input_dir.
32
+ Expects .json to have a 'text' field containing de-identified content.
33
+ Returns list of dicts: {id, text, section?}
34
+ """
35
+ items = []
36
+ p = Path(input_dir)
37
+ if not p.exists():
38
+ raise FileNotFoundError(f"Input dir not found: {input_dir}")
39
+
40
+ for fp in p.glob("**/*"):
41
+ if fp.is_dir():
42
+ continue
43
+ if fp.suffix.lower() == ".txt":
44
+ text = fp.read_text(encoding="utf-8", errors="ignore").strip()
45
+ if text:
46
+ items.append({"id": fp.stem, "text": text, "section": None})
47
+ elif fp.suffix.lower() == ".json":
48
+ try:
49
+ obj = json.loads(fp.read_text(encoding="utf-8", errors="ignore"))
50
+ text = obj.get("text") or obj.get("deidentified_text") or ""
51
+ section = obj.get("section")
52
+ if text:
53
+ items.append({"id": fp.stem, "text": text.strip(), "section": section})
54
+ except Exception:
55
+ # Skip malformed
56
+ continue
57
+ return items
58
+
59
+ def approx_tokenize(text: str) -> List[str]:
60
+ """
61
+ Approximate tokenization by splitting on whitespace.
62
+ For MVP this is fine; can replace with tiktoken later.
63
+ """
64
+ return text.split()
65
+
66
+ def detokenize(tokens: List[str]) -> str:
67
+ return " ".join(tokens)
68
+
69
+ def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> List[str]:
70
+ """
71
+ Simple sliding window chunking.
72
+ """
73
+ tokens = approx_tokenize(text)
74
+ chunks = []
75
+ i = 0
76
+ n = len(tokens)
77
+ while i < n:
78
+ j = min(i + chunk_tokens, n)
79
+ chunk = detokenize(tokens[i:j])
80
+ if chunk.strip():
81
+ chunks.append(chunk)
82
+ if j == n:
83
+ break
84
+ i = j - overlap_tokens
85
+ if i < 0:
86
+ i = 0
87
+ return chunks
88
+
89
+ def embed_texts(model: SentenceTransformer, texts: List[str]):
90
+ return model.encode(texts, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
91
+
92
+ def build_chroma(persist_dir: str, collection_name: str = "notes"):
93
+ client = chromadb.PersistentClient(
94
+ path=persist_dir,
95
+ settings=ChromaSettings(allow_reset=True)
96
+ )
97
+ if collection_name in [c.name for c in client.list_collections()]:
98
+ coll = client.get_collection(collection_name)
99
+ else:
100
+ coll = client.create_collection(collection_name)
101
+ return client, coll
102
+
103
+ def save_faiss(index, vectors_meta: List[Dict], persist_dir: str):
104
+ os.makedirs(persist_dir, exist_ok=True)
105
+ faiss_path = os.path.join(persist_dir, "index.faiss")
106
+ meta_path = os.path.join(persist_dir, "meta.pkl")
107
+ faiss.write_index(index, faiss_path)
108
+ with open(meta_path, "wb") as f:
109
+ pickle.dump(vectors_meta, f)
110
+
111
+ def load_faiss(persist_dir: str):
112
+ faiss_path = os.path.join(persist_dir, "index.faiss")
113
+ meta_path = os.path.join(persist_dir, "meta.pkl")
114
+ if os.path.exists(faiss_path) and os.path.exists(meta_path):
115
+ index = faiss.read_index(faiss_path)
116
+ with open(meta_path, "rb") as f:
117
+ meta = pickle.load(f)
118
+ return index, meta
119
+ return None, []
120
+
121
+ def index_note(
122
+ text: str,
123
+ note_id: str = "temp_note",
124
+ persist_dir: str = "./data/vector_store",
125
+ db_type: str = "chroma",
126
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
127
+ collection: str = "notes"
128
+ ) -> str:
129
+ from sentence_transformers import SentenceTransformer
130
+ import os
131
+
132
+ DEFAULT_CHUNK_TOKENS = 200
133
+ DEFAULT_OVERLAP_TOKENS = 50
134
+
135
+ def approx_tokenize(text: str):
136
+ return text.split()
137
+
138
+ def detokenize(tokens):
139
+ return " ".join(tokens)
140
+
141
+ def chunk_text(text, chunk_tokens, overlap_tokens):
142
+ tokens = approx_tokenize(text)
143
+ chunks = []
144
+ i = 0
145
+ n = len(tokens)
146
+ while i < n:
147
+ j = min(i + chunk_tokens, n)
148
+ chunk = detokenize(tokens[i:j])
149
+ if chunk.strip():
150
+ chunks.append(chunk)
151
+ if j == n:
152
+ break
153
+ i = j - overlap_tokens
154
+ if i < 0:
155
+ i = 0
156
+ return chunks
157
+
158
+ os.makedirs(persist_dir, exist_ok=True)
159
+ model = SentenceTransformer(model_name)
160
+ chunks = chunk_text(text, DEFAULT_CHUNK_TOKENS, DEFAULT_OVERLAP_TOKENS)
161
+ chunk_ids = [f"{note_id}::chunk_{i}" for i in range(len(chunks))]
162
+ metadatas = [{"note_id": note_id, "chunk_index": i} for i in range(len(chunks))]
163
+ vectors = model.encode(chunks, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
164
+
165
+ if db_type == "chroma":
166
+ from chromadb.config import Settings as ChromaSettings
167
+ import chromadb
168
+ client = chromadb.PersistentClient(
169
+ path=persist_dir,
170
+ settings=ChromaSettings(allow_reset=True)
171
+ )
172
+ if collection in [c.name for c in client.list_collections()]:
173
+ coll = client.get_collection(collection)
174
+ else:
175
+ coll = client.create_collection(collection)
176
+ coll.upsert(
177
+ ids=chunk_ids,
178
+ embeddings=vectors.tolist(),
179
+ documents=chunks,
180
+ metadatas=metadatas,
181
+ )
182
+ elif db_type == "faiss":
183
+ import faiss
184
+ import pickle
185
+ d = vectors.shape[1]
186
+ index = faiss.IndexFlatIP(d)
187
+ index.add(vectors)
188
+ vectors_meta = [
189
+ {"id": chunk_ids[k], "text": chunks[k], "meta": metadatas[k]}
190
+ for k in range(len(chunks))
191
+ ]
192
+ faiss_path = os.path.join(persist_dir, "index.faiss")
193
+ meta_path = os.path.join(persist_dir, "meta.pkl")
194
+ faiss.write_index(index, faiss_path)
195
+ with open(meta_path, "wb") as f:
196
+ pickle.dump(vectors_meta, f)
197
+
198
+ return note_id
199
+
200
+
201
+ def main():
202
+ parser = argparse.ArgumentParser(description="Day 6: Build local vector DB from de-identified notes.")
203
+ parser.add_argument("--input_dir", required=True, help="Directory with de-identified notes (.txt or .json).")
204
+ parser.add_argument("--persist_dir", default="./data/vector_store", help="Where to persist the DB.")
205
+ parser.add_argument("--db_type", choices=["chroma", "faiss"], default="chroma", help="Vector DB type.")
206
+ parser.add_argument("--model_name", default="sentence-transformers/all-MiniLM-L6-v2", help="Embedding model.")
207
+ parser.add_argument("--chunk_tokens", type=int, default=DEFAULT_CHUNK_TOKENS, help="Approx tokens per chunk.")
208
+ parser.add_argument("--overlap_tokens", type=int, default=DEFAULT_OVERLAP_TOKENS, help="Token overlap.")
209
+ parser.add_argument("--collection", default="notes", help="Collection name (Chroma).")
210
+ args = parser.parse_args()
211
+
212
+ notes = read_note_files(args.input_dir)
213
+ if not notes:
214
+ print(f"No de-identified notes found in {args.input_dir}. Ensure Day 5 outputs exist.")
215
+ return
216
+
217
+ print(f"Loaded {len(notes)} de-identified notes from {args.input_dir}")
218
+ os.makedirs(args.persist_dir, exist_ok=True)
219
+
220
+ print(f"Loading embedding model: {args.model_name}")
221
+ model = SentenceTransformer(args.model_name)
222
+
223
+ all_chunk_texts = []
224
+ all_chunk_ids = []
225
+ all_metadata = []
226
+
227
+ print("Chunking notes...")
228
+ for note in tqdm(notes):
229
+ chunks = chunk_text(note["text"], args.chunk_tokens, args.overlap_tokens)
230
+ for idx, ch in enumerate(chunks):
231
+ cid = f"{note['id']}::chunk_{idx}"
232
+ all_chunk_texts.append(ch)
233
+ all_chunk_ids.append(cid)
234
+ all_metadata.append({
235
+ "note_id": note["id"],
236
+ "chunk_index": idx,
237
+ "section": note.get("section")
238
+ })
239
+
240
+ print(f"Total chunks: {len(all_chunk_texts)}")
241
+
242
+ print("Embedding chunks...")
243
+ vectors = embed_texts(model, all_chunk_texts)
244
+
245
+ if args.db_type == "chroma":
246
+ print("Building Chroma persistent collection...")
247
+ client, coll = build_chroma(args.persist_dir, args.collection)
248
+
249
+ # Upsert in manageable batches
250
+ batch = 512
251
+ for i in tqdm(range(0, len(all_chunk_texts), batch)):
252
+ j = min(i + batch, len(all_chunk_texts))
253
+ coll.upsert(
254
+ ids=all_chunk_ids[i:j],
255
+ embeddings=vectors[i:j].tolist(),
256
+ documents=all_chunk_texts[i:j],
257
+ metadatas=all_metadata[i:j],
258
+ )
259
+ print(f"Chroma collection '{args.collection}' persisted at {args.persist_dir}")
260
+
261
+ elif args.db_type == "faiss":
262
+ print("Building FAISS index...")
263
+ d = vectors.shape[1]
264
+ index = faiss.IndexFlatIP(d) # normalized vectors → use inner product as cosine
265
+ # Try to load existing
266
+ existing_index, existing_meta = load_faiss(args.persist_dir)
267
+ if existing_index is not None:
268
+ print("Appending to existing FAISS index...")
269
+ index = existing_index
270
+ vectors_meta = existing_meta
271
+ else:
272
+ vectors_meta = []
273
+ index.add(vectors)
274
+ vectors_meta.extend([
275
+ {
276
+ "id": all_chunk_ids[k],
277
+ "text": all_chunk_texts[k],
278
+ "meta": all_metadata[k]
279
+ } for k in range(len(all_chunk_texts))
280
+ ])
281
+ save_faiss(index, vectors_meta, args.persist_dir)
282
+ print(f"FAISS index persisted at {args.persist_dir}")
283
+
284
+ print("Done.")
285
+
286
+ if __name__ == "__main__":
287
+ main()
288
+ ##result = pipeline.run_on_text(text=note_text, note_id="temp_note")
289
+ ##deid_text = result["masked_text"]
main.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Imports and page setup ---
2
+ import streamlit as st
3
+ import yaml
4
+ from yaml.loader import SafeLoader
5
+ import streamlit_authenticator as stauth
6
+ import uuid
7
+ import datetime
8
+ from audit import AuditLogger
9
+
10
+ st.set_page_config(page_title="Clinical Summarizer", layout="wide")
11
+ st.title("HIPAA-compliant Clinical RAG Summarizer (MVP)")
12
+
13
+ # --- Authentication setup ---
14
+ def load_config():
15
+ """Load configuration from Streamlit secrets or local YAML"""
16
+ try:
17
+ # Check if running on Streamlit Cloud (secrets available)
18
+ if "credentials" in st.secrets:
19
+ # Convert immutable Streamlit secrets to mutable dict
20
+ config = {
21
+ "credentials": {
22
+ "usernames": {}
23
+ },
24
+ "cookie": {
25
+ "name": str(st.secrets["cookie"]["name"]),
26
+ "key": str(st.secrets["cookie"]["key"]),
27
+ "expiry_days": int(st.secrets["cookie"]["expiry_days"])
28
+ }
29
+ }
30
+
31
+ # Convert each user to mutable dict
32
+ for username, user_data in st.secrets["credentials"]["usernames"].items():
33
+ config["credentials"]["usernames"][str(username)] = {
34
+ "email": str(user_data["email"]),
35
+ "failed_login_attempts": int(user_data.get("failed_login_attempts", 0)),
36
+ "logged_in": bool(user_data.get("logged_in", False)),
37
+ "name": str(user_data["name"]),
38
+ "password": str(user_data["password"]),
39
+ "role": str(user_data["role"])
40
+ }
41
+
42
+ return config
43
+ else:
44
+ # Local development: Load from YAML file
45
+ with open("app/streamlit_config.yaml") as f:
46
+ return yaml.load(f, Loader=SafeLoader)
47
+
48
+ except FileNotFoundError:
49
+ st.error("⚠️ Configuration file not found. Please set up authentication.")
50
+ st.stop()
51
+ except Exception as e:
52
+ st.error(f"⚠️ Configuration error: {e}")
53
+ st.info("Make sure secrets are configured in Streamlit Cloud settings.")
54
+ st.stop()
55
+
56
+ # Load config
57
+ config = load_config()
58
+
59
+ # Create authenticator with mutable config
60
+ authenticator = stauth.Authenticate(
61
+ config["credentials"],
62
+ config["cookie"]["name"],
63
+ config["cookie"]["key"],
64
+ config["cookie"]["expiry_days"],
65
+ )
66
+
67
+ # Render the login widget
68
+ authenticator.login(location="sidebar")
69
+
70
+ # Read values from session_state
71
+ auth_status = st.session_state.get("authentication_status")
72
+ username = st.session_state.get("username")
73
+ name = st.session_state.get("name")
74
+
75
+ if auth_status is False:
76
+ st.error("Invalid username or password")
77
+ st.stop()
78
+ elif auth_status is None:
79
+ st.info("Please log in")
80
+ st.stop()
81
+ else:
82
+ role = config["credentials"]["usernames"][username]["role"]
83
+ st.session_state["role"] = role
84
+ with st.sidebar:
85
+ st.header("Clinical RAG Summarizer")
86
+ st.markdown("HIPAA-compliant, secure, and easy to use.")
87
+ st.markdown("---")
88
+ st.success(f"Logged in as {name}")
89
+ st.markdown(f"**Role:** {role}")
90
+ authenticator.logout("Logout", location="sidebar")
91
+ st.markdown("---")
92
+ st.info("Use the tabs above to upload notes, generate summaries, and view logs.")
93
+
94
+ # Clear ChromaDB cache to prevent singleton conflicts
95
+ try:
96
+ from chromadb.api.client import SharedSystemClient
97
+ SharedSystemClient.clear_system_cache()
98
+ except:
99
+ pass
100
+
101
+ # Generate a unique persist_dir for each session if not already set
102
+ if "persist_dir" not in st.session_state:
103
+ if st.session_state.get("username"):
104
+ st.session_state["persist_dir"] = f"./data/vector_store_{st.session_state['username']}"
105
+ else:
106
+ unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + str(uuid.uuid4())[:8]
107
+ st.session_state["persist_dir"] = f"./data/vector_store_{unique_id}"
108
+
109
+ # Initialize the audit logger
110
+ audit_logger = AuditLogger()
111
+
112
+ # Initialize model cache in session state
113
+ if "t5_model" not in st.session_state:
114
+ st.session_state["t5_model"] = None
115
+ if "t5_tokenizer" not in st.session_state:
116
+ st.session_state["t5_tokenizer"] = None
117
+
118
+ # --- Tabs ---
119
+ upload_tab, summarize_tab, logs_tab = st.tabs(["Upload/Enter Note", "Summarize", "Logs"])
120
+
121
+ # --- Upload/Enter Note tab ---
122
+ with upload_tab:
123
+ st.subheader("Enter or Upload Note")
124
+ st.caption("Paste a synthetic note or upload a .txt file, then de-identify and index.")
125
+
126
+ col_upload, col_text = st.columns([1, 2])
127
+ with col_upload:
128
+ file = st.file_uploader("Upload .txt file", type=["txt"])
129
+ with col_text:
130
+ note_text = st.text_area("Paste note text", height=200, placeholder="Paste clinical note text here...")
131
+
132
+ col1, col2 = st.columns(2)
133
+ with col1:
134
+ deid_index_clicked = st.button("De-identify & Index", use_container_width=True)
135
+ with col2:
136
+ skip_index_clicked = st.button("Skip (already indexed)", use_container_width=True)
137
+
138
+ if file and not note_text:
139
+ note_text = file.read().decode("utf-8", errors="ignore")
140
+
141
+ if deid_index_clicked and note_text:
142
+ try:
143
+ with st.spinner("De-identifying and indexing..."):
144
+ from deid_pipeline import DeidPipeline
145
+ pipeline = DeidPipeline()
146
+ result = pipeline.run_on_text(text=note_text, note_id="temp_note")
147
+ deid_text = result["masked_text"]
148
+ st.success("De-identified.")
149
+ st.text_area("De-identified preview", deid_text, height=160)
150
+
151
+ from indexer import index_note
152
+ # Use session-specific persist_dir
153
+ note_id = index_note(
154
+ text=deid_text,
155
+ note_id=f"note_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}",
156
+ persist_dir=st.session_state["persist_dir"],
157
+ db_type="chroma",
158
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
159
+ collection="notes"
160
+ )
161
+ st.session_state["last_note_id"] = note_id
162
+ st.session_state["last_deid_text"] = deid_text
163
+ st.session_state["last_note_indexed"] = True
164
+ st.success(f"✓ Indexed note_id: {note_id}")
165
+ st.info(f"📁 Stored in: {st.session_state['persist_dir']}")
166
+
167
+ # Audit log for indexing
168
+ audit_logger.log_action(
169
+ user=st.session_state.get('username', 'anonymous'),
170
+ action="INDEX_NOTE",
171
+ resource=note_id,
172
+ additional_info={
173
+ "text_length": len(deid_text),
174
+ "persist_dir": st.session_state["persist_dir"]
175
+ }
176
+ )
177
+
178
+ # Audit log for de-identification
179
+ audit_logger.log_action(
180
+ user=st.session_state.get('username', 'anonymous'),
181
+ action="DEID_PROCESS",
182
+ resource="temp_note",
183
+ additional_info={"original_length": len(note_text), "deid_length": len(deid_text)}
184
+ )
185
+ st.toast("Note indexed and de-identified!", icon="✅")
186
+ except Exception as e:
187
+ st.error(f"De-identification error: {e}")
188
+ import traceback
189
+ st.code(traceback.format_exc())
190
+ elif skip_index_clicked and note_text:
191
+ st.session_state["last_deid_text"] = note_text
192
+ st.session_state["last_note_indexed"] = False
193
+ st.info("Skipped indexing; text saved for summarization.")
194
+
195
+ # Audit log for skipping
196
+ audit_logger.log_action(
197
+ user=st.session_state.get('username', 'anonymous'),
198
+ action="SKIP_INDEX",
199
+ resource="temp_note",
200
+ additional_info={"text_length": len(note_text)}
201
+ )
202
+
203
+ if "last_deid_text" not in st.session_state:
204
+ st.caption("Tip: click 'De-identify & Index' or 'Skip' to carry text into the Summarize tab.")
205
+ else:
206
+ st.write(f"✓ Note text ready: {len(st.session_state['last_deid_text'])} characters")
207
+ st.write(f"Preview: {st.session_state['last_deid_text'][:100]}...")
208
+
209
+ # --- Summarize tab ---
210
+ with summarize_tab:
211
+ st.subheader("Summarize")
212
+ st.caption("Retrieves context and generates a structured clinical summary.")
213
+
214
+ from rag_pipeline import load_embedder, load_chroma, load_faiss_langchain, retrieve
215
+ from summarizer import make_t5, summarize_docs, validate_summary_quality
216
+
217
+ # Environment detection
218
+ import os
219
+ IS_CLOUD = os.path.exists('/mount/src')
220
+ if IS_CLOUD:
221
+ st.info("🌐 Cloud Mode: Using optimized model (flan-t5-base)")
222
+
223
+ # Clear ChromaDB system cache to avoid singleton conflicts
224
+ try:
225
+ import chromadb
226
+ from chromadb.api.client import SharedSystemClient
227
+ SharedSystemClient.clear_system_cache()
228
+ except Exception as e:
229
+ st.warning(f"Could not clear ChromaDB cache: {e}")
230
+
231
+ # Show current vector store location
232
+ st.info(f"📁 Using vector store: {st.session_state['persist_dir']}")
233
+
234
+ source_choice = st.radio("Use source:", ["Last de-identified text", "Note ID"], horizontal=True)
235
+ default_note_id = st.session_state.get("last_note_id", "")
236
+ user_note_id = st.text_input("Note ID (optional)", value=str(default_note_id))
237
+
238
+ # Add method selection
239
+ method_choice = st.radio("Extraction method:", ["multistage", "singleshot"], horizontal=True,
240
+ help="Multistage: Better quality, slower. Singleshot: Faster, may miss details.")
241
+
242
+ generate_clicked = st.button("Generate Summary", type="primary", use_container_width=True)
243
+
244
+ if generate_clicked:
245
+ try:
246
+ with st.spinner("Retrieving context..."):
247
+ embed_model = "sentence-transformers/all-MiniLM-L6-v2"
248
+ db_type = "chroma"
249
+ persist_dir = st.session_state["persist_dir"]
250
+ collection = "notes"
251
+ top_k = 5
252
+
253
+ # Cache vector database in session state to avoid recreating
254
+ cache_key = f"vdb_{persist_dir}_{collection}"
255
+
256
+ if cache_key not in st.session_state:
257
+ st.info("⏳ Loading vector database (first time)...")
258
+ _, embeddings = load_embedder(embed_model)
259
+
260
+ # Clear cache before creating new instance
261
+ try:
262
+ SharedSystemClient.clear_system_cache()
263
+ except:
264
+ pass
265
+
266
+ if db_type == "chroma":
267
+ vdb = load_chroma(persist_dir, collection, embeddings)
268
+ else:
269
+ vdb = load_faiss_langchain(persist_dir, embeddings)
270
+
271
+ st.session_state[cache_key] = vdb
272
+ st.success("✓ Vector database loaded")
273
+ else:
274
+ vdb = st.session_state[cache_key]
275
+ st.info("✓ Using cached vector database")
276
+
277
+ # Use actual note content for retrieval
278
+ if source_choice == "Note ID" and user_note_id:
279
+ query_text = user_note_id
280
+ st.info(f"🔍 Retrieving by Note ID: {user_note_id}")
281
+ else:
282
+ deid_text = st.session_state.get("last_deid_text", "")
283
+ if not deid_text:
284
+ st.warning("No de-identified text available. Please use the Upload tab first.")
285
+ st.stop()
286
+ query_text = deid_text[:500]
287
+ st.info(f"🔍 Retrieving using note content ({len(deid_text)} chars)")
288
+
289
+ docs = retrieve(vdb, query_text, top_k)
290
+
291
+ if not docs:
292
+ st.error("⚠ No documents retrieved from vector database!")
293
+ st.warning("This usually means:")
294
+ st.write("• The vector database is empty")
295
+ st.write("• The note wasn't properly indexed")
296
+ st.write(f"• Check if files exist in: {persist_dir}")
297
+
298
+ if st.button("🔄 Clear cache and retry"):
299
+ if cache_key in st.session_state:
300
+ del st.session_state[cache_key]
301
+ SharedSystemClient.clear_system_cache()
302
+ st.rerun()
303
+ st.stop()
304
+
305
+ st.success(f"✓ Retrieved {len(docs)} document(s)")
306
+
307
+ # Show preview of retrieved content
308
+ with st.expander("View retrieved content"):
309
+ for i, doc in enumerate(docs, 1):
310
+ st.write(f"**Document {i}:**")
311
+ st.code(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content)
312
+
313
+ with st.spinner("Generating summary... (this may take 1-2 minutes on CPU)"):
314
+ # Cache model loading in session state
315
+ if st.session_state["t5_model"] is None or st.session_state["t5_tokenizer"] is None:
316
+ st.info("⏳ Loading T5 model (first time only)...")
317
+ tokenizer, model = make_t5("google/flan-t5-base")
318
+ st.session_state["t5_tokenizer"] = tokenizer
319
+ st.session_state["t5_model"] = model
320
+ else:
321
+ tokenizer = st.session_state["t5_tokenizer"]
322
+ model = st.session_state["t5_model"]
323
+ st.info("✓ Using cached model")
324
+
325
+ # Generate summary
326
+ summary = summarize_docs(tokenizer, model, docs, method=method_choice)
327
+
328
+ # Store summary in session state
329
+ summary_key = f"summary_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
330
+ st.session_state["last_summary"] = summary
331
+ st.session_state["last_summary_key"] = summary_key
332
+
333
+ # Validation
334
+ original_text = st.session_state.get("last_deid_text", "")
335
+ validation = validate_summary_quality(summary, original_text)
336
+
337
+ # Display validation results
338
+ status_color = {
339
+ "GOOD": "🟢",
340
+ "FAIR": "🟡",
341
+ "POOR": "🟠",
342
+ "FAILED": "🔴"
343
+ }
344
+
345
+ st.success("✓ Summary generated successfully")
346
+
347
+ # Show quality assessment in two columns
348
+ col_status, col_score = st.columns([3, 1])
349
+ with col_status:
350
+ st.markdown(f"### {status_color.get(validation['status'], '⚪')} Quality Status: **{validation['status']}**")
351
+ with col_score:
352
+ st.metric("Quality Score", f"{validation['quality_score']}/100")
353
+
354
+ # Display critical issues if any
355
+ if validation['issues']:
356
+ st.error("**❌ Critical Issues Detected:**")
357
+ for issue in validation['issues']:
358
+ st.markdown(f"- {issue}")
359
+ st.markdown("**Recommendation:** Review de-identification settings and retrieval quality")
360
+
361
+ # Display warnings if any
362
+ if validation['warnings']:
363
+ st.warning("**⚠️ Quality Warnings:**")
364
+ for warning in validation['warnings']:
365
+ st.markdown(f"- {warning}")
366
+
367
+ # Show detailed quality metrics in expandable section
368
+ with st.expander("📊 Detailed Quality Metrics"):
369
+ metric_col1, metric_col2, metric_col3, metric_col4 = st.columns(4)
370
+ with metric_col1:
371
+ st.metric("PHI Placeholders", validation['metrics']['total_placeholders'])
372
+ with metric_col2:
373
+ st.metric("Empty Sections", validation['metrics']['empty_sections'])
374
+ with metric_col3:
375
+ st.metric("Filled Sections", f"{validation['metrics']['filled_sections']}/7")
376
+ with metric_col4:
377
+ st.metric("Total Length", f"{validation['metrics']['total_length']} chars")
378
+
379
+ # Show warning banner if quality is poor
380
+ if validation['status'] in ['POOR', 'FAILED']:
381
+ st.warning("⚠️ **Quality Alert:** The summary below has significant quality issues. Review carefully before clinical use.")
382
+ elif validation['status'] == 'FAIR':
383
+ st.info("ℹ️ The summary has minor quality issues. Review the warnings above.")
384
+ else:
385
+ st.success("✅ Summary quality is acceptable.")
386
+
387
+ # Display the summary
388
+ st.text_area("Structured Summary", summary, height=400, key=f"summary_display_{summary_key}")
389
+ st.download_button("Download .txt", data=summary, file_name=f"summary_{user_note_id or 'latest'}.txt")
390
+
391
+ # Show summary statistics
392
+ col1, col2, col3 = st.columns(3)
393
+ with col1:
394
+ st.metric("Summary Length", f"{len(summary)} chars")
395
+ with col2:
396
+ st.metric("Documents Retrieved", len(docs))
397
+ with col3:
398
+ sections_filled = 7 - summary.count("None stated")
399
+ st.metric("Sections Filled", f"{sections_filled}/7")
400
+
401
+ # Audit log for summary generation with validation results
402
+ audit_logger.log_action(
403
+ user=st.session_state.get('username', 'anonymous'),
404
+ action="GENERATE_SUMMARY",
405
+ resource=user_note_id or "temp_note",
406
+ additional_info={
407
+ "retrieved_docs": len(docs),
408
+ "method": method_choice,
409
+ "summary_length": len(summary),
410
+ "persist_dir": persist_dir,
411
+ "sections_filled": sections_filled,
412
+ "quality_status": validation['status'],
413
+ "quality_score": validation['quality_score'],
414
+ "validation_issues": len(validation['issues']),
415
+ "validation_warnings": len(validation['warnings']),
416
+ "phi_placeholders": validation['metrics']['total_placeholders']
417
+ }
418
+ )
419
+
420
+ except ValueError as ve:
421
+ if "already exists" in str(ve):
422
+ st.error("❌ ChromaDB instance conflict detected!")
423
+ st.warning("This happens when the vector database is accessed with different settings.")
424
+ st.info("**Solution:** Click the button below to clear the cache and retry.")
425
+
426
+ if st.button("🔄 Clear ChromaDB cache and retry", type="primary"):
427
+ try:
428
+ SharedSystemClient.clear_system_cache()
429
+ except:
430
+ pass
431
+
432
+ keys_to_delete = [k for k in st.session_state.keys() if k.startswith("vdb_")]
433
+ for key in keys_to_delete:
434
+ del st.session_state[key]
435
+
436
+ st.success("✓ Cache cleared! Click 'Generate Summary' again.")
437
+ st.rerun()
438
+ else:
439
+ st.error(f"❌ Error during summarization: {ve}")
440
+ import traceback
441
+ st.code(traceback.format_exc())
442
+ except Exception as e:
443
+ st.error(f"❌ Error during summarization: {e}")
444
+ import traceback
445
+ st.code(traceback.format_exc())
446
+
447
+ # Show last summary if available (when button not clicked)
448
+ elif "last_summary" in st.session_state:
449
+ st.info("Showing last generated summary:")
450
+ st.text_area("Last Summary", st.session_state["last_summary"], height=400)
451
+ st.download_button("Download Last Summary",
452
+ data=st.session_state["last_summary"],
453
+ file_name="last_summary.txt")
454
+
455
+ # --- Logs tab ---
456
+ with logs_tab:
457
+ st.subheader("Logs")
458
+ if st.session_state.get("role") != "admin":
459
+ st.info("Admins only.")
460
+ else:
461
+ st.caption("Audit logs for all user actions.")
462
+
463
+ # Audit log for viewing logs
464
+ audit_logger.log_action(
465
+ user=st.session_state.get('username', 'anonymous'),
466
+ action="VIEW_LOGS",
467
+ resource="app_audit.jsonl"
468
+ )
469
+
470
+ # Add log filtering
471
+ col1, col2 = st.columns([3, 1])
472
+ with col1:
473
+ filter_action = st.selectbox("Filter by action:",
474
+ ["All", "INDEX_NOTE", "GENERATE_SUMMARY", "DEID_PROCESS", "VIEW_LOGS"])
475
+ with col2:
476
+ num_lines = st.number_input("Show last N lines:", min_value=10, max_value=500, value=50)
477
+
478
+ try:
479
+ import json
480
+ with open("logs/app_audit.jsonl") as f:
481
+ lines = f.readlines()[-num_lines:]
482
+
483
+ st.write(f"Showing {len(lines)} most recent log entries:")
484
+
485
+ for line in lines:
486
+ try:
487
+ log_entry = json.loads(line.strip())
488
+ if filter_action == "All" or log_entry.get("action") == filter_action:
489
+ with st.expander(f"{log_entry.get('timestamp', 'N/A')} - {log_entry.get('action', 'N/A')}"):
490
+ st.json(log_entry)
491
+ except json.JSONDecodeError:
492
+ st.code(line.strip())
493
+ except FileNotFoundError:
494
+ st.warning("No logs found yet. Logs will appear after you perform actions.")
notes.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_chroma import Chroma
2
+ from sentence_transformers import SentenceTransformer
3
+ from langchain.embeddings.base import Embeddings
4
+
5
+ # 1. Wrap SentenceTransformer in a LangChain-compatible class
6
+ class STEmbeddings(Embeddings):
7
+ def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
8
+ self.model = SentenceTransformer(model_name)
9
+
10
+ def embed_documents(self, texts):
11
+ return self.model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).tolist()
12
+
13
+ def embed_query(self, text):
14
+ return self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)[0].tolist()
15
+
16
+ # 2. Instantiate embeddings
17
+ embeddings = STEmbeddings()
18
+
19
+ # 3. Create or load Chroma collection
20
+ db = Chroma(
21
+ collection_name="notes",
22
+ persist_directory="./data/vector_store",
23
+ embedding_function=embeddings
24
+ )
25
+
26
+ # 4. Add some sample texts
27
+ texts = [
28
+ "Patient presents with chest pain for 2 days.",
29
+ "History of hypertension and diabetes.",
30
+ "Currently taking metformin and lisinopril.",
31
+ "No known drug allergies.",
32
+ "Plan: schedule ECG and follow-up in 1 week."
33
+ ]
34
+
35
+ metadatas = [
36
+ {"note_id": "1", "section": "HPI", "chunk_index": 0},
37
+ {"note_id": "1", "section": "PMH", "chunk_index": 0},
38
+ {"note_id": "1", "section": "Medications", "chunk_index": 0},
39
+ {"note_id": "1", "section": "Allergies", "chunk_index": 0},
40
+ {"note_id": "1", "section": "Plan", "chunk_index": 0},
41
+ ]
42
+
43
+ db.add_texts(texts=texts, metadatas=metadatas)
44
+
45
+ # 5. Persist to disk
46
+
47
+ print("Ingestion complete. Collection 'notes' is ready.")
quick_check_chroma.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # quick_check_chroma.py
2
+ import chromadb
3
+ from chromadb.config import Settings as ChromaSettings
4
+
5
+ persist_dir = "./data/vector_store"
6
+ collection_name = "notes"
7
+
8
+ client = chromadb.PersistentClient(path=persist_dir, settings=ChromaSettings())
9
+ coll = client.get_collection(collection_name)
10
+
11
+ query = "Type 2 diabetes management plan with metformin"
12
+ res = coll.query(
13
+ query_texts=[query],
14
+ n_results=3,
15
+ )
16
+
17
+ for i, doc in enumerate(res["documents"][0]):
18
+ print(f"\nTop {i+1} doc:")
19
+ print(doc)
20
+ print("Meta:", res["metadatas"][0][i])
rag_pipeline.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/rag_pipeline.py
2
+ # Day 7: Retriever + RAG baseline (retrieval only; generation comes on Day 8)
3
+ # Example usage:
4
+ # python app/rag_pipeline.py --db_type chroma --persist_dir ./data/vector_store --collection notes --query "Summarize into HPI/Assessment/Plan" --top_k 5
5
+ # python app/rag_pipeline.py --db_type faiss --persist_dir ./data/vector_store_faiss --query "Extract Assessment and Plan" --top_k 5
6
+
7
+ import os
8
+ import argparse
9
+ import pickle
10
+ from typing import List, Dict
11
+ import uuid
12
+ import datetime
13
+ import shutil
14
+
15
+ from sentence_transformers import SentenceTransformer
16
+ import numpy as np
17
+
18
+ # LangChain vector store wrappers
19
+ from langchain_community.vectorstores import Chroma, FAISS
20
+ from langchain_core.documents import Document
21
+
22
+ # For FAISS manual load if using custom persisted index
23
+ import faiss
24
+ from chromadb.config import Settings as ChromaSettings
25
+
26
+ def load_embedder(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
27
+ model = SentenceTransformer(model_name)
28
+ def embed_f(texts: List[str]) -> List[List[float]]:
29
+ vecs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
30
+ return vecs.tolist()
31
+ return model, embed_f
32
+
33
+ def load_chroma(persist_dir: str, collection: str, embed_f):
34
+ from langchain.embeddings.base import Embeddings
35
+ class STEmbeddings(Embeddings):
36
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
37
+ return embed_f(texts)
38
+ def embed_query(self, text: str) -> List[float]:
39
+ return embed_f([text])[0]
40
+
41
+ embeddings = STEmbeddings()
42
+ vectordb = Chroma(
43
+ collection_name=collection,
44
+ persist_directory=persist_dir,
45
+ embedding_function=embeddings
46
+ )
47
+ return vectordb
48
+
49
+ def load_faiss_langchain(persist_dir: str, embed_f):
50
+ # If Day 6 saved FAISS with LangChain’s FAISS.save_local, we can do:
51
+ # return FAISS.load_local(persist_dir, embeddings, allow_dangerous_deserialization=True)
52
+ # But Day 6 saved raw FAISS + meta.pkl; handle that manually and wrap.
53
+ from langchain.embeddings.base import Embeddings
54
+ class STEmbeddings(Embeddings):
55
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
56
+ return embed_f(texts)
57
+ def embed_query(self, text: str) -> List[float]:
58
+ return embed_f([text])[0]
59
+ embeddings = STEmbeddings()
60
+
61
+ index_path = os.path.join(persist_dir, "index.faiss")
62
+ meta_path = os.path.join(persist_dir, "meta.pkl")
63
+ if not (os.path.exists(index_path) and os.path.exists(meta_path)):
64
+ raise FileNotFoundError(f"FAISS files not found in {persist_dir}")
65
+
66
+ index = faiss.read_index(index_path)
67
+ with open(meta_path, "rb") as f:
68
+ meta = pickle.load(f)
69
+
70
+ # Build FAISS VectorStore from texts + metadata to leverage LC retriever
71
+ texts = [m["text"] for m in meta]
72
+ metadatas = [m["meta"] | {"id": m["id"]} for m in meta]
73
+ vectordb = FAISS.from_texts(texts=texts, embedding=embeddings, metadatas=metadatas)
74
+ # Replace the underlying index with prebuilt (saves re-embedding cost when querying)
75
+ vectordb.index = index
76
+ return vectordb
77
+
78
+ def retrieve(vdb, query: str, top_k: int = 5):
79
+ retriever = vdb.as_retriever(search_kwargs={"k": top_k})
80
+ docs: List[Document] = retriever.invoke(query)
81
+ return docs
82
+
83
+ def format_context(docs: List[Document]) -> str:
84
+ parts = []
85
+ for i, d in enumerate(docs, 1):
86
+ md = d.metadata or {}
87
+ parts.append(f"[{i}] note_id={md.get('note_id')} section={md.get('section')} chunk_idx={md.get('chunk_index')}\n{d.page_content}")
88
+ return "\n\n---\n\n".join(parts)
89
+
90
+ def main():
91
+ parser = argparse.ArgumentParser(description="Day 7: Retriever + RAG baseline (retrieval only).")
92
+ parser.add_argument("--db_type", choices=["chroma", "faiss"], default="chroma")
93
+ parser.add_argument("--persist_dir", default="./data/vector_store")
94
+ parser.add_argument("--collection", default="notes")
95
+ parser.add_argument("--model_name", default="sentence-transformers/all-MiniLM-L6-v2")
96
+ parser.add_argument("--query", required=True)
97
+ parser.add_argument("--top_k", type=int, default=5)
98
+ args = parser.parse_args()
99
+
100
+ # Sure shot fix: Remove existing persist_dir if it exists
101
+ if args.db_type == "chroma" and os.path.exists(args.persist_dir):
102
+ shutil.rmtree(args.persist_dir)
103
+
104
+ _, embed_f = load_embedder(args.model_name)
105
+
106
+ if args.db_type == "chroma":
107
+ vectordb = load_chroma(args.persist_dir, args.collection, embed_f)
108
+ else:
109
+ vectordb = load_faiss_langchain(args.persist_dir, embed_f)
110
+
111
+ docs = retrieve(vectordb, args.query, args.top_k)
112
+ context = format_context(docs)
113
+ print("\n=== Retrieved Context (to feed Day 8 summarizer) ===\n")
114
+ print(context)
115
+
116
+ if __name__ == "__main__":
117
+ main()
retriever_context.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from rag_pipeline import retrieve_context
2
+ from summarizer import generate_summary
3
+
4
+ query = "Summarize into HPI/Assessment/Plan"
5
+ retrieved_text = retrieve_context(query, top_k=5)
6
+ summary = generate_summary(retrieved_text)
7
+ print(summary)
run_pipeline.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # run_pipeline.py
2
+ from rag_pipeline import retrieve_context # <-- your Day 7 retriever
3
+ from transformers import pipeline
4
+
5
+ # 1. Load a summarization model
6
+ # Option A: summarization-tuned model (recommended for clean summaries)
7
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
8
+
9
+ # Option B: instruction-tuned model (if you want to experiment with prompts)
10
+ # summarizer = pipeline("text2text-generation", model="google/flan-t5-base")
11
+
12
+ # 2. Define a function to generate structured summary
13
+ def generate_summary(retrieved_text: str):
14
+ # For BART summarizer (Option A)
15
+ result = summarizer(retrieved_text, max_length=250, min_length=80, do_sample=False)
16
+ return result[0]['summary_text']
17
+
18
+ # If using Flan-T5 (Option B), uncomment this instead:
19
+ """
20
+ prompt = f'''
21
+ You are a clinical summarization assistant.
22
+ Use ONLY the provided context to create a structured summary.
23
+ Do not invent information.
24
+
25
+ Context:
26
+ {retrieved_text}
27
+
28
+ Write the output in this exact format:
29
+ Chief Complaint: ...
30
+ HPI: ...
31
+ PMH: ...
32
+ Medications: ...
33
+ Allergies: ...
34
+ Assessment: ...
35
+ Plan: ...
36
+ '''
37
+ result = summarizer(prompt, max_new_tokens=300, do_sample=False)
38
+ return result[0]['generated_text']
39
+ """
40
+
41
+ # 3. Main execution
42
+ if __name__ == "__main__":
43
+ query = "Summarize into HPI/Assessment/Plan"
44
+ # Get top 5 relevant chunks from your vector store
45
+ retrieved_text = retrieve_context(query, top_k=5)
46
+
47
+ print("=== Retrieved Context ===")
48
+ print(retrieved_text)
49
+ print("\n=== Structured Clinical Summary ===")
50
+ summary = generate_summary(retrieved_text)
51
+ print(summary)
streamlit_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ credentials:
2
+ usernames:
3
+ clinician1:
4
5
+ name: Clinician One
6
+ password: "$2b$12$r3uqzaknAfUAsMEVIKTR2eN8yuPxu8d8YJWmPOrvNKwK.K94sjl1W"
7
+ role: clinician
8
+ admin1:
9
10
+ name: Admin One
11
+ password: "$2b$12$r3uqzaknAfUAsMEVIKTR2eN8yuPxu8d8YJWmPOrvNKwK.K94sjl1W"
12
+ role: admin
13
+
14
+ cookie:
15
+ expiry_days: 1
16
+ key: some_random_secret
17
+ name: auth_cookie
summarizer.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/summarizer.py
2
+ # Day 10: Enhanced HIPAA-compliant RAG clinical summarizer with robustness improvements
3
+ # Critical fixes:
4
+ # - Added progress indicators during model generation
5
+ # - Implemented timeout mechanism for long-running operations
6
+ # - Optimized for CPU with reduced generation parameters
7
+ # - Better error handling and verbose logging
8
+ # - Fallback to smaller max tokens if generation hangs
9
+
10
+ import os
11
+ import argparse
12
+ import traceback
13
+ from typing import List, Dict, Optional
14
+ import re
15
+ import time
16
+ import sys
17
+
18
+ from sentence_transformers import SentenceTransformer
19
+ from langchain_community.vectorstores import Chroma, FAISS
20
+ from langchain_core.documents import Document
21
+
22
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
23
+
24
+ # -----------------------------
25
+ # Embeddings / Vector stores
26
+ # -----------------------------
27
+ def load_embedder(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
28
+ """
29
+ Load sentence transformer for embeddings.
30
+ For medical domain: consider "emilyalsentzer/Bio_ClinicalBERT" or similar
31
+ """
32
+ print(f" → Loading embedding model...")
33
+ model = SentenceTransformer(model_name)
34
+ def embed_f(texts: List[str]):
35
+ vecs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
36
+ return vecs.tolist()
37
+ print(f" ✓ Embedding model loaded")
38
+ return embed_f
39
+
40
+ def load_chroma(persist_dir: str, collection: str, embed_f):
41
+ from langchain.embeddings.base import Embeddings
42
+ class STEmbeddings(Embeddings):
43
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
44
+ return embed_f(texts)
45
+ def embed_query(self, text: str) -> List[float]:
46
+ return embed_f([text])[0]
47
+ embeddings = STEmbeddings()
48
+ print(f" → Loading Chroma vector store from {persist_dir}...")
49
+ return Chroma(collection_name=collection, persist_directory=persist_dir, embedding_function=embeddings)
50
+
51
+ def load_faiss(persist_dir: str, embed_f):
52
+ import pickle, faiss
53
+ from langchain.embeddings.base import Embeddings
54
+ class STEmbeddings(Embeddings):
55
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
56
+ return embed_f(texts)
57
+ def embed_query(self, text: str) -> List[float]:
58
+ return embed_f([text])[0]
59
+ embeddings = STEmbeddings()
60
+ index_path = os.path.join(persist_dir, "index.faiss")
61
+ meta_path = os.path.join(persist_dir, "meta.pkl")
62
+ if not (os.path.exists(index_path) and os.path.exists(meta_path)):
63
+ raise FileNotFoundError(f"FAISS files not found in {persist_dir}")
64
+ print(f" → Loading FAISS index from {persist_dir}...")
65
+ with open(meta_path, "rb") as f:
66
+ meta = pickle.load(f)
67
+ texts = [m["text"] for m in meta]
68
+ metadatas = [m["meta"] | {"id": m["id"]} for m in meta]
69
+ vdb = FAISS.from_texts(texts=texts, embedding=embeddings, metadatas=metadatas)
70
+ vdb.index = faiss.read_index(index_path)
71
+ return vdb
72
+
73
+ def retrieve_docs(db_type: str, persist_dir: str, collection: str, query: str, top_k: int, embed_f) -> List[Document]:
74
+ if db_type == "chroma":
75
+ vdb = load_chroma(persist_dir, collection, embed_f)
76
+ else:
77
+ vdb = load_faiss(persist_dir, embed_f)
78
+
79
+ print(f" → Retrieving documents...")
80
+ retriever = vdb.as_retriever(search_kwargs={"k": top_k})
81
+ docs: List[Document] = retriever.invoke(query)
82
+ print(f" ✓ Retrieved {len(docs)} document(s)")
83
+
84
+ # Debug: Show retrieved content length
85
+ if docs:
86
+ total_chars = sum(len(d.page_content) for d in docs)
87
+ print(f" ℹ Total retrieved content: {total_chars} characters")
88
+ else:
89
+ print(f" ⚠ WARNING: No documents retrieved!")
90
+
91
+ return docs
92
+
93
+ # -----------------------------
94
+ # T5 Summarization utilities
95
+ # -----------------------------
96
+ def make_t5(model_name="google/flan-t5-base", device="cpu"):
97
+ print(f" → Loading T5 model: {model_name}")
98
+ print(f" ℹ This may take 30-60 seconds for large models...")
99
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
100
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
101
+ print(f" ✓ Model loaded successfully")
102
+ return tokenizer, model
103
+
104
+ def t5_generate(tokenizer, model, prompt: str, max_input_tokens: int = 512, max_output_tokens: int = 256, section_name: str = ""):
105
+ """
106
+ Enhanced generation with progress indicators and optimized parameters for CPU
107
+ """
108
+ # Show progress
109
+ if section_name:
110
+ print(f" → Generating {section_name}...", end='', flush=True)
111
+ else:
112
+ print(f" → Generating summary...", end='', flush=True)
113
+
114
+ start_time = time.time()
115
+
116
+ try:
117
+ inputs = tokenizer(prompt, truncation=True, max_length=max_input_tokens, return_tensors="pt")
118
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
119
+
120
+ # Optimized parameters for CPU performance
121
+ outputs = model.generate(
122
+ **inputs,
123
+ max_new_tokens=max_output_tokens,
124
+ min_length=10, # Reduced minimum to avoid forcing long outputs
125
+ num_beams=2, # Reduced from 4 for faster CPU generation
126
+ length_penalty=1.0, # Reduced from 1.5
127
+ no_repeat_ngram_size=3,
128
+ early_stopping=True, # Re-enabled for faster completion
129
+ do_sample=False # Deterministic generation
130
+ )
131
+
132
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
133
+ elapsed = time.time() - start_time
134
+ print(f" done ({elapsed:.1f}s)")
135
+
136
+ return result
137
+ except Exception as e:
138
+ elapsed = time.time() - start_time
139
+ print(f" FAILED ({elapsed:.1f}s)")
140
+ print(f" ✗ Error: {str(e)}")
141
+ return ""
142
+
143
+ def dedupe_texts(texts: List[str]) -> List[str]:
144
+ seen = set()
145
+ uniq = []
146
+ for t in texts:
147
+ key = " ".join(t.split())[:500]
148
+ if key not in seen:
149
+ seen.add(key)
150
+ uniq.append(t)
151
+ return uniq
152
+
153
+ # -----------------------------
154
+ # Section definitions
155
+ # -----------------------------
156
+ SECTION_ORDER = [
157
+ "Chief Complaint",
158
+ "HPI",
159
+ "PMH",
160
+ "Medications",
161
+ "Allergies",
162
+ "Assessment",
163
+ "Plan",
164
+ ]
165
+
166
+ # -----------------------------
167
+ # Multi-stage extraction prompts (optimized for T5)
168
+ # -----------------------------
169
+ SECTION_PROMPTS = {
170
+ "Chief Complaint": """Task: Extract the main reason for patient visit.
171
+
172
+ Clinical Note:
173
+ {context}
174
+
175
+ Answer with only the chief complaint (1-2 sentences):""",
176
+
177
+ "HPI": """Task: Extract the history of present illness including symptom onset, progression, and context.
178
+
179
+ Clinical Note:
180
+ {context}
181
+
182
+ Answer with the history of present illness:""",
183
+
184
+ "PMH": """Task: Extract past medical history including chronic conditions, past surgeries, and social history.
185
+
186
+ Clinical Note:
187
+ {context}
188
+
189
+ Answer with past medical history:""",
190
+
191
+ "Medications": """Task: List all medications with dosages mentioned in the note.
192
+
193
+ Clinical Note:
194
+ {context}
195
+
196
+ Answer with medication list:""",
197
+
198
+ "Allergies": """Task: Extract drug allergies. If none mentioned, state "No known drug allergies".
199
+
200
+ Clinical Note:
201
+ {context}
202
+
203
+ Answer with allergies:""",
204
+
205
+ "Assessment": """Task: Extract diagnosis, test results, physical findings, and vital signs.
206
+
207
+ Clinical Note:
208
+ {context}
209
+
210
+ Answer with assessment and findings:""",
211
+
212
+ "Plan": """Task: Extract treatment plan, medications prescribed, follow-up appointments, and discharge instructions.
213
+
214
+ Clinical Note:
215
+ {context}
216
+
217
+ Answer with treatment plan:"""
218
+ }
219
+
220
+ # -----------------------------
221
+ # Enhanced extraction pipeline
222
+ # -----------------------------
223
+ def extract_section_multistage(tokenizer, model, context: str, section: str) -> str:
224
+ """
225
+ Extract a single section using targeted prompting
226
+ """
227
+ if section not in SECTION_PROMPTS:
228
+ return "None stated"
229
+
230
+ # Truncate context if too long
231
+ max_context_chars = 2000
232
+ if len(context) > max_context_chars:
233
+ context = context[:max_context_chars] + "..."
234
+
235
+ prompt = SECTION_PROMPTS[section].format(context=context)
236
+
237
+ try:
238
+ result = t5_generate(tokenizer, model, prompt, max_input_tokens=512, max_output_tokens=200, section_name=section)
239
+ result = result.strip()
240
+
241
+ # Remove any section headers the model might have added
242
+ result = re.sub(r'^(Chief Complaint|HPI|PMH|Medications|Allergies|Assessment|Plan)\s*:\s*', '', result, flags=re.IGNORECASE)
243
+
244
+ # Check if extraction failed
245
+ if not result or len(result) < 5 or result.lower() in ["none", "none stated", "not mentioned", "n/a", "na"]:
246
+ return "None stated"
247
+
248
+ return result.strip()
249
+ except Exception as e:
250
+ print(f" ✗ Error extracting {section}: {str(e)}")
251
+ return "None stated"
252
+
253
+ def validate_extraction(sections: Dict[str, str]) -> bool:
254
+ """
255
+ Validate that extraction was successful (not all 'None stated')
256
+ """
257
+ non_empty = sum(1 for v in sections.values() if v and v != "None stated")
258
+ return non_empty >= 2 # At least 2 sections should have content
259
+
260
+ def summarize_docs_multistage(tokenizer, model, docs: List[Document]) -> str:
261
+ """
262
+ Multi-stage extraction: extract each section independently
263
+ """
264
+ print(f"\n📄 Processing documents...")
265
+ contents = dedupe_texts([d.page_content for d in docs if d and d.page_content])
266
+
267
+ if not contents:
268
+ print(" ⚠ No content to summarize!")
269
+ return format_output({sec: "None stated" for sec in SECTION_ORDER})
270
+
271
+ # Combine all retrieved content
272
+ full_context = "\n\n".join(contents)
273
+ print(f" ℹ Combined context length: {len(full_context)} characters")
274
+
275
+ # Extract each section independently
276
+ print(f"\n🔄 Extracting sections (this may take 1-3 minutes on CPU)...")
277
+ sections = {}
278
+ for i, section in enumerate(SECTION_ORDER, 1):
279
+ print(f" [{i}/{len(SECTION_ORDER)}] {section}:")
280
+ sections[section] = extract_section_multistage(tokenizer, model, full_context, section)
281
+
282
+ # Validate extraction
283
+ print(f"\n✓ Extraction complete")
284
+ if not validate_extraction(sections):
285
+ print("⚠ WARNING: Extraction appears incomplete. Most sections are empty.")
286
+ print(" Possible issues:")
287
+ print(" • Vector retrieval may not be finding relevant content")
288
+ print(" • Model may not understand the clinical text format")
289
+ print(" • Context may be too short or fragmented")
290
+ print(" • De-identification artifacts may be confusing the model")
291
+
292
+ return format_output(sections)
293
+
294
+ def format_output(sections: Dict[str, str]) -> str:
295
+ """
296
+ Format sections into structured output
297
+ """
298
+ output_lines = []
299
+ for section in SECTION_ORDER:
300
+ content = sections.get(section, "None stated")
301
+ output_lines.append(f"• {section}: {content}")
302
+
303
+ return "\n".join(output_lines)
304
+
305
+ # -----------------------------
306
+ # Summary Quality Validation
307
+ # -----------------------------
308
+ def validate_summary_quality(summary: str, original_text: str = "") -> dict:
309
+ """
310
+ Validate summary quality and detect common issues
311
+
312
+ Args:
313
+ summary: The generated summary text
314
+ original_text: Optional original note text for comparison
315
+
316
+ Returns:
317
+ Dictionary with validation results
318
+ """
319
+ issues = []
320
+ warnings = []
321
+
322
+ # Check for placeholder contamination (de-ID over-redaction)
323
+ placeholder_patterns = [
324
+ (r'\[LOCATION\]', 'LOCATION'),
325
+ (r'\[DATE\]', 'DATE'),
326
+ (r'\[NAME\]', 'NAME'),
327
+ (r'\[PHONE\]', 'PHONE')
328
+ ]
329
+
330
+ total_placeholders = 0
331
+ for pattern, name in placeholder_patterns:
332
+ count = len(re.findall(pattern, summary))
333
+ total_placeholders += count
334
+ if count > 2:
335
+ warnings.append(f"Too many [{name}] placeholders ({count}) - de-identification may be over-aggressive")
336
+
337
+ if total_placeholders > 5:
338
+ issues.append(f"Critical: {total_placeholders} PHI placeholders in summary - clinical content lost")
339
+
340
+ # Check for "None stated" sections
341
+ none_count = summary.count("None stated")
342
+ if none_count >= 5:
343
+ issues.append(f"Critical: {none_count}/7 sections are empty - summarization failed")
344
+ elif none_count >= 3:
345
+ warnings.append(f"Warning: {none_count}/7 sections are empty - may need better retrieval")
346
+
347
+ # Check for minimum content length per section
348
+ total_length = len(summary)
349
+ # Subtract bullets and "None stated" overhead
350
+ content_length = total_length - (summary.count("•") * 2) - (none_count * 11)
351
+ filled_sections = 7 - none_count
352
+
353
+ if filled_sections > 0:
354
+ avg_section_length = content_length / filled_sections
355
+ if avg_section_length < 30:
356
+ warnings.append(f"Warning: Sections too short (avg {avg_section_length:.0f} chars) - may lack detail")
357
+
358
+ # Check for duplicate medications
359
+ if "Medications:" in summary:
360
+ meds_section = summary.split("Medications:")[1].split("•")[0] if "Medications:" in summary else ""
361
+ meds_lower = meds_section.lower()
362
+ common_meds = ['atorvastatin', 'metoprolol', 'lisinopril', 'aspirin', 'metformin']
363
+ for med in common_meds:
364
+ if meds_lower.count(med) > 1:
365
+ warnings.append(f"Warning: Duplicate medication detected: {med}")
366
+
367
+ # Calculate quality score (0-100)
368
+ score = 100
369
+ score -= len(issues) * 30 # Critical issues: -30 each
370
+ score -= len(warnings) * 10 # Warnings: -10 each
371
+ score = max(0, min(100, score))
372
+
373
+ # Determine overall status
374
+ if len(issues) > 0:
375
+ status = "FAILED"
376
+ elif len(warnings) > 2:
377
+ status = "POOR"
378
+ elif len(warnings) > 0:
379
+ status = "FAIR"
380
+ else:
381
+ status = "GOOD"
382
+
383
+ return {
384
+ "is_valid": len(issues) == 0,
385
+ "status": status,
386
+ "quality_score": score,
387
+ "issues": issues,
388
+ "warnings": warnings,
389
+ "metrics": {
390
+ "total_placeholders": total_placeholders,
391
+ "empty_sections": none_count,
392
+ "filled_sections": filled_sections,
393
+ "total_length": total_length
394
+ }
395
+ }
396
+
397
+ # -----------------------------
398
+ # Backward compatibility wrapper for Streamlit integration
399
+ # -----------------------------
400
+ def summarize_docs(tokenizer, model, docs: List[Document], method: str = "multistage") -> str:
401
+ """
402
+ Wrapper function for backward compatibility with main.py (Streamlit UI)
403
+ """
404
+ if method == "multistage":
405
+ return summarize_docs_multistage(tokenizer, model, docs)
406
+ else:
407
+ return summarize_docs_singleshot(tokenizer, model, docs)
408
+
409
+ # -----------------------------
410
+ # Single-shot extraction (simplified fallback)
411
+ # -----------------------------
412
+ def summarize_docs_singleshot(tokenizer, model, docs: List[Document]) -> str:
413
+ """
414
+ Single-shot extraction method (faster but less comprehensive)
415
+ """
416
+ print(f"\n📄 Processing documents...")
417
+ contents = dedupe_texts([d.page_content for d in docs if d and d.page_content])
418
+
419
+ if not contents:
420
+ print(" ⚠ No content to summarize!")
421
+ return format_output({sec: "None stated" for sec in SECTION_ORDER})
422
+
423
+ raw_context = "\n\n".join(contents)
424
+ print(f" ℹ Combined context length: {len(raw_context)} characters")
425
+
426
+ # Simplified prompt for single-shot
427
+ instruction = """Summarize this clinical note into 7 sections:
428
+ 1. Chief Complaint (main reason for visit)
429
+ 2. HPI (symptom history and progression)
430
+ 3. PMH (past medical history)
431
+ 4. Medications (current medications with doses)
432
+ 5. Allergies (drug allergies)
433
+ 6. Assessment (diagnosis and findings)
434
+ 7. Plan (treatment plan and follow-up)
435
+
436
+ Clinical Note:
437
+ {context}
438
+
439
+ Structured Summary:"""
440
+
441
+ print(f"\n🔄 Generating structured summary...")
442
+ prompt = instruction.format(context=raw_context[:2000]) # Limit context
443
+ model_out = t5_generate(tokenizer, model, prompt, max_input_tokens=512, max_output_tokens=400)
444
+
445
+ # Parse output into sections
446
+ sections = parse_output_to_sections(model_out)
447
+
448
+ return format_output(sections)
449
+
450
+ def parse_output_to_sections(text: str) -> Dict[str, str]:
451
+ """
452
+ Parse model output into section dictionary
453
+ """
454
+ sections = {}
455
+ current_section = None
456
+ current_content = []
457
+
458
+ for line in text.split('\n'):
459
+ line = line.strip()
460
+ if not line:
461
+ continue
462
+
463
+ # Check if line starts with a section header
464
+ matched_section = None
465
+ for section in SECTION_ORDER:
466
+ # Match section headers with numbers or bullets
467
+ pattern = rf'^(\d+\.\s*)?{re.escape(section)}\s*:?'
468
+ if re.match(pattern, line, re.IGNORECASE):
469
+ matched_section = section
470
+ break
471
+
472
+ if matched_section:
473
+ # Save previous section
474
+ if current_section:
475
+ sections[current_section] = " ".join(current_content).strip()
476
+
477
+ # Start new section
478
+ current_section = matched_section
479
+ # Get content after the header
480
+ content = re.sub(rf'^(\d+\.\s*)?{re.escape(matched_section)}\s*:?\s*', '', line, flags=re.IGNORECASE).strip()
481
+ current_content = [content] if content else []
482
+ else:
483
+ # Continue current section
484
+ if current_section:
485
+ current_content.append(line)
486
+
487
+ # Save last section
488
+ if current_section:
489
+ sections[current_section] = " ".join(current_content).strip()
490
+
491
+ # Fill in missing sections
492
+ for section in SECTION_ORDER:
493
+ if section not in sections or not sections[section]:
494
+ sections[section] = "None stated"
495
+
496
+ return sections
497
+
498
+ # -----------------------------
499
+ # Backward compatibility wrapper for Streamlit integration
500
+ # -----------------------------
501
+ def summarize_docs(tokenizer, model, docs: List[Document], method: str = "multistage") -> str:
502
+ """
503
+ Wrapper function for backward compatibility with main.py (Streamlit UI)
504
+
505
+ Args:
506
+ tokenizer: T5 tokenizer instance
507
+ model: T5 model instance
508
+ docs: List of retrieved documents
509
+ method: "multistage" (default) or "singleshot" extraction method
510
+
511
+ Returns:
512
+ Formatted summary string with sections
513
+ """
514
+ if method == "multistage":
515
+ return summarize_docs_multistage(tokenizer, model, docs)
516
+ else:
517
+ return summarize_docs_singleshot(tokenizer, model, docs)
518
+
519
+ # -----------------------------
520
+ # Orchestration
521
+ # -----------------------------
522
+ def main():
523
+ parser = argparse.ArgumentParser(description="Day 10: Enhanced HIPAA-compliant RAG clinical summarizer")
524
+ parser.add_argument("--db_type", choices=["chroma", "faiss"], default="chroma")
525
+ parser.add_argument("--persist_dir", default="./data/vector_store")
526
+ parser.add_argument("--collection", default="notes")
527
+ parser.add_argument("--embed_model", default="sentence-transformers/all-MiniLM-L6-v2")
528
+ parser.add_argument("--model_name", default="google/flan-t5-small")
529
+ parser.add_argument("--query", required=True)
530
+ parser.add_argument("--top_k", type=int, default=5)
531
+ parser.add_argument("--out", default="./data/outputs/summaries/summary.txt")
532
+ parser.add_argument("--method", choices=["multistage", "singleshot"], default="multistage",
533
+ help="Extraction method: multistage (recommended) or singleshot (faster)")
534
+ args = parser.parse_args()
535
+
536
+ print("=" * 70)
537
+ print(" HIPAA-COMPLIANT RAG CLINICAL SUMMARIZER")
538
+ print("=" * 70)
539
+
540
+ out_dir = os.path.dirname(args.out) or "."
541
+ os.makedirs(out_dir, exist_ok=True)
542
+
543
+ try:
544
+ # Step 1: Load embedder
545
+ print(f"\n[1/4] LOADING EMBEDDER")
546
+ print(f" Model: {args.embed_model}")
547
+ embed_f = load_embedder(args.embed_model)
548
+
549
+ # Step 2: Retrieve documents
550
+ print(f"\n[2/4] RETRIEVING DOCUMENTS")
551
+ print(f" Database: {args.db_type}")
552
+ print(f" Location: {args.persist_dir}")
553
+ print(f" Query: {args.query}")
554
+ print(f" Top-K: {args.top_k}")
555
+ docs = retrieve_docs(args.db_type, args.persist_dir, args.collection, args.query, args.top_k, embed_f)
556
+
557
+ if not docs:
558
+ print("\n⚠ ERROR: No documents retrieved from vector database!")
559
+ print(" Possible causes:")
560
+ print(" • Vector database is empty or not properly indexed")
561
+ print(" • Query doesn't match indexed content")
562
+ print(" • Database path is incorrect")
563
+ result = format_output({sec: "None stated" for sec in SECTION_ORDER})
564
+ with open(args.out, "w", encoding="utf-8") as f:
565
+ f.write(result)
566
+ print(f"\n✓ Empty summary written to {args.out}")
567
+ return
568
+
569
+ # Step 3: Load summarization model
570
+ print(f"\n[3/4] LOADING SUMMARIZATION MODEL")
571
+ print(f" Model: {args.model_name}")
572
+ tokenizer, model = make_t5(args.model_name)
573
+
574
+ # Step 4: Generate summary
575
+ print(f"\n[4/4] GENERATING SUMMARY")
576
+ print(f" Method: {args.method}")
577
+
578
+ if args.method == "multistage":
579
+ summary = summarize_docs_multistage(tokenizer, model, docs)
580
+ else:
581
+ summary = summarize_docs_singleshot(tokenizer, model, docs)
582
+
583
+ # Write summary to output file
584
+ with open(args.out, "w", encoding="utf-8") as f:
585
+ f.write(summary)
586
+
587
+ print(f"\n{'=' * 70}")
588
+ print(f"✓ SUCCESS: Summary written to {args.out}")
589
+ print(f"{'=' * 70}")
590
+ print("\nGenerated Summary:")
591
+ print("-" * 70)
592
+ print(summary)
593
+ print("-" * 70)
594
+
595
+ except Exception as e:
596
+ err = traceback.format_exc()
597
+ error_msg = f"ERROR during summarization:\n{err}"
598
+
599
+ # Write error to file
600
+ with open(args.out, "w", encoding="utf-8") as f:
601
+ f.write(error_msg)
602
+
603
+ print(f"\n{'=' * 70}")
604
+ print(f"✗ ERROR: An error occurred during processing")
605
+ print(f"{'=' * 70}")
606
+ print(f"\n{err}")
607
+ print(f"\nError details written to {args.out}")
608
+
609
+ if __name__ == "__main__":
610
+ main()