should work now
Browse files- download_diffusion.py +0 -36
- text2world_hf.py +2 -2
download_diffusion.py
CHANGED
|
@@ -21,42 +21,6 @@ from huggingface_hub import snapshot_download
|
|
| 21 |
from .convert_pixtral_ckpt import convert_pixtral_checkpoint
|
| 22 |
|
| 23 |
|
| 24 |
-
def parse_args():
|
| 25 |
-
parser = argparse.ArgumentParser(description="Download NVIDIA Cosmos-1.0 Diffusion models from Hugging Face")
|
| 26 |
-
parser.add_argument(
|
| 27 |
-
"--model_sizes",
|
| 28 |
-
nargs="*",
|
| 29 |
-
default=[
|
| 30 |
-
"7B",
|
| 31 |
-
"14B",
|
| 32 |
-
], # Download all by default
|
| 33 |
-
choices=["7B", "14B"],
|
| 34 |
-
help="Which model sizes to download. Possible values: 7B, 14B",
|
| 35 |
-
)
|
| 36 |
-
parser.add_argument(
|
| 37 |
-
"--model_types",
|
| 38 |
-
nargs="*",
|
| 39 |
-
default=[
|
| 40 |
-
"Text2World",
|
| 41 |
-
"Video2World",
|
| 42 |
-
], # Download all by default
|
| 43 |
-
choices=["Text2World", "Video2World"],
|
| 44 |
-
help="Which model types to download. Possible values: Text2World, Video2World",
|
| 45 |
-
)
|
| 46 |
-
parser.add_argument(
|
| 47 |
-
"--cosmos_version",
|
| 48 |
-
type=str,
|
| 49 |
-
default="1.0",
|
| 50 |
-
choices=["1.0"],
|
| 51 |
-
help="Which version of Cosmos to download. Only 1.0 is available at the moment.",
|
| 52 |
-
)
|
| 53 |
-
parser.add_argument(
|
| 54 |
-
"--checkpoint_dir", type=str, default="checkpoints", help="Directory to save the downloaded checkpoints."
|
| 55 |
-
)
|
| 56 |
-
args = parser.parse_args()
|
| 57 |
-
return args
|
| 58 |
-
|
| 59 |
-
|
| 60 |
def main(model_types, model_sizes, checkpoint_dir="checkpoints"):
|
| 61 |
ORG_NAME = "nvidia"
|
| 62 |
|
|
|
|
| 21 |
from .convert_pixtral_ckpt import convert_pixtral_checkpoint
|
| 22 |
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def main(model_types, model_sizes, checkpoint_dir="checkpoints"):
|
| 25 |
ORG_NAME = "nvidia"
|
| 26 |
|
text2world_hf.py
CHANGED
|
@@ -48,10 +48,10 @@ class DiffusionText2World(PreTrainedModel):
|
|
| 48 |
|
| 49 |
def __init__(self, config=DiffusionText2WorldConfig()):
|
| 50 |
super().__init__(config)
|
| 51 |
-
torch.enable_grad(False)
|
| 52 |
self.config = config
|
| 53 |
inference_type = "text2world"
|
| 54 |
-
config.prompt = 1 #
|
| 55 |
validate_args(config, inference_type)
|
| 56 |
del config.prompt
|
| 57 |
self.pipeline = DiffusionText2WorldGenerationPipeline(
|
|
|
|
| 48 |
|
| 49 |
def __init__(self, config=DiffusionText2WorldConfig()):
|
| 50 |
super().__init__(config)
|
| 51 |
+
torch.enable_grad(False)
|
| 52 |
self.config = config
|
| 53 |
inference_type = "text2world"
|
| 54 |
+
config.prompt = 1 # this is to hack args validation, maybe find a better way
|
| 55 |
validate_args(config, inference_type)
|
| 56 |
del config.prompt
|
| 57 |
self.pipeline = DiffusionText2WorldGenerationPipeline(
|