Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| from torch import nn | |
| from diffusers import DDPMScheduler, UNet2DModel | |
| import matplotlib.pyplot as plt | |
| from tqdm.auto import tqdm | |
| # Reuse your existing model code | |
| class ClassConditionedUnet(nn.Module): | |
| def __init__(self, num_classes=3, class_emb_size=12): | |
| super().__init__() | |
| self.class_emb = nn.Embedding(num_classes, class_emb_size) | |
| self.model = UNet2DModel( | |
| sample_size=64, | |
| in_channels=3 + class_emb_size, | |
| out_channels=3, | |
| layers_per_block=2, | |
| block_out_channels=(64, 128, 256, 512), | |
| down_block_types=( | |
| "DownBlock2D", | |
| "DownBlock2D", | |
| "AttnDownBlock2D", | |
| "AttnDownBlock2D", | |
| ), | |
| up_block_types=( | |
| "AttnUpBlock2D", | |
| "AttnUpBlock2D", | |
| "UpBlock2D", | |
| "UpBlock2D", | |
| ), | |
| ) | |
| def forward(self, x, t, class_labels): | |
| bs, ch, w, h = x.shape | |
| class_cond = self.class_emb(class_labels) | |
| class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h) | |
| net_input = torch.cat((x, class_cond), 1) | |
| return self.model(net_input, t).sample | |
| def load_model(model_path): | |
| """Load the model with caching to avoid reloading""" | |
| device = 'cpu' # For deployment, we'll use CPU | |
| net = ClassConditionedUnet().to(device) | |
| noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2') | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| net.load_state_dict(checkpoint['model_state_dict']) | |
| return net, noise_scheduler | |
| def generate_mixed_faces(net, noise_scheduler, mix_weights, num_images=1): | |
| """Generate faces with mixed ethnic features""" | |
| device = next(net.parameters()).device | |
| net.eval() | |
| with torch.no_grad(): | |
| x = torch.randn(num_images, 3, 64, 64).to(device) | |
| # Get embeddings for all classes | |
| emb_asian = net.class_emb(torch.zeros(num_images).long().to(device)) | |
| emb_indian = net.class_emb(torch.ones(num_images).long().to(device)) | |
| emb_european = net.class_emb(torch.full((num_images,), 2).to(device)) | |
| progress_bar = st.progress(0) | |
| for idx, t in enumerate(noise_scheduler.timesteps): | |
| # Update progress bar | |
| progress_bar.progress(idx / len(noise_scheduler.timesteps)) | |
| # Mix embeddings according to weights | |
| mixed_emb = ( | |
| mix_weights[0] * emb_asian + | |
| mix_weights[1] * emb_indian + | |
| mix_weights[2] * emb_european | |
| ) | |
| # Override embedding layer temporarily | |
| original_forward = net.class_emb.forward | |
| net.class_emb.forward = lambda _: mixed_emb | |
| residual = net(x, t, torch.zeros(num_images).long().to(device)) | |
| x = noise_scheduler.step(residual, t, x).prev_sample | |
| # Restore original embedding layer | |
| net.class_emb.forward = original_forward | |
| progress_bar.progress(1.0) | |
| x = (x.clamp(-1, 1) + 1) / 2 | |
| return x | |
| def main(): | |
| st.title("AI Face Generator with Ethnic Features Mixing") | |
| # Load model | |
| try: | |
| net, noise_scheduler = load_model('final_model/final_diffusion_model.pt') | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return | |
| # Create sliders for ethnicity percentages | |
| st.subheader("Adjust Ethnicity Mix") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| asian_pct = st.slider("Asian Features %", 0, 100, 33, 1) | |
| with col2: | |
| indian_pct = st.slider("Indian Features %", 0, 100, 33, 1) | |
| with col3: | |
| european_pct = st.slider("European Features %", 0, 100, 34, 1) | |
| # Calculate total and normalize if needed | |
| total = asian_pct + indian_pct + european_pct | |
| if total == 0: | |
| st.warning("Total percentage cannot be 0%. Please adjust the sliders.") | |
| return | |
| # Normalize weights to sum to 1 | |
| weights = [asian_pct/total, indian_pct/total, european_pct/total] | |
| # Display current mix | |
| st.write("Current mix (normalized):") | |
| st.write(f"Asian: {weights[0]:.2%}, Indian: {weights[1]:.2%}, European: {weights[2]:.2%}") | |
| # Generate button | |
| if st.button("Generate Face"): | |
| try: | |
| with st.spinner("Generating face..."): | |
| # Generate the image | |
| generated_images = generate_mixed_faces(net, noise_scheduler, weights) | |
| # Convert to numpy and display | |
| img = generated_images[0].permute(1, 2, 0).cpu().numpy() | |
| st.image(img, caption="Generated Face", use_column_width=True) | |
| except Exception as e: | |
| st.error(f"Error generating image: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |