Upload 2 files
Browse files- media.py +1 -4
- modeling_vila.py +106 -11
media.py
CHANGED
|
@@ -11,7 +11,7 @@ import requests
|
|
| 11 |
from transformers import PretrainedConfig
|
| 12 |
|
| 13 |
# from llava.constants import MEDIA_TOKENS
|
| 14 |
-
|
| 15 |
# from llava.utils import make_list
|
| 16 |
# from llava.utils.logging import logger
|
| 17 |
|
|
@@ -31,9 +31,6 @@ class Image(File):
|
|
| 31 |
pass
|
| 32 |
|
| 33 |
|
| 34 |
-
class Video(File):
|
| 35 |
-
pass
|
| 36 |
-
|
| 37 |
def make_list(obj: Any) -> List:
|
| 38 |
return obj if isinstance(obj, list) else [obj]
|
| 39 |
|
|
|
|
| 11 |
from transformers import PretrainedConfig
|
| 12 |
|
| 13 |
# from llava.constants import MEDIA_TOKENS
|
| 14 |
+
from llava.media import Image, Video
|
| 15 |
# from llava.utils import make_list
|
| 16 |
# from llava.utils.logging import logger
|
| 17 |
|
|
|
|
| 31 |
pass
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
def make_list(obj: Any) -> List:
|
| 35 |
return obj if isinstance(obj, list) else [obj]
|
| 36 |
|
modeling_vila.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import copy
|
| 2 |
import json
|
| 3 |
import logging
|
|
@@ -142,14 +143,97 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 142 |
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
|
| 143 |
), "At least one of the components must be instantiated."
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
@classmethod
|
| 148 |
-
def
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
@classmethod
|
| 155 |
def from_pretrained(
|
|
@@ -202,6 +286,16 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 202 |
if getattr(self.config, "mm_projector_cfg", None) is None:
|
| 203 |
self.config.mm_projector_cfg = self.mm_projector.config
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
def get_vision_tower(self):
|
| 206 |
vision_tower = getattr(self, "vision_tower", None)
|
| 207 |
if type(vision_tower) is list:
|
|
@@ -408,7 +502,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 408 |
if self.training:
|
| 409 |
# Gather metainfo of media objects from all ranks
|
| 410 |
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
|
| 411 |
-
infos = list(chain(
|
| 412 |
|
| 413 |
# The entire batch does not contain any media objects of this type.
|
| 414 |
if not infos:
|
|
@@ -750,7 +844,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 750 |
if images is not None:
|
| 751 |
if media is not None:
|
| 752 |
raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
|
| 753 |
-
|
| 754 |
media = {"image": images}
|
| 755 |
|
| 756 |
if media_config is None:
|
|
@@ -845,7 +939,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 845 |
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
|
| 846 |
media[name] = [image for image in images]
|
| 847 |
elif name == "video":
|
| 848 |
-
if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
|
| 849 |
media[name] = [
|
| 850 |
process_images(
|
| 851 |
images,
|
|
@@ -856,7 +950,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 856 |
).half()
|
| 857 |
for images in media[name]
|
| 858 |
]
|
| 859 |
-
elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
|
| 860 |
self.config.image_processor = self.vision_tower.image_processor
|
| 861 |
if type(self.config.s2_scales) is str:
|
| 862 |
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
|
|
@@ -930,3 +1024,4 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 930 |
if generation_config.eos_token_id is None:
|
| 931 |
generation_config.eos_token_id = self.tokenizer.eos_token_id
|
| 932 |
return generation_config
|
|
|
|
|
|
| 1 |
+
import shutil
|
| 2 |
import copy
|
| 3 |
import json
|
| 4 |
import logging
|
|
|
|
| 143 |
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
|
| 144 |
), "At least one of the components must be instantiated."
|
| 145 |
|
|
|
|
|
|
|
| 146 |
@classmethod
|
| 147 |
+
def convert_vila_dev_ckpt_to_remote(self, model_path: str, output_dir:str = None, *model_args, **kwargs):
|
| 148 |
+
# assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
|
| 149 |
+
from huggingface_hub import HfApi, snapshot_download
|
| 150 |
+
|
| 151 |
+
if os.path.isdir(model_path):
|
| 152 |
+
model_path = model_path
|
| 153 |
+
api = HfApi()
|
| 154 |
+
if api.repo_exists(model_path):
|
| 155 |
+
model_path = snapshot_download(model_path, local_dir=output_dir)
|
| 156 |
+
print("downloading HF model to", model_path)
|
| 157 |
+
|
| 158 |
+
cfg_path = os.path.join(model_path, "config.json")
|
| 159 |
+
config = json.load(open(cfg_path))
|
| 160 |
+
config["version"] = "2.0" # nvila tag
|
| 161 |
+
config["architectures"] = ["VILAForCasualLM"]
|
| 162 |
+
config["auto_map"] = {
|
| 163 |
+
"AutoConfig": "modeling_vila.VILAConfig",
|
| 164 |
+
"AutoModel": "modeling_vila.VILAForCasualLM",
|
| 165 |
+
"AutoModelForCausalLM": "modeling_vila.VILAForCasualLM"
|
| 166 |
+
}
|
| 167 |
+
config["model_type"] = "vila"
|
| 168 |
+
json.dump(config, open(cfg_path, "w"), indent=2)
|
| 169 |
+
self.copy_remote_py_files(model_path)
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def copy_remote_py_files(cls, output_dir):
|
| 173 |
+
## copy .py and REAMDE for next loading remote code
|
| 174 |
+
current_file_path = os.path.abspath(__file__)
|
| 175 |
+
current_folder = os.path.dirname(current_file_path)
|
| 176 |
+
for file_name in os.listdir(current_folder):
|
| 177 |
+
if file_name.endswith(".py"):
|
| 178 |
+
full_file_name = os.path.join(current_folder, file_name)
|
| 179 |
+
if os.path.isfile(full_file_name):
|
| 180 |
+
shutil.copy(full_file_name, output_dir)
|
| 181 |
+
print("[HF remote code] copying", full_file_name, "to", output_dir)
|
| 182 |
+
|
| 183 |
+
def save_pretrained(self, output_dir, state_dict=None, safe_serialization=None):
|
| 184 |
+
if state_dict is None:
|
| 185 |
+
# other wise fetch from deepspeed
|
| 186 |
+
# state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
|
| 187 |
+
state_dict = self.state_dict()
|
| 188 |
+
|
| 189 |
+
if getattr(self, "tokenizer", None):
|
| 190 |
+
self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
|
| 191 |
+
|
| 192 |
+
if self.get_llm():
|
| 193 |
+
print(f"saving llm to {osp.join(output_dir, 'llm')}")
|
| 194 |
+
self.llm.config._name_or_path = osp.join(output_dir, "llm")
|
| 195 |
+
llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
|
| 196 |
+
self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
|
| 197 |
+
self.config.llm_cfg = self.llm.config
|
| 198 |
+
|
| 199 |
+
if self.get_vision_tower():
|
| 200 |
+
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
|
| 201 |
+
self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
|
| 202 |
+
vision_tower_state_dict = OrderedDict(
|
| 203 |
+
{k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
|
| 204 |
+
)
|
| 205 |
+
self.vision_tower.vision_tower.save_pretrained(
|
| 206 |
+
os.path.join(output_dir, "vision_tower"),
|
| 207 |
+
state_dict=vision_tower_state_dict,
|
| 208 |
+
)
|
| 209 |
+
self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
|
| 210 |
+
self.config.vision_tower_cfg = self.vision_tower.config
|
| 211 |
+
if hasattr(self.config.vision_tower_cfg, "auto_map"):
|
| 212 |
+
if "radio" not in self.get_vision_tower().__class__.__name__.lower():
|
| 213 |
+
delattr(self.config.vision_tower_cfg, "auto_map")
|
| 214 |
+
|
| 215 |
+
if self.get_mm_projector():
|
| 216 |
+
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
|
| 217 |
+
self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
|
| 218 |
+
mm_projector_state_dict = OrderedDict(
|
| 219 |
+
{k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
|
| 220 |
+
)
|
| 221 |
+
self.mm_projector.save_pretrained(
|
| 222 |
+
os.path.join(output_dir, "mm_projector"),
|
| 223 |
+
state_dict=mm_projector_state_dict,
|
| 224 |
+
)
|
| 225 |
+
self.config.mm_projector_cfg = self.mm_projector.config
|
| 226 |
+
|
| 227 |
+
## update and save top-level config
|
| 228 |
+
self.config._name_or_path = output_dir
|
| 229 |
+
self.config.architectures = [self.__class__.__name__]
|
| 230 |
+
#print(self.config)
|
| 231 |
+
#self.config.save_pretrained(output_dir)
|
| 232 |
+
|
| 233 |
+
## copy .py and REAMDE for next loading remote code
|
| 234 |
+
self.copy_remote_py_files(output_dir)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
|
| 238 |
@classmethod
|
| 239 |
def from_pretrained(
|
|
|
|
| 286 |
if getattr(self.config, "mm_projector_cfg", None) is None:
|
| 287 |
self.config.mm_projector_cfg = self.mm_projector.config
|
| 288 |
|
| 289 |
+
def get_llm(self):
|
| 290 |
+
llm = getattr(self, "llm", None)
|
| 291 |
+
if type(llm) is list:
|
| 292 |
+
llm = llm[0]
|
| 293 |
+
return llm
|
| 294 |
+
|
| 295 |
+
def get_lm_head(self):
|
| 296 |
+
lm_head = getattr(self.get_llm(), "lm_head", None)
|
| 297 |
+
return lm_head
|
| 298 |
+
|
| 299 |
def get_vision_tower(self):
|
| 300 |
vision_tower = getattr(self, "vision_tower", None)
|
| 301 |
if type(vision_tower) is list:
|
|
|
|
| 502 |
if self.training:
|
| 503 |
# Gather metainfo of media objects from all ranks
|
| 504 |
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
|
| 505 |
+
infos = list(chain(all_gather(info)))
|
| 506 |
|
| 507 |
# The entire batch does not contain any media objects of this type.
|
| 508 |
if not infos:
|
|
|
|
| 844 |
if images is not None:
|
| 845 |
if media is not None:
|
| 846 |
raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
|
| 847 |
+
print("The 'images' argument is deprecated. Please use 'media' instead.")
|
| 848 |
media = {"image": images}
|
| 849 |
|
| 850 |
if media_config is None:
|
|
|
|
| 939 |
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
|
| 940 |
media[name] = [image for image in images]
|
| 941 |
elif name == "video":
|
| 942 |
+
if False: #self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
|
| 943 |
media[name] = [
|
| 944 |
process_images(
|
| 945 |
images,
|
|
|
|
| 950 |
).half()
|
| 951 |
for images in media[name]
|
| 952 |
]
|
| 953 |
+
elif False: #self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
|
| 954 |
self.config.image_processor = self.vision_tower.image_processor
|
| 955 |
if type(self.config.s2_scales) is str:
|
| 956 |
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
|
|
|
|
| 1024 |
if generation_config.eos_token_id is None:
|
| 1025 |
generation_config.eos_token_id = self.tokenizer.eos_token_id
|
| 1026 |
return generation_config
|
| 1027 |
+
|