Anime_Images_Style_Embedder / minimal_script.py
Fgdfgfthgr's picture
Upload 2 files
9381b0e verified
import os
import torch
import numpy as np
import torch.nn as nn
import lightning.pytorch as pl
from safetensors.torch import save_file, load_file
from transformers import AutoImageProcessor, AutoModel
from transformers.image_utils import load_image
class EmbeddingNetwork(nn.Module):
def __init__(self):
super(EmbeddingNetwork, self).__init__()
self.fc1 = nn.Linear(1280, 256)
self.dropout1 = nn.Dropout(0.33)
self.fc2 = nn.Linear(256, 128)
self.dropout2 = nn.Dropout(0.33)
self.fc3 = nn.Linear(128, 7)
self.act = nn.ReLU(inplace=True)
def forward(self, x):
x = self.fc1(x)
#x = self.dropout1(x)
x = self.act(x)
x = self.fc2(x)
#x = self.dropout2(x)
x = self.act(x)
x = self.fc3(x)
return x
class PLModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.network = EmbeddingNetwork()
def forward(self, x):
return self.network(x)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self.forward(batch[0])
return outputs, batch[1]
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embd_model = EmbeddingNetwork().to(device=device, dtype=torch.bfloat16)
state_dict = load_file("Style Embedder v4.safetensors")
embd_model.load_state_dict(state_dict)
token = 'Enter your huggingface token here'
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vits16-pretrain-lvd1689m",
do_resize=False, token=token)
dino_model = AutoModel.from_pretrained("facebook/dinov3-vith16plus-pretrain-lvd1689m", token=token, device_map="auto",
dtype=torch.bfloat16)
image = load_image('images_for_style_embedding/6857740.webp')
input = processor(images=image, return_tensors="pt").to(device=dino_model.device, dtype=torch.bfloat16)
output = dino_model(**input)
last_hidden_states = output.last_hidden_state
cls_token = last_hidden_states[:, 0, :]
pred = embd_model(cls_token).cpu()
print(pred)