QJerry commited on
Commit
524d875
·
verified ·
1 Parent(s): c3cf953

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -22
app.py CHANGED
@@ -7,6 +7,7 @@ import sys
7
  import logging
8
  import warnings
9
  import re
 
10
  from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
11
  from transformers import AutoModel, AutoTokenizer
12
  from dataclasses import dataclass
@@ -17,14 +18,14 @@ from diffusers import ZImagePipeline
17
  from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
18
  from pe import prompt_template
19
 
20
-
21
- # ==================== Environment Variables ================================
22
  MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
23
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
24
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
25
- ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "_flash_3")
26
  DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
27
- # ===========================================================================
 
28
 
29
 
30
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -69,22 +70,43 @@ def get_resolution(resolution):
69
 
70
  def load_models(model_path, enable_compile=False, attention_backend="native"):
71
  print(f"Loading models from {model_path}...")
72
- if not os.path.exists(model_path):
73
- raise FileNotFoundError(f"Model directory not found: {model_path}")
74
-
75
- vae = AutoencoderKL.from_pretrained(
76
- os.path.join(model_path, "vae"),
77
- torch_dtype=torch.bfloat16,
78
- device_map="cuda"
79
- )
80
 
81
- text_encoder = AutoModel.from_pretrained(
82
- os.path.join(model_path, "text_encoder"),
83
- torch_dtype=torch.bfloat16,
84
- device_map="cuda",
85
- ).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
88
  tokenizer.padding_side = "left"
89
 
90
  if enable_compile:
@@ -108,9 +130,15 @@ def load_models(model_path, enable_compile=False, attention_backend="native"):
108
  if enable_compile:
109
  pipe.vae.disable_tiling()
110
 
111
- transformer = ZImageTransformer2DModel.from_pretrained(
112
- os.path.join(model_path, "transformer")
113
- ).to("cuda", torch.bfloat16)
 
 
 
 
 
 
114
 
115
  pipe.transformer = transformer
116
  pipe.transformer.set_attention_backend(attention_backend)
@@ -320,6 +348,7 @@ def prompt_enhance(prompt, enable_enhance):
320
  except Exception as e:
321
  return prompt, f"Error: {str(e)}"
322
 
 
323
  def generate(prompt, resolution, seed, steps, shift, enhance):
324
  if pipe is None:
325
  raise gr.Error("Model not loaded.")
@@ -350,7 +379,6 @@ def generate(prompt, resolution, seed, steps, shift, enhance):
350
 
351
  return image, final_prompt, str(seed)
352
 
353
- # ==================== Gradio Interface ====================
354
  init_app()
355
 
356
  with gr.Blocks(title="Z-Image Demo") as demo:
 
7
  import logging
8
  import warnings
9
  import re
10
+ import spaces
11
  from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
12
  from transformers import AutoModel, AutoTokenizer
13
  from dataclasses import dataclass
 
18
  from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
19
  from pe import prompt_template
20
 
21
+ # ==================== Environment Variables ==================================
 
22
  MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
23
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
24
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
25
+ ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
26
  DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
27
+ HF_TOKEN = os.environ.get("HF_TOKEN")
28
+ # =============================================================================
29
 
30
 
31
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
70
 
71
  def load_models(model_path, enable_compile=False, attention_backend="native"):
72
  print(f"Loading models from {model_path}...")
 
 
 
 
 
 
 
 
73
 
74
+ use_auth_token = HF_TOKEN if HF_TOKEN else True
75
+
76
+ if not os.path.exists(model_path):
77
+ vae = AutoencoderKL.from_pretrained(
78
+ f"{model_path}/vae",
79
+ torch_dtype=torch.bfloat16,
80
+ device_map="cuda",
81
+ use_auth_token=use_auth_token
82
+ )
83
+
84
+ text_encoder = AutoModel.from_pretrained(
85
+ f"{model_path}/text_encoder",
86
+ torch_dtype=torch.bfloat16,
87
+ device_map="cuda",
88
+ use_auth_token=use_auth_token
89
+ ).eval()
90
+
91
+ tokenizer = AutoTokenizer.from_pretrained(
92
+ f"{model_path}/tokenizer",
93
+ use_auth_token=use_auth_token
94
+ )
95
+ else:
96
+ vae = AutoencoderKL.from_pretrained(
97
+ os.path.join(model_path, "vae"),
98
+ torch_dtype=torch.bfloat16,
99
+ device_map="cuda"
100
+ )
101
+
102
+ text_encoder = AutoModel.from_pretrained(
103
+ os.path.join(model_path, "text_encoder"),
104
+ torch_dtype=torch.bfloat16,
105
+ device_map="cuda",
106
+ ).eval()
107
+
108
+ tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
109
 
 
110
  tokenizer.padding_side = "left"
111
 
112
  if enable_compile:
 
130
  if enable_compile:
131
  pipe.vae.disable_tiling()
132
 
133
+ if not os.path.exists(model_path):
134
+ transformer = ZImageTransformer2DModel.from_pretrained(
135
+ f"{model_path}/transformer",
136
+ use_auth_token=use_auth_token
137
+ ).to("cuda", torch.bfloat16)
138
+ else:
139
+ transformer = ZImageTransformer2DModel.from_pretrained(
140
+ os.path.join(model_path, "transformer")
141
+ ).to("cuda", torch.bfloat16)
142
 
143
  pipe.transformer = transformer
144
  pipe.transformer.set_attention_backend(attention_backend)
 
348
  except Exception as e:
349
  return prompt, f"Error: {str(e)}"
350
 
351
+ @spaces.GPU
352
  def generate(prompt, resolution, seed, steps, shift, enhance):
353
  if pipe is None:
354
  raise gr.Error("Model not loaded.")
 
379
 
380
  return image, final_prompt, str(seed)
381
 
 
382
  init_app()
383
 
384
  with gr.Blocks(title="Z-Image Demo") as demo: