Updating source code to support batching

#1
Files changed (1) hide show
  1. moondream.py +225 -0
moondream.py CHANGED
@@ -828,6 +828,231 @@ class MoondreamModel(nn.Module):
828
 
829
  return {"points": objects}
830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831
  def _detect_gaze(
832
  self,
833
  image: EncodedImage,
 
828
 
829
  return {"points": objects}
830
 
831
+
832
+ # === BEGIN: Batched multi-label detection additions ===
833
+ def _load_encoded_image_batched(self, encoded_image, batch_size: int):
834
+ """
835
+ Clone single-image KV caches into a batch-B cache so we can decode B labels in parallel.
836
+ """
837
+ for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
838
+ T = k.size(2)
839
+ # Allocate new [B, n_kv_heads, T_max, head_dim] caches if needed
840
+ if b.kv_cache.k_cache.size(0) != batch_size:
841
+ new_k = b.kv_cache.k_cache.new_zeros((batch_size,) + b.kv_cache.k_cache.shape[1:])
842
+ new_v = b.kv_cache.v_cache.new_zeros((batch_size,) + b.kv_cache.v_cache.shape[1:])
843
+ b.kv_cache.k_cache = new_k
844
+ b.kv_cache.v_cache = new_v
845
+ # Copy current prefix from the encoded image into all B rows
846
+ b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
847
+ b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
848
+
849
+ def _prefill_prompt_batched(self, labels, pos: int, lora=None, temperature: float = 0.0, top_p: float = 0.0):
850
+ """
851
+ Build detect prompts for many labels, pad to same length, prefill once as a batch,
852
+ then return (last_hidden per row, next_token per row, pos per row).
853
+ """
854
+ import torch
855
+ from .text import text_encoder, lm_head
856
+
857
+ tpl = self.config.tokenizer.templates["detect"]
858
+ if tpl is None:
859
+ raise NotImplementedError("Model does not support object detection (no detect template).")
860
+
861
+ rows, lens = [], []
862
+ for lab in labels:
863
+ ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
864
+ rows.append(torch.tensor(ids, device=self.device, dtype=torch.long))
865
+ lens.append(len(ids))
866
+ B = len(rows); T = max(lens)
867
+ eos = self.config.tokenizer.eos_id
868
+
869
+ # Pad with eos so we can prefill as a single batch
870
+ prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
871
+ for i, ids in enumerate(rows):
872
+ prompt_ids[i, : ids.numel()] = ids
873
+
874
+ # Embed & prefill once
875
+ prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
876
+ import torch
877
+ torch._dynamo.mark_dynamic(prompt_emb, 1) # allow variable T
878
+
879
+ attn_mask = self.attn_mask
880
+ mask = attn_mask[:, :, pos : pos + T, :].expand(B, -1, -1, -1).contiguous()
881
+ pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long)
882
+
883
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
884
+ logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
885
+
886
+ # Take the last *real* token per row (ignore padding positions)
887
+ idx = (torch.tensor(lens, device=self.device, dtype=torch.long) - 1).clamp_min(0)
888
+ last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B, 1, C)
889
+ last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B, V)
890
+
891
+ if temperature == 0.0:
892
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B, 1)
893
+ else:
894
+ probs = torch.softmax(last_logits / temperature, dim=-1)
895
+ probs = self._apply_top_p(probs, top_p)
896
+ next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
897
+
898
+ pos_vec = torch.tensor([pos], device=self.device, dtype=torch.long).repeat(B) + torch.tensor(lens, device=self.device)
899
+
900
+ return last_hidden, next_token, pos_vec # (B,1,C), (B,1), (B,)
901
+
902
+ def _generate_points_batched(self, hidden, next_token, pos_vec, include_size: bool = True, max_objects: int = 50, lora=None):
903
+ """
904
+ Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
905
+ for all rows in the batch simultaneously.
906
+ Returns: list-of-lists of dicts, length B.
907
+ """
908
+ import torch
909
+ from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
910
+
911
+ B = hidden.size(0)
912
+ device = self.device
913
+ out = [[] for _ in range(B)]
914
+ eos_id = self.config.tokenizer.eos_id
915
+
916
+ # Per-row attention/masking state
917
+ max_ctx = self.config.text.max_context
918
+ mask = torch.zeros(B, 1, max_ctx, device=device, dtype=torch.bool)
919
+ for i in range(B):
920
+ mask[i, :, : int(pos_vec[i].item())] = 1
921
+ pos_ids = pos_vec.clone()
922
+
923
+ alive = torch.ones(B, dtype=torch.bool, device=device)
924
+ counts = torch.zeros(B, dtype=torch.int32, device=device)
925
+
926
+ with torch.inference_mode():
927
+ while alive.any() and (counts < max_objects).any():
928
+ # --- x coordinate (from current hidden) ---
929
+ x_logits = decode_coordinate(hidden, self.region) # (B, 1, 1024) or (B, 1024)
930
+ if x_logits.dim() == 3:
931
+ x_logits = x_logits.squeeze(1) # (B, 1024)
932
+ x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
933
+ x_center = x_bin / float(x_logits.size(-1)) # normalize to [0,1]
934
+ x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype), self.region).unsqueeze(1) # (B,1,C)
935
+
936
+ # step: decode to get hidden for y
937
+ for i in range(B):
938
+ if alive[i]:
939
+ mask[i, :, pos_ids[i]] = 1
940
+ logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
941
+ pos_ids = pos_ids + alive.to(torch.long)
942
+
943
+ # --- y coordinate ---
944
+ y_logits = decode_coordinate(hidden, self.region)
945
+ if y_logits.dim() == 3:
946
+ y_logits = y_logits.squeeze(1) # (B, 1024)
947
+ y_bin = y_logits.argmax(dim=-1).to(torch.float32)
948
+ y_center = y_bin / float(y_logits.size(-1))
949
+ y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype), self.region).unsqueeze(1)
950
+
951
+ # step: decode to get hidden for size (or eos)
952
+ for i in range(B):
953
+ if alive[i]:
954
+ mask[i, :, pos_ids[i]] = 1
955
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
956
+ pos_ids = pos_ids + alive.to(torch.long)
957
+
958
+ if include_size:
959
+ # --- size logits (batched) ---
960
+ size_logits = decode_size(hidden, self.region) # tuple/list [w_logits, h_logits] shaped (B,1,1024)
961
+ w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1) # (B,1024), (B,1024)
962
+ w_bin = w_logits.argmax(dim=-1).to(torch.float32)
963
+ h_bin = h_logits.argmax(dim=-1).to(torch.float32)
964
+ # Convert from log-scale bin to size in [0,1]
965
+ w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
966
+ h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
967
+ size_emb = encode_size(torch.stack([w, h], dim=0), self.region).transpose(0,1).unsqueeze(1) # (B,1,C)
968
+
969
+ # Commit boxes for alive rows
970
+ for i in range(B):
971
+ if not alive[i]:
972
+ continue
973
+ out[i].append({
974
+ "x_min": (x_center[i] - w[i] / 2).item(),
975
+ "y_min": (y_center[i] - h[i] / 2).item(),
976
+ "x_max": (x_center[i] + w[i] / 2).item(),
977
+ "y_max": (y_center[i] + h[i] / 2).item(),
978
+ })
979
+
980
+ # step: decode "next token" to decide continuation
981
+ for i in range(B):
982
+ if alive[i]:
983
+ mask[i, :, pos_ids[i]] = 1
984
+ logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
985
+ pos_ids = pos_ids + alive.to(torch.long)
986
+ next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
987
+ else:
988
+ # Points mode (no size)
989
+ for i in range(B):
990
+ if not alive[i]:
991
+ continue
992
+ out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
993
+ # step: decode next token from y_emb
994
+ for i in range(B):
995
+ if alive[i]:
996
+ mask[i, :, pos_ids[i]] = 1
997
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
998
+ pos_ids = pos_ids + alive.to(torch.long)
999
+ next_tok = logits.argmax(dim=-1).squeeze(-1)
1000
+
1001
+ # Update which rows are done and count
1002
+ finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1003
+ counts = counts + (~finished_now & alive).to(counts.dtype)
1004
+ alive &= ~finished_now
1005
+
1006
+ return out
1007
+
1008
+ def detect_multi(self, image, objects, settings=None):
1009
+ """
1010
+ Parallel multi-label detection.
1011
+ Args:
1012
+ image: PIL.Image or EncodedImage
1013
+ objects: list[str], e.g. ["person", "car"]
1014
+ settings: Optional[ObjectSamplingSettings], honors "max_objects" and "variant"
1015
+ Returns:
1016
+ {"objects": {label: [box_dict, ...]}}
1017
+ """
1018
+ import torch
1019
+ from typing import Optional, List, Union
1020
+
1021
+ if self.config.tokenizer.templates["detect"] is None:
1022
+ raise NotImplementedError("Model does not support object detection.")
1023
+ settings = settings or {}
1024
+
1025
+ # Encode once; reuse caches
1026
+ image = self.encode_image(image, settings)
1027
+ B = len(objects)
1028
+ self._load_encoded_image_batched(image, B)
1029
+
1030
+ # Optional LoRA variant (same as detect())
1031
+ lora = None
1032
+ if "variant" in settings:
1033
+ from .lora import variant_state_dict
1034
+ lora = variant_state_dict(settings["variant"], device=self.device)
1035
+
1036
+ # Prefill all prompts at once
1037
+ last_hidden, next_token, pos_vec = self._prefill_prompt_batched(
1038
+ objects, image.pos, lora=lora, temperature=0.0, top_p=0.0
1039
+ )
1040
+
1041
+ # Batched decode loop
1042
+ max_objects = settings.get("max_objects", 50)
1043
+ det_lists = self._generate_points_batched(
1044
+ last_hidden, next_token, pos_vec,
1045
+ include_size=True, max_objects=max_objects, lora=lora
1046
+ )
1047
+
1048
+ # Map back to labels and add "label" tags
1049
+ res = {}
1050
+ for lab, lst in zip(objects, det_lists):
1051
+ for d in lst:
1052
+ d["label"] = lab
1053
+ res[lab] = lst
1054
+ return {"objects": res}
1055
+ # === END: Batched multi-label detection additions ===
1056
  def _detect_gaze(
1057
  self,
1058
  image: EncodedImage,