borso271 commited on
Commit
7466f1c
·
1 Parent(s): 2a97c1d

Add name-based deduplication to prevent duplicate label names

Browse files
Files changed (1) hide show
  1. handler.py +23 -2
handler.py CHANGED
@@ -152,12 +152,32 @@ class EndpointHandler:
152
  if not new_items:
153
  return 0
154
  with self._lock:
155
- known = set(getattr(self, "class_ids", []))
156
- batch = [it for it in new_items if int(it.get("id")) not in known]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if not batch:
158
  return 0
 
 
159
  prompts = [it["prompt"] for it in batch]
160
  feats = self._encode_text(prompts).detach().cpu().to(torch.float32)
 
 
161
  if not hasattr(self, "text_features_cpu"):
162
  self.text_features_cpu = feats.contiguous()
163
  self.class_ids = [int(it["id"]) for it in batch]
@@ -166,6 +186,7 @@ class EndpointHandler:
166
  self.text_features_cpu = torch.cat([self.text_features_cpu, feats], dim=0).contiguous()
167
  self.class_ids.extend([int(it["id"]) for it in batch])
168
  self.class_names.extend([it["name"] for it in batch])
 
169
  self._to_device()
170
  return len(batch)
171
 
 
152
  if not new_items:
153
  return 0
154
  with self._lock:
155
+ # Get ALL existing IDs and names from current state
156
+ known_ids = set(getattr(self, "class_ids", []))
157
+ known_names = set(getattr(self, "class_names", []))
158
+
159
+ # Filter items, checking against both ID and name
160
+ batch = []
161
+ for it in new_items:
162
+ item_id = int(it.get("id"))
163
+ item_name = it.get("name")
164
+
165
+ # Skip if either ID or name already exists
166
+ if item_id in known_ids:
167
+ continue # Skip duplicate ID
168
+ elif item_name in known_names:
169
+ continue # Skip duplicate name
170
+ else:
171
+ batch.append(it)
172
+
173
  if not batch:
174
  return 0
175
+
176
+ # Process the filtered batch
177
  prompts = [it["prompt"] for it in batch]
178
  feats = self._encode_text(prompts).detach().cpu().to(torch.float32)
179
+
180
+ # Update the persistent state
181
  if not hasattr(self, "text_features_cpu"):
182
  self.text_features_cpu = feats.contiguous()
183
  self.class_ids = [int(it["id"]) for it in batch]
 
186
  self.text_features_cpu = torch.cat([self.text_features_cpu, feats], dim=0).contiguous()
187
  self.class_ids.extend([int(it["id"]) for it in batch])
188
  self.class_names.extend([it["name"] for it in batch])
189
+
190
  self._to_device()
191
  return len(batch)
192