samarth09healthPM commited on
Commit
1c2a87b
Β·
1 Parent(s): 445c1de

Fix duplicate key error with session state

Browse files
Files changed (1) hide show
  1. main.py +331 -126
main.py CHANGED
@@ -7,28 +7,33 @@ from pathlib import Path
7
  import subprocess
8
  import torch
9
 
10
- # Fix torch.classes path error
11
  torch.classes.__path__ = []
12
 
13
- # HF Spaces env vars
14
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
  os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
16
 
17
  warnings.filterwarnings("ignore", category=DeprecationWarning)
18
  warnings.filterwarnings("ignore", category=UserWarning)
19
 
20
- st.set_page_config(page_title="Clinical AI Summarizer", layout="wide", initial_sidebar_state="expanded")
 
 
 
 
 
21
  st.title("πŸ₯ HIPAA-Compliant RAG Clinical Summarizer")
22
- st.markdown("De-identification β†’ Clinical Summarization β†’ Quality Assessment")
23
 
24
  # Global configuration
25
  secure_dir = "./secure_store"
26
- model_name = "google/flan-t5-base" # Changed to flan-t5 for better summarization
27
 
28
  # Ensure directories exist
29
  Path(secure_dir).mkdir(exist_ok=True)
30
 
31
- # Sidebar
32
  with st.sidebar:
33
  st.header("System Status")
34
 
@@ -36,7 +41,7 @@ with st.sidebar:
36
  from deid_pipeline import DeidPipeline
37
  st.success("βœ“ De-identification module")
38
  HAS_DEID = True
39
- except ImportError as e:
40
  st.warning("⚠ De-ID fallback mode")
41
  HAS_DEID = False
42
 
@@ -44,17 +49,19 @@ with st.sidebar:
44
  import transformers
45
  st.success("βœ“ Transformers loaded")
46
  except ImportError:
47
- st.error("βœ— Transformers missing")
48
  st.stop()
49
 
50
- st.info("Mode: Direct Summarization")
51
- st.caption(f"Model: {model_name}")
 
52
 
53
- # Fallback De-ID
54
  def fallback_deid(text: str) -> str:
 
55
  patterns = [
56
  (r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]'),
57
- (r'\b[A-Z][a-z]{2,}\b(?! (mg|mmHg|bpm|CT|MRI|TIA|BP|HR|RR|NIH|EF|BID|QID|PCP|PMH|HPI))', '[NAME]'),
58
  (r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', '[DATE]'),
59
  (r'\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b', '[PHONE]'),
60
  (r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]'),
@@ -66,9 +73,10 @@ def fallback_deid(text: str) -> str:
66
  result = re.sub(pat, rep, result, flags=re.IGNORECASE)
67
  return result
68
 
69
- # Load model with proper caching
70
  @st.cache_resource
71
  def load_model(model_name):
 
72
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
73
 
74
  tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
@@ -91,66 +99,151 @@ def load_model(model_name):
91
 
92
  tokenizer, model, device = load_model(model_name)
93
 
94
- def extract_sections_from_note(text: str) -> dict:
95
- """Extract clinical sections using keywords"""
96
- sections = {
97
- "Chief Complaint": "",
98
- "HPI": "",
99
- "Assessment": "",
100
- "Vitals": "",
101
- "Medications": "",
102
- "Plan": "",
103
- "Discharge Summary": ""
 
 
104
  }
105
 
 
 
 
 
 
 
 
 
 
 
106
  lines = text.split('\n')
107
  current_section = None
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  for line in lines:
110
- line_lower = line.lower().strip()
 
 
 
 
111
 
112
- # Detect section headers
113
- if any(kw in line_lower for kw in ['chief complaint', 'cc:']):
114
- current_section = "Chief Complaint"
115
- elif any(kw in line_lower for kw in ['history of present illness', 'hpi:', 'history:']):
116
- current_section = "HPI"
117
- elif any(kw in line_lower for kw in ['assessment', 'impression', 'diagnosis']):
118
- current_section = "Assessment"
119
- elif any(kw in line_lower for kw in ['vital signs', 'vitals:', 'bp:', 'temp:']):
120
- current_section = "Vitals"
121
- elif any(kw in line_lower for kw in ['medications', 'meds:', 'current medications']):
122
- current_section = "Medications"
123
- elif any(kw in line_lower for kw in ['plan:', 'treatment plan', 'recommendations']):
124
- current_section = "Plan"
125
- elif any(kw in line_lower for kw in ['discharge', 'discharge summary', 'disposition']):
126
- current_section = "Discharge Summary"
 
 
 
 
 
 
 
127
 
128
- # Append content to current section
129
- if current_section and line.strip():
130
- sections[current_section] += line + " "
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  return sections
133
 
134
- def summarize_clinical_note(text: str, tokenizer, model, device) -> str:
135
- """Generate structured clinical summary using T5"""
 
 
 
136
 
137
- # First extract any existing structure
138
- sections = extract_sections_from_note(text)
139
 
140
- # Truncate if too long
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  max_input_length = 1024
142
- if len(text) > max_input_length * 4: # Rough character estimate
143
  text = text[:max_input_length * 4]
144
 
145
- # Create focused prompt for T5
146
- prompt = f"""Summarize this clinical note into these sections:
147
- Chief Complaint: [patient's main concern]
148
- HPI: [history of present illness]
149
- Assessment: [clinical findings and diagnosis]
150
- Vitals: [vital signs if present]
151
- Medications: [current medications]
152
- Plan: [treatment plan]
153
- Discharge Summary: [discharge plan if applicable]
 
154
 
155
  Clinical Note:
156
  {text}
@@ -167,80 +260,159 @@ Structured Summary:"""
167
 
168
  inputs = {k: v.to(device) for k, v in inputs.items()}
169
 
 
170
  with torch.no_grad():
171
  outputs = model.generate(
172
  inputs['input_ids'],
173
- max_new_tokens=512,
174
- min_length=100,
175
  num_beams=4,
176
- temperature=0.7,
177
  do_sample=False,
178
  early_stopping=True,
179
- no_repeat_ngram_size=3, # Prevent repetition
180
- repetition_penalty=2.0, # Strong penalty for repetition
 
181
  pad_token_id=tokenizer.pad_token_id,
182
  eos_token_id=tokenizer.eos_token_id
183
  )
184
 
185
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
186
 
187
- # Format output with sections
188
- formatted_summary = ""
189
- for section_name in sections.keys():
190
- formatted_summary += f"**{section_name}:**\n"
 
 
 
 
 
 
191
 
192
- # Check if AI generated content for this section
193
- if section_name.lower() in summary.lower():
194
- # Extract relevant part from summary
195
- relevant_content = "Generated summary content"
196
- elif sections[section_name].strip():
197
- # Use extracted content
198
- formatted_summary += f"{sections[section_name].strip()[:200]}\n\n"
199
- else:
200
- formatted_summary += "Not documented\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- # If summary is too short, add the full AI output
203
- if len(formatted_summary) < 200:
204
- formatted_summary = summary
 
205
 
206
- return formatted_summary
207
 
 
208
  def validate_summary(summary: str, original_text: str) -> dict:
209
- """Assess summary quality"""
210
  score = 100
211
  warnings = []
212
 
213
- required_sections = ["Chief Complaint", "HPI", "Assessment", "Vitals", "Medications", "Plan"]
214
- present = sum(1 for sec in required_sections if sec.lower() in summary.lower())
215
- missing_count = len(required_sections) - present
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- if missing_count > 0:
218
- score -= missing_count * 10
219
- warnings.append(f"{missing_count} sections missing or incomplete")
220
 
221
- # Check for medical content
222
- if re.search(r'\d+\s*(mg|mmHg|bpm|%)', summary, re.I):
 
 
 
 
 
 
 
 
 
 
 
 
223
  score += 10
 
 
224
 
225
- # Check for repetition (like "windshield windshield")
226
  words = summary.lower().split()
227
- if len(words) > 10:
228
  unique_ratio = len(set(words)) / len(words)
229
- if unique_ratio < 0.3:
230
- score -= 40
231
- warnings.append("High repetition detected - review summary quality")
232
-
233
- # Check length
234
- if len(summary) < 100:
235
- score -= 20
236
- warnings.append("Summary too short")
 
 
 
 
 
 
 
 
 
 
237
 
238
  score = max(0, min(100, score))
239
- status = "EXCELLENT" if score >= 85 else "GOOD" if score >= 70 else "FAIR" if score >= 50 else "POOR"
240
 
241
- return {"quality_score": score, "status": status, "warnings": warnings}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
- # Session state
244
  if 'deid_text' not in st.session_state:
245
  st.session_state.deid_text = ""
246
  if 'original_text' not in st.session_state:
@@ -250,15 +422,19 @@ if 'summary' not in st.session_state:
250
  if 'validation' not in st.session_state:
251
  st.session_state.validation = None
252
 
253
- # UI Tabs
254
  tab1, tab2 = st.tabs(["πŸ“ De-Identify Note", "✨ Generate Summary"])
255
 
256
  with tab1:
257
  st.header("Step 1: De-identify Clinical Note")
 
258
 
259
  uploaded = st.file_uploader("Upload clinical note (.txt)", type=["txt"])
260
- input_text = st.text_area("Or paste clinical note here:", height=300,
261
- placeholder="Paste clinical documentation here...")
 
 
 
262
 
263
  note_text = ""
264
  if uploaded:
@@ -281,33 +457,34 @@ with tab1:
281
  with open(f"{secure_dir}/session_note.spanmap.enc", "wb") as f:
282
  f.write(result["encrypted_span_map"])
283
 
284
- st.success("βœ“ De-identified with encrypted audit trail")
285
  except Exception as e:
286
- st.warning(f"Using regex-based de-identification")
287
  deid_text = fallback_deid(note_text)
288
  else:
289
  deid_text = fallback_deid(note_text)
290
- st.info("Using regex-based de-identification")
291
 
292
  st.session_state.deid_text = deid_text
293
- st.success(f"βœ“ Processed {len(deid_text)} characters (PHI redacted)")
294
  else:
295
  st.warning("⚠ Please enter or upload a clinical note")
296
 
297
  if st.session_state.deid_text:
298
- with st.expander("πŸ“„ Preview De-identified Text"):
299
  st.text_area("", st.session_state.deid_text, height=250, disabled=True, key="preview_deid")
300
 
301
  with tab2:
302
  st.header("Step 2: Generate Clinical Summary")
 
303
 
304
  if not st.session_state.deid_text:
305
- st.warning("⚠ Please de-identify a note first (Tab 1)")
306
  else:
307
- st.info(f"βœ“ Ready to summarize: {len(st.session_state.deid_text)} characters")
308
 
309
  if st.button("πŸš€ Generate Summary", type="primary"):
310
- with st.spinner("Generating structured summary (30-60 seconds)..."):
311
  try:
312
  summary = summarize_clinical_note(
313
  st.session_state.deid_text,
@@ -317,51 +494,79 @@ with tab2:
317
  )
318
 
319
  st.session_state.summary = summary
320
- st.session_state.validation = validate_summary(summary, st.session_state.deid_text)
321
- st.success("βœ“ Summary generated successfully!")
 
 
 
322
 
323
  except Exception as e:
324
- st.error(f"Summarization failed: {str(e)}")
 
325
  st.session_state.summary = None
326
 
327
  if st.session_state.summary:
328
- col1, col2 = st.columns([3, 1])
 
 
329
 
330
  with col1:
331
  st.subheader("πŸ“‹ Structured Clinical Summary")
332
  st.markdown(st.session_state.summary)
333
 
334
  with col2:
335
- st.subheader("πŸ“Š Quality")
336
  val = st.session_state.validation
337
 
338
- color_map = {"EXCELLENT": "🟒", "GOOD": "πŸ”΅", "FAIR": "🟑", "POOR": "πŸ”΄"}
 
 
 
 
 
339
  status_color = color_map.get(val.get("status", ""), "βšͺ")
340
 
341
  st.markdown(f"### {status_color} {val.get('status', 'N/A')}")
342
  st.metric("Quality Score", f"{val.get('quality_score', 0)}/100")
 
 
 
 
343
 
344
  if val.get("warnings"):
345
- st.warning("**Issues:**")
346
- for w in val["warnings"]:
347
- st.write(f"β€’ {w}")
 
 
348
 
349
- # Download buttons
350
- col_a, col_b = st.columns(2)
351
  with col_a:
352
  st.download_button(
353
  "πŸ’Ύ Download Summary",
354
  st.session_state.summary,
355
  "clinical_summary.txt",
 
356
  type="secondary"
357
  )
358
  with col_b:
359
- if st.button("πŸ”„ Reset & Start Over"):
 
 
 
 
 
 
 
 
360
  st.session_state.deid_text = ""
361
  st.session_state.original_text = ""
362
  st.session_state.summary = None
363
  st.session_state.validation = None
364
  st.rerun()
365
 
 
366
  st.markdown("---")
367
- st.caption("πŸ₯ HIPAA-Compliant Clinical Summarizer | Portfolio Demo | Powered by Flan-T5")
 
 
7
  import subprocess
8
  import torch
9
 
10
+ # Fix torch.classes path error for Streamlit compatibility
11
  torch.classes.__path__ = []
12
 
13
+ # HF Spaces environment variables
14
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
  os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
16
 
17
  warnings.filterwarnings("ignore", category=DeprecationWarning)
18
  warnings.filterwarnings("ignore", category=UserWarning)
19
 
20
+ st.set_page_config(
21
+ page_title="Clinical AI Summarizer",
22
+ layout="wide",
23
+ initial_sidebar_state="expanded"
24
+ )
25
+
26
  st.title("πŸ₯ HIPAA-Compliant RAG Clinical Summarizer")
27
+ st.markdown("**De-identification β†’ Clinical Summarization β†’ Quality Assessment**")
28
 
29
  # Global configuration
30
  secure_dir = "./secure_store"
31
+ model_name = "google/flan-t5-base"
32
 
33
  # Ensure directories exist
34
  Path(secure_dir).mkdir(exist_ok=True)
35
 
36
+ # ==================== SIDEBAR ====================
37
  with st.sidebar:
38
  st.header("System Status")
39
 
 
41
  from deid_pipeline import DeidPipeline
42
  st.success("βœ“ De-identification module")
43
  HAS_DEID = True
44
+ except ImportError:
45
  st.warning("⚠ De-ID fallback mode")
46
  HAS_DEID = False
47
 
 
49
  import transformers
50
  st.success("βœ“ Transformers loaded")
51
  except ImportError:
52
+ st.error("βœ— Transformers missing - rebuild Space")
53
  st.stop()
54
 
55
+ st.info("**Mode:** Direct Summarization")
56
+ st.caption(f"**Model:** {model_name}")
57
+ st.caption(f"**Secure Dir:** {secure_dir}")
58
 
59
+ # ==================== FALLBACK DE-ID ====================
60
  def fallback_deid(text: str) -> str:
61
+ """Regex-based PHI removal fallback"""
62
  patterns = [
63
  (r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', '[NAME]'),
64
+ (r'\b[A-Z][a-z]{2,}\b(?! (mg|mmHg|bpm|CT|MRI|TIA|BP|HR|RR|NIH|EF|BID|QID|PCP|PMH|HPI|ROS))', '[NAME]'),
65
  (r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', '[DATE]'),
66
  (r'\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b', '[PHONE]'),
67
  (r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]'),
 
73
  result = re.sub(pat, rep, result, flags=re.IGNORECASE)
74
  return result
75
 
76
+ # ==================== MODEL LOADING ====================
77
  @st.cache_resource
78
  def load_model(model_name):
79
+ """Load T5 model with proper caching"""
80
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
81
 
82
  tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/hf_cache")
 
99
 
100
  tokenizer, model, device = load_model(model_name)
101
 
102
+ # ==================== SECTION EXTRACTION FUNCTIONS ====================
103
+ def extract_vitals(text: str) -> str:
104
+ """Extract vital signs using pattern matching"""
105
+ vitals_found = []
106
+
107
+ patterns = {
108
+ 'BP': r'(?:BP|Blood Pressure)[:\s]+(\d{2,3}/\d{2,3})',
109
+ 'HR': r'(?:HR|Heart Rate|Pulse)[:\s]+(\d{2,3})(?:\s*bpm)?',
110
+ 'Temp': r'(?:Temp|Temperature)[:\s]+(\d{2,3}\.?\d*)(?:\s*[FCΒ°])?',
111
+ 'RR': r'(?:RR|Respiratory Rate|Resp)[:\s]+(\d{1,2})',
112
+ 'O2': r'(?:O2|Oxygen|SpO2)[:\s]+(\d{2,3})%?',
113
+ 'Weight': r'(?:Weight|Wt)[:\s]+(\d{2,3}\.?\d*)\s*(?:kg|lbs)?',
114
  }
115
 
116
+ for vital_name, pattern in patterns.items():
117
+ matches = re.findall(pattern, text, re.IGNORECASE)
118
+ if matches:
119
+ vitals_found.append(f"{vital_name}: {matches[0]}")
120
+
121
+ return ', '.join(vitals_found) if vitals_found else ""
122
+
123
+ def extract_all_sections(text: str) -> dict:
124
+ """Enhanced section extraction from clinical notes"""
125
+ sections = {}
126
  lines = text.split('\n')
127
  current_section = None
128
+ buffer = []
129
+
130
+ section_keywords = {
131
+ "Chief Complaint": ['chief complaint', 'cc:', 'presenting complaint', 'reason for visit', 'presenting concern'],
132
+ "HPI": ['history of present illness', 'hpi:', 'present illness', 'history:', 'clinical history'],
133
+ "Assessment": ['assessment:', 'impression:', 'diagnosis:', 'clinical impression', 'diagnoses:'],
134
+ "Vitals": ['vital signs', 'vitals:', 'bp:', 'blood pressure', 'temperature', 'pulse', 'hr:', 'physical exam'],
135
+ "Medications": ['medications:', 'meds:', 'current medications', 'home medications', 'prescriptions', 'drug list'],
136
+ "Plan": ['plan:', 'treatment plan', 'recommendations:', 'disposition:', 'instructions', 'management plan'],
137
+ "Discharge Summary": ['discharge', 'discharge summary', 'discharge plan', 'follow-up', 'disposition', 'discharge instructions']
138
+ }
139
 
140
  for line in lines:
141
+ line_clean = line.strip()
142
+ line_lower = line_clean.lower()
143
+
144
+ if not line_clean:
145
+ continue
146
 
147
+ # Check if this line is a section header
148
+ matched_section = None
149
+ for section_name, keywords in section_keywords.items():
150
+ if any(kw in line_lower for kw in keywords):
151
+ # Save previous section
152
+ if current_section and buffer:
153
+ sections[current_section] = ' '.join(buffer).strip()
154
+
155
+ matched_section = section_name
156
+ current_section = section_name
157
+ buffer = []
158
+
159
+ # Capture content on the same line after the header
160
+ for kw in keywords:
161
+ if kw in line_lower:
162
+ idx = line_lower.index(kw)
163
+ remainder = line_clean[idx + len(kw):].strip()
164
+ # Remove leading colon/dash
165
+ remainder = re.sub(r'^[:\-\s]+', '', remainder).strip()
166
+ if remainder and len(remainder) > 2:
167
+ buffer.append(remainder)
168
+ break
169
 
170
+ # If not a header and we have an active section, add to buffer
171
+ if not matched_section and current_section and line_clean:
172
+ # Avoid adding another section header accidentally
173
+ is_likely_header = any(kw in line_lower for keywords_list in section_keywords.values() for kw in keywords_list)
174
+ if not is_likely_header:
175
+ buffer.append(line_clean)
176
+
177
+ # Save final section
178
+ if current_section and buffer:
179
+ sections[current_section] = ' '.join(buffer).strip()
180
+
181
+ # Special extraction for vitals using regex
182
+ if "Vitals" not in sections or not sections["Vitals"]:
183
+ vitals = extract_vitals(text)
184
+ if vitals:
185
+ sections["Vitals"] = vitals
186
 
187
  return sections
188
 
189
+ def parse_ai_summary(ai_text: str) -> dict:
190
+ """Parse structured output from AI if it generated section-based content"""
191
+ sections = {}
192
+ current_section = None
193
+ buffer = []
194
 
195
+ lines = ai_text.split('\n')
 
196
 
197
+ for line in lines:
198
+ line_clean = line.strip()
199
+
200
+ # Check if line starts with a section name
201
+ section_starters = ['Chief Complaint:', 'HPI:', 'Assessment:', 'Vitals:',
202
+ 'Medications:', 'Plan:', 'Discharge Summary:']
203
+
204
+ matched = None
205
+ for starter in section_starters:
206
+ if line_clean.startswith(starter):
207
+ matched = starter
208
+ break
209
+
210
+ if matched:
211
+ # Save previous section
212
+ if current_section and buffer:
213
+ sections[current_section] = ' '.join(buffer).strip()
214
+
215
+ # Start new section
216
+ current_section = matched.replace(':', '').strip()
217
+ content = line_clean[len(matched):].strip()
218
+ buffer = [content] if content else []
219
+ elif current_section and line_clean:
220
+ buffer.append(line_clean)
221
+
222
+ # Save final section
223
+ if current_section and buffer:
224
+ sections[current_section] = ' '.join(buffer).strip()
225
+
226
+ return sections
227
+
228
+ # ==================== MAIN SUMMARIZATION FUNCTION ====================
229
+ def summarize_clinical_note(text: str, tokenizer, model, device) -> str:
230
+ """Generate structured clinical summary using T5 with proper section extraction"""
231
+
232
+ # Truncate if too long (T5 has token limits)
233
  max_input_length = 1024
234
+ if len(text) > max_input_length * 4:
235
  text = text[:max_input_length * 4]
236
 
237
+ # Create detailed prompt for T5
238
+ prompt = f"""Summarize this clinical documentation into a structured format with these exact sections:
239
+
240
+ Chief Complaint: State the patient's main presenting concern or reason for visit
241
+ HPI: Summarize the history of present illness including onset, duration, and progression
242
+ Assessment: List clinical findings, diagnoses, and impressions
243
+ Vitals: Extract all vital signs including BP, HR, Temperature, RR, O2 saturation
244
+ Medications: List all current medications with dosages and frequencies
245
+ Plan: Describe the treatment plan, recommendations, and next steps
246
+ Discharge Summary: Provide discharge status, instructions, and follow-up plans
247
 
248
  Clinical Note:
249
  {text}
 
260
 
261
  inputs = {k: v.to(device) for k, v in inputs.items()}
262
 
263
+ # Generate with optimal parameters to prevent repetition
264
  with torch.no_grad():
265
  outputs = model.generate(
266
  inputs['input_ids'],
267
+ max_new_tokens=650,
268
+ min_length=200,
269
  num_beams=4,
270
+ temperature=0.8,
271
  do_sample=False,
272
  early_stopping=True,
273
+ no_repeat_ngram_size=3,
274
+ repetition_penalty=2.5,
275
+ length_penalty=1.0,
276
  pad_token_id=tokenizer.pad_token_id,
277
  eos_token_id=tokenizer.eos_token_id
278
  )
279
 
280
+ ai_summary = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
281
+
282
+ # Extract sections from original text using keyword matching
283
+ sections_content = extract_all_sections(text)
284
 
285
+ # Parse AI output for any additional structured content
286
+ ai_sections = parse_ai_summary(ai_summary)
287
+
288
+ # Merge: prioritize extracted content, fallback to AI, then "Not documented"
289
+ final_sections = {}
290
+ section_names = ["Chief Complaint", "HPI", "Assessment", "Vitals", "Medications", "Plan", "Discharge Summary"]
291
+
292
+ for section in section_names:
293
+ # Try extracted content first
294
+ content = sections_content.get(section, "").strip()
295
 
296
+ # If no content or too short, try AI summary
297
+ if not content or len(content) < 15:
298
+ content = ai_sections.get(section, "").strip()
299
+
300
+ # If still no content and AI generated something generic, use it
301
+ if not content or len(content) < 10:
302
+ # Check if AI summary contains relevant info in unstructured format
303
+ if section.lower() in ai_summary.lower():
304
+ # Extract sentences mentioning this section
305
+ sentences = ai_summary.split('.')
306
+ relevant = [s.strip() for s in sentences if section.lower() in s.lower()]
307
+ if relevant:
308
+ content = '. '.join(relevant) + '.'
309
+
310
+ # Final fallback
311
+ if not content or len(content) < 10:
312
+ content = "Not documented"
313
+
314
+ # Clean up content
315
+ content = content.replace(' ', ' ').strip()
316
+ final_sections[section] = content
317
 
318
+ # Format output with proper markdown
319
+ formatted_output = ""
320
+ for section in section_names:
321
+ formatted_output += f"**{section}:**\n{final_sections[section]}\n\n"
322
 
323
+ return formatted_output
324
 
325
+ # ==================== QUALITY VALIDATION ====================
326
  def validate_summary(summary: str, original_text: str) -> dict:
327
+ """Assess summary quality with detailed metrics"""
328
  score = 100
329
  warnings = []
330
 
331
+ required_sections = ["Chief Complaint", "HPI", "Assessment", "Vitals", "Medications", "Plan", "Discharge Summary"]
332
+
333
+ # Count present sections
334
+ present_count = 0
335
+ for sec in required_sections:
336
+ section_content = ""
337
+ if sec + ":" in summary:
338
+ # Extract content for this section
339
+ lines = summary.split('\n')
340
+ in_section = False
341
+ for line in lines:
342
+ if line.startswith(f"**{sec}:**"):
343
+ in_section = True
344
+ continue
345
+ if in_section:
346
+ if line.startswith("**"):
347
+ break
348
+ section_content += line
349
+
350
+ if "not documented" not in section_content.lower() and len(section_content.strip()) > 10:
351
+ present_count += 1
352
 
353
+ missing_count = len(required_sections) - present_count
 
 
354
 
355
+ if missing_count > 0:
356
+ score -= missing_count * 12
357
+ warnings.append(f"{missing_count} of 7 sections incomplete")
358
+
359
+ # Check for medical content indicators
360
+ medical_patterns = [
361
+ r'\d+\s*mg',
362
+ r'\d+/\d+\s*mmHg',
363
+ r'\d+\s*bpm',
364
+ r'\d+\.?\d*\s*[FCΒ°]',
365
+ r'\d+%',
366
+ ]
367
+ medical_content_found = any(re.search(pattern, summary, re.I) for pattern in medical_patterns)
368
+ if medical_content_found:
369
  score += 10
370
+ else:
371
+ warnings.append("Limited quantitative clinical data")
372
 
373
+ # Check for repetition issues
374
  words = summary.lower().split()
375
+ if len(words) > 20:
376
  unique_ratio = len(set(words)) / len(words)
377
+ if unique_ratio < 0.35:
378
+ score -= 30
379
+ warnings.append("High repetition detected - summary quality poor")
380
+
381
+ # Check overall length
382
+ if len(summary) < 150:
383
+ score -= 15
384
+ warnings.append("Summary too brief")
385
+ elif len(summary) > 2000:
386
+ score -= 5
387
+ warnings.append("Summary may be overly verbose")
388
+
389
+ # Check for key clinical terms
390
+ clinical_terms = ['patient', 'diagnosis', 'treatment', 'plan', 'medication', 'assessment']
391
+ terms_found = sum(1 for term in clinical_terms if term in summary.lower())
392
+ if terms_found < 3:
393
+ score -= 10
394
+ warnings.append("Limited clinical terminology")
395
 
396
  score = max(0, min(100, score))
 
397
 
398
+ if score >= 90:
399
+ status = "EXCELLENT"
400
+ elif score >= 75:
401
+ status = "GOOD"
402
+ elif score >= 60:
403
+ status = "FAIR"
404
+ else:
405
+ status = "POOR"
406
+
407
+ return {
408
+ "quality_score": score,
409
+ "status": status,
410
+ "warnings": warnings,
411
+ "sections_present": present_count,
412
+ "sections_total": len(required_sections)
413
+ }
414
 
415
+ # ==================== SESSION STATE ====================
416
  if 'deid_text' not in st.session_state:
417
  st.session_state.deid_text = ""
418
  if 'original_text' not in st.session_state:
 
422
  if 'validation' not in st.session_state:
423
  st.session_state.validation = None
424
 
425
+ # ==================== UI TABS ====================
426
  tab1, tab2 = st.tabs(["πŸ“ De-Identify Note", "✨ Generate Summary"])
427
 
428
  with tab1:
429
  st.header("Step 1: De-identify Clinical Note")
430
+ st.markdown("Upload or paste a clinical note to remove PHI (Protected Health Information)")
431
 
432
  uploaded = st.file_uploader("Upload clinical note (.txt)", type=["txt"])
433
+ input_text = st.text_area(
434
+ "Or paste clinical note here:",
435
+ height=300,
436
+ placeholder="Paste clinical documentation here...\n\nExample:\nChief Complaint: Chest pain\nHPI: 72-year-old male presents with...\nVitals: BP 140/90, HR 88..."
437
+ )
438
 
439
  note_text = ""
440
  if uploaded:
 
457
  with open(f"{secure_dir}/session_note.spanmap.enc", "wb") as f:
458
  f.write(result["encrypted_span_map"])
459
 
460
+ st.success("βœ… De-identified with encrypted audit trail saved")
461
  except Exception as e:
462
+ st.warning(f"⚠ Using regex-based de-identification: {str(e)[:100]}")
463
  deid_text = fallback_deid(note_text)
464
  else:
465
  deid_text = fallback_deid(note_text)
466
+ st.info("β„Ή Using regex-based de-identification")
467
 
468
  st.session_state.deid_text = deid_text
469
+ st.success(f"βœ… Processed **{len(deid_text)}** characters (PHI redacted)")
470
  else:
471
  st.warning("⚠ Please enter or upload a clinical note")
472
 
473
  if st.session_state.deid_text:
474
+ with st.expander("πŸ“„ Preview De-identified Text", expanded=False):
475
  st.text_area("", st.session_state.deid_text, height=250, disabled=True, key="preview_deid")
476
 
477
  with tab2:
478
  st.header("Step 2: Generate Clinical Summary")
479
+ st.markdown("AI-powered structured summarization with quality assessment")
480
 
481
  if not st.session_state.deid_text:
482
+ st.warning("⚠ Please de-identify a note first in **Tab 1**")
483
  else:
484
+ st.info(f"βœ… Ready to summarize: **{len(st.session_state.deid_text)}** characters")
485
 
486
  if st.button("πŸš€ Generate Summary", type="primary"):
487
+ with st.spinner("⏳ Generating structured summary (30-60 seconds)..."):
488
  try:
489
  summary = summarize_clinical_note(
490
  st.session_state.deid_text,
 
494
  )
495
 
496
  st.session_state.summary = summary
497
+ st.session_state.validation = validate_summary(
498
+ summary,
499
+ st.session_state.deid_text
500
+ )
501
+ st.success("βœ… Summary generated successfully!")
502
 
503
  except Exception as e:
504
+ st.error(f"❌ Summarization failed: {str(e)}")
505
+ st.exception(e)
506
  st.session_state.summary = None
507
 
508
  if st.session_state.summary:
509
+ st.markdown("---")
510
+
511
+ col1, col2 = st.columns([2.5, 1])
512
 
513
  with col1:
514
  st.subheader("πŸ“‹ Structured Clinical Summary")
515
  st.markdown(st.session_state.summary)
516
 
517
  with col2:
518
+ st.subheader("πŸ“Š Quality Assessment")
519
  val = st.session_state.validation
520
 
521
+ color_map = {
522
+ "EXCELLENT": "🟒",
523
+ "GOOD": "πŸ”΅",
524
+ "FAIR": "🟑",
525
+ "POOR": "πŸ”΄"
526
+ }
527
  status_color = color_map.get(val.get("status", ""), "βšͺ")
528
 
529
  st.markdown(f"### {status_color} {val.get('status', 'N/A')}")
530
  st.metric("Quality Score", f"{val.get('quality_score', 0)}/100")
531
+ st.metric(
532
+ "Sections Complete",
533
+ f"{val.get('sections_present', 0)}/{val.get('sections_total', 7)}"
534
+ )
535
 
536
  if val.get("warnings"):
537
+ with st.expander("⚠ Quality Warnings", expanded=True):
538
+ for w in val["warnings"]:
539
+ st.warning(f"β€’ {w}")
540
+
541
+ st.markdown("---")
542
 
543
+ # Download and reset buttons
544
+ col_a, col_b, col_c = st.columns([2, 2, 1])
545
  with col_a:
546
  st.download_button(
547
  "πŸ’Ύ Download Summary",
548
  st.session_state.summary,
549
  "clinical_summary.txt",
550
+ mime="text/plain",
551
  type="secondary"
552
  )
553
  with col_b:
554
+ st.download_button(
555
+ "πŸ’Ύ Download De-identified Note",
556
+ st.session_state.deid_text,
557
+ "deidentified_note.txt",
558
+ mime="text/plain",
559
+ type="secondary"
560
+ )
561
+ with col_c:
562
+ if st.button("πŸ”„ Reset"):
563
  st.session_state.deid_text = ""
564
  st.session_state.original_text = ""
565
  st.session_state.summary = None
566
  st.session_state.validation = None
567
  st.rerun()
568
 
569
+ # ==================== FOOTER ====================
570
  st.markdown("---")
571
+ st.caption("πŸ₯ **HIPAA-Compliant Clinical Summarizer** | Portfolio Demo | Powered by Flan-T5 & Presidio")
572
+ st.caption("⚠ For demonstration purposes only - not for clinical use")