my-streamlit-app / main.py
samarth09healthPM's picture
Update main.py
c33edbc verified
import streamlit as st
import os
import warnings
import sys
import re
from pathlib import Path
import subprocess
import torch
# Fix torch.classes path error for Streamlit compatibility
torch.classes.__path__ = []
# HF Spaces environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
st.set_page_config(
page_title="Clinical AI Summarizer",
layout="wide",
initial_sidebar_state="expanded"
)
st.title("πŸ₯ HIPAA-Compliant RAG Clinical Summarizer")
st.markdown("**De-identification β†’ Clinical Summarization β†’ Quality Assessment**")
# Global configuration
secure_dir = "./secure_store"
model_name = "google/flan-t5-xl"
# Ensure directories exist
Path(secure_dir).mkdir(exist_ok=True)
# ==================== SIDEBAR ====================
with st.sidebar:
st.header("System Status")
try:
from deid_pipeline import DeidPipeline
st.success("βœ“ De-identification module")
HAS_DEID = True
except ImportError:
st.warning("⚠ De-ID fallback mode")
HAS_DEID = False
try:
import transformers
st.success("βœ“ Transformers loaded")
except ImportError:
st.error("βœ— Transformers missing - rebuild Space")
st.stop()
st.info("**Mode:** Direct Summarization")
st.caption(f"**Model:** {model_name}")
st.caption(f"**Secure Dir:** {secure_dir}")
# ==================== FALLBACK DE-ID ====================
def fallback_deid(text: str) -> str:
"""Regex-based PHI removal fallback"""
patterns = [
(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]'),
(r'\b[A-Z][a-z]{2,}\b(?! (mg|mmHg|bpm|CT|MRI|TIA|BP|HR|RR|NIH|EF|BID|QID|PCP|PMH|HPI|ROS))', '[NAME]'),
(r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', '[DATE]'),
(r'\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b', '[PHONE]'),
(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]'),
(r'\b\d{5}(-\d{4})?\b', '[ZIP]'),
(r'\b\d{9}\b', '[SSN]'),
]
result = text
for pat, rep in patterns:
result = re.sub(pat, rep, result, flags=re.IGNORECASE)
return result
# ==================== MODEL LOADING ====================
@st.cache_resource
def load_model(model_name):
"""Load T5 model with proper caching"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
cache_dir="/tmp/hf_cache",
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
st.sidebar.success("βœ“ Model ready")
return tokenizer, model, device
tokenizer, model, device = load_model(model_name)
# ==================== SECTION EXTRACTION FUNCTIONS ====================
def extract_vitals(text: str) -> str:
"""Extract vital signs using pattern matching"""
vitals_found = []
patterns = {
'BP': r'(?:BP|Blood Pressure)[:\s]+(\d{2,3}/\d{2,3})',
'HR': r'(?:HR|Heart Rate|Pulse)[:\s]+(\d{2,3})(?:\s*bpm)?',
'Temp': r'(?:Temp|Temperature)[:\s]+(\d{2,3}\.?\d*)(?:\s*[FCΒ°])?',
'RR': r'(?:RR|Respiratory Rate|Resp)[:\s]+(\d{1,2})',
'O2': r'(?:O2|Oxygen|SpO2)[:\s]+(\d{2,3})%?',
'Weight': r'(?:Weight|Wt)[:\s]+(\d{2,3}\.?\d*)\s*(?:kg|lbs)?',
}
for vital_name, pattern in patterns.items():
matches = re.findall(pattern, text, re.IGNORECASE)
if matches:
vitals_found.append(f"{vital_name}: {matches[0]}")
return ', '.join(vitals_found) if vitals_found else ""
def extract_all_sections(text: str) -> dict:
"""Enhanced section extraction with strict boundary detection"""
sections = {
"Chief Complaint": "",
"HPI": "",
"Assessment": "",
"Vitals": "",
"Medications": "",
"Plan": "",
"Discharge Summary": ""
}
lines = text.split('\n')
current_section = None
buffer = []
# More specific keyword patterns with priorities
section_patterns = [
# Format: (section_name, [keywords], priority)
("Chief Complaint", ['chief complaint:', 'cc:', 'presenting complaint:', 'chief compliant:'], 1),
("HPI", ['history of present illness:', 'hpi:', 'present illness:', 'clinical history:'], 1),
("Assessment", ['assessment:', 'impression:', 'diagnosis:', 'diagnoses:', 'clinical impression:'], 1),
("Plan", ['plan:', 'treatment plan:', 'management plan:', 'recommendations:', 'disposition:'], 1),
("Discharge Summary", ['discharge summary:', 'discharge:', 'discharge plan:', 'discharge instructions:'], 1),
("Medications", ['medications:', 'meds:', 'current medications:', 'home medications:', 'medication list:'], 1),
("Vitals", ['vital signs:', 'vitals:', 'physical exam:', 'examination:'], 1),
]
# First pass: identify section headers and their line numbers
section_markers = []
for i, line in enumerate(lines):
line_lower = line.strip().lower()
if not line_lower:
continue
# Check if line is a section header (must be at start or after colon)
for section_name, keywords, priority in section_patterns:
for kw in keywords:
if line_lower.startswith(kw) or (': ' in line_lower and kw in line_lower.split(': ')[0]):
section_markers.append((i, section_name, kw))
break
# Second pass: extract content between section markers
for idx, (line_num, section_name, keyword) in enumerate(section_markers):
# Determine end of this section (start of next section or end of document)
end_line = section_markers[idx + 1][0] if idx + 1 < len(section_markers) else len(lines)
# Extract content
content_lines = []
start_line = lines[line_num].strip()
# Get content from header line if present
if ':' in start_line:
header_content = start_line.split(':', 1)[1].strip()
if header_content and len(header_content) > 2:
content_lines.append(header_content)
# Get content from subsequent lines until next section
for i in range(line_num + 1, end_line):
line_text = lines[i].strip()
if line_text:
content_lines.append(line_text)
if content_lines:
sections[section_name] = ' '.join(content_lines).strip()
# Special handling: Extract vitals using regex if not found
if not sections["Vitals"] or len(sections["Vitals"]) < 10:
vitals = extract_vitals(text)
if vitals:
sections["Vitals"] = vitals
# Fallback: search for content without clear headers using context clues
full_text_lower = text.lower()
# Chief Complaint fallback (usually early in note, mentions symptoms)
if not sections["Chief Complaint"] or sections["Chief Complaint"] == "Not documented":
# Look for symptom keywords in first 500 characters
symptom_keywords = ['pain', 'fever', 'cough', 'weakness', 'dizzy', 'nausea', 'shortness of breath', 'headache']
first_part = text[:500]
for line in first_part.split('\n'):
if any(symptom in line.lower() for symptom in symptom_keywords):
sections["Chief Complaint"] = line.strip()
break
# HPI fallback (contains temporal words: onset, duration, started)
if not sections["HPI"] or sections["HPI"] == "Not documented":
hpi_keywords = ['year-old', 'year old', 'presented', 'reports', 'denies', 'states', 'onset', 'duration', 'started', 'began']
for para in text.split('\n\n'):
if any(kw in para.lower() for kw in hpi_keywords) and len(para) > 50:
sections["HPI"] = para.strip()
break
# Assessment fallback (mentions diagnoses)
if not sections["Assessment"] or sections["Assessment"] == "Not documented":
assessment_terms = ['hypertension', 'diabetes', 'pneumonia', 'fracture', 'infection', 'disease', 'syndrome', 'disorder']
for para in text.split('\n\n'):
if any(term in para.lower() for term in assessment_terms) and 20 < len(para) < 300:
sections["Assessment"] = para.strip()
break
# Plan fallback (mentions follow-up, continue, prescribe, instructions)
if not sections["Plan"] or sections["Plan"] == "Not documented":
plan_keywords = ['continue', 'follow-up', 'follow up', 'prescribe', 'instruct', 'monitor', 'schedule', 'arrange', 'refer']
for para in text.split('\n\n'):
if any(kw in para.lower() for kw in plan_keywords) and len(para) > 40:
sections["Plan"] = para.strip()
break
return sections
def parse_ai_summary(ai_text: str) -> dict:
"""Parse structured output from AI if it generated section-based content"""
sections = {}
current_section = None
buffer = []
lines = ai_text.split('\n')
for line in lines:
line_clean = line.strip()
# Check if line starts with a section name
section_starters = ['Chief Complaint:', 'HPI:', 'Assessment:', 'Vitals:',
'Medications:', 'Plan:', 'Discharge Summary:']
matched = None
for starter in section_starters:
if line_clean.startswith(starter):
matched = starter
break
if matched:
# Save previous section
if current_section and buffer:
sections[current_section] = ' '.join(buffer).strip()
# Start new section
current_section = matched.replace(':', '').strip()
content = line_clean[len(matched):].strip()
buffer = [content] if content else []
elif current_section and line_clean:
buffer.append(line_clean)
# Save final section
if current_section and buffer:
sections[current_section] = ' '.join(buffer).strip()
return sections
# ==================== MAIN SUMMARIZATION FUNCTION ====================
def summarize_clinical_note(text: str, tokenizer, model, device) -> str:
"""Generate structured clinical summary using T5 with proper section extraction"""
# Truncate if too long (T5 has token limits)
max_input_length = 1024
if len(text) > max_input_length * 4:
text = text[:max_input_length * 4]
# Create detailed prompt for T5
prompt = f"""Summarize this clinical documentation into a structured format with these exact sections:
Chief Complaint: State the patient's main presenting concern or reason for visit
HPI: Summarize the history of present illness including onset, duration, and progression
Assessment: List clinical findings, diagnoses, and impressions
Vitals: Extract all vital signs including BP, HR, Temperature, RR, O2 saturation
Medications: List all current medications with dosages and frequencies
Plan: Describe the treatment plan, recommendations, and next steps
Discharge Summary: Provide discharge status, instructions, and follow-up plans
Clinical Note:
{text}
Structured Summary:"""
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate with optimal parameters to prevent repetition
with torch.no_grad():
outputs = model.generate(
inputs['input_ids'],
max_new_tokens=650,
min_length=200,
num_beams=4,
temperature=0.8,
do_sample=False,
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=2.5,
length_penalty=1.0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
ai_summary = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Extract sections from original text using keyword matching
sections_content = extract_all_sections(text)
# Parse AI output for any additional structured content
ai_sections = parse_ai_summary(ai_summary)
# Merge: prioritize extracted content, fallback to AI, then "Not documented"
final_sections = {}
section_names = ["Chief Complaint", "HPI", "Assessment", "Vitals", "Medications", "Plan", "Discharge Summary"]
for section in section_names:
# Try extracted content first
content = sections_content.get(section, "").strip()
# If no content or too short, try AI summary
if not content or len(content) < 15:
content = ai_sections.get(section, "").strip()
# If still no content and AI generated something generic, use it
if not content or len(content) < 10:
# Check if AI summary contains relevant info in unstructured format
if section.lower() in ai_summary.lower():
# Extract sentences mentioning this section
sentences = ai_summary.split('.')
relevant = [s.strip() for s in sentences if section.lower() in s.lower()]
if relevant:
content = '. '.join(relevant) + '.'
# Final fallback
if not content or len(content) < 10:
content = "Not documented"
# Clean up content
content = content.replace(' ', ' ').strip()
final_sections[section] = content
# Format output with proper markdown
formatted_output = ""
for section in section_names:
formatted_output += f"**{section}:**\n{final_sections[section]}\n\n"
return formatted_output
# ==================== QUALITY VALIDATION ====================
def validate_summary(summary: str, original_text: str) -> dict:
"""Assess summary quality with detailed metrics"""
score = 100
warnings = []
required_sections = ["Chief Complaint", "HPI", "Assessment", "Vitals", "Medications", "Plan", "Discharge Summary"]
# Count present sections
present_count = 0
for sec in required_sections:
section_content = ""
if sec + ":" in summary:
# Extract content for this section
lines = summary.split('\n')
in_section = False
for line in lines:
if line.startswith(f"**{sec}:**"):
in_section = True
continue
if in_section:
if line.startswith("**"):
break
section_content += line
if "not documented" not in section_content.lower() and len(section_content.strip()) > 10:
present_count += 1
missing_count = len(required_sections) - present_count
if missing_count > 0:
score -= missing_count * 12
warnings.append(f"{missing_count} of 7 sections incomplete")
# Check for medical content indicators
medical_patterns = [
r'\d+\s*mg',
r'\d+/\d+\s*mmHg',
r'\d+\s*bpm',
r'\d+\.?\d*\s*[FCΒ°]',
r'\d+%',
]
medical_content_found = any(re.search(pattern, summary, re.I) for pattern in medical_patterns)
if medical_content_found:
score += 10
else:
warnings.append("Limited quantitative clinical data")
# Check for repetition issues
words = summary.lower().split()
if len(words) > 20:
unique_ratio = len(set(words)) / len(words)
if unique_ratio < 0.35:
score -= 30
warnings.append("High repetition detected - summary quality poor")
# Check overall length
if len(summary) < 150:
score -= 15
warnings.append("Summary too brief")
elif len(summary) > 2000:
score -= 5
warnings.append("Summary may be overly verbose")
# Check for key clinical terms
clinical_terms = ['patient', 'diagnosis', 'treatment', 'plan', 'medication', 'assessment']
terms_found = sum(1 for term in clinical_terms if term in summary.lower())
if terms_found < 3:
score -= 10
warnings.append("Limited clinical terminology")
score = max(0, min(100, score))
if score >= 90:
status = "EXCELLENT"
elif score >= 75:
status = "GOOD"
elif score >= 60:
status = "FAIR"
else:
status = "POOR"
return {
"quality_score": score,
"status": status,
"warnings": warnings,
"sections_present": present_count,
"sections_total": len(required_sections)
}
# ==================== SESSION STATE ====================
if 'deid_text' not in st.session_state:
st.session_state.deid_text = ""
if 'original_text' not in st.session_state:
st.session_state.original_text = ""
if 'summary' not in st.session_state:
st.session_state.summary = None
if 'validation' not in st.session_state:
st.session_state.validation = None
# ==================== UI TABS ====================
tab1, tab2 = st.tabs(["πŸ“ De-Identify Note", "✨ Generate Summary"])
with tab1:
st.header("Step 1: De-identify Clinical Note")
st.markdown("Upload or paste a clinical note to remove PHI (Protected Health Information)")
uploaded = st.file_uploader("Upload clinical note (.txt)", type=["txt"])
input_text = st.text_area(
"Or paste clinical note here:",
height=300,
placeholder="Paste clinical documentation here...\n\nExample:\nChief Complaint: Chest pain\nHPI: 72-year-old male presents with...\nVitals: BP 140/90, HR 88..."
)
note_text = ""
if uploaded:
note_text = uploaded.read().decode("utf-8", errors="ignore")
elif input_text:
note_text = input_text
if st.button("πŸ”’ De-Identify & Process", type="primary"):
if note_text:
with st.spinner("De-identifying PHI..."):
st.session_state.original_text = note_text
if HAS_DEID:
try:
pipeline = DeidPipeline(secure_dir)
result = pipeline.run_on_text(note_text, "session_note")
deid_text = result["masked_text"]
if "encrypted_span_map" in result:
with open(f"{secure_dir}/session_note.spanmap.enc", "wb") as f:
f.write(result["encrypted_span_map"])
st.success("βœ… De-identified with encrypted audit trail saved")
except Exception as e:
st.warning(f"⚠ Using regex-based de-identification: {str(e)[:100]}")
deid_text = fallback_deid(note_text)
else:
deid_text = fallback_deid(note_text)
st.info("β„Ή Using regex-based de-identification")
st.session_state.deid_text = deid_text
st.success(f"βœ… Processed **{len(deid_text)}** characters (PHI redacted)")
else:
st.warning("⚠ Please enter or upload a clinical note")
if st.session_state.deid_text:
with st.expander("πŸ“„ Preview De-identified Text", expanded=False):
st.text_area("", st.session_state.deid_text, height=250, disabled=True, key="preview_deid")
with tab2:
st.header("Step 2: Generate Clinical Summary")
st.markdown("AI-powered structured summarization with quality assessment")
if not st.session_state.deid_text:
st.warning("⚠ Please de-identify a note first in **Tab 1**")
else:
st.info(f"βœ… Ready to summarize: **{len(st.session_state.deid_text)}** characters")
if st.button("πŸš€ Generate Summary", type="primary"):
with st.spinner("⏳ Generating structured summary (30-60 seconds)..."):
try:
summary = summarize_clinical_note(
st.session_state.deid_text,
tokenizer,
model,
device
)
st.session_state.summary = summary
st.session_state.validation = validate_summary(
summary,
st.session_state.deid_text
)
st.success("βœ… Summary generated successfully!")
except Exception as e:
st.error(f"❌ Summarization failed: {str(e)}")
st.exception(e)
st.session_state.summary = None
if st.session_state.summary:
st.markdown("---")
col1, col2 = st.columns([2.5, 1])
with col1:
st.subheader("πŸ“‹ Structured Clinical Summary")
st.markdown(st.session_state.summary)
with col2:
st.subheader("πŸ“Š Quality Assessment")
val = st.session_state.validation
color_map = {
"EXCELLENT": "🟒",
"GOOD": "πŸ”΅",
"FAIR": "🟑",
"POOR": "πŸ”΄"
}
status_color = color_map.get(val.get("status", ""), "βšͺ")
st.markdown(f"### {status_color} {val.get('status', 'N/A')}")
st.metric("Quality Score", f"{val.get('quality_score', 0)}/100")
st.metric(
"Sections Complete",
f"{val.get('sections_present', 0)}/{val.get('sections_total', 7)}"
)
if val.get("warnings"):
with st.expander("⚠ Quality Warnings", expanded=True):
for w in val["warnings"]:
st.warning(f"β€’ {w}")
st.markdown("---")
# Download and reset buttons
col_a, col_b, col_c = st.columns([2, 2, 1])
with col_a:
st.download_button(
"πŸ’Ύ Download Summary",
st.session_state.summary,
"clinical_summary.txt",
mime="text/plain",
type="secondary"
)
with col_b:
st.download_button(
"πŸ’Ύ Download De-identified Note",
st.session_state.deid_text,
"deidentified_note.txt",
mime="text/plain",
type="secondary"
)
with col_c:
if st.button("πŸ”„ Reset"):
st.session_state.deid_text = ""
st.session_state.original_text = ""
st.session_state.summary = None
st.session_state.validation = None
st.rerun()
# ==================== FOOTER ====================
st.markdown("---")
st.caption("πŸ₯ **HIPAA-Compliant Clinical Summarizer** | Portfolio Demo | Powered by Flan-T5 & Presidio")
st.caption("⚠ For demonstration purposes only - not for clinical use")