|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
import lightning.pytorch as pl |
|
|
from safetensors.torch import save_file, load_file |
|
|
from transformers import AutoImageProcessor, AutoModel |
|
|
from transformers.image_utils import load_image |
|
|
|
|
|
class EmbeddingNetwork(nn.Module): |
|
|
def __init__(self): |
|
|
super(EmbeddingNetwork, self).__init__() |
|
|
self.fc1 = nn.Linear(1280, 256) |
|
|
self.dropout1 = nn.Dropout(0.33) |
|
|
self.fc2 = nn.Linear(256, 128) |
|
|
self.dropout2 = nn.Dropout(0.33) |
|
|
self.fc3 = nn.Linear(128, 7) |
|
|
self.act = nn.ReLU(inplace=True) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
|
|
|
x = self.act(x) |
|
|
x = self.fc2(x) |
|
|
|
|
|
x = self.act(x) |
|
|
x = self.fc3(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class PLModule(pl.LightningModule): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
self.network = EmbeddingNetwork() |
|
|
|
|
|
def forward(self, x): |
|
|
return self.network(x) |
|
|
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0): |
|
|
outputs = self.forward(batch[0]) |
|
|
return outputs, batch[1] |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
embd_model = EmbeddingNetwork().to(device=device, dtype=torch.bfloat16) |
|
|
state_dict = load_file("Style Embedder v4.safetensors") |
|
|
embd_model.load_state_dict(state_dict) |
|
|
|
|
|
token = 'Enter your huggingface token here' |
|
|
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vits16-pretrain-lvd1689m", |
|
|
do_resize=False, token=token) |
|
|
dino_model = AutoModel.from_pretrained("facebook/dinov3-vith16plus-pretrain-lvd1689m", token=token, device_map="auto", |
|
|
dtype=torch.bfloat16) |
|
|
image = load_image('images_for_style_embedding/6857740.webp') |
|
|
input = processor(images=image, return_tensors="pt").to(device=dino_model.device, dtype=torch.bfloat16) |
|
|
output = dino_model(**input) |
|
|
last_hidden_states = output.last_hidden_state |
|
|
cls_token = last_hidden_states[:, 0, :] |
|
|
|
|
|
pred = embd_model(cls_token).cpu() |
|
|
print(pred) |
|
|
|