Commit
·
f64b3f9
1
Parent(s):
7f09b7f
Add HIPAA RAG Clinical Summarizer (essential files only)
Browse files- .gitignore +46 -0
- audit.py +79 -0
- bcrypt_pw.py +2 -0
- deid_pipeline.py +266 -0
- indexer.py +289 -0
- main.py +494 -0
- notes.py +47 -0
- quick_check_chroma.py +20 -0
- rag_pipeline.py +117 -0
- retriever_context.py +7 -0
- run_pipeline.py +51 -0
- streamlit_config.yaml +17 -0
- summarizer.py +610 -0
.gitignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Large data folders
|
| 2 |
+
synthea/
|
| 3 |
+
data/raw_synthea/
|
| 4 |
+
data/vector_store*/
|
| 5 |
+
synthea/output/fhir/
|
| 6 |
+
*.json
|
| 7 |
+
*.csv
|
| 8 |
+
*.zip
|
| 9 |
+
*.sqlite
|
| 10 |
+
__pycache__/
|
| 11 |
+
*.ipynb_checkpoints
|
| 12 |
+
|
| 13 |
+
# Ignore vector store database
|
| 14 |
+
app/data/vector_store/chroma.sqlite3
|
| 15 |
+
*.sqlite3
|
| 16 |
+
*.db
|
| 17 |
+
# Models and cache
|
| 18 |
+
models/
|
| 19 |
+
.cache/
|
| 20 |
+
transformers_cache/
|
| 21 |
+
|
| 22 |
+
# Virtual environment
|
| 23 |
+
new_env_rag/
|
| 24 |
+
venv/
|
| 25 |
+
env/
|
| 26 |
+
|
| 27 |
+
# Secrets
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Logs and outputs
|
| 31 |
+
logs/
|
| 32 |
+
*.log
|
| 33 |
+
*.jsonl
|
| 34 |
+
|
| 35 |
+
# Python cache
|
| 36 |
+
__pycache__/
|
| 37 |
+
*.pyc
|
| 38 |
+
|
| 39 |
+
# IDE
|
| 40 |
+
.vs/
|
| 41 |
+
.vscode/
|
| 42 |
+
.idea/
|
| 43 |
+
|
| 44 |
+
# OS files
|
| 45 |
+
.DS_Store
|
| 46 |
+
Thumbs.db
|
audit.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import hashlib
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import uuid
|
| 6 |
+
import pytz
|
| 7 |
+
|
| 8 |
+
class AuditLogger:
|
| 9 |
+
def __init__(self, log_file_path="logs/app_audit.jsonl"):
|
| 10 |
+
self.log_file = Path(log_file_path)
|
| 11 |
+
self.log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 12 |
+
# Create file if it doesn't exist
|
| 13 |
+
if not self.log_file.exists():
|
| 14 |
+
self.log_file.touch()
|
| 15 |
+
|
| 16 |
+
def _get_last_hash(self):
|
| 17 |
+
"""Read the last log entry and return its hash"""
|
| 18 |
+
try:
|
| 19 |
+
with open(self.log_file, 'r') as f:
|
| 20 |
+
lines = f.readlines()
|
| 21 |
+
if lines:
|
| 22 |
+
last_entry = json.loads(lines[-1])
|
| 23 |
+
return last_entry.get('sha256_curr', '')
|
| 24 |
+
except:
|
| 25 |
+
pass
|
| 26 |
+
return '' # First entry has no previous hash
|
| 27 |
+
|
| 28 |
+
def _compute_hash(self, log_entry):
|
| 29 |
+
"""Create a hash fingerprint of the log entry"""
|
| 30 |
+
# Convert the log entry to a string and hash it
|
| 31 |
+
entry_string = json.dumps(log_entry, sort_keys=True)
|
| 32 |
+
return hashlib.sha256(entry_string.encode()).hexdigest()
|
| 33 |
+
|
| 34 |
+
def log_action(self, user, action, resource, additional_info=None):
|
| 35 |
+
"""
|
| 36 |
+
Main logging function - call this whenever a user does something
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
user: username (e.g., 'dr_smith')
|
| 40 |
+
action: what they did (e.g., 'UPLOAD_NOTE', 'GENERATE_SUMMARY', 'VIEW_LOGS')
|
| 41 |
+
resource: what they acted on (e.g., 'note_12345.txt', 'patient_record')
|
| 42 |
+
additional_info: any extra details (dictionary)
|
| 43 |
+
"""
|
| 44 |
+
# Get the hash of the previous log entry
|
| 45 |
+
previous_hash = self._get_last_hash()
|
| 46 |
+
|
| 47 |
+
# Generate unique IDs for tracing
|
| 48 |
+
trace_id = str(uuid.uuid4())
|
| 49 |
+
span_id = str(uuid.uuid4())[:16] # Shorter ID for span
|
| 50 |
+
|
| 51 |
+
# India timezone
|
| 52 |
+
india = pytz.timezone('Asia/Kolkata')
|
| 53 |
+
local_time = datetime.now(india).isoformat()
|
| 54 |
+
|
| 55 |
+
# Create the new log entry
|
| 56 |
+
log_entry = {
|
| 57 |
+
"timestamp": local_time + "Z",
|
| 58 |
+
"user": user,
|
| 59 |
+
"action": action,
|
| 60 |
+
"resource": resource,
|
| 61 |
+
"sha256_prev": previous_hash,
|
| 62 |
+
"additional_info": additional_info or {},
|
| 63 |
+
|
| 64 |
+
# OpenTelemetry attributes
|
| 65 |
+
"otel_trace_id": trace_id,
|
| 66 |
+
"otel_span_id": span_id,
|
| 67 |
+
"otel_service_name": "clinical-rag-app",
|
| 68 |
+
"severity": "INFO" # Can be DEBUG, INFO, WARN, ERROR
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Compute hash of THIS entry
|
| 72 |
+
current_hash = self._compute_hash(log_entry)
|
| 73 |
+
log_entry["sha256_curr"] = current_hash
|
| 74 |
+
|
| 75 |
+
# Append to log file (append-only = cannot change old entries)
|
| 76 |
+
with open(self.log_file, 'a') as f:
|
| 77 |
+
f.write(json.dumps(log_entry) + '\n')
|
| 78 |
+
|
| 79 |
+
return log_entry
|
bcrypt_pw.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import bcrypt
|
| 2 |
+
print(bcrypt.hashpw(b"mypassword", bcrypt.gensalt()).decode())
|
deid_pipeline.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Dict, Any, Tuple
|
| 5 |
+
|
| 6 |
+
# Presidio
|
| 7 |
+
from presidio_analyzer import AnalyzerEngine
|
| 8 |
+
from presidio_analyzer.recognizer_registry import RecognizerRegistry
|
| 9 |
+
from presidio_anonymizer import AnonymizerEngine
|
| 10 |
+
from presidio_analyzer import PatternRecognizer
|
| 11 |
+
|
| 12 |
+
# Define medical terms that should NOT be redacted
|
| 13 |
+
medical_terms_allowlist = [
|
| 14 |
+
"substernal", "exertional", "pressure-like", "diaphoresis",
|
| 15 |
+
"chest pain", "nausea", "radiation", "murmurs", "ischemia"
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
# Configure analyzer to ignore these terms
|
| 19 |
+
analyzer_config = {
|
| 20 |
+
"nlp_engine_name": "spacy",
|
| 21 |
+
"models": [{"lang_code": "en", "model_name": "en_core_web_lg"}],
|
| 22 |
+
"allow_list": medical_terms_allowlist # Don't redact these
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# NLP for optional section detection
|
| 26 |
+
import spacy
|
| 27 |
+
|
| 28 |
+
# If using medspacy, uncomment (preferred for clinical):
|
| 29 |
+
# import medspacy
|
| 30 |
+
# from medspacy.sectionizer import Sectionizer
|
| 31 |
+
|
| 32 |
+
# If not using medspacy, optional lightweight section tagging:
|
| 33 |
+
# We'll use regex on common headers as a fallback
|
| 34 |
+
import re
|
| 35 |
+
|
| 36 |
+
# Encryption
|
| 37 |
+
from cryptography.fernet import Fernet
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class PHISpan:
|
| 41 |
+
entity_type: str
|
| 42 |
+
start: int
|
| 43 |
+
end: int
|
| 44 |
+
text: str
|
| 45 |
+
section: str
|
| 46 |
+
|
| 47 |
+
SECTION_HEADERS = [
|
| 48 |
+
# Common clinical sections; customize as needed
|
| 49 |
+
"HPI", "History of Present Illness",
|
| 50 |
+
"PMH", "Past Medical History",
|
| 51 |
+
"Medications", "Allergies",
|
| 52 |
+
"Assessment and Plan", "Assessment & Plan", "Assessment",
|
| 53 |
+
"Plan", "ROS", "Review of Systems",
|
| 54 |
+
"Physical Exam"
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
SECTION_PATTERN = re.compile(
|
| 58 |
+
r"^(?P<header>(" + "|".join([re.escape(h) for h in SECTION_HEADERS]) + r"))\s*:\s*$",
|
| 59 |
+
re.IGNORECASE | re.MULTILINE
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
TAG_MAP = {
|
| 63 |
+
"PERSON": "[NAME]",
|
| 64 |
+
"PHONE_NUMBER": "[PHONE]",
|
| 65 |
+
"DATE_TIME": "[DATE]",
|
| 66 |
+
"DATE": "[DATE]",
|
| 67 |
+
"EMAIL_ADDRESS": "[EMAIL]",
|
| 68 |
+
"US_SSN": "[SSN]"
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
class DeidPipeline:
|
| 72 |
+
"""
|
| 73 |
+
De-identification pipeline using Microsoft Presidio
|
| 74 |
+
"""
|
| 75 |
+
def __init__(self, fernet_key_path="secure_store/fernet.key"):
|
| 76 |
+
"""
|
| 77 |
+
Initialize de-identification pipeline with Presidio
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
fernet_key_path: Path to Fernet encryption key
|
| 81 |
+
"""
|
| 82 |
+
import os
|
| 83 |
+
from cryptography.fernet import Fernet
|
| 84 |
+
|
| 85 |
+
# Initialize encryption
|
| 86 |
+
try:
|
| 87 |
+
if os.path.exists(fernet_key_path):
|
| 88 |
+
# Load existing key from file
|
| 89 |
+
with open(fernet_key_path, "rb") as f:
|
| 90 |
+
key = f.read()
|
| 91 |
+
else:
|
| 92 |
+
# Generate new key for this session
|
| 93 |
+
key = Fernet.generate_key()
|
| 94 |
+
# Try to save it (might fail on read-only filesystems)
|
| 95 |
+
try:
|
| 96 |
+
os.makedirs(os.path.dirname(fernet_key_path), exist_ok=True)
|
| 97 |
+
with open(fernet_key_path, "wb") as f:
|
| 98 |
+
f.write(key)
|
| 99 |
+
except (PermissionError, OSError):
|
| 100 |
+
# Cloud filesystem is read-only, just use the generated key
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
self.fernet = Fernet(key)
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
# Emergency fallback: Generate temporary key
|
| 107 |
+
print(f"Warning: Could not load encryption key, generating temporary key: {e}")
|
| 108 |
+
key = Fernet.generate_key()
|
| 109 |
+
self.fernet = Fernet(key)
|
| 110 |
+
|
| 111 |
+
# Initialize Presidio components
|
| 112 |
+
self.analyzer = AnalyzerEngine()
|
| 113 |
+
self.anonymizer = AnonymizerEngine()
|
| 114 |
+
|
| 115 |
+
# Load spaCy model
|
| 116 |
+
try:
|
| 117 |
+
self.nlp = spacy.load("en_core_web_lg")
|
| 118 |
+
except OSError:
|
| 119 |
+
print("Downloading spaCy model...")
|
| 120 |
+
import subprocess
|
| 121 |
+
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_lg"])
|
| 122 |
+
self.nlp = spacy.load("en_core_web_lg")
|
| 123 |
+
|
| 124 |
+
def _detect_sections(self, text: str) -> List[Tuple[str, int, int]]:
|
| 125 |
+
"""
|
| 126 |
+
Lightweight section finder:
|
| 127 |
+
Return list of (section_title, start_idx, end_idx_of_section_block)
|
| 128 |
+
"""
|
| 129 |
+
# Find headers by regex, map their start positions
|
| 130 |
+
headers = []
|
| 131 |
+
for m in SECTION_PATTERN.finditer(text):
|
| 132 |
+
headers.append((m.group("header"), m.start()))
|
| 133 |
+
# Add end sentinel
|
| 134 |
+
headers.append(("[END]", len(text)))
|
| 135 |
+
|
| 136 |
+
sections = []
|
| 137 |
+
for i in range(len(headers) - 1):
|
| 138 |
+
title, start_pos = headers[i]
|
| 139 |
+
next_title, next_pos = headers[i+1]
|
| 140 |
+
sections.append((title.strip(), start_pos, next_pos))
|
| 141 |
+
if not sections:
|
| 142 |
+
# Single default section if none found
|
| 143 |
+
sections = [("DOCUMENT", 0, len(text))]
|
| 144 |
+
return sections
|
| 145 |
+
|
| 146 |
+
def _find_section_for_span(self, sections, start_idx) -> str:
|
| 147 |
+
for title, s, e in sections:
|
| 148 |
+
if s <= start_idx < e:
|
| 149 |
+
return title
|
| 150 |
+
return "DOCUMENT"
|
| 151 |
+
|
| 152 |
+
def analyze(self, text: str) -> List[Dict[str, Any]]:
|
| 153 |
+
# Detect entities
|
| 154 |
+
results = self.analyzer.analyze(text=text, language="en")
|
| 155 |
+
# Convert to dict for consistency
|
| 156 |
+
detections = []
|
| 157 |
+
for r in results:
|
| 158 |
+
detections.append({
|
| 159 |
+
"entity_type": r.entity_type,
|
| 160 |
+
"start": r.start,
|
| 161 |
+
"end": r.end,
|
| 162 |
+
"score": r.score
|
| 163 |
+
})
|
| 164 |
+
return detections
|
| 165 |
+
|
| 166 |
+
def mask(self, text: str, detections: List[Dict[str, Any]]) -> Tuple[str, List[PHISpan]]:
|
| 167 |
+
"""
|
| 168 |
+
Replace spans with tags safely (right-to-left to maintain indices).
|
| 169 |
+
"""
|
| 170 |
+
# Determine sections for context
|
| 171 |
+
sections = self._detect_sections(text)
|
| 172 |
+
|
| 173 |
+
# Build PHI span records
|
| 174 |
+
spans: List[PHISpan] = []
|
| 175 |
+
for d in detections:
|
| 176 |
+
entity = d["entity_type"]
|
| 177 |
+
start = d["start"]
|
| 178 |
+
end = d["end"]
|
| 179 |
+
original = text[start:end]
|
| 180 |
+
section = self._find_section_for_span(sections, start)
|
| 181 |
+
spans.append(PHISpan(entity_type=entity, start=start, end=end, text=original, section=section))
|
| 182 |
+
|
| 183 |
+
# Replace from the end to avoid index shifting
|
| 184 |
+
masked = text
|
| 185 |
+
for d in sorted(detections, key=lambda x: x["start"], reverse=True):
|
| 186 |
+
entity = d["entity_type"]
|
| 187 |
+
start = d["start"]
|
| 188 |
+
end = d["end"]
|
| 189 |
+
tag = TAG_MAP.get(entity, f"[{entity}]")
|
| 190 |
+
masked = masked[:start] + tag + masked[end:]
|
| 191 |
+
|
| 192 |
+
return masked, spans
|
| 193 |
+
|
| 194 |
+
def encrypt_span_map(self, spans: List[PHISpan], meta: Dict[str, Any]) -> bytes:
|
| 195 |
+
payload = {
|
| 196 |
+
"meta": meta,
|
| 197 |
+
"spans": [s.__dict__ for s in spans]
|
| 198 |
+
}
|
| 199 |
+
blob = json.dumps(payload).encode("utf-8")
|
| 200 |
+
token = self.fernet.encrypt(blob)
|
| 201 |
+
return token
|
| 202 |
+
|
| 203 |
+
def run_on_text(self, text: str, note_id: str) -> Dict[str, Any]:
|
| 204 |
+
detections = self.analyze(text)
|
| 205 |
+
masked, spans = self.mask(text, detections)
|
| 206 |
+
|
| 207 |
+
# Encrypt span map
|
| 208 |
+
token = self.encrypt_span_map(
|
| 209 |
+
spans=spans,
|
| 210 |
+
meta={"note_id": note_id}
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
return {
|
| 214 |
+
"masked_text": masked,
|
| 215 |
+
"encrypted_span_map": token
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
def _read_text_with_fallback(path: str) -> str:
|
| 219 |
+
# 1) Try UTF-8 (preferred for cross-platform)
|
| 220 |
+
try:
|
| 221 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 222 |
+
return f.read()
|
| 223 |
+
except UnicodeDecodeError:
|
| 224 |
+
pass
|
| 225 |
+
# 2) Try Windows-1252 (common for Notepad/docx copy-paste on Windows)
|
| 226 |
+
try:
|
| 227 |
+
with open(path, "r", encoding="cp1252") as f:
|
| 228 |
+
return f.read()
|
| 229 |
+
except UnicodeDecodeError:
|
| 230 |
+
pass
|
| 231 |
+
# 3) Last resort: decode with replacement to avoid crashing; preserves structure
|
| 232 |
+
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
| 233 |
+
return f.read()
|
| 234 |
+
|
| 235 |
+
def run_file(input_path: str, outputs_dir: str = "data/outputs", secure_dir: str = "secure_store"):
|
| 236 |
+
os.makedirs(outputs_dir, exist_ok=True)
|
| 237 |
+
os.makedirs(secure_dir, exist_ok=True)
|
| 238 |
+
|
| 239 |
+
note_id = os.path.splitext(os.path.basename(input_path))[0]
|
| 240 |
+
text = _read_text_with_fallback(input_path)
|
| 241 |
+
|
| 242 |
+
pipeline = DeidPipeline()
|
| 243 |
+
result = pipeline.run_on_text(text, note_id=note_id)
|
| 244 |
+
|
| 245 |
+
# Save masked text normalized to UTF-8
|
| 246 |
+
out_txt = os.path.join(outputs_dir, f"{note_id}.deid.txt")
|
| 247 |
+
with open(out_txt, "w", encoding="utf-8", newline="\n") as f:
|
| 248 |
+
f.write(result["masked_text"])
|
| 249 |
+
|
| 250 |
+
# Save encrypted span map (binary)
|
| 251 |
+
out_bin = os.path.join(secure_dir, f"{note_id}.spanmap.enc")
|
| 252 |
+
with open(out_bin, "wb") as f:
|
| 253 |
+
f.write(result["encrypted_span_map"])
|
| 254 |
+
|
| 255 |
+
print(f"De-identified text -> {out_txt}")
|
| 256 |
+
print(f"Encrypted span map -> {out_bin}")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
+
import argparse
|
| 261 |
+
parser = argparse.ArgumentParser(description="De-identify a clinical note and save encrypted span map.")
|
| 262 |
+
parser.add_argument("--input", required=True, help="Path to input .txt note")
|
| 263 |
+
parser.add_argument("--outputs_dir", default="data/outputs", help="Output folder for masked text")
|
| 264 |
+
parser.add_argument("--secure_dir", default="secure_store", help="Folder for encrypted span maps")
|
| 265 |
+
args = parser.parse_args()
|
| 266 |
+
run_file(args.input, args.outputs_dir, args.secure_dir)
|
indexer.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/indexer.py
|
| 2 |
+
# Day 6: Vector store & embeddings
|
| 3 |
+
# Usage examples:
|
| 4 |
+
# python app/indexer.py --input_dir ./data/outputs --db_type chroma --persist_dir ./data/vector_store
|
| 5 |
+
# python app/indexer.py --input_dir ./data/outputs --db_type faiss --persist_dir ./data/vector_store_faiss
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import argparse
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Dict, Tuple
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
# Embeddings
|
| 15 |
+
from sentence_transformers import SentenceTransformer
|
| 16 |
+
|
| 17 |
+
# Vector stores
|
| 18 |
+
# Chroma
|
| 19 |
+
import chromadb
|
| 20 |
+
from chromadb.config import Settings as ChromaSettings
|
| 21 |
+
|
| 22 |
+
# FAISS
|
| 23 |
+
import faiss
|
| 24 |
+
import pickle
|
| 25 |
+
|
| 26 |
+
DEFAULT_CHUNK_TOKENS = 200
|
| 27 |
+
DEFAULT_OVERLAP_TOKENS = 50
|
| 28 |
+
|
| 29 |
+
def read_note_files(input_dir: str) -> List[Dict]:
|
| 30 |
+
"""
|
| 31 |
+
Reads de-identified notes from .txt or .json in input_dir.
|
| 32 |
+
Expects .json to have a 'text' field containing de-identified content.
|
| 33 |
+
Returns list of dicts: {id, text, section?}
|
| 34 |
+
"""
|
| 35 |
+
items = []
|
| 36 |
+
p = Path(input_dir)
|
| 37 |
+
if not p.exists():
|
| 38 |
+
raise FileNotFoundError(f"Input dir not found: {input_dir}")
|
| 39 |
+
|
| 40 |
+
for fp in p.glob("**/*"):
|
| 41 |
+
if fp.is_dir():
|
| 42 |
+
continue
|
| 43 |
+
if fp.suffix.lower() == ".txt":
|
| 44 |
+
text = fp.read_text(encoding="utf-8", errors="ignore").strip()
|
| 45 |
+
if text:
|
| 46 |
+
items.append({"id": fp.stem, "text": text, "section": None})
|
| 47 |
+
elif fp.suffix.lower() == ".json":
|
| 48 |
+
try:
|
| 49 |
+
obj = json.loads(fp.read_text(encoding="utf-8", errors="ignore"))
|
| 50 |
+
text = obj.get("text") or obj.get("deidentified_text") or ""
|
| 51 |
+
section = obj.get("section")
|
| 52 |
+
if text:
|
| 53 |
+
items.append({"id": fp.stem, "text": text.strip(), "section": section})
|
| 54 |
+
except Exception:
|
| 55 |
+
# Skip malformed
|
| 56 |
+
continue
|
| 57 |
+
return items
|
| 58 |
+
|
| 59 |
+
def approx_tokenize(text: str) -> List[str]:
|
| 60 |
+
"""
|
| 61 |
+
Approximate tokenization by splitting on whitespace.
|
| 62 |
+
For MVP this is fine; can replace with tiktoken later.
|
| 63 |
+
"""
|
| 64 |
+
return text.split()
|
| 65 |
+
|
| 66 |
+
def detokenize(tokens: List[str]) -> str:
|
| 67 |
+
return " ".join(tokens)
|
| 68 |
+
|
| 69 |
+
def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> List[str]:
|
| 70 |
+
"""
|
| 71 |
+
Simple sliding window chunking.
|
| 72 |
+
"""
|
| 73 |
+
tokens = approx_tokenize(text)
|
| 74 |
+
chunks = []
|
| 75 |
+
i = 0
|
| 76 |
+
n = len(tokens)
|
| 77 |
+
while i < n:
|
| 78 |
+
j = min(i + chunk_tokens, n)
|
| 79 |
+
chunk = detokenize(tokens[i:j])
|
| 80 |
+
if chunk.strip():
|
| 81 |
+
chunks.append(chunk)
|
| 82 |
+
if j == n:
|
| 83 |
+
break
|
| 84 |
+
i = j - overlap_tokens
|
| 85 |
+
if i < 0:
|
| 86 |
+
i = 0
|
| 87 |
+
return chunks
|
| 88 |
+
|
| 89 |
+
def embed_texts(model: SentenceTransformer, texts: List[str]):
|
| 90 |
+
return model.encode(texts, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
|
| 91 |
+
|
| 92 |
+
def build_chroma(persist_dir: str, collection_name: str = "notes"):
|
| 93 |
+
client = chromadb.PersistentClient(
|
| 94 |
+
path=persist_dir,
|
| 95 |
+
settings=ChromaSettings(allow_reset=True)
|
| 96 |
+
)
|
| 97 |
+
if collection_name in [c.name for c in client.list_collections()]:
|
| 98 |
+
coll = client.get_collection(collection_name)
|
| 99 |
+
else:
|
| 100 |
+
coll = client.create_collection(collection_name)
|
| 101 |
+
return client, coll
|
| 102 |
+
|
| 103 |
+
def save_faiss(index, vectors_meta: List[Dict], persist_dir: str):
|
| 104 |
+
os.makedirs(persist_dir, exist_ok=True)
|
| 105 |
+
faiss_path = os.path.join(persist_dir, "index.faiss")
|
| 106 |
+
meta_path = os.path.join(persist_dir, "meta.pkl")
|
| 107 |
+
faiss.write_index(index, faiss_path)
|
| 108 |
+
with open(meta_path, "wb") as f:
|
| 109 |
+
pickle.dump(vectors_meta, f)
|
| 110 |
+
|
| 111 |
+
def load_faiss(persist_dir: str):
|
| 112 |
+
faiss_path = os.path.join(persist_dir, "index.faiss")
|
| 113 |
+
meta_path = os.path.join(persist_dir, "meta.pkl")
|
| 114 |
+
if os.path.exists(faiss_path) and os.path.exists(meta_path):
|
| 115 |
+
index = faiss.read_index(faiss_path)
|
| 116 |
+
with open(meta_path, "rb") as f:
|
| 117 |
+
meta = pickle.load(f)
|
| 118 |
+
return index, meta
|
| 119 |
+
return None, []
|
| 120 |
+
|
| 121 |
+
def index_note(
|
| 122 |
+
text: str,
|
| 123 |
+
note_id: str = "temp_note",
|
| 124 |
+
persist_dir: str = "./data/vector_store",
|
| 125 |
+
db_type: str = "chroma",
|
| 126 |
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 127 |
+
collection: str = "notes"
|
| 128 |
+
) -> str:
|
| 129 |
+
from sentence_transformers import SentenceTransformer
|
| 130 |
+
import os
|
| 131 |
+
|
| 132 |
+
DEFAULT_CHUNK_TOKENS = 200
|
| 133 |
+
DEFAULT_OVERLAP_TOKENS = 50
|
| 134 |
+
|
| 135 |
+
def approx_tokenize(text: str):
|
| 136 |
+
return text.split()
|
| 137 |
+
|
| 138 |
+
def detokenize(tokens):
|
| 139 |
+
return " ".join(tokens)
|
| 140 |
+
|
| 141 |
+
def chunk_text(text, chunk_tokens, overlap_tokens):
|
| 142 |
+
tokens = approx_tokenize(text)
|
| 143 |
+
chunks = []
|
| 144 |
+
i = 0
|
| 145 |
+
n = len(tokens)
|
| 146 |
+
while i < n:
|
| 147 |
+
j = min(i + chunk_tokens, n)
|
| 148 |
+
chunk = detokenize(tokens[i:j])
|
| 149 |
+
if chunk.strip():
|
| 150 |
+
chunks.append(chunk)
|
| 151 |
+
if j == n:
|
| 152 |
+
break
|
| 153 |
+
i = j - overlap_tokens
|
| 154 |
+
if i < 0:
|
| 155 |
+
i = 0
|
| 156 |
+
return chunks
|
| 157 |
+
|
| 158 |
+
os.makedirs(persist_dir, exist_ok=True)
|
| 159 |
+
model = SentenceTransformer(model_name)
|
| 160 |
+
chunks = chunk_text(text, DEFAULT_CHUNK_TOKENS, DEFAULT_OVERLAP_TOKENS)
|
| 161 |
+
chunk_ids = [f"{note_id}::chunk_{i}" for i in range(len(chunks))]
|
| 162 |
+
metadatas = [{"note_id": note_id, "chunk_index": i} for i in range(len(chunks))]
|
| 163 |
+
vectors = model.encode(chunks, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
|
| 164 |
+
|
| 165 |
+
if db_type == "chroma":
|
| 166 |
+
from chromadb.config import Settings as ChromaSettings
|
| 167 |
+
import chromadb
|
| 168 |
+
client = chromadb.PersistentClient(
|
| 169 |
+
path=persist_dir,
|
| 170 |
+
settings=ChromaSettings(allow_reset=True)
|
| 171 |
+
)
|
| 172 |
+
if collection in [c.name for c in client.list_collections()]:
|
| 173 |
+
coll = client.get_collection(collection)
|
| 174 |
+
else:
|
| 175 |
+
coll = client.create_collection(collection)
|
| 176 |
+
coll.upsert(
|
| 177 |
+
ids=chunk_ids,
|
| 178 |
+
embeddings=vectors.tolist(),
|
| 179 |
+
documents=chunks,
|
| 180 |
+
metadatas=metadatas,
|
| 181 |
+
)
|
| 182 |
+
elif db_type == "faiss":
|
| 183 |
+
import faiss
|
| 184 |
+
import pickle
|
| 185 |
+
d = vectors.shape[1]
|
| 186 |
+
index = faiss.IndexFlatIP(d)
|
| 187 |
+
index.add(vectors)
|
| 188 |
+
vectors_meta = [
|
| 189 |
+
{"id": chunk_ids[k], "text": chunks[k], "meta": metadatas[k]}
|
| 190 |
+
for k in range(len(chunks))
|
| 191 |
+
]
|
| 192 |
+
faiss_path = os.path.join(persist_dir, "index.faiss")
|
| 193 |
+
meta_path = os.path.join(persist_dir, "meta.pkl")
|
| 194 |
+
faiss.write_index(index, faiss_path)
|
| 195 |
+
with open(meta_path, "wb") as f:
|
| 196 |
+
pickle.dump(vectors_meta, f)
|
| 197 |
+
|
| 198 |
+
return note_id
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def main():
|
| 202 |
+
parser = argparse.ArgumentParser(description="Day 6: Build local vector DB from de-identified notes.")
|
| 203 |
+
parser.add_argument("--input_dir", required=True, help="Directory with de-identified notes (.txt or .json).")
|
| 204 |
+
parser.add_argument("--persist_dir", default="./data/vector_store", help="Where to persist the DB.")
|
| 205 |
+
parser.add_argument("--db_type", choices=["chroma", "faiss"], default="chroma", help="Vector DB type.")
|
| 206 |
+
parser.add_argument("--model_name", default="sentence-transformers/all-MiniLM-L6-v2", help="Embedding model.")
|
| 207 |
+
parser.add_argument("--chunk_tokens", type=int, default=DEFAULT_CHUNK_TOKENS, help="Approx tokens per chunk.")
|
| 208 |
+
parser.add_argument("--overlap_tokens", type=int, default=DEFAULT_OVERLAP_TOKENS, help="Token overlap.")
|
| 209 |
+
parser.add_argument("--collection", default="notes", help="Collection name (Chroma).")
|
| 210 |
+
args = parser.parse_args()
|
| 211 |
+
|
| 212 |
+
notes = read_note_files(args.input_dir)
|
| 213 |
+
if not notes:
|
| 214 |
+
print(f"No de-identified notes found in {args.input_dir}. Ensure Day 5 outputs exist.")
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
print(f"Loaded {len(notes)} de-identified notes from {args.input_dir}")
|
| 218 |
+
os.makedirs(args.persist_dir, exist_ok=True)
|
| 219 |
+
|
| 220 |
+
print(f"Loading embedding model: {args.model_name}")
|
| 221 |
+
model = SentenceTransformer(args.model_name)
|
| 222 |
+
|
| 223 |
+
all_chunk_texts = []
|
| 224 |
+
all_chunk_ids = []
|
| 225 |
+
all_metadata = []
|
| 226 |
+
|
| 227 |
+
print("Chunking notes...")
|
| 228 |
+
for note in tqdm(notes):
|
| 229 |
+
chunks = chunk_text(note["text"], args.chunk_tokens, args.overlap_tokens)
|
| 230 |
+
for idx, ch in enumerate(chunks):
|
| 231 |
+
cid = f"{note['id']}::chunk_{idx}"
|
| 232 |
+
all_chunk_texts.append(ch)
|
| 233 |
+
all_chunk_ids.append(cid)
|
| 234 |
+
all_metadata.append({
|
| 235 |
+
"note_id": note["id"],
|
| 236 |
+
"chunk_index": idx,
|
| 237 |
+
"section": note.get("section")
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
print(f"Total chunks: {len(all_chunk_texts)}")
|
| 241 |
+
|
| 242 |
+
print("Embedding chunks...")
|
| 243 |
+
vectors = embed_texts(model, all_chunk_texts)
|
| 244 |
+
|
| 245 |
+
if args.db_type == "chroma":
|
| 246 |
+
print("Building Chroma persistent collection...")
|
| 247 |
+
client, coll = build_chroma(args.persist_dir, args.collection)
|
| 248 |
+
|
| 249 |
+
# Upsert in manageable batches
|
| 250 |
+
batch = 512
|
| 251 |
+
for i in tqdm(range(0, len(all_chunk_texts), batch)):
|
| 252 |
+
j = min(i + batch, len(all_chunk_texts))
|
| 253 |
+
coll.upsert(
|
| 254 |
+
ids=all_chunk_ids[i:j],
|
| 255 |
+
embeddings=vectors[i:j].tolist(),
|
| 256 |
+
documents=all_chunk_texts[i:j],
|
| 257 |
+
metadatas=all_metadata[i:j],
|
| 258 |
+
)
|
| 259 |
+
print(f"Chroma collection '{args.collection}' persisted at {args.persist_dir}")
|
| 260 |
+
|
| 261 |
+
elif args.db_type == "faiss":
|
| 262 |
+
print("Building FAISS index...")
|
| 263 |
+
d = vectors.shape[1]
|
| 264 |
+
index = faiss.IndexFlatIP(d) # normalized vectors → use inner product as cosine
|
| 265 |
+
# Try to load existing
|
| 266 |
+
existing_index, existing_meta = load_faiss(args.persist_dir)
|
| 267 |
+
if existing_index is not None:
|
| 268 |
+
print("Appending to existing FAISS index...")
|
| 269 |
+
index = existing_index
|
| 270 |
+
vectors_meta = existing_meta
|
| 271 |
+
else:
|
| 272 |
+
vectors_meta = []
|
| 273 |
+
index.add(vectors)
|
| 274 |
+
vectors_meta.extend([
|
| 275 |
+
{
|
| 276 |
+
"id": all_chunk_ids[k],
|
| 277 |
+
"text": all_chunk_texts[k],
|
| 278 |
+
"meta": all_metadata[k]
|
| 279 |
+
} for k in range(len(all_chunk_texts))
|
| 280 |
+
])
|
| 281 |
+
save_faiss(index, vectors_meta, args.persist_dir)
|
| 282 |
+
print(f"FAISS index persisted at {args.persist_dir}")
|
| 283 |
+
|
| 284 |
+
print("Done.")
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
main()
|
| 288 |
+
##result = pipeline.run_on_text(text=note_text, note_id="temp_note")
|
| 289 |
+
##deid_text = result["masked_text"]
|
main.py
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- Imports and page setup ---
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import yaml
|
| 4 |
+
from yaml.loader import SafeLoader
|
| 5 |
+
import streamlit_authenticator as stauth
|
| 6 |
+
import uuid
|
| 7 |
+
import datetime
|
| 8 |
+
from audit import AuditLogger
|
| 9 |
+
|
| 10 |
+
st.set_page_config(page_title="Clinical Summarizer", layout="wide")
|
| 11 |
+
st.title("HIPAA-compliant Clinical RAG Summarizer (MVP)")
|
| 12 |
+
|
| 13 |
+
# --- Authentication setup ---
|
| 14 |
+
def load_config():
|
| 15 |
+
"""Load configuration from Streamlit secrets or local YAML"""
|
| 16 |
+
try:
|
| 17 |
+
# Check if running on Streamlit Cloud (secrets available)
|
| 18 |
+
if "credentials" in st.secrets:
|
| 19 |
+
# Convert immutable Streamlit secrets to mutable dict
|
| 20 |
+
config = {
|
| 21 |
+
"credentials": {
|
| 22 |
+
"usernames": {}
|
| 23 |
+
},
|
| 24 |
+
"cookie": {
|
| 25 |
+
"name": str(st.secrets["cookie"]["name"]),
|
| 26 |
+
"key": str(st.secrets["cookie"]["key"]),
|
| 27 |
+
"expiry_days": int(st.secrets["cookie"]["expiry_days"])
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
# Convert each user to mutable dict
|
| 32 |
+
for username, user_data in st.secrets["credentials"]["usernames"].items():
|
| 33 |
+
config["credentials"]["usernames"][str(username)] = {
|
| 34 |
+
"email": str(user_data["email"]),
|
| 35 |
+
"failed_login_attempts": int(user_data.get("failed_login_attempts", 0)),
|
| 36 |
+
"logged_in": bool(user_data.get("logged_in", False)),
|
| 37 |
+
"name": str(user_data["name"]),
|
| 38 |
+
"password": str(user_data["password"]),
|
| 39 |
+
"role": str(user_data["role"])
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
return config
|
| 43 |
+
else:
|
| 44 |
+
# Local development: Load from YAML file
|
| 45 |
+
with open("app/streamlit_config.yaml") as f:
|
| 46 |
+
return yaml.load(f, Loader=SafeLoader)
|
| 47 |
+
|
| 48 |
+
except FileNotFoundError:
|
| 49 |
+
st.error("⚠️ Configuration file not found. Please set up authentication.")
|
| 50 |
+
st.stop()
|
| 51 |
+
except Exception as e:
|
| 52 |
+
st.error(f"⚠️ Configuration error: {e}")
|
| 53 |
+
st.info("Make sure secrets are configured in Streamlit Cloud settings.")
|
| 54 |
+
st.stop()
|
| 55 |
+
|
| 56 |
+
# Load config
|
| 57 |
+
config = load_config()
|
| 58 |
+
|
| 59 |
+
# Create authenticator with mutable config
|
| 60 |
+
authenticator = stauth.Authenticate(
|
| 61 |
+
config["credentials"],
|
| 62 |
+
config["cookie"]["name"],
|
| 63 |
+
config["cookie"]["key"],
|
| 64 |
+
config["cookie"]["expiry_days"],
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Render the login widget
|
| 68 |
+
authenticator.login(location="sidebar")
|
| 69 |
+
|
| 70 |
+
# Read values from session_state
|
| 71 |
+
auth_status = st.session_state.get("authentication_status")
|
| 72 |
+
username = st.session_state.get("username")
|
| 73 |
+
name = st.session_state.get("name")
|
| 74 |
+
|
| 75 |
+
if auth_status is False:
|
| 76 |
+
st.error("Invalid username or password")
|
| 77 |
+
st.stop()
|
| 78 |
+
elif auth_status is None:
|
| 79 |
+
st.info("Please log in")
|
| 80 |
+
st.stop()
|
| 81 |
+
else:
|
| 82 |
+
role = config["credentials"]["usernames"][username]["role"]
|
| 83 |
+
st.session_state["role"] = role
|
| 84 |
+
with st.sidebar:
|
| 85 |
+
st.header("Clinical RAG Summarizer")
|
| 86 |
+
st.markdown("HIPAA-compliant, secure, and easy to use.")
|
| 87 |
+
st.markdown("---")
|
| 88 |
+
st.success(f"Logged in as {name}")
|
| 89 |
+
st.markdown(f"**Role:** {role}")
|
| 90 |
+
authenticator.logout("Logout", location="sidebar")
|
| 91 |
+
st.markdown("---")
|
| 92 |
+
st.info("Use the tabs above to upload notes, generate summaries, and view logs.")
|
| 93 |
+
|
| 94 |
+
# Clear ChromaDB cache to prevent singleton conflicts
|
| 95 |
+
try:
|
| 96 |
+
from chromadb.api.client import SharedSystemClient
|
| 97 |
+
SharedSystemClient.clear_system_cache()
|
| 98 |
+
except:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
# Generate a unique persist_dir for each session if not already set
|
| 102 |
+
if "persist_dir" not in st.session_state:
|
| 103 |
+
if st.session_state.get("username"):
|
| 104 |
+
st.session_state["persist_dir"] = f"./data/vector_store_{st.session_state['username']}"
|
| 105 |
+
else:
|
| 106 |
+
unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + str(uuid.uuid4())[:8]
|
| 107 |
+
st.session_state["persist_dir"] = f"./data/vector_store_{unique_id}"
|
| 108 |
+
|
| 109 |
+
# Initialize the audit logger
|
| 110 |
+
audit_logger = AuditLogger()
|
| 111 |
+
|
| 112 |
+
# Initialize model cache in session state
|
| 113 |
+
if "t5_model" not in st.session_state:
|
| 114 |
+
st.session_state["t5_model"] = None
|
| 115 |
+
if "t5_tokenizer" not in st.session_state:
|
| 116 |
+
st.session_state["t5_tokenizer"] = None
|
| 117 |
+
|
| 118 |
+
# --- Tabs ---
|
| 119 |
+
upload_tab, summarize_tab, logs_tab = st.tabs(["Upload/Enter Note", "Summarize", "Logs"])
|
| 120 |
+
|
| 121 |
+
# --- Upload/Enter Note tab ---
|
| 122 |
+
with upload_tab:
|
| 123 |
+
st.subheader("Enter or Upload Note")
|
| 124 |
+
st.caption("Paste a synthetic note or upload a .txt file, then de-identify and index.")
|
| 125 |
+
|
| 126 |
+
col_upload, col_text = st.columns([1, 2])
|
| 127 |
+
with col_upload:
|
| 128 |
+
file = st.file_uploader("Upload .txt file", type=["txt"])
|
| 129 |
+
with col_text:
|
| 130 |
+
note_text = st.text_area("Paste note text", height=200, placeholder="Paste clinical note text here...")
|
| 131 |
+
|
| 132 |
+
col1, col2 = st.columns(2)
|
| 133 |
+
with col1:
|
| 134 |
+
deid_index_clicked = st.button("De-identify & Index", use_container_width=True)
|
| 135 |
+
with col2:
|
| 136 |
+
skip_index_clicked = st.button("Skip (already indexed)", use_container_width=True)
|
| 137 |
+
|
| 138 |
+
if file and not note_text:
|
| 139 |
+
note_text = file.read().decode("utf-8", errors="ignore")
|
| 140 |
+
|
| 141 |
+
if deid_index_clicked and note_text:
|
| 142 |
+
try:
|
| 143 |
+
with st.spinner("De-identifying and indexing..."):
|
| 144 |
+
from deid_pipeline import DeidPipeline
|
| 145 |
+
pipeline = DeidPipeline()
|
| 146 |
+
result = pipeline.run_on_text(text=note_text, note_id="temp_note")
|
| 147 |
+
deid_text = result["masked_text"]
|
| 148 |
+
st.success("De-identified.")
|
| 149 |
+
st.text_area("De-identified preview", deid_text, height=160)
|
| 150 |
+
|
| 151 |
+
from indexer import index_note
|
| 152 |
+
# Use session-specific persist_dir
|
| 153 |
+
note_id = index_note(
|
| 154 |
+
text=deid_text,
|
| 155 |
+
note_id=f"note_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
| 156 |
+
persist_dir=st.session_state["persist_dir"],
|
| 157 |
+
db_type="chroma",
|
| 158 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 159 |
+
collection="notes"
|
| 160 |
+
)
|
| 161 |
+
st.session_state["last_note_id"] = note_id
|
| 162 |
+
st.session_state["last_deid_text"] = deid_text
|
| 163 |
+
st.session_state["last_note_indexed"] = True
|
| 164 |
+
st.success(f"✓ Indexed note_id: {note_id}")
|
| 165 |
+
st.info(f"📁 Stored in: {st.session_state['persist_dir']}")
|
| 166 |
+
|
| 167 |
+
# Audit log for indexing
|
| 168 |
+
audit_logger.log_action(
|
| 169 |
+
user=st.session_state.get('username', 'anonymous'),
|
| 170 |
+
action="INDEX_NOTE",
|
| 171 |
+
resource=note_id,
|
| 172 |
+
additional_info={
|
| 173 |
+
"text_length": len(deid_text),
|
| 174 |
+
"persist_dir": st.session_state["persist_dir"]
|
| 175 |
+
}
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Audit log for de-identification
|
| 179 |
+
audit_logger.log_action(
|
| 180 |
+
user=st.session_state.get('username', 'anonymous'),
|
| 181 |
+
action="DEID_PROCESS",
|
| 182 |
+
resource="temp_note",
|
| 183 |
+
additional_info={"original_length": len(note_text), "deid_length": len(deid_text)}
|
| 184 |
+
)
|
| 185 |
+
st.toast("Note indexed and de-identified!", icon="✅")
|
| 186 |
+
except Exception as e:
|
| 187 |
+
st.error(f"De-identification error: {e}")
|
| 188 |
+
import traceback
|
| 189 |
+
st.code(traceback.format_exc())
|
| 190 |
+
elif skip_index_clicked and note_text:
|
| 191 |
+
st.session_state["last_deid_text"] = note_text
|
| 192 |
+
st.session_state["last_note_indexed"] = False
|
| 193 |
+
st.info("Skipped indexing; text saved for summarization.")
|
| 194 |
+
|
| 195 |
+
# Audit log for skipping
|
| 196 |
+
audit_logger.log_action(
|
| 197 |
+
user=st.session_state.get('username', 'anonymous'),
|
| 198 |
+
action="SKIP_INDEX",
|
| 199 |
+
resource="temp_note",
|
| 200 |
+
additional_info={"text_length": len(note_text)}
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if "last_deid_text" not in st.session_state:
|
| 204 |
+
st.caption("Tip: click 'De-identify & Index' or 'Skip' to carry text into the Summarize tab.")
|
| 205 |
+
else:
|
| 206 |
+
st.write(f"✓ Note text ready: {len(st.session_state['last_deid_text'])} characters")
|
| 207 |
+
st.write(f"Preview: {st.session_state['last_deid_text'][:100]}...")
|
| 208 |
+
|
| 209 |
+
# --- Summarize tab ---
|
| 210 |
+
with summarize_tab:
|
| 211 |
+
st.subheader("Summarize")
|
| 212 |
+
st.caption("Retrieves context and generates a structured clinical summary.")
|
| 213 |
+
|
| 214 |
+
from rag_pipeline import load_embedder, load_chroma, load_faiss_langchain, retrieve
|
| 215 |
+
from summarizer import make_t5, summarize_docs, validate_summary_quality
|
| 216 |
+
|
| 217 |
+
# Environment detection
|
| 218 |
+
import os
|
| 219 |
+
IS_CLOUD = os.path.exists('/mount/src')
|
| 220 |
+
if IS_CLOUD:
|
| 221 |
+
st.info("🌐 Cloud Mode: Using optimized model (flan-t5-base)")
|
| 222 |
+
|
| 223 |
+
# Clear ChromaDB system cache to avoid singleton conflicts
|
| 224 |
+
try:
|
| 225 |
+
import chromadb
|
| 226 |
+
from chromadb.api.client import SharedSystemClient
|
| 227 |
+
SharedSystemClient.clear_system_cache()
|
| 228 |
+
except Exception as e:
|
| 229 |
+
st.warning(f"Could not clear ChromaDB cache: {e}")
|
| 230 |
+
|
| 231 |
+
# Show current vector store location
|
| 232 |
+
st.info(f"📁 Using vector store: {st.session_state['persist_dir']}")
|
| 233 |
+
|
| 234 |
+
source_choice = st.radio("Use source:", ["Last de-identified text", "Note ID"], horizontal=True)
|
| 235 |
+
default_note_id = st.session_state.get("last_note_id", "")
|
| 236 |
+
user_note_id = st.text_input("Note ID (optional)", value=str(default_note_id))
|
| 237 |
+
|
| 238 |
+
# Add method selection
|
| 239 |
+
method_choice = st.radio("Extraction method:", ["multistage", "singleshot"], horizontal=True,
|
| 240 |
+
help="Multistage: Better quality, slower. Singleshot: Faster, may miss details.")
|
| 241 |
+
|
| 242 |
+
generate_clicked = st.button("Generate Summary", type="primary", use_container_width=True)
|
| 243 |
+
|
| 244 |
+
if generate_clicked:
|
| 245 |
+
try:
|
| 246 |
+
with st.spinner("Retrieving context..."):
|
| 247 |
+
embed_model = "sentence-transformers/all-MiniLM-L6-v2"
|
| 248 |
+
db_type = "chroma"
|
| 249 |
+
persist_dir = st.session_state["persist_dir"]
|
| 250 |
+
collection = "notes"
|
| 251 |
+
top_k = 5
|
| 252 |
+
|
| 253 |
+
# Cache vector database in session state to avoid recreating
|
| 254 |
+
cache_key = f"vdb_{persist_dir}_{collection}"
|
| 255 |
+
|
| 256 |
+
if cache_key not in st.session_state:
|
| 257 |
+
st.info("⏳ Loading vector database (first time)...")
|
| 258 |
+
_, embeddings = load_embedder(embed_model)
|
| 259 |
+
|
| 260 |
+
# Clear cache before creating new instance
|
| 261 |
+
try:
|
| 262 |
+
SharedSystemClient.clear_system_cache()
|
| 263 |
+
except:
|
| 264 |
+
pass
|
| 265 |
+
|
| 266 |
+
if db_type == "chroma":
|
| 267 |
+
vdb = load_chroma(persist_dir, collection, embeddings)
|
| 268 |
+
else:
|
| 269 |
+
vdb = load_faiss_langchain(persist_dir, embeddings)
|
| 270 |
+
|
| 271 |
+
st.session_state[cache_key] = vdb
|
| 272 |
+
st.success("✓ Vector database loaded")
|
| 273 |
+
else:
|
| 274 |
+
vdb = st.session_state[cache_key]
|
| 275 |
+
st.info("✓ Using cached vector database")
|
| 276 |
+
|
| 277 |
+
# Use actual note content for retrieval
|
| 278 |
+
if source_choice == "Note ID" and user_note_id:
|
| 279 |
+
query_text = user_note_id
|
| 280 |
+
st.info(f"🔍 Retrieving by Note ID: {user_note_id}")
|
| 281 |
+
else:
|
| 282 |
+
deid_text = st.session_state.get("last_deid_text", "")
|
| 283 |
+
if not deid_text:
|
| 284 |
+
st.warning("No de-identified text available. Please use the Upload tab first.")
|
| 285 |
+
st.stop()
|
| 286 |
+
query_text = deid_text[:500]
|
| 287 |
+
st.info(f"🔍 Retrieving using note content ({len(deid_text)} chars)")
|
| 288 |
+
|
| 289 |
+
docs = retrieve(vdb, query_text, top_k)
|
| 290 |
+
|
| 291 |
+
if not docs:
|
| 292 |
+
st.error("⚠ No documents retrieved from vector database!")
|
| 293 |
+
st.warning("This usually means:")
|
| 294 |
+
st.write("• The vector database is empty")
|
| 295 |
+
st.write("• The note wasn't properly indexed")
|
| 296 |
+
st.write(f"• Check if files exist in: {persist_dir}")
|
| 297 |
+
|
| 298 |
+
if st.button("🔄 Clear cache and retry"):
|
| 299 |
+
if cache_key in st.session_state:
|
| 300 |
+
del st.session_state[cache_key]
|
| 301 |
+
SharedSystemClient.clear_system_cache()
|
| 302 |
+
st.rerun()
|
| 303 |
+
st.stop()
|
| 304 |
+
|
| 305 |
+
st.success(f"✓ Retrieved {len(docs)} document(s)")
|
| 306 |
+
|
| 307 |
+
# Show preview of retrieved content
|
| 308 |
+
with st.expander("View retrieved content"):
|
| 309 |
+
for i, doc in enumerate(docs, 1):
|
| 310 |
+
st.write(f"**Document {i}:**")
|
| 311 |
+
st.code(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content)
|
| 312 |
+
|
| 313 |
+
with st.spinner("Generating summary... (this may take 1-2 minutes on CPU)"):
|
| 314 |
+
# Cache model loading in session state
|
| 315 |
+
if st.session_state["t5_model"] is None or st.session_state["t5_tokenizer"] is None:
|
| 316 |
+
st.info("⏳ Loading T5 model (first time only)...")
|
| 317 |
+
tokenizer, model = make_t5("google/flan-t5-base")
|
| 318 |
+
st.session_state["t5_tokenizer"] = tokenizer
|
| 319 |
+
st.session_state["t5_model"] = model
|
| 320 |
+
else:
|
| 321 |
+
tokenizer = st.session_state["t5_tokenizer"]
|
| 322 |
+
model = st.session_state["t5_model"]
|
| 323 |
+
st.info("✓ Using cached model")
|
| 324 |
+
|
| 325 |
+
# Generate summary
|
| 326 |
+
summary = summarize_docs(tokenizer, model, docs, method=method_choice)
|
| 327 |
+
|
| 328 |
+
# Store summary in session state
|
| 329 |
+
summary_key = f"summary_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 330 |
+
st.session_state["last_summary"] = summary
|
| 331 |
+
st.session_state["last_summary_key"] = summary_key
|
| 332 |
+
|
| 333 |
+
# Validation
|
| 334 |
+
original_text = st.session_state.get("last_deid_text", "")
|
| 335 |
+
validation = validate_summary_quality(summary, original_text)
|
| 336 |
+
|
| 337 |
+
# Display validation results
|
| 338 |
+
status_color = {
|
| 339 |
+
"GOOD": "🟢",
|
| 340 |
+
"FAIR": "🟡",
|
| 341 |
+
"POOR": "🟠",
|
| 342 |
+
"FAILED": "🔴"
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
st.success("✓ Summary generated successfully")
|
| 346 |
+
|
| 347 |
+
# Show quality assessment in two columns
|
| 348 |
+
col_status, col_score = st.columns([3, 1])
|
| 349 |
+
with col_status:
|
| 350 |
+
st.markdown(f"### {status_color.get(validation['status'], '⚪')} Quality Status: **{validation['status']}**")
|
| 351 |
+
with col_score:
|
| 352 |
+
st.metric("Quality Score", f"{validation['quality_score']}/100")
|
| 353 |
+
|
| 354 |
+
# Display critical issues if any
|
| 355 |
+
if validation['issues']:
|
| 356 |
+
st.error("**❌ Critical Issues Detected:**")
|
| 357 |
+
for issue in validation['issues']:
|
| 358 |
+
st.markdown(f"- {issue}")
|
| 359 |
+
st.markdown("**Recommendation:** Review de-identification settings and retrieval quality")
|
| 360 |
+
|
| 361 |
+
# Display warnings if any
|
| 362 |
+
if validation['warnings']:
|
| 363 |
+
st.warning("**⚠️ Quality Warnings:**")
|
| 364 |
+
for warning in validation['warnings']:
|
| 365 |
+
st.markdown(f"- {warning}")
|
| 366 |
+
|
| 367 |
+
# Show detailed quality metrics in expandable section
|
| 368 |
+
with st.expander("📊 Detailed Quality Metrics"):
|
| 369 |
+
metric_col1, metric_col2, metric_col3, metric_col4 = st.columns(4)
|
| 370 |
+
with metric_col1:
|
| 371 |
+
st.metric("PHI Placeholders", validation['metrics']['total_placeholders'])
|
| 372 |
+
with metric_col2:
|
| 373 |
+
st.metric("Empty Sections", validation['metrics']['empty_sections'])
|
| 374 |
+
with metric_col3:
|
| 375 |
+
st.metric("Filled Sections", f"{validation['metrics']['filled_sections']}/7")
|
| 376 |
+
with metric_col4:
|
| 377 |
+
st.metric("Total Length", f"{validation['metrics']['total_length']} chars")
|
| 378 |
+
|
| 379 |
+
# Show warning banner if quality is poor
|
| 380 |
+
if validation['status'] in ['POOR', 'FAILED']:
|
| 381 |
+
st.warning("⚠️ **Quality Alert:** The summary below has significant quality issues. Review carefully before clinical use.")
|
| 382 |
+
elif validation['status'] == 'FAIR':
|
| 383 |
+
st.info("ℹ️ The summary has minor quality issues. Review the warnings above.")
|
| 384 |
+
else:
|
| 385 |
+
st.success("✅ Summary quality is acceptable.")
|
| 386 |
+
|
| 387 |
+
# Display the summary
|
| 388 |
+
st.text_area("Structured Summary", summary, height=400, key=f"summary_display_{summary_key}")
|
| 389 |
+
st.download_button("Download .txt", data=summary, file_name=f"summary_{user_note_id or 'latest'}.txt")
|
| 390 |
+
|
| 391 |
+
# Show summary statistics
|
| 392 |
+
col1, col2, col3 = st.columns(3)
|
| 393 |
+
with col1:
|
| 394 |
+
st.metric("Summary Length", f"{len(summary)} chars")
|
| 395 |
+
with col2:
|
| 396 |
+
st.metric("Documents Retrieved", len(docs))
|
| 397 |
+
with col3:
|
| 398 |
+
sections_filled = 7 - summary.count("None stated")
|
| 399 |
+
st.metric("Sections Filled", f"{sections_filled}/7")
|
| 400 |
+
|
| 401 |
+
# Audit log for summary generation with validation results
|
| 402 |
+
audit_logger.log_action(
|
| 403 |
+
user=st.session_state.get('username', 'anonymous'),
|
| 404 |
+
action="GENERATE_SUMMARY",
|
| 405 |
+
resource=user_note_id or "temp_note",
|
| 406 |
+
additional_info={
|
| 407 |
+
"retrieved_docs": len(docs),
|
| 408 |
+
"method": method_choice,
|
| 409 |
+
"summary_length": len(summary),
|
| 410 |
+
"persist_dir": persist_dir,
|
| 411 |
+
"sections_filled": sections_filled,
|
| 412 |
+
"quality_status": validation['status'],
|
| 413 |
+
"quality_score": validation['quality_score'],
|
| 414 |
+
"validation_issues": len(validation['issues']),
|
| 415 |
+
"validation_warnings": len(validation['warnings']),
|
| 416 |
+
"phi_placeholders": validation['metrics']['total_placeholders']
|
| 417 |
+
}
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
except ValueError as ve:
|
| 421 |
+
if "already exists" in str(ve):
|
| 422 |
+
st.error("❌ ChromaDB instance conflict detected!")
|
| 423 |
+
st.warning("This happens when the vector database is accessed with different settings.")
|
| 424 |
+
st.info("**Solution:** Click the button below to clear the cache and retry.")
|
| 425 |
+
|
| 426 |
+
if st.button("🔄 Clear ChromaDB cache and retry", type="primary"):
|
| 427 |
+
try:
|
| 428 |
+
SharedSystemClient.clear_system_cache()
|
| 429 |
+
except:
|
| 430 |
+
pass
|
| 431 |
+
|
| 432 |
+
keys_to_delete = [k for k in st.session_state.keys() if k.startswith("vdb_")]
|
| 433 |
+
for key in keys_to_delete:
|
| 434 |
+
del st.session_state[key]
|
| 435 |
+
|
| 436 |
+
st.success("✓ Cache cleared! Click 'Generate Summary' again.")
|
| 437 |
+
st.rerun()
|
| 438 |
+
else:
|
| 439 |
+
st.error(f"❌ Error during summarization: {ve}")
|
| 440 |
+
import traceback
|
| 441 |
+
st.code(traceback.format_exc())
|
| 442 |
+
except Exception as e:
|
| 443 |
+
st.error(f"❌ Error during summarization: {e}")
|
| 444 |
+
import traceback
|
| 445 |
+
st.code(traceback.format_exc())
|
| 446 |
+
|
| 447 |
+
# Show last summary if available (when button not clicked)
|
| 448 |
+
elif "last_summary" in st.session_state:
|
| 449 |
+
st.info("Showing last generated summary:")
|
| 450 |
+
st.text_area("Last Summary", st.session_state["last_summary"], height=400)
|
| 451 |
+
st.download_button("Download Last Summary",
|
| 452 |
+
data=st.session_state["last_summary"],
|
| 453 |
+
file_name="last_summary.txt")
|
| 454 |
+
|
| 455 |
+
# --- Logs tab ---
|
| 456 |
+
with logs_tab:
|
| 457 |
+
st.subheader("Logs")
|
| 458 |
+
if st.session_state.get("role") != "admin":
|
| 459 |
+
st.info("Admins only.")
|
| 460 |
+
else:
|
| 461 |
+
st.caption("Audit logs for all user actions.")
|
| 462 |
+
|
| 463 |
+
# Audit log for viewing logs
|
| 464 |
+
audit_logger.log_action(
|
| 465 |
+
user=st.session_state.get('username', 'anonymous'),
|
| 466 |
+
action="VIEW_LOGS",
|
| 467 |
+
resource="app_audit.jsonl"
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# Add log filtering
|
| 471 |
+
col1, col2 = st.columns([3, 1])
|
| 472 |
+
with col1:
|
| 473 |
+
filter_action = st.selectbox("Filter by action:",
|
| 474 |
+
["All", "INDEX_NOTE", "GENERATE_SUMMARY", "DEID_PROCESS", "VIEW_LOGS"])
|
| 475 |
+
with col2:
|
| 476 |
+
num_lines = st.number_input("Show last N lines:", min_value=10, max_value=500, value=50)
|
| 477 |
+
|
| 478 |
+
try:
|
| 479 |
+
import json
|
| 480 |
+
with open("logs/app_audit.jsonl") as f:
|
| 481 |
+
lines = f.readlines()[-num_lines:]
|
| 482 |
+
|
| 483 |
+
st.write(f"Showing {len(lines)} most recent log entries:")
|
| 484 |
+
|
| 485 |
+
for line in lines:
|
| 486 |
+
try:
|
| 487 |
+
log_entry = json.loads(line.strip())
|
| 488 |
+
if filter_action == "All" or log_entry.get("action") == filter_action:
|
| 489 |
+
with st.expander(f"{log_entry.get('timestamp', 'N/A')} - {log_entry.get('action', 'N/A')}"):
|
| 490 |
+
st.json(log_entry)
|
| 491 |
+
except json.JSONDecodeError:
|
| 492 |
+
st.code(line.strip())
|
| 493 |
+
except FileNotFoundError:
|
| 494 |
+
st.warning("No logs found yet. Logs will appear after you perform actions.")
|
notes.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_chroma import Chroma
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
from langchain.embeddings.base import Embeddings
|
| 4 |
+
|
| 5 |
+
# 1. Wrap SentenceTransformer in a LangChain-compatible class
|
| 6 |
+
class STEmbeddings(Embeddings):
|
| 7 |
+
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
|
| 8 |
+
self.model = SentenceTransformer(model_name)
|
| 9 |
+
|
| 10 |
+
def embed_documents(self, texts):
|
| 11 |
+
return self.model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).tolist()
|
| 12 |
+
|
| 13 |
+
def embed_query(self, text):
|
| 14 |
+
return self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)[0].tolist()
|
| 15 |
+
|
| 16 |
+
# 2. Instantiate embeddings
|
| 17 |
+
embeddings = STEmbeddings()
|
| 18 |
+
|
| 19 |
+
# 3. Create or load Chroma collection
|
| 20 |
+
db = Chroma(
|
| 21 |
+
collection_name="notes",
|
| 22 |
+
persist_directory="./data/vector_store",
|
| 23 |
+
embedding_function=embeddings
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# 4. Add some sample texts
|
| 27 |
+
texts = [
|
| 28 |
+
"Patient presents with chest pain for 2 days.",
|
| 29 |
+
"History of hypertension and diabetes.",
|
| 30 |
+
"Currently taking metformin and lisinopril.",
|
| 31 |
+
"No known drug allergies.",
|
| 32 |
+
"Plan: schedule ECG and follow-up in 1 week."
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
metadatas = [
|
| 36 |
+
{"note_id": "1", "section": "HPI", "chunk_index": 0},
|
| 37 |
+
{"note_id": "1", "section": "PMH", "chunk_index": 0},
|
| 38 |
+
{"note_id": "1", "section": "Medications", "chunk_index": 0},
|
| 39 |
+
{"note_id": "1", "section": "Allergies", "chunk_index": 0},
|
| 40 |
+
{"note_id": "1", "section": "Plan", "chunk_index": 0},
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
db.add_texts(texts=texts, metadatas=metadatas)
|
| 44 |
+
|
| 45 |
+
# 5. Persist to disk
|
| 46 |
+
|
| 47 |
+
print("Ingestion complete. Collection 'notes' is ready.")
|
quick_check_chroma.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# quick_check_chroma.py
|
| 2 |
+
import chromadb
|
| 3 |
+
from chromadb.config import Settings as ChromaSettings
|
| 4 |
+
|
| 5 |
+
persist_dir = "./data/vector_store"
|
| 6 |
+
collection_name = "notes"
|
| 7 |
+
|
| 8 |
+
client = chromadb.PersistentClient(path=persist_dir, settings=ChromaSettings())
|
| 9 |
+
coll = client.get_collection(collection_name)
|
| 10 |
+
|
| 11 |
+
query = "Type 2 diabetes management plan with metformin"
|
| 12 |
+
res = coll.query(
|
| 13 |
+
query_texts=[query],
|
| 14 |
+
n_results=3,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
for i, doc in enumerate(res["documents"][0]):
|
| 18 |
+
print(f"\nTop {i+1} doc:")
|
| 19 |
+
print(doc)
|
| 20 |
+
print("Meta:", res["metadatas"][0][i])
|
rag_pipeline.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/rag_pipeline.py
|
| 2 |
+
# Day 7: Retriever + RAG baseline (retrieval only; generation comes on Day 8)
|
| 3 |
+
# Example usage:
|
| 4 |
+
# python app/rag_pipeline.py --db_type chroma --persist_dir ./data/vector_store --collection notes --query "Summarize into HPI/Assessment/Plan" --top_k 5
|
| 5 |
+
# python app/rag_pipeline.py --db_type faiss --persist_dir ./data/vector_store_faiss --query "Extract Assessment and Plan" --top_k 5
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import argparse
|
| 9 |
+
import pickle
|
| 10 |
+
from typing import List, Dict
|
| 11 |
+
import uuid
|
| 12 |
+
import datetime
|
| 13 |
+
import shutil
|
| 14 |
+
|
| 15 |
+
from sentence_transformers import SentenceTransformer
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
# LangChain vector store wrappers
|
| 19 |
+
from langchain_community.vectorstores import Chroma, FAISS
|
| 20 |
+
from langchain_core.documents import Document
|
| 21 |
+
|
| 22 |
+
# For FAISS manual load if using custom persisted index
|
| 23 |
+
import faiss
|
| 24 |
+
from chromadb.config import Settings as ChromaSettings
|
| 25 |
+
|
| 26 |
+
def load_embedder(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
| 27 |
+
model = SentenceTransformer(model_name)
|
| 28 |
+
def embed_f(texts: List[str]) -> List[List[float]]:
|
| 29 |
+
vecs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
|
| 30 |
+
return vecs.tolist()
|
| 31 |
+
return model, embed_f
|
| 32 |
+
|
| 33 |
+
def load_chroma(persist_dir: str, collection: str, embed_f):
|
| 34 |
+
from langchain.embeddings.base import Embeddings
|
| 35 |
+
class STEmbeddings(Embeddings):
|
| 36 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 37 |
+
return embed_f(texts)
|
| 38 |
+
def embed_query(self, text: str) -> List[float]:
|
| 39 |
+
return embed_f([text])[0]
|
| 40 |
+
|
| 41 |
+
embeddings = STEmbeddings()
|
| 42 |
+
vectordb = Chroma(
|
| 43 |
+
collection_name=collection,
|
| 44 |
+
persist_directory=persist_dir,
|
| 45 |
+
embedding_function=embeddings
|
| 46 |
+
)
|
| 47 |
+
return vectordb
|
| 48 |
+
|
| 49 |
+
def load_faiss_langchain(persist_dir: str, embed_f):
|
| 50 |
+
# If Day 6 saved FAISS with LangChain’s FAISS.save_local, we can do:
|
| 51 |
+
# return FAISS.load_local(persist_dir, embeddings, allow_dangerous_deserialization=True)
|
| 52 |
+
# But Day 6 saved raw FAISS + meta.pkl; handle that manually and wrap.
|
| 53 |
+
from langchain.embeddings.base import Embeddings
|
| 54 |
+
class STEmbeddings(Embeddings):
|
| 55 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 56 |
+
return embed_f(texts)
|
| 57 |
+
def embed_query(self, text: str) -> List[float]:
|
| 58 |
+
return embed_f([text])[0]
|
| 59 |
+
embeddings = STEmbeddings()
|
| 60 |
+
|
| 61 |
+
index_path = os.path.join(persist_dir, "index.faiss")
|
| 62 |
+
meta_path = os.path.join(persist_dir, "meta.pkl")
|
| 63 |
+
if not (os.path.exists(index_path) and os.path.exists(meta_path)):
|
| 64 |
+
raise FileNotFoundError(f"FAISS files not found in {persist_dir}")
|
| 65 |
+
|
| 66 |
+
index = faiss.read_index(index_path)
|
| 67 |
+
with open(meta_path, "rb") as f:
|
| 68 |
+
meta = pickle.load(f)
|
| 69 |
+
|
| 70 |
+
# Build FAISS VectorStore from texts + metadata to leverage LC retriever
|
| 71 |
+
texts = [m["text"] for m in meta]
|
| 72 |
+
metadatas = [m["meta"] | {"id": m["id"]} for m in meta]
|
| 73 |
+
vectordb = FAISS.from_texts(texts=texts, embedding=embeddings, metadatas=metadatas)
|
| 74 |
+
# Replace the underlying index with prebuilt (saves re-embedding cost when querying)
|
| 75 |
+
vectordb.index = index
|
| 76 |
+
return vectordb
|
| 77 |
+
|
| 78 |
+
def retrieve(vdb, query: str, top_k: int = 5):
|
| 79 |
+
retriever = vdb.as_retriever(search_kwargs={"k": top_k})
|
| 80 |
+
docs: List[Document] = retriever.invoke(query)
|
| 81 |
+
return docs
|
| 82 |
+
|
| 83 |
+
def format_context(docs: List[Document]) -> str:
|
| 84 |
+
parts = []
|
| 85 |
+
for i, d in enumerate(docs, 1):
|
| 86 |
+
md = d.metadata or {}
|
| 87 |
+
parts.append(f"[{i}] note_id={md.get('note_id')} section={md.get('section')} chunk_idx={md.get('chunk_index')}\n{d.page_content}")
|
| 88 |
+
return "\n\n---\n\n".join(parts)
|
| 89 |
+
|
| 90 |
+
def main():
|
| 91 |
+
parser = argparse.ArgumentParser(description="Day 7: Retriever + RAG baseline (retrieval only).")
|
| 92 |
+
parser.add_argument("--db_type", choices=["chroma", "faiss"], default="chroma")
|
| 93 |
+
parser.add_argument("--persist_dir", default="./data/vector_store")
|
| 94 |
+
parser.add_argument("--collection", default="notes")
|
| 95 |
+
parser.add_argument("--model_name", default="sentence-transformers/all-MiniLM-L6-v2")
|
| 96 |
+
parser.add_argument("--query", required=True)
|
| 97 |
+
parser.add_argument("--top_k", type=int, default=5)
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
|
| 100 |
+
# Sure shot fix: Remove existing persist_dir if it exists
|
| 101 |
+
if args.db_type == "chroma" and os.path.exists(args.persist_dir):
|
| 102 |
+
shutil.rmtree(args.persist_dir)
|
| 103 |
+
|
| 104 |
+
_, embed_f = load_embedder(args.model_name)
|
| 105 |
+
|
| 106 |
+
if args.db_type == "chroma":
|
| 107 |
+
vectordb = load_chroma(args.persist_dir, args.collection, embed_f)
|
| 108 |
+
else:
|
| 109 |
+
vectordb = load_faiss_langchain(args.persist_dir, embed_f)
|
| 110 |
+
|
| 111 |
+
docs = retrieve(vectordb, args.query, args.top_k)
|
| 112 |
+
context = format_context(docs)
|
| 113 |
+
print("\n=== Retrieved Context (to feed Day 8 summarizer) ===\n")
|
| 114 |
+
print(context)
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
main()
|
retriever_context.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag_pipeline import retrieve_context
|
| 2 |
+
from summarizer import generate_summary
|
| 3 |
+
|
| 4 |
+
query = "Summarize into HPI/Assessment/Plan"
|
| 5 |
+
retrieved_text = retrieve_context(query, top_k=5)
|
| 6 |
+
summary = generate_summary(retrieved_text)
|
| 7 |
+
print(summary)
|
run_pipeline.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# run_pipeline.py
|
| 2 |
+
from rag_pipeline import retrieve_context # <-- your Day 7 retriever
|
| 3 |
+
from transformers import pipeline
|
| 4 |
+
|
| 5 |
+
# 1. Load a summarization model
|
| 6 |
+
# Option A: summarization-tuned model (recommended for clean summaries)
|
| 7 |
+
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
| 8 |
+
|
| 9 |
+
# Option B: instruction-tuned model (if you want to experiment with prompts)
|
| 10 |
+
# summarizer = pipeline("text2text-generation", model="google/flan-t5-base")
|
| 11 |
+
|
| 12 |
+
# 2. Define a function to generate structured summary
|
| 13 |
+
def generate_summary(retrieved_text: str):
|
| 14 |
+
# For BART summarizer (Option A)
|
| 15 |
+
result = summarizer(retrieved_text, max_length=250, min_length=80, do_sample=False)
|
| 16 |
+
return result[0]['summary_text']
|
| 17 |
+
|
| 18 |
+
# If using Flan-T5 (Option B), uncomment this instead:
|
| 19 |
+
"""
|
| 20 |
+
prompt = f'''
|
| 21 |
+
You are a clinical summarization assistant.
|
| 22 |
+
Use ONLY the provided context to create a structured summary.
|
| 23 |
+
Do not invent information.
|
| 24 |
+
|
| 25 |
+
Context:
|
| 26 |
+
{retrieved_text}
|
| 27 |
+
|
| 28 |
+
Write the output in this exact format:
|
| 29 |
+
Chief Complaint: ...
|
| 30 |
+
HPI: ...
|
| 31 |
+
PMH: ...
|
| 32 |
+
Medications: ...
|
| 33 |
+
Allergies: ...
|
| 34 |
+
Assessment: ...
|
| 35 |
+
Plan: ...
|
| 36 |
+
'''
|
| 37 |
+
result = summarizer(prompt, max_new_tokens=300, do_sample=False)
|
| 38 |
+
return result[0]['generated_text']
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
# 3. Main execution
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
query = "Summarize into HPI/Assessment/Plan"
|
| 44 |
+
# Get top 5 relevant chunks from your vector store
|
| 45 |
+
retrieved_text = retrieve_context(query, top_k=5)
|
| 46 |
+
|
| 47 |
+
print("=== Retrieved Context ===")
|
| 48 |
+
print(retrieved_text)
|
| 49 |
+
print("\n=== Structured Clinical Summary ===")
|
| 50 |
+
summary = generate_summary(retrieved_text)
|
| 51 |
+
print(summary)
|
streamlit_config.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
credentials:
|
| 2 |
+
usernames:
|
| 3 |
+
clinician1:
|
| 4 |
+
email: [email protected]
|
| 5 |
+
name: Clinician One
|
| 6 |
+
password: "$2b$12$r3uqzaknAfUAsMEVIKTR2eN8yuPxu8d8YJWmPOrvNKwK.K94sjl1W"
|
| 7 |
+
role: clinician
|
| 8 |
+
admin1:
|
| 9 |
+
email: [email protected]
|
| 10 |
+
name: Admin One
|
| 11 |
+
password: "$2b$12$r3uqzaknAfUAsMEVIKTR2eN8yuPxu8d8YJWmPOrvNKwK.K94sjl1W"
|
| 12 |
+
role: admin
|
| 13 |
+
|
| 14 |
+
cookie:
|
| 15 |
+
expiry_days: 1
|
| 16 |
+
key: some_random_secret
|
| 17 |
+
name: auth_cookie
|
summarizer.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/summarizer.py
|
| 2 |
+
# Day 10: Enhanced HIPAA-compliant RAG clinical summarizer with robustness improvements
|
| 3 |
+
# Critical fixes:
|
| 4 |
+
# - Added progress indicators during model generation
|
| 5 |
+
# - Implemented timeout mechanism for long-running operations
|
| 6 |
+
# - Optimized for CPU with reduced generation parameters
|
| 7 |
+
# - Better error handling and verbose logging
|
| 8 |
+
# - Fallback to smaller max tokens if generation hangs
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
import traceback
|
| 13 |
+
from typing import List, Dict, Optional
|
| 14 |
+
import re
|
| 15 |
+
import time
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
from sentence_transformers import SentenceTransformer
|
| 19 |
+
from langchain_community.vectorstores import Chroma, FAISS
|
| 20 |
+
from langchain_core.documents import Document
|
| 21 |
+
|
| 22 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 23 |
+
|
| 24 |
+
# -----------------------------
|
| 25 |
+
# Embeddings / Vector stores
|
| 26 |
+
# -----------------------------
|
| 27 |
+
def load_embedder(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
| 28 |
+
"""
|
| 29 |
+
Load sentence transformer for embeddings.
|
| 30 |
+
For medical domain: consider "emilyalsentzer/Bio_ClinicalBERT" or similar
|
| 31 |
+
"""
|
| 32 |
+
print(f" → Loading embedding model...")
|
| 33 |
+
model = SentenceTransformer(model_name)
|
| 34 |
+
def embed_f(texts: List[str]):
|
| 35 |
+
vecs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
|
| 36 |
+
return vecs.tolist()
|
| 37 |
+
print(f" ✓ Embedding model loaded")
|
| 38 |
+
return embed_f
|
| 39 |
+
|
| 40 |
+
def load_chroma(persist_dir: str, collection: str, embed_f):
|
| 41 |
+
from langchain.embeddings.base import Embeddings
|
| 42 |
+
class STEmbeddings(Embeddings):
|
| 43 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 44 |
+
return embed_f(texts)
|
| 45 |
+
def embed_query(self, text: str) -> List[float]:
|
| 46 |
+
return embed_f([text])[0]
|
| 47 |
+
embeddings = STEmbeddings()
|
| 48 |
+
print(f" → Loading Chroma vector store from {persist_dir}...")
|
| 49 |
+
return Chroma(collection_name=collection, persist_directory=persist_dir, embedding_function=embeddings)
|
| 50 |
+
|
| 51 |
+
def load_faiss(persist_dir: str, embed_f):
|
| 52 |
+
import pickle, faiss
|
| 53 |
+
from langchain.embeddings.base import Embeddings
|
| 54 |
+
class STEmbeddings(Embeddings):
|
| 55 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 56 |
+
return embed_f(texts)
|
| 57 |
+
def embed_query(self, text: str) -> List[float]:
|
| 58 |
+
return embed_f([text])[0]
|
| 59 |
+
embeddings = STEmbeddings()
|
| 60 |
+
index_path = os.path.join(persist_dir, "index.faiss")
|
| 61 |
+
meta_path = os.path.join(persist_dir, "meta.pkl")
|
| 62 |
+
if not (os.path.exists(index_path) and os.path.exists(meta_path)):
|
| 63 |
+
raise FileNotFoundError(f"FAISS files not found in {persist_dir}")
|
| 64 |
+
print(f" → Loading FAISS index from {persist_dir}...")
|
| 65 |
+
with open(meta_path, "rb") as f:
|
| 66 |
+
meta = pickle.load(f)
|
| 67 |
+
texts = [m["text"] for m in meta]
|
| 68 |
+
metadatas = [m["meta"] | {"id": m["id"]} for m in meta]
|
| 69 |
+
vdb = FAISS.from_texts(texts=texts, embedding=embeddings, metadatas=metadatas)
|
| 70 |
+
vdb.index = faiss.read_index(index_path)
|
| 71 |
+
return vdb
|
| 72 |
+
|
| 73 |
+
def retrieve_docs(db_type: str, persist_dir: str, collection: str, query: str, top_k: int, embed_f) -> List[Document]:
|
| 74 |
+
if db_type == "chroma":
|
| 75 |
+
vdb = load_chroma(persist_dir, collection, embed_f)
|
| 76 |
+
else:
|
| 77 |
+
vdb = load_faiss(persist_dir, embed_f)
|
| 78 |
+
|
| 79 |
+
print(f" → Retrieving documents...")
|
| 80 |
+
retriever = vdb.as_retriever(search_kwargs={"k": top_k})
|
| 81 |
+
docs: List[Document] = retriever.invoke(query)
|
| 82 |
+
print(f" ✓ Retrieved {len(docs)} document(s)")
|
| 83 |
+
|
| 84 |
+
# Debug: Show retrieved content length
|
| 85 |
+
if docs:
|
| 86 |
+
total_chars = sum(len(d.page_content) for d in docs)
|
| 87 |
+
print(f" ℹ Total retrieved content: {total_chars} characters")
|
| 88 |
+
else:
|
| 89 |
+
print(f" ⚠ WARNING: No documents retrieved!")
|
| 90 |
+
|
| 91 |
+
return docs
|
| 92 |
+
|
| 93 |
+
# -----------------------------
|
| 94 |
+
# T5 Summarization utilities
|
| 95 |
+
# -----------------------------
|
| 96 |
+
def make_t5(model_name="google/flan-t5-base", device="cpu"):
|
| 97 |
+
print(f" → Loading T5 model: {model_name}")
|
| 98 |
+
print(f" ℹ This may take 30-60 seconds for large models...")
|
| 99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 100 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
| 101 |
+
print(f" ✓ Model loaded successfully")
|
| 102 |
+
return tokenizer, model
|
| 103 |
+
|
| 104 |
+
def t5_generate(tokenizer, model, prompt: str, max_input_tokens: int = 512, max_output_tokens: int = 256, section_name: str = ""):
|
| 105 |
+
"""
|
| 106 |
+
Enhanced generation with progress indicators and optimized parameters for CPU
|
| 107 |
+
"""
|
| 108 |
+
# Show progress
|
| 109 |
+
if section_name:
|
| 110 |
+
print(f" → Generating {section_name}...", end='', flush=True)
|
| 111 |
+
else:
|
| 112 |
+
print(f" → Generating summary...", end='', flush=True)
|
| 113 |
+
|
| 114 |
+
start_time = time.time()
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
inputs = tokenizer(prompt, truncation=True, max_length=max_input_tokens, return_tensors="pt")
|
| 118 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 119 |
+
|
| 120 |
+
# Optimized parameters for CPU performance
|
| 121 |
+
outputs = model.generate(
|
| 122 |
+
**inputs,
|
| 123 |
+
max_new_tokens=max_output_tokens,
|
| 124 |
+
min_length=10, # Reduced minimum to avoid forcing long outputs
|
| 125 |
+
num_beams=2, # Reduced from 4 for faster CPU generation
|
| 126 |
+
length_penalty=1.0, # Reduced from 1.5
|
| 127 |
+
no_repeat_ngram_size=3,
|
| 128 |
+
early_stopping=True, # Re-enabled for faster completion
|
| 129 |
+
do_sample=False # Deterministic generation
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 133 |
+
elapsed = time.time() - start_time
|
| 134 |
+
print(f" done ({elapsed:.1f}s)")
|
| 135 |
+
|
| 136 |
+
return result
|
| 137 |
+
except Exception as e:
|
| 138 |
+
elapsed = time.time() - start_time
|
| 139 |
+
print(f" FAILED ({elapsed:.1f}s)")
|
| 140 |
+
print(f" ✗ Error: {str(e)}")
|
| 141 |
+
return ""
|
| 142 |
+
|
| 143 |
+
def dedupe_texts(texts: List[str]) -> List[str]:
|
| 144 |
+
seen = set()
|
| 145 |
+
uniq = []
|
| 146 |
+
for t in texts:
|
| 147 |
+
key = " ".join(t.split())[:500]
|
| 148 |
+
if key not in seen:
|
| 149 |
+
seen.add(key)
|
| 150 |
+
uniq.append(t)
|
| 151 |
+
return uniq
|
| 152 |
+
|
| 153 |
+
# -----------------------------
|
| 154 |
+
# Section definitions
|
| 155 |
+
# -----------------------------
|
| 156 |
+
SECTION_ORDER = [
|
| 157 |
+
"Chief Complaint",
|
| 158 |
+
"HPI",
|
| 159 |
+
"PMH",
|
| 160 |
+
"Medications",
|
| 161 |
+
"Allergies",
|
| 162 |
+
"Assessment",
|
| 163 |
+
"Plan",
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
# -----------------------------
|
| 167 |
+
# Multi-stage extraction prompts (optimized for T5)
|
| 168 |
+
# -----------------------------
|
| 169 |
+
SECTION_PROMPTS = {
|
| 170 |
+
"Chief Complaint": """Task: Extract the main reason for patient visit.
|
| 171 |
+
|
| 172 |
+
Clinical Note:
|
| 173 |
+
{context}
|
| 174 |
+
|
| 175 |
+
Answer with only the chief complaint (1-2 sentences):""",
|
| 176 |
+
|
| 177 |
+
"HPI": """Task: Extract the history of present illness including symptom onset, progression, and context.
|
| 178 |
+
|
| 179 |
+
Clinical Note:
|
| 180 |
+
{context}
|
| 181 |
+
|
| 182 |
+
Answer with the history of present illness:""",
|
| 183 |
+
|
| 184 |
+
"PMH": """Task: Extract past medical history including chronic conditions, past surgeries, and social history.
|
| 185 |
+
|
| 186 |
+
Clinical Note:
|
| 187 |
+
{context}
|
| 188 |
+
|
| 189 |
+
Answer with past medical history:""",
|
| 190 |
+
|
| 191 |
+
"Medications": """Task: List all medications with dosages mentioned in the note.
|
| 192 |
+
|
| 193 |
+
Clinical Note:
|
| 194 |
+
{context}
|
| 195 |
+
|
| 196 |
+
Answer with medication list:""",
|
| 197 |
+
|
| 198 |
+
"Allergies": """Task: Extract drug allergies. If none mentioned, state "No known drug allergies".
|
| 199 |
+
|
| 200 |
+
Clinical Note:
|
| 201 |
+
{context}
|
| 202 |
+
|
| 203 |
+
Answer with allergies:""",
|
| 204 |
+
|
| 205 |
+
"Assessment": """Task: Extract diagnosis, test results, physical findings, and vital signs.
|
| 206 |
+
|
| 207 |
+
Clinical Note:
|
| 208 |
+
{context}
|
| 209 |
+
|
| 210 |
+
Answer with assessment and findings:""",
|
| 211 |
+
|
| 212 |
+
"Plan": """Task: Extract treatment plan, medications prescribed, follow-up appointments, and discharge instructions.
|
| 213 |
+
|
| 214 |
+
Clinical Note:
|
| 215 |
+
{context}
|
| 216 |
+
|
| 217 |
+
Answer with treatment plan:"""
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
# -----------------------------
|
| 221 |
+
# Enhanced extraction pipeline
|
| 222 |
+
# -----------------------------
|
| 223 |
+
def extract_section_multistage(tokenizer, model, context: str, section: str) -> str:
|
| 224 |
+
"""
|
| 225 |
+
Extract a single section using targeted prompting
|
| 226 |
+
"""
|
| 227 |
+
if section not in SECTION_PROMPTS:
|
| 228 |
+
return "None stated"
|
| 229 |
+
|
| 230 |
+
# Truncate context if too long
|
| 231 |
+
max_context_chars = 2000
|
| 232 |
+
if len(context) > max_context_chars:
|
| 233 |
+
context = context[:max_context_chars] + "..."
|
| 234 |
+
|
| 235 |
+
prompt = SECTION_PROMPTS[section].format(context=context)
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
result = t5_generate(tokenizer, model, prompt, max_input_tokens=512, max_output_tokens=200, section_name=section)
|
| 239 |
+
result = result.strip()
|
| 240 |
+
|
| 241 |
+
# Remove any section headers the model might have added
|
| 242 |
+
result = re.sub(r'^(Chief Complaint|HPI|PMH|Medications|Allergies|Assessment|Plan)\s*:\s*', '', result, flags=re.IGNORECASE)
|
| 243 |
+
|
| 244 |
+
# Check if extraction failed
|
| 245 |
+
if not result or len(result) < 5 or result.lower() in ["none", "none stated", "not mentioned", "n/a", "na"]:
|
| 246 |
+
return "None stated"
|
| 247 |
+
|
| 248 |
+
return result.strip()
|
| 249 |
+
except Exception as e:
|
| 250 |
+
print(f" ✗ Error extracting {section}: {str(e)}")
|
| 251 |
+
return "None stated"
|
| 252 |
+
|
| 253 |
+
def validate_extraction(sections: Dict[str, str]) -> bool:
|
| 254 |
+
"""
|
| 255 |
+
Validate that extraction was successful (not all 'None stated')
|
| 256 |
+
"""
|
| 257 |
+
non_empty = sum(1 for v in sections.values() if v and v != "None stated")
|
| 258 |
+
return non_empty >= 2 # At least 2 sections should have content
|
| 259 |
+
|
| 260 |
+
def summarize_docs_multistage(tokenizer, model, docs: List[Document]) -> str:
|
| 261 |
+
"""
|
| 262 |
+
Multi-stage extraction: extract each section independently
|
| 263 |
+
"""
|
| 264 |
+
print(f"\n📄 Processing documents...")
|
| 265 |
+
contents = dedupe_texts([d.page_content for d in docs if d and d.page_content])
|
| 266 |
+
|
| 267 |
+
if not contents:
|
| 268 |
+
print(" ⚠ No content to summarize!")
|
| 269 |
+
return format_output({sec: "None stated" for sec in SECTION_ORDER})
|
| 270 |
+
|
| 271 |
+
# Combine all retrieved content
|
| 272 |
+
full_context = "\n\n".join(contents)
|
| 273 |
+
print(f" ℹ Combined context length: {len(full_context)} characters")
|
| 274 |
+
|
| 275 |
+
# Extract each section independently
|
| 276 |
+
print(f"\n🔄 Extracting sections (this may take 1-3 minutes on CPU)...")
|
| 277 |
+
sections = {}
|
| 278 |
+
for i, section in enumerate(SECTION_ORDER, 1):
|
| 279 |
+
print(f" [{i}/{len(SECTION_ORDER)}] {section}:")
|
| 280 |
+
sections[section] = extract_section_multistage(tokenizer, model, full_context, section)
|
| 281 |
+
|
| 282 |
+
# Validate extraction
|
| 283 |
+
print(f"\n✓ Extraction complete")
|
| 284 |
+
if not validate_extraction(sections):
|
| 285 |
+
print("⚠ WARNING: Extraction appears incomplete. Most sections are empty.")
|
| 286 |
+
print(" Possible issues:")
|
| 287 |
+
print(" • Vector retrieval may not be finding relevant content")
|
| 288 |
+
print(" • Model may not understand the clinical text format")
|
| 289 |
+
print(" • Context may be too short or fragmented")
|
| 290 |
+
print(" • De-identification artifacts may be confusing the model")
|
| 291 |
+
|
| 292 |
+
return format_output(sections)
|
| 293 |
+
|
| 294 |
+
def format_output(sections: Dict[str, str]) -> str:
|
| 295 |
+
"""
|
| 296 |
+
Format sections into structured output
|
| 297 |
+
"""
|
| 298 |
+
output_lines = []
|
| 299 |
+
for section in SECTION_ORDER:
|
| 300 |
+
content = sections.get(section, "None stated")
|
| 301 |
+
output_lines.append(f"• {section}: {content}")
|
| 302 |
+
|
| 303 |
+
return "\n".join(output_lines)
|
| 304 |
+
|
| 305 |
+
# -----------------------------
|
| 306 |
+
# Summary Quality Validation
|
| 307 |
+
# -----------------------------
|
| 308 |
+
def validate_summary_quality(summary: str, original_text: str = "") -> dict:
|
| 309 |
+
"""
|
| 310 |
+
Validate summary quality and detect common issues
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
summary: The generated summary text
|
| 314 |
+
original_text: Optional original note text for comparison
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Dictionary with validation results
|
| 318 |
+
"""
|
| 319 |
+
issues = []
|
| 320 |
+
warnings = []
|
| 321 |
+
|
| 322 |
+
# Check for placeholder contamination (de-ID over-redaction)
|
| 323 |
+
placeholder_patterns = [
|
| 324 |
+
(r'\[LOCATION\]', 'LOCATION'),
|
| 325 |
+
(r'\[DATE\]', 'DATE'),
|
| 326 |
+
(r'\[NAME\]', 'NAME'),
|
| 327 |
+
(r'\[PHONE\]', 'PHONE')
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
total_placeholders = 0
|
| 331 |
+
for pattern, name in placeholder_patterns:
|
| 332 |
+
count = len(re.findall(pattern, summary))
|
| 333 |
+
total_placeholders += count
|
| 334 |
+
if count > 2:
|
| 335 |
+
warnings.append(f"Too many [{name}] placeholders ({count}) - de-identification may be over-aggressive")
|
| 336 |
+
|
| 337 |
+
if total_placeholders > 5:
|
| 338 |
+
issues.append(f"Critical: {total_placeholders} PHI placeholders in summary - clinical content lost")
|
| 339 |
+
|
| 340 |
+
# Check for "None stated" sections
|
| 341 |
+
none_count = summary.count("None stated")
|
| 342 |
+
if none_count >= 5:
|
| 343 |
+
issues.append(f"Critical: {none_count}/7 sections are empty - summarization failed")
|
| 344 |
+
elif none_count >= 3:
|
| 345 |
+
warnings.append(f"Warning: {none_count}/7 sections are empty - may need better retrieval")
|
| 346 |
+
|
| 347 |
+
# Check for minimum content length per section
|
| 348 |
+
total_length = len(summary)
|
| 349 |
+
# Subtract bullets and "None stated" overhead
|
| 350 |
+
content_length = total_length - (summary.count("•") * 2) - (none_count * 11)
|
| 351 |
+
filled_sections = 7 - none_count
|
| 352 |
+
|
| 353 |
+
if filled_sections > 0:
|
| 354 |
+
avg_section_length = content_length / filled_sections
|
| 355 |
+
if avg_section_length < 30:
|
| 356 |
+
warnings.append(f"Warning: Sections too short (avg {avg_section_length:.0f} chars) - may lack detail")
|
| 357 |
+
|
| 358 |
+
# Check for duplicate medications
|
| 359 |
+
if "Medications:" in summary:
|
| 360 |
+
meds_section = summary.split("Medications:")[1].split("•")[0] if "Medications:" in summary else ""
|
| 361 |
+
meds_lower = meds_section.lower()
|
| 362 |
+
common_meds = ['atorvastatin', 'metoprolol', 'lisinopril', 'aspirin', 'metformin']
|
| 363 |
+
for med in common_meds:
|
| 364 |
+
if meds_lower.count(med) > 1:
|
| 365 |
+
warnings.append(f"Warning: Duplicate medication detected: {med}")
|
| 366 |
+
|
| 367 |
+
# Calculate quality score (0-100)
|
| 368 |
+
score = 100
|
| 369 |
+
score -= len(issues) * 30 # Critical issues: -30 each
|
| 370 |
+
score -= len(warnings) * 10 # Warnings: -10 each
|
| 371 |
+
score = max(0, min(100, score))
|
| 372 |
+
|
| 373 |
+
# Determine overall status
|
| 374 |
+
if len(issues) > 0:
|
| 375 |
+
status = "FAILED"
|
| 376 |
+
elif len(warnings) > 2:
|
| 377 |
+
status = "POOR"
|
| 378 |
+
elif len(warnings) > 0:
|
| 379 |
+
status = "FAIR"
|
| 380 |
+
else:
|
| 381 |
+
status = "GOOD"
|
| 382 |
+
|
| 383 |
+
return {
|
| 384 |
+
"is_valid": len(issues) == 0,
|
| 385 |
+
"status": status,
|
| 386 |
+
"quality_score": score,
|
| 387 |
+
"issues": issues,
|
| 388 |
+
"warnings": warnings,
|
| 389 |
+
"metrics": {
|
| 390 |
+
"total_placeholders": total_placeholders,
|
| 391 |
+
"empty_sections": none_count,
|
| 392 |
+
"filled_sections": filled_sections,
|
| 393 |
+
"total_length": total_length
|
| 394 |
+
}
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
# -----------------------------
|
| 398 |
+
# Backward compatibility wrapper for Streamlit integration
|
| 399 |
+
# -----------------------------
|
| 400 |
+
def summarize_docs(tokenizer, model, docs: List[Document], method: str = "multistage") -> str:
|
| 401 |
+
"""
|
| 402 |
+
Wrapper function for backward compatibility with main.py (Streamlit UI)
|
| 403 |
+
"""
|
| 404 |
+
if method == "multistage":
|
| 405 |
+
return summarize_docs_multistage(tokenizer, model, docs)
|
| 406 |
+
else:
|
| 407 |
+
return summarize_docs_singleshot(tokenizer, model, docs)
|
| 408 |
+
|
| 409 |
+
# -----------------------------
|
| 410 |
+
# Single-shot extraction (simplified fallback)
|
| 411 |
+
# -----------------------------
|
| 412 |
+
def summarize_docs_singleshot(tokenizer, model, docs: List[Document]) -> str:
|
| 413 |
+
"""
|
| 414 |
+
Single-shot extraction method (faster but less comprehensive)
|
| 415 |
+
"""
|
| 416 |
+
print(f"\n📄 Processing documents...")
|
| 417 |
+
contents = dedupe_texts([d.page_content for d in docs if d and d.page_content])
|
| 418 |
+
|
| 419 |
+
if not contents:
|
| 420 |
+
print(" ⚠ No content to summarize!")
|
| 421 |
+
return format_output({sec: "None stated" for sec in SECTION_ORDER})
|
| 422 |
+
|
| 423 |
+
raw_context = "\n\n".join(contents)
|
| 424 |
+
print(f" ℹ Combined context length: {len(raw_context)} characters")
|
| 425 |
+
|
| 426 |
+
# Simplified prompt for single-shot
|
| 427 |
+
instruction = """Summarize this clinical note into 7 sections:
|
| 428 |
+
1. Chief Complaint (main reason for visit)
|
| 429 |
+
2. HPI (symptom history and progression)
|
| 430 |
+
3. PMH (past medical history)
|
| 431 |
+
4. Medications (current medications with doses)
|
| 432 |
+
5. Allergies (drug allergies)
|
| 433 |
+
6. Assessment (diagnosis and findings)
|
| 434 |
+
7. Plan (treatment plan and follow-up)
|
| 435 |
+
|
| 436 |
+
Clinical Note:
|
| 437 |
+
{context}
|
| 438 |
+
|
| 439 |
+
Structured Summary:"""
|
| 440 |
+
|
| 441 |
+
print(f"\n🔄 Generating structured summary...")
|
| 442 |
+
prompt = instruction.format(context=raw_context[:2000]) # Limit context
|
| 443 |
+
model_out = t5_generate(tokenizer, model, prompt, max_input_tokens=512, max_output_tokens=400)
|
| 444 |
+
|
| 445 |
+
# Parse output into sections
|
| 446 |
+
sections = parse_output_to_sections(model_out)
|
| 447 |
+
|
| 448 |
+
return format_output(sections)
|
| 449 |
+
|
| 450 |
+
def parse_output_to_sections(text: str) -> Dict[str, str]:
|
| 451 |
+
"""
|
| 452 |
+
Parse model output into section dictionary
|
| 453 |
+
"""
|
| 454 |
+
sections = {}
|
| 455 |
+
current_section = None
|
| 456 |
+
current_content = []
|
| 457 |
+
|
| 458 |
+
for line in text.split('\n'):
|
| 459 |
+
line = line.strip()
|
| 460 |
+
if not line:
|
| 461 |
+
continue
|
| 462 |
+
|
| 463 |
+
# Check if line starts with a section header
|
| 464 |
+
matched_section = None
|
| 465 |
+
for section in SECTION_ORDER:
|
| 466 |
+
# Match section headers with numbers or bullets
|
| 467 |
+
pattern = rf'^(\d+\.\s*)?{re.escape(section)}\s*:?'
|
| 468 |
+
if re.match(pattern, line, re.IGNORECASE):
|
| 469 |
+
matched_section = section
|
| 470 |
+
break
|
| 471 |
+
|
| 472 |
+
if matched_section:
|
| 473 |
+
# Save previous section
|
| 474 |
+
if current_section:
|
| 475 |
+
sections[current_section] = " ".join(current_content).strip()
|
| 476 |
+
|
| 477 |
+
# Start new section
|
| 478 |
+
current_section = matched_section
|
| 479 |
+
# Get content after the header
|
| 480 |
+
content = re.sub(rf'^(\d+\.\s*)?{re.escape(matched_section)}\s*:?\s*', '', line, flags=re.IGNORECASE).strip()
|
| 481 |
+
current_content = [content] if content else []
|
| 482 |
+
else:
|
| 483 |
+
# Continue current section
|
| 484 |
+
if current_section:
|
| 485 |
+
current_content.append(line)
|
| 486 |
+
|
| 487 |
+
# Save last section
|
| 488 |
+
if current_section:
|
| 489 |
+
sections[current_section] = " ".join(current_content).strip()
|
| 490 |
+
|
| 491 |
+
# Fill in missing sections
|
| 492 |
+
for section in SECTION_ORDER:
|
| 493 |
+
if section not in sections or not sections[section]:
|
| 494 |
+
sections[section] = "None stated"
|
| 495 |
+
|
| 496 |
+
return sections
|
| 497 |
+
|
| 498 |
+
# -----------------------------
|
| 499 |
+
# Backward compatibility wrapper for Streamlit integration
|
| 500 |
+
# -----------------------------
|
| 501 |
+
def summarize_docs(tokenizer, model, docs: List[Document], method: str = "multistage") -> str:
|
| 502 |
+
"""
|
| 503 |
+
Wrapper function for backward compatibility with main.py (Streamlit UI)
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
tokenizer: T5 tokenizer instance
|
| 507 |
+
model: T5 model instance
|
| 508 |
+
docs: List of retrieved documents
|
| 509 |
+
method: "multistage" (default) or "singleshot" extraction method
|
| 510 |
+
|
| 511 |
+
Returns:
|
| 512 |
+
Formatted summary string with sections
|
| 513 |
+
"""
|
| 514 |
+
if method == "multistage":
|
| 515 |
+
return summarize_docs_multistage(tokenizer, model, docs)
|
| 516 |
+
else:
|
| 517 |
+
return summarize_docs_singleshot(tokenizer, model, docs)
|
| 518 |
+
|
| 519 |
+
# -----------------------------
|
| 520 |
+
# Orchestration
|
| 521 |
+
# -----------------------------
|
| 522 |
+
def main():
|
| 523 |
+
parser = argparse.ArgumentParser(description="Day 10: Enhanced HIPAA-compliant RAG clinical summarizer")
|
| 524 |
+
parser.add_argument("--db_type", choices=["chroma", "faiss"], default="chroma")
|
| 525 |
+
parser.add_argument("--persist_dir", default="./data/vector_store")
|
| 526 |
+
parser.add_argument("--collection", default="notes")
|
| 527 |
+
parser.add_argument("--embed_model", default="sentence-transformers/all-MiniLM-L6-v2")
|
| 528 |
+
parser.add_argument("--model_name", default="google/flan-t5-small")
|
| 529 |
+
parser.add_argument("--query", required=True)
|
| 530 |
+
parser.add_argument("--top_k", type=int, default=5)
|
| 531 |
+
parser.add_argument("--out", default="./data/outputs/summaries/summary.txt")
|
| 532 |
+
parser.add_argument("--method", choices=["multistage", "singleshot"], default="multistage",
|
| 533 |
+
help="Extraction method: multistage (recommended) or singleshot (faster)")
|
| 534 |
+
args = parser.parse_args()
|
| 535 |
+
|
| 536 |
+
print("=" * 70)
|
| 537 |
+
print(" HIPAA-COMPLIANT RAG CLINICAL SUMMARIZER")
|
| 538 |
+
print("=" * 70)
|
| 539 |
+
|
| 540 |
+
out_dir = os.path.dirname(args.out) or "."
|
| 541 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 542 |
+
|
| 543 |
+
try:
|
| 544 |
+
# Step 1: Load embedder
|
| 545 |
+
print(f"\n[1/4] LOADING EMBEDDER")
|
| 546 |
+
print(f" Model: {args.embed_model}")
|
| 547 |
+
embed_f = load_embedder(args.embed_model)
|
| 548 |
+
|
| 549 |
+
# Step 2: Retrieve documents
|
| 550 |
+
print(f"\n[2/4] RETRIEVING DOCUMENTS")
|
| 551 |
+
print(f" Database: {args.db_type}")
|
| 552 |
+
print(f" Location: {args.persist_dir}")
|
| 553 |
+
print(f" Query: {args.query}")
|
| 554 |
+
print(f" Top-K: {args.top_k}")
|
| 555 |
+
docs = retrieve_docs(args.db_type, args.persist_dir, args.collection, args.query, args.top_k, embed_f)
|
| 556 |
+
|
| 557 |
+
if not docs:
|
| 558 |
+
print("\n⚠ ERROR: No documents retrieved from vector database!")
|
| 559 |
+
print(" Possible causes:")
|
| 560 |
+
print(" • Vector database is empty or not properly indexed")
|
| 561 |
+
print(" • Query doesn't match indexed content")
|
| 562 |
+
print(" • Database path is incorrect")
|
| 563 |
+
result = format_output({sec: "None stated" for sec in SECTION_ORDER})
|
| 564 |
+
with open(args.out, "w", encoding="utf-8") as f:
|
| 565 |
+
f.write(result)
|
| 566 |
+
print(f"\n✓ Empty summary written to {args.out}")
|
| 567 |
+
return
|
| 568 |
+
|
| 569 |
+
# Step 3: Load summarization model
|
| 570 |
+
print(f"\n[3/4] LOADING SUMMARIZATION MODEL")
|
| 571 |
+
print(f" Model: {args.model_name}")
|
| 572 |
+
tokenizer, model = make_t5(args.model_name)
|
| 573 |
+
|
| 574 |
+
# Step 4: Generate summary
|
| 575 |
+
print(f"\n[4/4] GENERATING SUMMARY")
|
| 576 |
+
print(f" Method: {args.method}")
|
| 577 |
+
|
| 578 |
+
if args.method == "multistage":
|
| 579 |
+
summary = summarize_docs_multistage(tokenizer, model, docs)
|
| 580 |
+
else:
|
| 581 |
+
summary = summarize_docs_singleshot(tokenizer, model, docs)
|
| 582 |
+
|
| 583 |
+
# Write summary to output file
|
| 584 |
+
with open(args.out, "w", encoding="utf-8") as f:
|
| 585 |
+
f.write(summary)
|
| 586 |
+
|
| 587 |
+
print(f"\n{'=' * 70}")
|
| 588 |
+
print(f"✓ SUCCESS: Summary written to {args.out}")
|
| 589 |
+
print(f"{'=' * 70}")
|
| 590 |
+
print("\nGenerated Summary:")
|
| 591 |
+
print("-" * 70)
|
| 592 |
+
print(summary)
|
| 593 |
+
print("-" * 70)
|
| 594 |
+
|
| 595 |
+
except Exception as e:
|
| 596 |
+
err = traceback.format_exc()
|
| 597 |
+
error_msg = f"ERROR during summarization:\n{err}"
|
| 598 |
+
|
| 599 |
+
# Write error to file
|
| 600 |
+
with open(args.out, "w", encoding="utf-8") as f:
|
| 601 |
+
f.write(error_msg)
|
| 602 |
+
|
| 603 |
+
print(f"\n{'=' * 70}")
|
| 604 |
+
print(f"✗ ERROR: An error occurred during processing")
|
| 605 |
+
print(f"{'=' * 70}")
|
| 606 |
+
print(f"\n{err}")
|
| 607 |
+
print(f"\nError details written to {args.out}")
|
| 608 |
+
|
| 609 |
+
if __name__ == "__main__":
|
| 610 |
+
main()
|