moseleydev commited on
Commit
237e309
·
verified ·
1 Parent(s): a30a7a5

loaded custome model

Browse files
Files changed (1) hide show
  1. main.py +55 -32
main.py CHANGED
@@ -2,16 +2,16 @@ from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModel
5
- from sklearn.cluster import KMeans
6
  import torch
7
- import numpy as np
8
  import spacy
9
  import spacy.cli
10
  import time
 
11
 
12
  app = FastAPI(
13
  title="Clinical Extractive Summarization",
14
- description="SciBERT + KMeans NLP Engine for Medical Reports"
15
  )
16
 
17
  app.add_middleware(
@@ -21,9 +21,25 @@ app.add_middleware(
21
  allow_headers=["*"],
22
  )
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  tokenizer = None
25
  model = None
26
  nlp = None
 
27
 
28
  class ReportRequest(BaseModel):
29
  text: str
@@ -41,14 +57,28 @@ def health_check():
41
  def summarize_medical_report(request: ReportRequest):
42
  start_time = time.time()
43
 
44
- global tokenizer, model, nlp
45
  if model is None:
46
- print("Initializing SciBERT and SpaCy... This takes a moment.")
47
 
48
- # Load SciBERT
49
  model_name = "allenai/scibert_scivocab_uncased"
50
  tokenizer = AutoTokenizer.from_pretrained(model_name)
51
- model = AutoModel.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  try:
54
  nlp = spacy.load("en_core_web_sm")
@@ -61,36 +91,29 @@ def summarize_medical_report(request: ReportRequest):
61
 
62
  # 1. Safely split text into sentences using SpaCy NLP
63
  doc = nlp(request.text)
64
- sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.strip()) > 5]
 
65
 
66
  # Edge case: Report is too short to summarize
67
  if len(sentences) <= request.num_sentences:
68
  return {"summary": request.text, "metadata": {"status": "too_short"}}
69
 
70
- # 2. Get embeddings for each sentence using SciBERT
71
- embeddings = []
72
- for sent in sentences:
73
- inputs = tokenizer(sent, return_tensors="pt", truncation=True, padding=True, max_length=512)
74
- with torch.no_grad():
75
- output = model(**inputs)
76
-
77
- # Extract the [CLS] token representation
78
- cls_embedding = output.last_hidden_state[0][0].numpy()
79
- embeddings.append(cls_embedding)
80
 
81
- # 3. Use KMeans to cluster the embeddings and find the most central sentences
82
- # n_init='auto' suppresses sklearn warnings
83
- kmeans = KMeans(n_clusters=request.num_sentences, n_init='auto', random_state=42).fit(embeddings)
84
-
85
- avg = []
86
- for i in range(request.num_sentences):
87
- # Find the sentence closest to the cluster centroid
88
- idx = np.argmin(np.linalg.norm(embeddings - kmeans.cluster_centers_[i], axis=1))
89
- avg.append(idx)
90
 
91
- # 4. Sort indices chronologically to maintain original report flow
92
- avg = sorted(list(set(avg)))
93
- final_summary = " ".join([sentences[i] for i in avg])
94
 
95
  process_time = round((time.time() - start_time) * 1000, 2)
96
 
@@ -99,7 +122,7 @@ def summarize_medical_report(request: ReportRequest):
99
  "metadata": {
100
  "processing_time_ms": process_time,
101
  "original_length": len(sentences),
102
- "summary_length": len(avg),
103
- "engine": "SciBERT + KMeans"
104
  }
105
  }
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModel
 
5
  import torch
6
+ import torch.nn as nn
7
  import spacy
8
  import spacy.cli
9
  import time
10
+ import os
11
 
12
  app = FastAPI(
13
  title="Clinical Extractive Summarization",
14
+ description="SciBERT + BERTsum Fine-Tuned Engine for Medical Reports"
15
  )
16
 
17
  app.add_middleware(
 
21
  allow_headers=["*"],
22
  )
23
 
24
+ # --- ARCHITECTURE DEFINITION ---
25
+ class BioExtractor(nn.Module):
26
+ def __init__(self, model_name):
27
+ super(BioExtractor, self).__init__()
28
+ self.bert = AutoModel.from_pretrained(model_name)
29
+ # The classification layer that predicts sentence salience [cite: 279]
30
+ self.classifier = nn.Linear(768, 1)
31
+ self.sigmoid = nn.Sigmoid()
32
+
33
+ def forward(self, input_ids, attention_mask):
34
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
35
+ cls_output = outputs.last_hidden_state[:, 0, :]
36
+ return self.sigmoid(self.classifier(cls_output))
37
+
38
+ # Global variables to cache models in memory
39
  tokenizer = None
40
  model = None
41
  nlp = None
42
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
 
44
  class ReportRequest(BaseModel):
45
  text: str
 
57
  def summarize_medical_report(request: ReportRequest):
58
  start_time = time.time()
59
 
60
+ global tokenizer, model, nlp, device
61
  if model is None:
62
+ print("Initializing Fine-Tuned SciBERT and SpaCy...")
63
 
64
+ # Load the base tokenizer
65
  model_name = "allenai/scibert_scivocab_uncased"
66
  tokenizer = AutoTokenizer.from_pretrained(model_name)
67
+
68
+ # Instantiate your custom architecture
69
+ model = BioExtractor(model_name)
70
+
71
+ # Load the trained weights from the uploaded .pt file
72
+ model_path = "med_summarizer_trained.pt"
73
+ if os.path.exists(model_path):
74
+ print(f"Loading fine-tuned weights from {model_path}...")
75
+ # map_location ensures it works even if Hugging Face runs on a CPU space
76
+ model.load_state_dict(torch.load(model_path, map_location=device))
77
+ else:
78
+ print(f"WARNING: {model_path} not found! Upload it to your Space.")
79
+
80
+ model.to(device)
81
+ model.eval() # Lock the model for inference
82
 
83
  try:
84
  nlp = spacy.load("en_core_web_sm")
 
91
 
92
  # 1. Safely split text into sentences using SpaCy NLP
93
  doc = nlp(request.text)
94
+ # Filter out extremely short strings just like your Colab script
95
+ sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.strip()) > 10]
96
 
97
  # Edge case: Report is too short to summarize
98
  if len(sentences) <= request.num_sentences:
99
  return {"summary": request.text, "metadata": {"status": "too_short"}}
100
 
101
+ # 2. Get probability scores for each sentence using the fine-tuned model
102
+ scores = []
103
+ with torch.no_grad():
104
+ for sent in sentences:
105
+ inputs = tokenizer(sent, return_tensors="pt", truncation=True, padding='max_length', max_length=128).to(device)
106
+ output = model(inputs['input_ids'], inputs['attention_mask'])
107
+ scores.append(output.item())
 
 
 
108
 
109
+ # 3. Rank and select the top N sentences
110
+ # Enumerate keeps track of the original sentence index (e.g., (index, score))
111
+ scored_sentences = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
112
+ top_indices = [idx for idx, score in scored_sentences[:request.num_sentences]]
 
 
 
 
 
113
 
114
+ # 4. Sort indices chronologically to maintain original report flow [cite: 248]
115
+ top_indices_sorted = sorted(top_indices)
116
+ final_summary = " ".join([sentences[i] for i in top_indices_sorted])
117
 
118
  process_time = round((time.time() - start_time) * 1000, 2)
119
 
 
122
  "metadata": {
123
  "processing_time_ms": process_time,
124
  "original_length": len(sentences),
125
+ "summary_length": len(top_indices_sorted),
126
+ "engine": "SciBERT + BERTsum Fine-Tuned"
127
  }
128
  }