Overwrite with converted Qwen2.5-3B model files
Browse files- README.md +222 -71
- VLV_stage1.py +257 -0
- VLV_stage2.py +460 -0
- build.py +78 -0
- config.json +13 -9
- configuration_vlv.py +172 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_clip.py +60 -3
- vlv_utils.py +71 -0
README.md
CHANGED
|
@@ -1,104 +1,255 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
-
|
| 4 |
-
-
|
| 5 |
-
|
| 6 |
-
-
|
| 7 |
-
-
|
|
|
|
|
|
|
|
|
|
| 8 |
pipeline_tag: image-to-text
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
-
#
|
| 12 |
|
| 13 |
-
|
| 14 |
-
[](https://arxiv.org/abs/2507.07104)
|
| 15 |
-
[](https://github.com/Tiezheng11/Vision-Language-Vision)
|
| 16 |
-
[](https://huggingface.co/lambertxiao/Vision-Language-Vision-Captioner-Qwen2.5-3B)
|
| 17 |
-
[](https://huggingface.co/datasets/ccvl/LAION-High-Qualtiy-Pro-6M-VLV)
|
| 18 |
|
| 19 |
-
##
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
##
|
| 27 |
|
| 28 |
```bash
|
| 29 |
-
|
| 30 |
-
pip install -r requirements.txt
|
| 31 |
```
|
| 32 |
|
| 33 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
```python
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
import torch, numpy as np
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
AutoModel.from_pretrained(
|
| 45 |
-
MODEL_NAME,
|
| 46 |
-
trust_remote_code=True,
|
| 47 |
-
low_cpu_mem_usage=False,
|
| 48 |
-
)
|
| 49 |
-
.to(device)
|
| 50 |
-
.eval()
|
| 51 |
-
)
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
sentences = [s.strip() for s in text.split(".") if s.strip()]
|
| 57 |
-
if not text.rstrip().endswith("."):
|
| 58 |
-
sentences = sentences[:-1] # drop dangling fragment
|
| 59 |
-
return ". ".join(sentences) + ("." if sentences else "")
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
"""
|
| 69 |
-
Wrapper for NumPy arrays.
|
| 70 |
-
Accepts uint8 [0, 255] or float [0, 1] ranges.
|
| 71 |
-
"""
|
| 72 |
-
if arr.dtype != np.uint8:
|
| 73 |
-
arr = (np.clip(arr, 0, 1) * 255).astype(np.uint8)
|
| 74 |
-
return caption_image(Image.fromarray(arr, mode="RGB"), max_len)
|
| 75 |
```
|
| 76 |
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
```python
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
from PIL import Image
|
| 85 |
-
from IPython.display import display # Jupyter/Colab only
|
| 86 |
|
| 87 |
-
|
| 88 |
|
| 89 |
-
|
| 90 |
-
img = Image.open(io.BytesIO(requests.get(IMG_URL, timeout=10).content)).convert("RGB")
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
```
|
| 96 |
-
|
|
|
|
| 97 |
|
| 98 |
```bibtex
|
| 99 |
-
@article{
|
| 100 |
-
title
|
| 101 |
-
author
|
| 102 |
-
journal
|
| 103 |
-
year
|
| 104 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- image-captioning
|
| 5 |
+
- multimodal
|
| 6 |
+
- vision-language
|
| 7 |
+
- diffusion
|
| 8 |
+
- pytorch
|
| 9 |
+
- transformers
|
| 10 |
+
library_name: transformers
|
| 11 |
pipeline_tag: image-to-text
|
| 12 |
+
datasets:
|
| 13 |
+
- conceptual_captions
|
| 14 |
+
- coco
|
| 15 |
+
model_type: VLV_decoder
|
| 16 |
---
|
| 17 |
|
| 18 |
+
# VLV Captioner Model
|
| 19 |
|
| 20 |
+
This is a VLV (Vision-Language-Vision) model for image captioning. The model combines stable diffusion image encoding with Qwen language model for generating descriptive captions from images.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
## Model Description
|
| 23 |
|
| 24 |
+
The VLV Captioner is a multimodal model that:
|
| 25 |
+
- Uses a diffusion-based vision encoder to extract image features
|
| 26 |
+
- Employs the Qwen2.5-3B language model for text generation
|
| 27 |
+
- Generates natural language descriptions of input images
|
| 28 |
|
| 29 |
+
## Model Architecture
|
| 30 |
+
|
| 31 |
+
- **Vision Encoder**: Stable Diffusion-based image encoder with Florence2 components
|
| 32 |
+
- **Language Model**: Qwen2.5-3B transformer model
|
| 33 |
+
- **Image Size**: 384x384 pixels
|
| 34 |
+
- **Max Caption Length**: 300 tokens
|
| 35 |
+
- **Precision**: Mixed precision (bfloat16/float32)
|
| 36 |
+
|
| 37 |
+
## Usage
|
| 38 |
+
|
| 39 |
+
### Method 1: Load from Hugging Face Hub
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
from transformers import AutoModel, AutoConfig
|
| 43 |
+
from PIL import Image
|
| 44 |
+
import torch
|
| 45 |
+
import os
|
| 46 |
+
|
| 47 |
+
# Optional: Set custom cache directory if needed
|
| 48 |
+
cache_dir = "/path/to/your/cache" # Use a directory with sufficient space
|
| 49 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
# Load the model with authentication token (if required)
|
| 52 |
+
token = os.getenv('HUGGINGFACE_TOKEN') # or your token string
|
| 53 |
+
|
| 54 |
+
print("Loading config...")
|
| 55 |
+
config = AutoConfig.from_pretrained(
|
| 56 |
+
"your-username/vlv-captioner",
|
| 57 |
+
trust_remote_code=True,
|
| 58 |
+
token=token,
|
| 59 |
+
cache_dir=cache_dir
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
print("Loading model...")
|
| 63 |
+
try:
|
| 64 |
+
model = AutoModel.from_pretrained(
|
| 65 |
+
"your-username/vlv-captioner",
|
| 66 |
+
trust_remote_code=True,
|
| 67 |
+
token=token,
|
| 68 |
+
cache_dir=cache_dir,
|
| 69 |
+
torch_dtype=torch.float32, # Specify dtype explicitly
|
| 70 |
+
low_cpu_mem_usage=True
|
| 71 |
+
# Note: Avoid device_map="auto" to prevent meta tensor issues
|
| 72 |
+
)
|
| 73 |
+
print("Model loaded successfully!")
|
| 74 |
+
|
| 75 |
+
# Load and process an image
|
| 76 |
+
image = Image.open("path/to/your/image.jpg")
|
| 77 |
+
|
| 78 |
+
# Move model to GPU if available
|
| 79 |
+
if torch.cuda.is_available():
|
| 80 |
+
model = model.to('cuda')
|
| 81 |
+
print("Model moved to GPU!")
|
| 82 |
+
|
| 83 |
+
# Generate caption
|
| 84 |
+
print("Generating caption...")
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
captions = model([image], max_length=300)
|
| 87 |
+
|
| 88 |
+
# Handle different possible output formats
|
| 89 |
+
if hasattr(captions, 'generated_text'):
|
| 90 |
+
print("Generated caption:", captions.generated_text[0])
|
| 91 |
+
elif isinstance(captions, list):
|
| 92 |
+
print("Generated caption:", captions[0])
|
| 93 |
+
else:
|
| 94 |
+
print("Generated caption:", captions)
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"Error during model loading or inference: {e}")
|
| 98 |
+
# If cached files are corrupted, try clearing cache and redownloading
|
| 99 |
+
import shutil
|
| 100 |
+
cache_path = f"{cache_dir}/modules/transformers_modules/your-username/vlv-captioner"
|
| 101 |
+
if os.path.exists(cache_path):
|
| 102 |
+
print(f"Clearing cache at {cache_path}")
|
| 103 |
+
shutil.rmtree(cache_path)
|
| 104 |
+
|
| 105 |
+
# Retry with force download
|
| 106 |
+
model = AutoModel.from_pretrained(
|
| 107 |
+
"your-username/vlv-captioner",
|
| 108 |
+
trust_remote_code=True,
|
| 109 |
+
token=token,
|
| 110 |
+
cache_dir=cache_dir,
|
| 111 |
+
force_download=True,
|
| 112 |
+
torch_dtype=torch.float32
|
| 113 |
+
)
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
### Method 2: Load from original checkpoint
|
| 117 |
+
|
| 118 |
+
```python
|
| 119 |
+
from VLV_stage2 import VLV_MODEL
|
| 120 |
+
|
| 121 |
+
# Load from original .pt checkpoint file
|
| 122 |
+
model = VLV_MODEL.from_checkpoint("path/to/model.pt")
|
| 123 |
+
|
| 124 |
+
# Load and process an image
|
| 125 |
+
image = Image.open("path/to/your/image.jpg")
|
| 126 |
+
|
| 127 |
+
# Generate caption
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
captions = model([image], max_length=300)
|
| 130 |
+
print(captions.generated_text[0]) # Generated caption
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
## Model Details
|
| 134 |
+
|
| 135 |
+
- **Model Type**: Vision-Language Model
|
| 136 |
+
- **Architecture**: VLV_decoder
|
| 137 |
+
- **Language Backbone**: Qwen/Qwen2.5-3B
|
| 138 |
+
- **Vision Backbone**: Stable Diffusion + Florence2
|
| 139 |
+
- **Training Data**: Various image-caption datasets
|
| 140 |
+
- **Framework**: PyTorch, Transformers
|
| 141 |
+
|
| 142 |
+
## Training Configuration
|
| 143 |
+
|
| 144 |
+
- **Batch Size**: 1 (inference)
|
| 145 |
+
- **Learnable Token Length**: 77
|
| 146 |
+
- **Guidance Scale**: 7.5
|
| 147 |
+
- **Inference Steps**: 50
|
| 148 |
+
- **Beam Search**: 4 beams
|
| 149 |
|
| 150 |
+
## Requirements
|
| 151 |
|
| 152 |
```bash
|
| 153 |
+
pip install torch transformers safetensors torchvision pillow diffusers
|
|
|
|
| 154 |
```
|
| 155 |
|
| 156 |
+
## Troubleshooting
|
| 157 |
+
|
| 158 |
+
### Common Issues and Solutions
|
| 159 |
+
|
| 160 |
+
#### 1. Meta Tensor Issues
|
| 161 |
+
If you encounter meta tensor errors, avoid using `device_map="auto"` when loading the model:
|
| 162 |
+
|
| 163 |
```python
|
| 164 |
+
# ❌ Don't use this - can cause meta tensor issues
|
| 165 |
+
model = AutoModel.from_pretrained("model-name", device_map="auto")
|
|
|
|
| 166 |
|
| 167 |
+
# ✅ Use this instead
|
| 168 |
+
model = AutoModel.from_pretrained("model-name", torch_dtype=torch.float32, low_cpu_mem_usage=True)
|
| 169 |
+
if torch.cuda.is_available():
|
| 170 |
+
model = model.to('cuda')
|
| 171 |
+
```
|
| 172 |
|
| 173 |
+
#### 2. Cache Issues
|
| 174 |
+
If you experience corrupted cache files, clear the cache and redownload:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
+
```python
|
| 177 |
+
import shutil
|
| 178 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
+
cache_dir = "/your/cache/directory"
|
| 181 |
+
cache_path = f"{cache_dir}/modules/transformers_modules/your-username/model-name"
|
| 182 |
+
if os.path.exists(cache_path):
|
| 183 |
+
shutil.rmtree(cache_path)
|
| 184 |
+
|
| 185 |
+
# Then reload with force_download=True
|
| 186 |
+
model = AutoModel.from_pretrained("model-name", force_download=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
```
|
| 188 |
|
| 189 |
+
#### 3. Authentication Issues
|
| 190 |
+
Make sure your Hugging Face token is properly set:
|
| 191 |
|
| 192 |
+
```bash
|
| 193 |
+
# Option 1: Environment variable
|
| 194 |
+
export HUGGINGFACE_TOKEN="your_token_here"
|
| 195 |
+
|
| 196 |
+
# Option 2: Hugging Face CLI login
|
| 197 |
+
huggingface-cli login
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
#### 4. Memory Issues
|
| 201 |
+
For large models, use a custom cache directory with sufficient space:
|
| 202 |
|
| 203 |
```python
|
| 204 |
+
cache_dir = "/path/to/large/storage"
|
| 205 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 206 |
+
model = AutoModel.from_pretrained("model-name", cache_dir=cache_dir, low_cpu_mem_usage=True)
|
| 207 |
+
```
|
| 208 |
|
| 209 |
+
## Advanced Usage
|
|
|
|
|
|
|
| 210 |
|
| 211 |
+
### Batch Processing with Original Inference Script
|
| 212 |
|
| 213 |
+
For large-scale inference, you can use the original training inference script:
|
|
|
|
| 214 |
|
| 215 |
+
```bash
|
| 216 |
+
python Caption_inference.py \
|
| 217 |
+
--input_path /path/to/images \
|
| 218 |
+
--output_path captions.json \
|
| 219 |
+
--clip_decoder_checkpoint /path/to/model.pt \
|
| 220 |
+
--qwen_model Qwen/Qwen2.5-3B \
|
| 221 |
+
--stable_diffusion_model_path stabilityai/stable-diffusion-2-1-base \
|
| 222 |
+
--florence2_model_path microsoft/Florence-2-large \
|
| 223 |
+
--batch_size 4 \
|
| 224 |
+
--max_length 300 \
|
| 225 |
+
--num_beams 4 \
|
| 226 |
+
--image_size 384 \
|
| 227 |
+
--guidance_scale 7.5 \
|
| 228 |
+
--use_text_encoder \
|
| 229 |
+
--distributed # For multi-GPU inference
|
| 230 |
+
```
|
| 231 |
|
| 232 |
+
### Configuration Parameters
|
| 233 |
+
|
| 234 |
+
- `image_size`: Input image resolution (default: 384)
|
| 235 |
+
- `guidance_scale`: Diffusion guidance scale (default: 7.5)
|
| 236 |
+
- `learnable_token_length`: Number of vision tokens (default: 77)
|
| 237 |
+
- `max_length`: Maximum caption length (default: 300)
|
| 238 |
+
- `num_beams`: Beam search width (default: 4)
|
| 239 |
+
- `use_text_encoder`: Enable CLIP text encoder (recommended: True)
|
| 240 |
```
|
| 241 |
+
|
| 242 |
+
## Citation
|
| 243 |
|
| 244 |
```bibtex
|
| 245 |
+
@article{vlv_autoencoder,
|
| 246 |
+
title={Vision-Language-Vision Auto-Encoder: Scalable Knowledge Distillation from Diffusion Models},
|
| 247 |
+
author={Zhang, Tiezheng and Li, Yitong and Chou, Yu-Cheng and Chen, Jieneng and Yuille, Alan L. and Wei, Chen and Xiao, Junfei},
|
| 248 |
+
journal={arXiv preprint},
|
| 249 |
+
year={2024}
|
| 250 |
}
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
+
## License
|
| 254 |
+
|
| 255 |
+
This model is released under the Apache 2.0 license.
|
VLV_stage1.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from transformers.utils import ModelOutput
|
| 7 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 8 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 9 |
+
from .build import load_sd_model, load_Florence2_model
|
| 10 |
+
from .vlv_utils import initiate_time_steps, normalize
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SDConfig(PretrainedConfig):
|
| 14 |
+
"""Configuration class for SDModel."""
|
| 15 |
+
model_type = "sd"
|
| 16 |
+
|
| 17 |
+
def __init__(self, **kwargs):
|
| 18 |
+
super().__init__(**kwargs)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MLP(nn.Module):
|
| 22 |
+
def __init__(self, input_dim, output_dim):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.layers = nn.Sequential(
|
| 25 |
+
nn.Linear(input_dim, output_dim),
|
| 26 |
+
nn.GELU(),
|
| 27 |
+
nn.Linear(output_dim, output_dim),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
return self.layers(x)
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class SDOutput(ModelOutput):
|
| 35 |
+
loss: Optional[torch.FloatTensor] = None
|
| 36 |
+
|
| 37 |
+
class SDModel(PreTrainedModel):
|
| 38 |
+
config_class = SDConfig
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
config=None,
|
| 43 |
+
training_args = None,
|
| 44 |
+
):
|
| 45 |
+
if config is None:
|
| 46 |
+
config = SDConfig()
|
| 47 |
+
super().__init__(config)
|
| 48 |
+
self.training_args = training_args
|
| 49 |
+
if self.training_args.fp32:
|
| 50 |
+
self._dtype = torch.float32
|
| 51 |
+
else:
|
| 52 |
+
self._dtype = torch.bfloat16
|
| 53 |
+
self._device = torch.device(self.training_args.device if hasattr(self.training_args, 'device') else "cuda" if torch.cuda.is_available() else "cpu")
|
| 54 |
+
|
| 55 |
+
self.vae, self.tokenizer, self.text_encoder, self.unet, self.scheduler = load_sd_model(training_args)
|
| 56 |
+
torch.cuda.empty_cache()
|
| 57 |
+
self.unet.eval()
|
| 58 |
+
self.text_encoder.eval()
|
| 59 |
+
self.model, self.processor = load_Florence2_model(training_args)
|
| 60 |
+
|
| 61 |
+
self.unet = self.unet.to(self._dtype).to(device=self._device)
|
| 62 |
+
self.text_encoder = self.text_encoder.to(self._dtype).to_empty(device=self._device)
|
| 63 |
+
self.model = self.model.to(self._dtype).to_empty(device=self._device)
|
| 64 |
+
self.vae = self.vae.to(torch.float32).to_empty(device=self._device)
|
| 65 |
+
|
| 66 |
+
self.batch_size = self.training_args.batch_size
|
| 67 |
+
|
| 68 |
+
hidden_dim = 1024
|
| 69 |
+
self.language_proj = nn.Sequential(
|
| 70 |
+
nn.Linear(1024, hidden_dim, dtype=self._dtype),
|
| 71 |
+
nn.GELU(),
|
| 72 |
+
nn.Linear(hidden_dim, 1024, dtype=self._dtype)
|
| 73 |
+
).to_empty(device=self._device)
|
| 74 |
+
for param in self.language_proj.parameters():
|
| 75 |
+
param.requires_grad = True
|
| 76 |
+
|
| 77 |
+
self.num_queries = self.training_args.learnable_token_length
|
| 78 |
+
self.query_embed = nn.Parameter(torch.randn(1, self.num_queries, 1024, dtype=self._dtype))
|
| 79 |
+
self.query_embed.requires_grad = True
|
| 80 |
+
|
| 81 |
+
self.unet.enable_gradient_checkpointing()
|
| 82 |
+
|
| 83 |
+
def _unet_pred_noise(self, x_start, t, noise, context):
|
| 84 |
+
t = t.to(dtype=torch.long)
|
| 85 |
+
|
| 86 |
+
dtype = self.unet.dtype
|
| 87 |
+
x_start = x_start.to(dtype)
|
| 88 |
+
noise = noise.to(dtype)
|
| 89 |
+
context = context.to(dtype)
|
| 90 |
+
|
| 91 |
+
nt = t.shape[0]
|
| 92 |
+
noised_latent = self.scheduler.add_noise(x_start, noise, t)
|
| 93 |
+
|
| 94 |
+
pred_noise = self.unet(
|
| 95 |
+
noised_latent,
|
| 96 |
+
t,
|
| 97 |
+
encoder_hidden_states=context.expand(nt, -1, -1)
|
| 98 |
+
).sample
|
| 99 |
+
|
| 100 |
+
return pred_noise
|
| 101 |
+
|
| 102 |
+
def generate_images(self, images):
|
| 103 |
+
batch_size = self.training_args.eval_batch_size
|
| 104 |
+
prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
|
| 105 |
+
inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype)
|
| 106 |
+
|
| 107 |
+
if inputs["input_ids"] is not None:
|
| 108 |
+
inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype)
|
| 109 |
+
if inputs["pixel_values"] is not None:
|
| 110 |
+
image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype)
|
| 111 |
+
inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
| 112 |
+
if inputs_embeds is not None:
|
| 113 |
+
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
| 114 |
+
encoder_outputs = self.model.language_model.model.encoder(
|
| 115 |
+
inputs_embeds=inputs_embeds,
|
| 116 |
+
attention_mask=attention_mask,
|
| 117 |
+
output_hidden_states=True,
|
| 118 |
+
return_dict=True
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1)
|
| 122 |
+
decoder_attention_mask = torch.ones(
|
| 123 |
+
(batch_size, self.num_queries),
|
| 124 |
+
dtype=self._dtype,
|
| 125 |
+
device=self._device
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
|
| 129 |
+
decoder_input_embeds = decoder_input_embeds.to(self._dtype)
|
| 130 |
+
attention_mask = attention_mask.to(self._dtype)
|
| 131 |
+
|
| 132 |
+
decoder_outputs = self.model.language_model.model.decoder(
|
| 133 |
+
inputs_embeds=decoder_input_embeds,
|
| 134 |
+
attention_mask=decoder_attention_mask,
|
| 135 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 136 |
+
encoder_attention_mask=attention_mask,
|
| 137 |
+
output_hidden_states=True,
|
| 138 |
+
return_dict=True
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
last_decoder_hidden_state = decoder_outputs.last_hidden_state
|
| 142 |
+
conditional_context = self.language_proj(last_decoder_hidden_state)
|
| 143 |
+
|
| 144 |
+
un_token = self.tokenizer("", padding="max_length", truncation=True,max_length=77, return_tensors="pt").input_ids.to(self._device)
|
| 145 |
+
un_context_embeddings = self.text_encoder(un_token).last_hidden_state
|
| 146 |
+
un_context_embeddings = un_context_embeddings.expand(batch_size, -1, -1)
|
| 147 |
+
if self.training_args.use_text_encoder:
|
| 148 |
+
context_embeddings = self.text_encoder(
|
| 149 |
+
inputs_embeds=conditional_context.to(self._dtype)
|
| 150 |
+
).last_hidden_state
|
| 151 |
+
|
| 152 |
+
latent_shape = (batch_size, 4, self.training_args.image_size // 8, self.training_args.image_size // 8)
|
| 153 |
+
latents = torch.randn(latent_shape, device=self._device, dtype=self._dtype)
|
| 154 |
+
|
| 155 |
+
scheduler = self.scheduler
|
| 156 |
+
scheduler.set_timesteps(self.training_args.num_inference_steps)
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
for t in scheduler.timesteps:
|
| 159 |
+
latent_model_input = torch.cat([latents, latents], dim=0)
|
| 160 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 161 |
+
|
| 162 |
+
combined_embeddings = torch.cat([un_context_embeddings, context_embeddings], dim=0).to(self._dtype)
|
| 163 |
+
noise_pred = self.unet(
|
| 164 |
+
latent_model_input, t, encoder_hidden_states=combined_embeddings
|
| 165 |
+
)[0]
|
| 166 |
+
|
| 167 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
|
| 168 |
+
noise_pred = noise_pred_uncond + self.training_args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 169 |
+
|
| 170 |
+
latents = scheduler.step(noise_pred, t, latents)[0]
|
| 171 |
+
|
| 172 |
+
scaled_latents = latents / 0.18215
|
| 173 |
+
with torch.no_grad():
|
| 174 |
+
decoded_latents = self.vae.decode(scaled_latents.to(torch.float32))[0]
|
| 175 |
+
|
| 176 |
+
return decoded_latents
|
| 177 |
+
|
| 178 |
+
def get_conditional_context(self, images, batch_size=None):
|
| 179 |
+
if batch_size is None:
|
| 180 |
+
batch_size = self.batch_size
|
| 181 |
+
prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
|
| 182 |
+
inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(self._device).to(self._dtype)
|
| 183 |
+
|
| 184 |
+
if inputs["input_ids"] is not None:
|
| 185 |
+
inputs_embeds = self.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self._dtype)
|
| 186 |
+
if inputs["pixel_values"] is not None:
|
| 187 |
+
image_features = self.model._encode_image(inputs["pixel_values"]).to(self._dtype)
|
| 188 |
+
inputs_embeds, attention_mask = self.model._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
| 189 |
+
if inputs_embeds is not None:
|
| 190 |
+
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
| 191 |
+
encoder_outputs = self.model.language_model.model.encoder(
|
| 192 |
+
inputs_embeds=inputs_embeds,
|
| 193 |
+
attention_mask=attention_mask,
|
| 194 |
+
output_hidden_states=True,
|
| 195 |
+
return_dict=True
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
decoder_input_embeds = self.query_embed.expand(batch_size, -1, -1)
|
| 199 |
+
decoder_attention_mask = torch.ones(
|
| 200 |
+
(batch_size, self.num_queries),
|
| 201 |
+
dtype=self._dtype,
|
| 202 |
+
device=self._device
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
|
| 206 |
+
decoder_input_embeds = decoder_input_embeds.to(self._dtype)
|
| 207 |
+
attention_mask = attention_mask.to(self._dtype)
|
| 208 |
+
|
| 209 |
+
decoder_outputs = self.model.language_model.model.decoder(
|
| 210 |
+
inputs_embeds=decoder_input_embeds,
|
| 211 |
+
attention_mask=decoder_attention_mask,
|
| 212 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 213 |
+
encoder_attention_mask=attention_mask,
|
| 214 |
+
output_hidden_states=True,
|
| 215 |
+
return_dict=True
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
last_decoder_hidden_state = decoder_outputs.last_hidden_state
|
| 219 |
+
return last_decoder_hidden_state
|
| 220 |
+
|
| 221 |
+
def forward(
|
| 222 |
+
self,
|
| 223 |
+
image=None,
|
| 224 |
+
filename=None,
|
| 225 |
+
**kwargs,
|
| 226 |
+
) -> SDOutput:
|
| 227 |
+
images_for_language_model = image
|
| 228 |
+
normalize_images = normalize(image, rescale=True)
|
| 229 |
+
x0=self.vae.encode(normalize_images.to(torch.float32)).latent_dist.sample()
|
| 230 |
+
latent = x0 * 0.18215
|
| 231 |
+
|
| 232 |
+
total_timestep = self.scheduler.num_train_timesteps
|
| 233 |
+
|
| 234 |
+
timesteps = initiate_time_steps(0, total_timestep, self.batch_size, self.training_args).long()
|
| 235 |
+
timesteps = timesteps.to(self._device)
|
| 236 |
+
c, h, w = latent.shape[1:]
|
| 237 |
+
if not self.training_args.use_same_noise_among_timesteps:
|
| 238 |
+
noise = torch.randn((self.batch_size, c, h, w), device=self._device, dtype=self._dtype)
|
| 239 |
+
else:
|
| 240 |
+
noise = torch.randn((1, c, h, w), device=self._device, dtype=self._dtype)
|
| 241 |
+
noise = noise.repeat(self.batch_size, 1, 1, 1)
|
| 242 |
+
|
| 243 |
+
conditional_context = self.get_conditional_context(images_for_language_model)
|
| 244 |
+
conditional_context = self.language_proj(conditional_context)
|
| 245 |
+
|
| 246 |
+
if self.training_args.use_text_encoder:
|
| 247 |
+
text_encoder_output = self.text_encoder(input_ids=None, inputs_embeds=conditional_context.to(self._dtype))
|
| 248 |
+
pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=text_encoder_output.last_hidden_state.to(self._dtype)).to(self._dtype)
|
| 249 |
+
else:
|
| 250 |
+
pred_noise = self._unet_pred_noise(x_start=latent, t=timesteps, noise=noise, context=conditional_context.to(self._dtype)).to(self._dtype)
|
| 251 |
+
|
| 252 |
+
if self.training_args.loss == "l1":
|
| 253 |
+
loss = torch.nn.functional.l1_loss(pred_noise, noise)
|
| 254 |
+
else:
|
| 255 |
+
loss = torch.nn.functional.mse_loss(pred_noise, noise)
|
| 256 |
+
|
| 257 |
+
return SDOutput(loss=loss)
|
VLV_stage2.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional, Tuple, Dict, Any, Union
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers.utils import ModelOutput
|
| 8 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig
|
| 10 |
+
from safetensors.torch import load_file
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
from .build import load_sd_model, load_Florence2_model
|
| 13 |
+
from .vlv_utils import initiate_time_steps, normalize, process_caption
|
| 14 |
+
from .VLV_stage1 import SDModel, SDConfig
|
| 15 |
+
from .configuration_vlv import VLV_Config
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
def handle_module_prefix(state_dict):
|
| 21 |
+
"""Handle 'module.' prefix in state dict keys."""
|
| 22 |
+
if any(k.startswith('module.') for k in state_dict.keys()):
|
| 23 |
+
return {k.replace('module.', ''): v for k, v in state_dict.items()}
|
| 24 |
+
return state_dict
|
| 25 |
+
|
| 26 |
+
def create_model_args(args):
|
| 27 |
+
"""Create model arguments needed by SDModel."""
|
| 28 |
+
model_args = argparse.Namespace()
|
| 29 |
+
model_args.use_text_encoder = args.use_text_encoder
|
| 30 |
+
model_args.batch_size = args.batch_size
|
| 31 |
+
model_args.eval_batch_size = args.batch_size
|
| 32 |
+
model_args.distributed_strategy = 'none'
|
| 33 |
+
model_args.fp32 = args.fp32
|
| 34 |
+
model_args.learnable_token_length = args.learnable_token_length
|
| 35 |
+
model_args.num_inference_steps = args.num_inference_steps
|
| 36 |
+
model_args.image_size = args.image_size
|
| 37 |
+
model_args.guidance_scale = args.guidance_scale
|
| 38 |
+
model_args.unfreeze_florence2_all = False
|
| 39 |
+
model_args.unfreeze_florence2_language_model = False
|
| 40 |
+
model_args.unfreeze_florence2_language_model_decoder = False
|
| 41 |
+
return model_args
|
| 42 |
+
|
| 43 |
+
def load_model_checkpoint(model, model_path, device):
|
| 44 |
+
"""Load model checkpoint."""
|
| 45 |
+
try:
|
| 46 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
| 47 |
+
|
| 48 |
+
# Handle different checkpoint formats
|
| 49 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 50 |
+
state_dict = checkpoint['model_state_dict']
|
| 51 |
+
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
| 52 |
+
state_dict = checkpoint['state_dict']
|
| 53 |
+
else:
|
| 54 |
+
state_dict = checkpoint
|
| 55 |
+
|
| 56 |
+
state_dict = handle_module_prefix(state_dict)
|
| 57 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 58 |
+
|
| 59 |
+
if missing_keys:
|
| 60 |
+
print(f"Missing keys: {missing_keys[:10]}...") # Show first 10
|
| 61 |
+
if unexpected_keys:
|
| 62 |
+
print(f"Unexpected keys: {unexpected_keys[:10]}...") # Show first 10
|
| 63 |
+
|
| 64 |
+
print(f"Successfully loaded model from {model_path}")
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"Error loading model: {e}")
|
| 67 |
+
raise e
|
| 68 |
+
|
| 69 |
+
return model
|
| 70 |
+
|
| 71 |
+
def initialize_diffusion_model(args):
|
| 72 |
+
"""Initialize the diffusion model."""
|
| 73 |
+
config = SDConfig()
|
| 74 |
+
diffusion_model_args = create_model_args(args)
|
| 75 |
+
diffusion_model = SDModel(config, diffusion_model_args)
|
| 76 |
+
_dtype = torch.float32 if diffusion_model_args.fp32 else torch.bfloat16
|
| 77 |
+
|
| 78 |
+
# Delete components that aren't needed for inference
|
| 79 |
+
if hasattr(diffusion_model, 'vae'):
|
| 80 |
+
del diffusion_model.vae
|
| 81 |
+
if hasattr(diffusion_model, 'unet'):
|
| 82 |
+
del diffusion_model.unet
|
| 83 |
+
|
| 84 |
+
# Clear CUDA cache
|
| 85 |
+
torch.cuda.empty_cache()
|
| 86 |
+
|
| 87 |
+
diffusion_model = diffusion_model.to(_dtype)
|
| 88 |
+
|
| 89 |
+
# Freeze parameters that shouldn't be trained
|
| 90 |
+
for param in diffusion_model.language_proj.parameters():
|
| 91 |
+
param.requires_grad = False
|
| 92 |
+
diffusion_model.query_embed.requires_grad = False
|
| 93 |
+
|
| 94 |
+
return diffusion_model
|
| 95 |
+
|
| 96 |
+
class MLP(nn.Module):
|
| 97 |
+
def __init__(self, input_dim, output_dim):
|
| 98 |
+
super(MLP, self).__init__()
|
| 99 |
+
self.layers = nn.Sequential(
|
| 100 |
+
nn.Linear(input_dim, output_dim),
|
| 101 |
+
nn.GELU(),
|
| 102 |
+
nn.Linear(output_dim, output_dim),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
return self.layers(x)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclass
|
| 110 |
+
class CLIPDecoderOutput(ModelOutput):
|
| 111 |
+
"""
|
| 112 |
+
Output class for the CLIP Decoder model.
|
| 113 |
+
"""
|
| 114 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 115 |
+
generated_ids: Optional[torch.LongTensor] = None
|
| 116 |
+
generated_text: Optional[list] = None
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class CLIPDecoder(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
language_model: str,
|
| 124 |
+
VLV_model: SDModel,
|
| 125 |
+
device: torch.device,
|
| 126 |
+
bf16: str,
|
| 127 |
+
qwen2_config: dict = None,
|
| 128 |
+
args: argparse.Namespace = None
|
| 129 |
+
):
|
| 130 |
+
"""
|
| 131 |
+
Initialize the CLIP Decoder model.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
language_model: Path to the language model
|
| 135 |
+
VLV_model: The VLV model instance
|
| 136 |
+
device: The device to run the model on
|
| 137 |
+
bf16: Whether to use bfloat16 precision
|
| 138 |
+
qwen2_config: Optional qwen2 configuration dict
|
| 139 |
+
"""
|
| 140 |
+
super(CLIPDecoder, self).__init__()
|
| 141 |
+
|
| 142 |
+
self._dtype = torch.bfloat16 if bf16 == "bf16" else torch.float32
|
| 143 |
+
self.qwen2_tokenizer = AutoTokenizer.from_pretrained(language_model)
|
| 144 |
+
|
| 145 |
+
self.qwen2_config = AutoConfig.from_pretrained(language_model)
|
| 146 |
+
self.qwen2_model = AutoModelForCausalLM.from_pretrained(
|
| 147 |
+
language_model,
|
| 148 |
+
torch_dtype=self._dtype,
|
| 149 |
+
device_map=None,
|
| 150 |
+
low_cpu_mem_usage=True
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.VLV_model = VLV_model # fp32 in this case
|
| 154 |
+
self.device = device
|
| 155 |
+
self.mlp = MLP(input_dim=1024, output_dim=self.qwen2_model.config.hidden_size)
|
| 156 |
+
self.ignore_token_id = -100
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def get_conditional_context(self, images, batch_size):
|
| 160 |
+
"""
|
| 161 |
+
Get conditional context from images using the diffusion model.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
images: Input images
|
| 165 |
+
batch_size: Batch size
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Decoder hidden states from the diffusion model
|
| 169 |
+
"""
|
| 170 |
+
prompt = ["<MORE_DETAILED_CAPTION>"] * batch_size
|
| 171 |
+
inputs = self.VLV_model.processor(text=prompt, images=images, return_tensors="pt").to(self.device).to(self._dtype)
|
| 172 |
+
|
| 173 |
+
# Ensure all components are on the correct device
|
| 174 |
+
self.VLV_model = self.VLV_model.to(inputs["input_ids"].device)
|
| 175 |
+
self.qwen2_model = self.qwen2_model.to(inputs["input_ids"].device)
|
| 176 |
+
self.mlp = self.mlp.to(inputs["input_ids"].device)
|
| 177 |
+
self.VLV_model.model.language_model.model = self.VLV_model.model.language_model.model.to(inputs["input_ids"].device)
|
| 178 |
+
|
| 179 |
+
if inputs["input_ids"] is not None:
|
| 180 |
+
inputs_embeds = self.VLV_model.model.language_model.get_input_embeddings()(inputs["input_ids"]).to(self.device)
|
| 181 |
+
|
| 182 |
+
if inputs["pixel_values"] is not None:
|
| 183 |
+
image_features = self.VLV_model.model._encode_image(inputs["pixel_values"]).to(self.device)
|
| 184 |
+
inputs_embeds, attention_mask = self.VLV_model.model._merge_input_ids_with_image_features(
|
| 185 |
+
image_features, inputs_embeds
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
if inputs_embeds is not None:
|
| 189 |
+
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
| 190 |
+
|
| 191 |
+
encoder_outputs = self.VLV_model.model.language_model.model.encoder(
|
| 192 |
+
inputs_embeds=inputs_embeds,
|
| 193 |
+
attention_mask=attention_mask,
|
| 194 |
+
output_hidden_states=True,
|
| 195 |
+
return_dict=True
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
decoder_inputs_embeds = self.VLV_model.query_embed.expand(batch_size, -1, -1)
|
| 199 |
+
decoder_attention_mask = torch.ones(
|
| 200 |
+
(batch_size, self.VLV_model.num_queries),
|
| 201 |
+
dtype=self._dtype,
|
| 202 |
+
device=self.device
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state.to(self._dtype)
|
| 206 |
+
decoder_input_embeds = decoder_inputs_embeds.to(self._dtype)
|
| 207 |
+
attention_mask = attention_mask.to(self._dtype)
|
| 208 |
+
|
| 209 |
+
decoder_outputs = self.VLV_model.model.language_model.model.decoder(
|
| 210 |
+
inputs_embeds=decoder_input_embeds,
|
| 211 |
+
attention_mask=decoder_attention_mask,
|
| 212 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 213 |
+
encoder_attention_mask=attention_mask,
|
| 214 |
+
output_hidden_states=True,
|
| 215 |
+
return_dict=True
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
return decoder_outputs.last_hidden_state
|
| 219 |
+
|
| 220 |
+
def process_image(self, images, batch_size):
|
| 221 |
+
"""
|
| 222 |
+
Process images to get clip text embeddings.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
images: Input images
|
| 226 |
+
batch_size: Batch size
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Processed clip text embeddings and attention mask
|
| 230 |
+
"""
|
| 231 |
+
decoder_hidden_states = self.get_conditional_context(images, batch_size)
|
| 232 |
+
context_embeds = self.VLV_model.language_proj(decoder_hidden_states)
|
| 233 |
+
clip_text_embeds = self.VLV_model.text_encoder(inputs_embeds=context_embeds).last_hidden_state
|
| 234 |
+
clip_text_embeds = self.mlp(clip_text_embeds)
|
| 235 |
+
clip_text_embeds_attention_mask = torch.ones(
|
| 236 |
+
(batch_size, self.VLV_model.num_queries),
|
| 237 |
+
dtype=torch.long,
|
| 238 |
+
device=self.device
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
return clip_text_embeds, clip_text_embeds_attention_mask
|
| 242 |
+
|
| 243 |
+
def prepare_generation_inputs(self, clip_text_embeds, clip_text_attention_mask=None):
|
| 244 |
+
"""
|
| 245 |
+
Prepare inputs for text generation.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
clip_text_embeds: Processed clip text embeddings
|
| 249 |
+
clip_text_attention_mask: Attention mask for clip text embeddings
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
Dictionary of generation inputs
|
| 253 |
+
"""
|
| 254 |
+
if clip_text_attention_mask is None:
|
| 255 |
+
clip_text_attention_mask = torch.ones(
|
| 256 |
+
(clip_text_embeds.shape[0], clip_text_embeds.shape[1]),
|
| 257 |
+
dtype=torch.long,
|
| 258 |
+
device=clip_text_embeds.device
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
return {
|
| 262 |
+
"inputs_embeds": clip_text_embeds,
|
| 263 |
+
"attention_mask": clip_text_attention_mask
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
def generate(self, images, max_new_tokens=300, num_beams=4, early_stopping=True):
|
| 267 |
+
"""
|
| 268 |
+
Generate text from images.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
images: Input images
|
| 272 |
+
max_new_tokens: Maximum number of tokens to generate
|
| 273 |
+
num_beams: Number of beams for beam search
|
| 274 |
+
early_stopping: Whether to stop early in beam search
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
CLIPDecoderOutput with generated ids and text
|
| 278 |
+
"""
|
| 279 |
+
batch_size = len(images)
|
| 280 |
+
clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size)
|
| 281 |
+
generation_inputs = self.prepare_generation_inputs(clip_text_embeds, clip_text_attention_mask)
|
| 282 |
+
|
| 283 |
+
generation_inputs["inputs_embeds"] = generation_inputs["inputs_embeds"].to(self._dtype)
|
| 284 |
+
generation_inputs["attention_mask"] = generation_inputs["attention_mask"].to(self._dtype)
|
| 285 |
+
|
| 286 |
+
generated_ids = self.qwen2_model.generate(
|
| 287 |
+
inputs_embeds=generation_inputs["inputs_embeds"],
|
| 288 |
+
attention_mask=generation_inputs["attention_mask"],
|
| 289 |
+
max_new_tokens=max_new_tokens,
|
| 290 |
+
num_beams=num_beams,
|
| 291 |
+
early_stopping=early_stopping
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
generated_text = self.qwen2_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
| 295 |
+
processed_generated_text = [process_caption(text) for text in generated_text]
|
| 296 |
+
|
| 297 |
+
return CLIPDecoderOutput(
|
| 298 |
+
generated_ids=generated_ids,
|
| 299 |
+
generated_text=processed_generated_text
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def forward(self, images, captions=None):
|
| 303 |
+
"""
|
| 304 |
+
Forward pass for training.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
images: Input images
|
| 308 |
+
captions: Target captions (optional, for training)
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
CLIPDecoderOutput with loss and logits
|
| 312 |
+
"""
|
| 313 |
+
batch_size = images.shape[0]
|
| 314 |
+
|
| 315 |
+
# Process images
|
| 316 |
+
clip_text_embeds, clip_text_attention_mask = self.process_image(images, batch_size)
|
| 317 |
+
|
| 318 |
+
# If no captions provided, return embeddings for generation
|
| 319 |
+
if captions is None:
|
| 320 |
+
return CLIPDecoderOutput(
|
| 321 |
+
last_hidden_state=clip_text_embeds
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
assert len(captions) == batch_size
|
| 325 |
+
# Process captions for training
|
| 326 |
+
processed_captions = [process_caption(caption) for caption in captions]
|
| 327 |
+
qwen_input_ids = self.qwen2_tokenizer(
|
| 328 |
+
text=processed_captions,
|
| 329 |
+
truncation=True,
|
| 330 |
+
return_tensors="pt",
|
| 331 |
+
padding="max_length",
|
| 332 |
+
max_length=300,
|
| 333 |
+
return_token_type_ids=False,
|
| 334 |
+
).input_ids
|
| 335 |
+
|
| 336 |
+
assert len(captions) == batch_size
|
| 337 |
+
qwen_attention_mask = qwen_input_ids.ne(self.qwen2_tokenizer.pad_token_id).to(torch.long).to(self.device)
|
| 338 |
+
|
| 339 |
+
# Prepare labels for training
|
| 340 |
+
labels = qwen_input_ids
|
| 341 |
+
labels[labels == self.qwen2_tokenizer.pad_token_id] = self.ignore_token_id
|
| 342 |
+
labels = labels.to(self.device)
|
| 343 |
+
|
| 344 |
+
# Get embeddings for captions to create the full input sequence
|
| 345 |
+
labels_for_embeddings = labels.clone()
|
| 346 |
+
labels_for_embeddings[labels_for_embeddings == self.ignore_token_id] = self.qwen2_tokenizer.pad_token_id
|
| 347 |
+
clip_text_embeds_qwen = self.qwen2_model.get_input_embeddings()(labels_for_embeddings)
|
| 348 |
+
|
| 349 |
+
# Concatenate the embeddings and prepare attention mask
|
| 350 |
+
inputs_embeds = torch.cat((clip_text_embeds, clip_text_embeds_qwen), dim=1)
|
| 351 |
+
clip_seq_len = clip_text_embeds.shape[1]
|
| 352 |
+
clip_ignore_labels = torch.full((labels.shape[0], clip_seq_len), self.ignore_token_id).to(labels)
|
| 353 |
+
combined_labels = torch.cat((clip_ignore_labels, labels), dim=1)
|
| 354 |
+
|
| 355 |
+
attention_mask = torch.cat((
|
| 356 |
+
clip_text_attention_mask,
|
| 357 |
+
qwen_attention_mask
|
| 358 |
+
), dim=1)
|
| 359 |
+
|
| 360 |
+
# Forward through language model
|
| 361 |
+
outputs = self.qwen2_model(
|
| 362 |
+
inputs_embeds=inputs_embeds,
|
| 363 |
+
labels=combined_labels,
|
| 364 |
+
attention_mask=attention_mask,
|
| 365 |
+
use_cache=False
|
| 366 |
+
)
|
| 367 |
+
return outputs
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
# HuggingFace Model Wrapper
|
| 371 |
+
class VLV_MODEL(PreTrainedModel):
|
| 372 |
+
config_class = VLV_Config
|
| 373 |
+
model_type = "VLV_decoder"
|
| 374 |
+
|
| 375 |
+
def __init__(self, config):
|
| 376 |
+
super().__init__(config)
|
| 377 |
+
"""Load the CLIPDecoder model."""
|
| 378 |
+
# Initialize the diffusion model first
|
| 379 |
+
device = "cuda"
|
| 380 |
+
de_diffusion_model = initialize_diffusion_model(config)
|
| 381 |
+
clip_decoder_model = CLIPDecoder(
|
| 382 |
+
language_model=config.qwen_model,
|
| 383 |
+
VLV_model=de_diffusion_model,
|
| 384 |
+
device=device,
|
| 385 |
+
bf16=config.mixed_precision,
|
| 386 |
+
qwen2_config=config.qwen2_config
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Load the trained weights
|
| 390 |
+
# clip_decoder_model = load_model_checkpoint(clip_decoder_model, config.clip_decoder_checkpoint, device)
|
| 391 |
+
|
| 392 |
+
# Set to evaluation mode
|
| 393 |
+
clip_decoder_model.eval()
|
| 394 |
+
|
| 395 |
+
# Store components directly as attributes to match checkpoint structure
|
| 396 |
+
self.VLV_model = clip_decoder_model.VLV_model
|
| 397 |
+
self.qwen2_model = clip_decoder_model.qwen2_model
|
| 398 |
+
self.mlp = clip_decoder_model.mlp
|
| 399 |
+
|
| 400 |
+
# Keep the full model for methods
|
| 401 |
+
self._clip_decoder_model = clip_decoder_model
|
| 402 |
+
self.max_new_tokens = config.max_length
|
| 403 |
+
self.num_beams = config.num_beams
|
| 404 |
+
self.transform = self.get_transform(config.image_size)
|
| 405 |
+
|
| 406 |
+
def get_transform(self, image_size):
|
| 407 |
+
"""Transformation pipeline for input images."""
|
| 408 |
+
return transforms.Compose([
|
| 409 |
+
transforms.Resize(image_size),
|
| 410 |
+
transforms.CenterCrop((image_size, image_size)),
|
| 411 |
+
transforms.PILToTensor(),
|
| 412 |
+
])
|
| 413 |
+
|
| 414 |
+
@classmethod
|
| 415 |
+
def from_checkpoint(cls, checkpoint_path, config=None, **kwargs):
|
| 416 |
+
"""
|
| 417 |
+
Load model from original training checkpoint.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
checkpoint_path: Path to the original model.pt checkpoint
|
| 421 |
+
config: Optional VLV_Config, will create default if None
|
| 422 |
+
**kwargs: Additional arguments for model initialization
|
| 423 |
+
"""
|
| 424 |
+
if config is None:
|
| 425 |
+
# Create default config
|
| 426 |
+
config = VLV_Config(
|
| 427 |
+
image_size=384,
|
| 428 |
+
guidance_scale=7.5,
|
| 429 |
+
learnable_token_length=77,
|
| 430 |
+
max_length=300,
|
| 431 |
+
num_beams=4,
|
| 432 |
+
**kwargs
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Initialize model
|
| 436 |
+
model = cls(config)
|
| 437 |
+
|
| 438 |
+
# Load checkpoint weights
|
| 439 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 440 |
+
load_model_checkpoint(model._clip_decoder_model, checkpoint_path, device)
|
| 441 |
+
|
| 442 |
+
return model
|
| 443 |
+
|
| 444 |
+
def forward(self, valid_images, max_length):
|
| 445 |
+
valid_images = [self.transform(img) for img in valid_images]
|
| 446 |
+
if hasattr(self._clip_decoder_model, 'module'):
|
| 447 |
+
outputs = self._clip_decoder_model.module.generate(
|
| 448 |
+
valid_images,
|
| 449 |
+
max_new_tokens=max_length,
|
| 450 |
+
num_beams=self.num_beams,
|
| 451 |
+
early_stopping=True
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
outputs = self._clip_decoder_model.generate(
|
| 455 |
+
valid_images,
|
| 456 |
+
max_new_tokens=max_length,
|
| 457 |
+
num_beams=self.num_beams,
|
| 458 |
+
early_stopping=True
|
| 459 |
+
)
|
| 460 |
+
return outputs
|
build.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
|
| 3 |
+
from transformers import CLIPTokenizer, AutoProcessor
|
| 4 |
+
from .modeling_clip import CustomCLIPTextModel
|
| 5 |
+
from .modeling_florence2 import Florence2ForConditionalGeneration
|
| 6 |
+
from .configuration_florence2 import Florence2Config
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_sd_model(training_args):
|
| 10 |
+
"""Load Stable Diffusion model"""
|
| 11 |
+
|
| 12 |
+
repo_id = "stabilityai/stable-diffusion-2-1-base"
|
| 13 |
+
|
| 14 |
+
text_encoder = CustomCLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder")
|
| 15 |
+
tokenizer = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer")
|
| 16 |
+
vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae",revision=None)
|
| 17 |
+
scheduler = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
| 18 |
+
unet = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet",revision=None)
|
| 19 |
+
|
| 20 |
+
for m in [vae, text_encoder, unet]:
|
| 21 |
+
for param in m.parameters():
|
| 22 |
+
param.requires_grad = False
|
| 23 |
+
|
| 24 |
+
return (vae, tokenizer, text_encoder, unet, scheduler)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_Florence2_model(training_args):
|
| 28 |
+
config = Florence2Config.from_pretrained("microsoft/Florence-2-large")
|
| 29 |
+
config.vision_config.model_type = "davit"
|
| 30 |
+
config._attn_implementation = "eager"
|
| 31 |
+
|
| 32 |
+
# Load the model with pre-trained weights
|
| 33 |
+
model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large", config=config)
|
| 34 |
+
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
| 35 |
+
|
| 36 |
+
# freeze the model
|
| 37 |
+
if training_args.unfreeze_florence2_all:
|
| 38 |
+
for param in model.parameters():
|
| 39 |
+
param.requires_grad = True
|
| 40 |
+
elif training_args.unfreeze_florence2_language_model:
|
| 41 |
+
for param in model.parameters():
|
| 42 |
+
param.requires_grad = False
|
| 43 |
+
for param in model.language_model.parameters():
|
| 44 |
+
param.requires_grad = True
|
| 45 |
+
for param in model.language_model.lm_head.parameters():
|
| 46 |
+
param.requires_grad = False
|
| 47 |
+
|
| 48 |
+
model.language_model.lm_head.weight = torch.nn.Parameter(
|
| 49 |
+
model.language_model.lm_head.weight.detach().clone())
|
| 50 |
+
|
| 51 |
+
for p in model.language_model.lm_head.parameters():
|
| 52 |
+
p.requires_grad = False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
elif training_args.unfreeze_florence2_language_model_decoder:
|
| 56 |
+
# Create a separate embedding layer for decoder
|
| 57 |
+
original_embeddings = model.language_model.model.shared
|
| 58 |
+
new_decoder_embeddings = torch.nn.Embedding(
|
| 59 |
+
num_embeddings=original_embeddings.num_embeddings,
|
| 60 |
+
embedding_dim=original_embeddings.embedding_dim,
|
| 61 |
+
padding_idx=original_embeddings.padding_idx
|
| 62 |
+
)
|
| 63 |
+
# Copy the weights
|
| 64 |
+
new_decoder_embeddings.weight.data = original_embeddings.weight.data.clone()
|
| 65 |
+
|
| 66 |
+
# Replace the decoder embeddings
|
| 67 |
+
model.language_model.model.encoder.embed_tokens = original_embeddings
|
| 68 |
+
model.language_model.model.decoder.embed_tokens = new_decoder_embeddings
|
| 69 |
+
for param in model.parameters():
|
| 70 |
+
param.requires_grad = False
|
| 71 |
+
for param in model.language_model.model.decoder.parameters():
|
| 72 |
+
param.requires_grad = True
|
| 73 |
+
model.language_model.model.decoder.embed_tokens.weight.requires_grad = False
|
| 74 |
+
else:
|
| 75 |
+
for param in model.parameters():
|
| 76 |
+
param.requires_grad = False
|
| 77 |
+
|
| 78 |
+
return model, processor
|
config.json
CHANGED
|
@@ -3,27 +3,31 @@
|
|
| 3 |
"VLV_MODEL"
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
-
"AutoConfig": "
|
| 7 |
-
"AutoModel": "
|
| 8 |
-
"AutoModelForCausalLM": "
|
| 9 |
},
|
| 10 |
"model_type": "VLV_decoder",
|
| 11 |
"batch_size": 1,
|
| 12 |
"deepspeed": true,
|
| 13 |
"distributed": true,
|
| 14 |
"fp32": true,
|
| 15 |
-
"guidance_scale": 2.
|
| 16 |
"hidden_size": 128,
|
| 17 |
-
"image_size":
|
| 18 |
"learnable_token_length": 77,
|
| 19 |
"local_rank": 0,
|
| 20 |
-
"mixed_precision": "
|
| 21 |
"num_inference_steps": 50,
|
| 22 |
-
"torch_dtype": "
|
| 23 |
"transformers_version": "4.51.1",
|
| 24 |
"use_text_encoder": true,
|
| 25 |
"verbose": true,
|
| 26 |
"qwen_model": "Qwen/Qwen2.5-3B",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"qwen2_config":{
|
| 28 |
"architectures": [
|
| 29 |
"Qwen2ForCausalLM"
|
|
@@ -45,11 +49,11 @@
|
|
| 45 |
"rope_theta": 1000000.0,
|
| 46 |
"sliding_window": 32768,
|
| 47 |
"tie_word_embeddings": true,
|
| 48 |
-
"torch_dtype": "
|
| 49 |
"transformers_version": "4.40.1",
|
| 50 |
"use_cache": true,
|
| 51 |
"use_mrope": false,
|
| 52 |
"use_sliding_window": false,
|
| 53 |
"vocab_size": 151936
|
| 54 |
}
|
| 55 |
-
}
|
|
|
|
| 3 |
"VLV_MODEL"
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_vlv.VLV_Config",
|
| 7 |
+
"AutoModel": "VLV_stage2.VLV_MODEL",
|
| 8 |
+
"AutoModelForCausalLM": "VLV_stage2.VLV_MODEL"
|
| 9 |
},
|
| 10 |
"model_type": "VLV_decoder",
|
| 11 |
"batch_size": 1,
|
| 12 |
"deepspeed": true,
|
| 13 |
"distributed": true,
|
| 14 |
"fp32": true,
|
| 15 |
+
"guidance_scale": 2.5,
|
| 16 |
"hidden_size": 128,
|
| 17 |
+
"image_size": 384,
|
| 18 |
"learnable_token_length": 77,
|
| 19 |
"local_rank": 0,
|
| 20 |
+
"mixed_precision": "fp32",
|
| 21 |
"num_inference_steps": 50,
|
| 22 |
+
"torch_dtype": "float32",
|
| 23 |
"transformers_version": "4.51.1",
|
| 24 |
"use_text_encoder": true,
|
| 25 |
"verbose": true,
|
| 26 |
"qwen_model": "Qwen/Qwen2.5-3B",
|
| 27 |
+
"stable_diffusion_model_path": "stabilityai/stable-diffusion-2-1-base",
|
| 28 |
+
"florence2_model_path": "microsoft/Florence-2-large",
|
| 29 |
+
"max_length": 300,
|
| 30 |
+
"num_beams": 4,
|
| 31 |
"qwen2_config":{
|
| 32 |
"architectures": [
|
| 33 |
"Qwen2ForCausalLM"
|
|
|
|
| 49 |
"rope_theta": 1000000.0,
|
| 50 |
"sliding_window": 32768,
|
| 51 |
"tie_word_embeddings": true,
|
| 52 |
+
"torch_dtype": "float32",
|
| 53 |
"transformers_version": "4.40.1",
|
| 54 |
"use_cache": true,
|
| 55 |
"use_mrope": false,
|
| 56 |
"use_sliding_window": false,
|
| 57 |
"vocab_size": 151936
|
| 58 |
}
|
| 59 |
+
}
|
configuration_vlv.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 VLV Team and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""VLV model configuration"""
|
| 16 |
+
|
| 17 |
+
from typing import Optional, Dict, Any
|
| 18 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 19 |
+
from transformers.utils import logging
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class VLV_Config(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`VLV_MODEL`]. It is used to instantiate a VLV model
|
| 27 |
+
according to the specified arguments, defining the model architecture.
|
| 28 |
+
|
| 29 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 30 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model_type (`str`, *optional*, defaults to "VLV_decoder"):
|
| 34 |
+
The model type identifier.
|
| 35 |
+
batch_size (`int`, *optional*, defaults to 1):
|
| 36 |
+
The batch size for inference.
|
| 37 |
+
deepspeed (`bool`, *optional*, defaults to True):
|
| 38 |
+
Whether to use deepspeed.
|
| 39 |
+
distributed (`bool`, *optional*, defaults to True):
|
| 40 |
+
Whether to use distributed training.
|
| 41 |
+
fp32 (`bool`, *optional*, defaults to True):
|
| 42 |
+
Whether to use fp32 precision.
|
| 43 |
+
guidance_scale (`float`, *optional*, defaults to 2.0):
|
| 44 |
+
The guidance scale for generation.
|
| 45 |
+
hidden_size (`int`, *optional*, defaults to 128):
|
| 46 |
+
The hidden size of the model.
|
| 47 |
+
image_size (`int`, *optional*, defaults to 768):
|
| 48 |
+
The size of input images.
|
| 49 |
+
learnable_token_length (`int`, *optional*, defaults to 77):
|
| 50 |
+
The length of learnable tokens.
|
| 51 |
+
local_rank (`int`, *optional*, defaults to 0):
|
| 52 |
+
The local rank for distributed training.
|
| 53 |
+
mixed_precision (`str`, *optional*, defaults to "bf16"):
|
| 54 |
+
The mixed precision mode.
|
| 55 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 56 |
+
The number of inference steps.
|
| 57 |
+
torch_dtype (`str`, *optional*, defaults to "bfloat16"):
|
| 58 |
+
The torch dtype.
|
| 59 |
+
use_text_encoder (`bool`, *optional*, defaults to True):
|
| 60 |
+
Whether to use text encoder.
|
| 61 |
+
verbose (`bool`, *optional*, defaults to True):
|
| 62 |
+
Whether to enable verbose mode.
|
| 63 |
+
qwen_model (`str`, *optional*, defaults to "Qwen/Qwen2.5-3B"):
|
| 64 |
+
The Qwen model to use.
|
| 65 |
+
qwen2_config (`dict`, *optional*):
|
| 66 |
+
The Qwen2 configuration.
|
| 67 |
+
max_length (`int`, *optional*, defaults to 300):
|
| 68 |
+
Maximum length for generation.
|
| 69 |
+
num_beams (`int`, *optional*, defaults to 4):
|
| 70 |
+
Number of beams for beam search.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
model_type = "VLV_decoder"
|
| 74 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
model_type: str = "VLV_decoder",
|
| 79 |
+
batch_size: int = 1,
|
| 80 |
+
deepspeed: bool = True,
|
| 81 |
+
distributed: bool = True,
|
| 82 |
+
fp32: bool = True,
|
| 83 |
+
guidance_scale: float = 2.0,
|
| 84 |
+
hidden_size: int = 128,
|
| 85 |
+
image_size: int = 768,
|
| 86 |
+
learnable_token_length: int = 77,
|
| 87 |
+
local_rank: int = 0,
|
| 88 |
+
mixed_precision: str = "bf16",
|
| 89 |
+
num_inference_steps: int = 50,
|
| 90 |
+
torch_dtype: str = "bfloat16",
|
| 91 |
+
transformers_version: str = "4.51.1",
|
| 92 |
+
use_text_encoder: bool = True,
|
| 93 |
+
verbose: bool = True,
|
| 94 |
+
qwen_model: str = "Qwen/Qwen2.5-3B",
|
| 95 |
+
stable_diffusion_model_path: str = "stabilityai/stable-diffusion-2-1-base",
|
| 96 |
+
florence2_model_path: str = "microsoft/Florence-2-large",
|
| 97 |
+
qwen2_config: Optional[Dict[str, Any]] = None,
|
| 98 |
+
max_length: int = 300,
|
| 99 |
+
num_beams: int = 4,
|
| 100 |
+
**kwargs,
|
| 101 |
+
):
|
| 102 |
+
self.model_type = model_type
|
| 103 |
+
self.batch_size = batch_size
|
| 104 |
+
self.deepspeed = deepspeed
|
| 105 |
+
self.distributed = distributed
|
| 106 |
+
self.fp32 = fp32
|
| 107 |
+
self.guidance_scale = guidance_scale
|
| 108 |
+
self.hidden_size = hidden_size
|
| 109 |
+
self.image_size = image_size
|
| 110 |
+
self.learnable_token_length = learnable_token_length
|
| 111 |
+
self.local_rank = local_rank
|
| 112 |
+
self.mixed_precision = mixed_precision
|
| 113 |
+
self.num_inference_steps = num_inference_steps
|
| 114 |
+
self.torch_dtype = torch_dtype
|
| 115 |
+
self.transformers_version = transformers_version
|
| 116 |
+
self.use_text_encoder = use_text_encoder
|
| 117 |
+
self.verbose = verbose
|
| 118 |
+
self.qwen_model = qwen_model
|
| 119 |
+
self.stable_diffusion_model_path = stable_diffusion_model_path
|
| 120 |
+
self.florence2_model_path = florence2_model_path
|
| 121 |
+
self.qwen2_config = qwen2_config or self._get_default_qwen2_config()
|
| 122 |
+
self.max_length = max_length
|
| 123 |
+
self.num_beams = num_beams
|
| 124 |
+
|
| 125 |
+
super().__init__(**kwargs)
|
| 126 |
+
|
| 127 |
+
def _get_default_qwen2_config(self):
|
| 128 |
+
"""Get default Qwen2 configuration."""
|
| 129 |
+
return {
|
| 130 |
+
"architectures": ["Qwen2ForCausalLM"],
|
| 131 |
+
"attention_dropout": 0.0,
|
| 132 |
+
"bos_token_id": 151643,
|
| 133 |
+
"eos_token_id": 151643,
|
| 134 |
+
"hidden_act": "silu",
|
| 135 |
+
"hidden_size": 2048,
|
| 136 |
+
"initializer_range": 0.02,
|
| 137 |
+
"intermediate_size": 11008,
|
| 138 |
+
"max_position_embeddings": 32768,
|
| 139 |
+
"max_window_layers": 36,
|
| 140 |
+
"model_type": "qwen2",
|
| 141 |
+
"num_attention_heads": 16,
|
| 142 |
+
"num_hidden_layers": 36,
|
| 143 |
+
"num_key_value_heads": 2,
|
| 144 |
+
"rms_norm_eps": 1e-06,
|
| 145 |
+
"rope_theta": 1000000.0,
|
| 146 |
+
"sliding_window": 32768,
|
| 147 |
+
"tie_word_embeddings": True,
|
| 148 |
+
"torch_dtype": "bfloat16",
|
| 149 |
+
"transformers_version": "4.40.1",
|
| 150 |
+
"use_cache": True,
|
| 151 |
+
"use_mrope": False,
|
| 152 |
+
"use_sliding_window": False,
|
| 153 |
+
"vocab_size": 151936
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class CLIPDecoderConfig(PretrainedConfig):
|
| 158 |
+
r"""
|
| 159 |
+
Configuration class for CLIPDecoder model (legacy support).
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
model_type = "vlv_stage2"
|
| 163 |
+
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
input_dim: int = 1024,
|
| 167 |
+
bf16: bool = False,
|
| 168 |
+
**kwargs,
|
| 169 |
+
):
|
| 170 |
+
self.input_dim = input_dim
|
| 171 |
+
self.bf16 = bf16
|
| 172 |
+
super().__init__(**kwargs)
|
model-00001-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7460963b2ea4c7cde35d0c64c8d46d4a9324c7574433f8cf9878bbaf687f61b
|
| 3 |
+
size 622330008
|
model-00002-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dca6a859202a8817026897383409ec85fb0a22d4b6527da6ab5f5e2ccd3745be
|
| 3 |
+
size 832409864
|
model-00003-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30d67c2d202ae6c4166ba0b82310f19225665305e6fb3b22c66ff5318fbf6f50
|
| 3 |
+
size 210079920
|
model-00004-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c015745c4638633cfb7d09e9b2b96bfa15fd21511fd74642d13296afc9423a4f
|
| 3 |
+
size 5215310704
|
model-00005-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da8bab2f53dbd82612d2034d6e67724a171a44fc040198ca5fe9d6120cc3409e
|
| 3 |
+
size 5046894020
|
model.safetensors.index.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_clip.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPTextModel, CLIPPreTrainedModel, CLIPTextConfig
|
| 2 |
-
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings, CLIPEncoder, CLIPAttention, CLIPMLP, CLIPEncoderLayer, _create_4d_causal_attention_mask, _prepare_4d_attention_mask, BaseModelOutputWithPooling
|
| 3 |
from typing import Optional, Union, Tuple
|
| 4 |
import torch
|
| 5 |
from torch import nn
|
|
@@ -53,7 +53,8 @@ class CustomCLIPTextTransformer(nn.Module):
|
|
| 53 |
|
| 54 |
|
| 55 |
if inputs_embeds is not None:
|
| 56 |
-
inputs_embeds
|
|
|
|
| 57 |
else:
|
| 58 |
inputs_embeds = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
| 59 |
|
|
@@ -134,9 +135,49 @@ class CustomCLIPTextModel(CLIPPreTrainedModel):
|
|
| 134 |
output_hidden_states: Optional[bool] = None,
|
| 135 |
return_dict: Optional[bool] = None,
|
| 136 |
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
input_ids=input_ids,
|
| 141 |
attention_mask=attention_mask,
|
| 142 |
position_ids=position_ids,
|
|
@@ -145,3 +186,19 @@ class CustomCLIPTextModel(CLIPPreTrainedModel):
|
|
| 145 |
output_hidden_states=output_hidden_states,
|
| 146 |
return_dict=return_dict,
|
| 147 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPTextModel, CLIPPreTrainedModel, CLIPTextConfig
|
| 2 |
+
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings, CLIPEncoder, CLIPAttention, CLIPMLP, CLIPEncoderLayer, _create_4d_causal_attention_mask, _prepare_4d_attention_mask, BaseModelOutputWithPooling, CLIPTextModelOutput
|
| 3 |
from typing import Optional, Union, Tuple
|
| 4 |
import torch
|
| 5 |
from torch import nn
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
if inputs_embeds is not None:
|
| 56 |
+
# inputs_embeds are already embeddings, just add positional embeddings
|
| 57 |
+
inputs_embeds = self.embeddings.position_embedding(self.embeddings.position_ids[:, :inputs_embeds.size(1)]) + inputs_embeds
|
| 58 |
else:
|
| 59 |
inputs_embeds = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
| 60 |
|
|
|
|
| 135 |
output_hidden_states: Optional[bool] = None,
|
| 136 |
return_dict: Optional[bool] = None,
|
| 137 |
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 138 |
+
return self.text_model(
|
| 139 |
+
input_ids=input_ids,
|
| 140 |
+
attention_mask=attention_mask,
|
| 141 |
+
position_ids=position_ids,
|
| 142 |
+
inputs_embeds=inputs_embeds,
|
| 143 |
+
output_attentions=output_attentions,
|
| 144 |
+
output_hidden_states=output_hidden_states,
|
| 145 |
+
return_dict=return_dict,
|
| 146 |
+
)
|
| 147 |
|
| 148 |
|
| 149 |
+
class CustomCLIPTextModelWithProjection(CLIPPreTrainedModel):
|
| 150 |
+
config_class = CLIPTextConfig
|
| 151 |
+
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
| 152 |
+
|
| 153 |
+
def __init__(self, config: CLIPTextConfig):
|
| 154 |
+
super().__init__(config)
|
| 155 |
+
self.text_model = CustomCLIPTextTransformer(config)
|
| 156 |
+
|
| 157 |
+
# Add the projection layer for SDXL's second text encoder
|
| 158 |
+
projection_dim = getattr(config, 'projection_dim', config.hidden_size)
|
| 159 |
+
self.text_projection = nn.Linear(config.hidden_size, projection_dim, bias=False)
|
| 160 |
+
|
| 161 |
+
# Initialize weights and apply final processing
|
| 162 |
+
self.post_init()
|
| 163 |
+
|
| 164 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 165 |
+
return self.text_model.embeddings.token_embedding
|
| 166 |
+
|
| 167 |
+
def set_input_embeddings(self, value):
|
| 168 |
+
self.text_model.embeddings.token_embedding = value
|
| 169 |
+
|
| 170 |
+
def forward(
|
| 171 |
+
self,
|
| 172 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 173 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 174 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 175 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 176 |
+
output_attentions: Optional[bool] = None,
|
| 177 |
+
output_hidden_states: Optional[bool] = None,
|
| 178 |
+
return_dict: Optional[bool] = None,
|
| 179 |
+
) -> Union[Tuple, CLIPTextModelOutput]:
|
| 180 |
+
text_outputs = self.text_model(
|
| 181 |
input_ids=input_ids,
|
| 182 |
attention_mask=attention_mask,
|
| 183 |
position_ids=position_ids,
|
|
|
|
| 186 |
output_hidden_states=output_hidden_states,
|
| 187 |
return_dict=return_dict,
|
| 188 |
)
|
| 189 |
+
|
| 190 |
+
pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output
|
| 191 |
+
|
| 192 |
+
# Apply the projection to the pooled output
|
| 193 |
+
text_embeds = self.text_projection(pooled_output)
|
| 194 |
+
|
| 195 |
+
if not return_dict:
|
| 196 |
+
# Include both last_hidden_state, pooler_output, text_embeds, and other outputs
|
| 197 |
+
return (text_outputs[0], text_outputs[1], text_embeds) + text_outputs[2:]
|
| 198 |
+
|
| 199 |
+
return CLIPTextModelOutput(
|
| 200 |
+
text_embeds=text_embeds, # Projected embeddings (for similarity)
|
| 201 |
+
last_hidden_state=text_outputs.last_hidden_state, # All token representations
|
| 202 |
+
hidden_states=text_outputs.hidden_states,
|
| 203 |
+
attentions=text_outputs.attentions,
|
| 204 |
+
)
|
vlv_utils.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions"""
|
| 2 |
+
import importlib
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def normalize(image,rescale=True):
|
| 11 |
+
|
| 12 |
+
if rescale:
|
| 13 |
+
image = image.float() / 255.0 # Convert to float and rescale to [0, 1]
|
| 14 |
+
normalize_image = 2*image-1 # normalize to [-1, 1]
|
| 15 |
+
|
| 16 |
+
return normalize_image
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def process_caption(caption):
|
| 21 |
+
"""Process a caption to ensure proper formatting and remove duplicates.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
caption: A string containing the caption text
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
processed_caption: A string with processed caption
|
| 28 |
+
"""
|
| 29 |
+
if not caption.endswith('.'):
|
| 30 |
+
last_period_index = caption.rfind('.')
|
| 31 |
+
if last_period_index != -1:
|
| 32 |
+
caption = caption[:last_period_index + 1]
|
| 33 |
+
|
| 34 |
+
sentences = re.split(r'(?<=[.!?])\s+', caption)
|
| 35 |
+
|
| 36 |
+
unique_sentences = []
|
| 37 |
+
for sentence in sentences:
|
| 38 |
+
if sentence and sentence not in unique_sentences:
|
| 39 |
+
unique_sentences.append(sentence)
|
| 40 |
+
|
| 41 |
+
processed_caption = ' '.join(unique_sentences)
|
| 42 |
+
|
| 43 |
+
return processed_caption
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def initiate_time_steps(step, total_timestep, batch_size, config):
|
| 47 |
+
"""A helper function to initiate time steps for the diffusion model.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
step: An integer of the constant step
|
| 51 |
+
total_timestep: An integer of the total timesteps of the diffusion model
|
| 52 |
+
batch_size: An integer of the batch size
|
| 53 |
+
config: A config object
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
timesteps: A tensor of shape [batch_size,] of the time steps
|
| 57 |
+
"""
|
| 58 |
+
if config.rand_timestep_equal_int:
|
| 59 |
+
# the same timestep for each image in the batch
|
| 60 |
+
interval_val = total_timestep // batch_size
|
| 61 |
+
start_point = random.randint(0, interval_val - 1)
|
| 62 |
+
timesteps = torch.tensor(
|
| 63 |
+
list(range(start_point, total_timestep, interval_val))
|
| 64 |
+
).long()
|
| 65 |
+
return timesteps
|
| 66 |
+
elif config.random_timestep_per_iteration:
|
| 67 |
+
# random timestep for each image in the batch
|
| 68 |
+
return torch.randint(0, total_timestep, (batch_size,)).long() #default
|
| 69 |
+
else:
|
| 70 |
+
# why we need to do this?
|
| 71 |
+
return torch.tensor([step] * batch_size).long()
|