import streamlit as st import os import warnings import sys import re from pathlib import Path import subprocess import torch # Fix torch.classes path error for Streamlit compatibility torch.classes.__path__ = [] # HF Spaces environment variables os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache" warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning) st.set_page_config( page_title="Clinical AI Summarizer", layout="wide", initial_sidebar_state="expanded" ) st.title("šŸ„ HIPAA-Compliant RAG Clinical Summarizer") st.markdown("**De-identification → Clinical Summarization → Quality Assessment**") # Global configuration secure_dir = "./secure_store" model_name = "google/flan-t5-xl" # Ensure directories exist Path(secure_dir).mkdir(exist_ok=True) # ==================== SIDEBAR ==================== with st.sidebar: st.header("System Status") try: from deid_pipeline import DeidPipeline st.success("āœ“ De-identification module") HAS_DEID = True except ImportError: st.warning("⚠ De-ID fallback mode") HAS_DEID = False try: import transformers st.success("āœ“ Transformers loaded") except ImportError: st.error("āœ— Transformers missing - rebuild Space") st.stop() st.info("**Mode:** Direct Summarization") st.caption(f"**Model:** {model_name}") st.caption(f"**Secure Dir:** {secure_dir}") # ==================== FALLBACK DE-ID ==================== def fallback_deid(text: str) -> str: """Regex-based PHI removal fallback""" patterns = [ (r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]'), (r'\b[A-Z][a-z]{2,}\b(?! (mg|mmHg|bpm|CT|MRI|TIA|BP|HR|RR|NIH|EF|BID|QID|PCP|PMH|HPI|ROS))', '[NAME]'), (r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', '[DATE]'), (r'\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b', '[PHONE]'), (r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]'), (r'\b\d{5}(-\d{4})?\b', '[ZIP]'), (r'\b\d{9}\b', '[SSN]'), ] result = text for pat, rep in patterns: result = re.sub(pat, rep, result, flags=re.IGNORECASE) return result # ==================== MODEL LOADING ==================== @st.cache_resource def load_model(model_name): """Load T5 model with proper caching""" from transformers import AutoTokenizer, AutoModelForSeq2SeqLM tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForSeq2SeqLM.from_pretrained( model_name, cache_dir="/tmp/hf_cache", torch_dtype=torch.float32, low_cpu_mem_usage=True ) device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) model.eval() st.sidebar.success("āœ“ Model ready") return tokenizer, model, device tokenizer, model, device = load_model(model_name) # ==================== SECTION EXTRACTION FUNCTIONS ==================== def extract_vitals(text: str) -> str: """Extract vital signs using pattern matching""" vitals_found = [] patterns = { 'BP': r'(?:BP|Blood Pressure)[:\s]+(\d{2,3}/\d{2,3})', 'HR': r'(?:HR|Heart Rate|Pulse)[:\s]+(\d{2,3})(?:\s*bpm)?', 'Temp': r'(?:Temp|Temperature)[:\s]+(\d{2,3}\.?\d*)(?:\s*[FC°])?', 'RR': r'(?:RR|Respiratory Rate|Resp)[:\s]+(\d{1,2})', 'O2': r'(?:O2|Oxygen|SpO2)[:\s]+(\d{2,3})%?', 'Weight': r'(?:Weight|Wt)[:\s]+(\d{2,3}\.?\d*)\s*(?:kg|lbs)?', } for vital_name, pattern in patterns.items(): matches = re.findall(pattern, text, re.IGNORECASE) if matches: vitals_found.append(f"{vital_name}: {matches[0]}") return ', '.join(vitals_found) if vitals_found else "" def extract_all_sections(text: str) -> dict: """Enhanced section extraction with strict boundary detection""" sections = { "Chief Complaint": "", "HPI": "", "Assessment": "", "Vitals": "", "Medications": "", "Plan": "", "Discharge Summary": "" } lines = text.split('\n') current_section = None buffer = [] # More specific keyword patterns with priorities section_patterns = [ # Format: (section_name, [keywords], priority) ("Chief Complaint", ['chief complaint:', 'cc:', 'presenting complaint:', 'chief compliant:'], 1), ("HPI", ['history of present illness:', 'hpi:', 'present illness:', 'clinical history:'], 1), ("Assessment", ['assessment:', 'impression:', 'diagnosis:', 'diagnoses:', 'clinical impression:'], 1), ("Plan", ['plan:', 'treatment plan:', 'management plan:', 'recommendations:', 'disposition:'], 1), ("Discharge Summary", ['discharge summary:', 'discharge:', 'discharge plan:', 'discharge instructions:'], 1), ("Medications", ['medications:', 'meds:', 'current medications:', 'home medications:', 'medication list:'], 1), ("Vitals", ['vital signs:', 'vitals:', 'physical exam:', 'examination:'], 1), ] # First pass: identify section headers and their line numbers section_markers = [] for i, line in enumerate(lines): line_lower = line.strip().lower() if not line_lower: continue # Check if line is a section header (must be at start or after colon) for section_name, keywords, priority in section_patterns: for kw in keywords: if line_lower.startswith(kw) or (': ' in line_lower and kw in line_lower.split(': ')[0]): section_markers.append((i, section_name, kw)) break # Second pass: extract content between section markers for idx, (line_num, section_name, keyword) in enumerate(section_markers): # Determine end of this section (start of next section or end of document) end_line = section_markers[idx + 1][0] if idx + 1 < len(section_markers) else len(lines) # Extract content content_lines = [] start_line = lines[line_num].strip() # Get content from header line if present if ':' in start_line: header_content = start_line.split(':', 1)[1].strip() if header_content and len(header_content) > 2: content_lines.append(header_content) # Get content from subsequent lines until next section for i in range(line_num + 1, end_line): line_text = lines[i].strip() if line_text: content_lines.append(line_text) if content_lines: sections[section_name] = ' '.join(content_lines).strip() # Special handling: Extract vitals using regex if not found if not sections["Vitals"] or len(sections["Vitals"]) < 10: vitals = extract_vitals(text) if vitals: sections["Vitals"] = vitals # Fallback: search for content without clear headers using context clues full_text_lower = text.lower() # Chief Complaint fallback (usually early in note, mentions symptoms) if not sections["Chief Complaint"] or sections["Chief Complaint"] == "Not documented": # Look for symptom keywords in first 500 characters symptom_keywords = ['pain', 'fever', 'cough', 'weakness', 'dizzy', 'nausea', 'shortness of breath', 'headache'] first_part = text[:500] for line in first_part.split('\n'): if any(symptom in line.lower() for symptom in symptom_keywords): sections["Chief Complaint"] = line.strip() break # HPI fallback (contains temporal words: onset, duration, started) if not sections["HPI"] or sections["HPI"] == "Not documented": hpi_keywords = ['year-old', 'year old', 'presented', 'reports', 'denies', 'states', 'onset', 'duration', 'started', 'began'] for para in text.split('\n\n'): if any(kw in para.lower() for kw in hpi_keywords) and len(para) > 50: sections["HPI"] = para.strip() break # Assessment fallback (mentions diagnoses) if not sections["Assessment"] or sections["Assessment"] == "Not documented": assessment_terms = ['hypertension', 'diabetes', 'pneumonia', 'fracture', 'infection', 'disease', 'syndrome', 'disorder'] for para in text.split('\n\n'): if any(term in para.lower() for term in assessment_terms) and 20 < len(para) < 300: sections["Assessment"] = para.strip() break # Plan fallback (mentions follow-up, continue, prescribe, instructions) if not sections["Plan"] or sections["Plan"] == "Not documented": plan_keywords = ['continue', 'follow-up', 'follow up', 'prescribe', 'instruct', 'monitor', 'schedule', 'arrange', 'refer'] for para in text.split('\n\n'): if any(kw in para.lower() for kw in plan_keywords) and len(para) > 40: sections["Plan"] = para.strip() break return sections def parse_ai_summary(ai_text: str) -> dict: """Parse structured output from AI if it generated section-based content""" sections = {} current_section = None buffer = [] lines = ai_text.split('\n') for line in lines: line_clean = line.strip() # Check if line starts with a section name section_starters = ['Chief Complaint:', 'HPI:', 'Assessment:', 'Vitals:', 'Medications:', 'Plan:', 'Discharge Summary:'] matched = None for starter in section_starters: if line_clean.startswith(starter): matched = starter break if matched: # Save previous section if current_section and buffer: sections[current_section] = ' '.join(buffer).strip() # Start new section current_section = matched.replace(':', '').strip() content = line_clean[len(matched):].strip() buffer = [content] if content else [] elif current_section and line_clean: buffer.append(line_clean) # Save final section if current_section and buffer: sections[current_section] = ' '.join(buffer).strip() return sections # ==================== MAIN SUMMARIZATION FUNCTION ==================== def summarize_clinical_note(text: str, tokenizer, model, device) -> str: """Generate structured clinical summary using T5 with proper section extraction""" # Truncate if too long (T5 has token limits) max_input_length = 1024 if len(text) > max_input_length * 4: text = text[:max_input_length * 4] # Create detailed prompt for T5 prompt = f"""Summarize this clinical documentation into a structured format with these exact sections: Chief Complaint: State the patient's main presenting concern or reason for visit HPI: Summarize the history of present illness including onset, duration, and progression Assessment: List clinical findings, diagnoses, and impressions Vitals: Extract all vital signs including BP, HR, Temperature, RR, O2 saturation Medications: List all current medications with dosages and frequencies Plan: Describe the treatment plan, recommendations, and next steps Discharge Summary: Provide discharge status, instructions, and follow-up plans Clinical Note: {text} Structured Summary:""" inputs = tokenizer( prompt, return_tensors="pt", max_length=1024, truncation=True, padding=True ) inputs = {k: v.to(device) for k, v in inputs.items()} # Generate with optimal parameters to prevent repetition with torch.no_grad(): outputs = model.generate( inputs['input_ids'], max_new_tokens=650, min_length=200, num_beams=4, temperature=0.8, do_sample=False, early_stopping=True, no_repeat_ngram_size=3, repetition_penalty=2.5, length_penalty=1.0, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) ai_summary = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() # Extract sections from original text using keyword matching sections_content = extract_all_sections(text) # Parse AI output for any additional structured content ai_sections = parse_ai_summary(ai_summary) # Merge: prioritize extracted content, fallback to AI, then "Not documented" final_sections = {} section_names = ["Chief Complaint", "HPI", "Assessment", "Vitals", "Medications", "Plan", "Discharge Summary"] for section in section_names: # Try extracted content first content = sections_content.get(section, "").strip() # If no content or too short, try AI summary if not content or len(content) < 15: content = ai_sections.get(section, "").strip() # If still no content and AI generated something generic, use it if not content or len(content) < 10: # Check if AI summary contains relevant info in unstructured format if section.lower() in ai_summary.lower(): # Extract sentences mentioning this section sentences = ai_summary.split('.') relevant = [s.strip() for s in sentences if section.lower() in s.lower()] if relevant: content = '. '.join(relevant) + '.' # Final fallback if not content or len(content) < 10: content = "Not documented" # Clean up content content = content.replace(' ', ' ').strip() final_sections[section] = content # Format output with proper markdown formatted_output = "" for section in section_names: formatted_output += f"**{section}:**\n{final_sections[section]}\n\n" return formatted_output # ==================== QUALITY VALIDATION ==================== def validate_summary(summary: str, original_text: str) -> dict: """Assess summary quality with detailed metrics""" score = 100 warnings = [] required_sections = ["Chief Complaint", "HPI", "Assessment", "Vitals", "Medications", "Plan", "Discharge Summary"] # Count present sections present_count = 0 for sec in required_sections: section_content = "" if sec + ":" in summary: # Extract content for this section lines = summary.split('\n') in_section = False for line in lines: if line.startswith(f"**{sec}:**"): in_section = True continue if in_section: if line.startswith("**"): break section_content += line if "not documented" not in section_content.lower() and len(section_content.strip()) > 10: present_count += 1 missing_count = len(required_sections) - present_count if missing_count > 0: score -= missing_count * 12 warnings.append(f"{missing_count} of 7 sections incomplete") # Check for medical content indicators medical_patterns = [ r'\d+\s*mg', r'\d+/\d+\s*mmHg', r'\d+\s*bpm', r'\d+\.?\d*\s*[FC°]', r'\d+%', ] medical_content_found = any(re.search(pattern, summary, re.I) for pattern in medical_patterns) if medical_content_found: score += 10 else: warnings.append("Limited quantitative clinical data") # Check for repetition issues words = summary.lower().split() if len(words) > 20: unique_ratio = len(set(words)) / len(words) if unique_ratio < 0.35: score -= 30 warnings.append("High repetition detected - summary quality poor") # Check overall length if len(summary) < 150: score -= 15 warnings.append("Summary too brief") elif len(summary) > 2000: score -= 5 warnings.append("Summary may be overly verbose") # Check for key clinical terms clinical_terms = ['patient', 'diagnosis', 'treatment', 'plan', 'medication', 'assessment'] terms_found = sum(1 for term in clinical_terms if term in summary.lower()) if terms_found < 3: score -= 10 warnings.append("Limited clinical terminology") score = max(0, min(100, score)) if score >= 90: status = "EXCELLENT" elif score >= 75: status = "GOOD" elif score >= 60: status = "FAIR" else: status = "POOR" return { "quality_score": score, "status": status, "warnings": warnings, "sections_present": present_count, "sections_total": len(required_sections) } # ==================== SESSION STATE ==================== if 'deid_text' not in st.session_state: st.session_state.deid_text = "" if 'original_text' not in st.session_state: st.session_state.original_text = "" if 'summary' not in st.session_state: st.session_state.summary = None if 'validation' not in st.session_state: st.session_state.validation = None # ==================== UI TABS ==================== tab1, tab2 = st.tabs(["šŸ“ De-Identify Note", "✨ Generate Summary"]) with tab1: st.header("Step 1: De-identify Clinical Note") st.markdown("Upload or paste a clinical note to remove PHI (Protected Health Information)") uploaded = st.file_uploader("Upload clinical note (.txt)", type=["txt"]) input_text = st.text_area( "Or paste clinical note here:", height=300, placeholder="Paste clinical documentation here...\n\nExample:\nChief Complaint: Chest pain\nHPI: 72-year-old male presents with...\nVitals: BP 140/90, HR 88..." ) note_text = "" if uploaded: note_text = uploaded.read().decode("utf-8", errors="ignore") elif input_text: note_text = input_text if st.button("šŸ”’ De-Identify & Process", type="primary"): if note_text: with st.spinner("De-identifying PHI..."): st.session_state.original_text = note_text if HAS_DEID: try: pipeline = DeidPipeline(secure_dir) result = pipeline.run_on_text(note_text, "session_note") deid_text = result["masked_text"] if "encrypted_span_map" in result: with open(f"{secure_dir}/session_note.spanmap.enc", "wb") as f: f.write(result["encrypted_span_map"]) st.success("āœ… De-identified with encrypted audit trail saved") except Exception as e: st.warning(f"⚠ Using regex-based de-identification: {str(e)[:100]}") deid_text = fallback_deid(note_text) else: deid_text = fallback_deid(note_text) st.info("ℹ Using regex-based de-identification") st.session_state.deid_text = deid_text st.success(f"āœ… Processed **{len(deid_text)}** characters (PHI redacted)") else: st.warning("⚠ Please enter or upload a clinical note") if st.session_state.deid_text: with st.expander("šŸ“„ Preview De-identified Text", expanded=False): st.text_area("", st.session_state.deid_text, height=250, disabled=True, key="preview_deid") with tab2: st.header("Step 2: Generate Clinical Summary") st.markdown("AI-powered structured summarization with quality assessment") if not st.session_state.deid_text: st.warning("⚠ Please de-identify a note first in **Tab 1**") else: st.info(f"āœ… Ready to summarize: **{len(st.session_state.deid_text)}** characters") if st.button("šŸš€ Generate Summary", type="primary"): with st.spinner("ā³ Generating structured summary (30-60 seconds)..."): try: summary = summarize_clinical_note( st.session_state.deid_text, tokenizer, model, device ) st.session_state.summary = summary st.session_state.validation = validate_summary( summary, st.session_state.deid_text ) st.success("āœ… Summary generated successfully!") except Exception as e: st.error(f"āŒ Summarization failed: {str(e)}") st.exception(e) st.session_state.summary = None if st.session_state.summary: st.markdown("---") col1, col2 = st.columns([2.5, 1]) with col1: st.subheader("šŸ“‹ Structured Clinical Summary") st.markdown(st.session_state.summary) with col2: st.subheader("šŸ“Š Quality Assessment") val = st.session_state.validation color_map = { "EXCELLENT": "🟢", "GOOD": "šŸ”µ", "FAIR": "🟔", "POOR": "šŸ”“" } status_color = color_map.get(val.get("status", ""), "⚪") st.markdown(f"### {status_color} {val.get('status', 'N/A')}") st.metric("Quality Score", f"{val.get('quality_score', 0)}/100") st.metric( "Sections Complete", f"{val.get('sections_present', 0)}/{val.get('sections_total', 7)}" ) if val.get("warnings"): with st.expander("⚠ Quality Warnings", expanded=True): for w in val["warnings"]: st.warning(f"• {w}") st.markdown("---") # Download and reset buttons col_a, col_b, col_c = st.columns([2, 2, 1]) with col_a: st.download_button( "šŸ’¾ Download Summary", st.session_state.summary, "clinical_summary.txt", mime="text/plain", type="secondary" ) with col_b: st.download_button( "šŸ’¾ Download De-identified Note", st.session_state.deid_text, "deidentified_note.txt", mime="text/plain", type="secondary" ) with col_c: if st.button("šŸ”„ Reset"): st.session_state.deid_text = "" st.session_state.original_text = "" st.session_state.summary = None st.session_state.validation = None st.rerun() # ==================== FOOTER ==================== st.markdown("---") st.caption("šŸ„ **HIPAA-Compliant Clinical Summarizer** | Portfolio Demo | Powered by Flan-T5 & Presidio") st.caption("⚠ For demonstration purposes only - not for clinical use")