Roman190928 commited on
Commit
5ad1055
·
verified ·
1 Parent(s): 0343654

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -36
app.py CHANGED
@@ -2,35 +2,26 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
 
5
- from diffusers import DiffusionPipeline, EulerDiscreteScheduler
 
6
  import torch
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_repo_id = "black-forest-labs/FLUX.1-schnell"
10
 
11
- # Choose dtype
12
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
13
 
14
- # Load scheduler first
15
- scheduler = EulerDiscreteScheduler.from_pretrained(
16
- model_repo_id,
17
- subfolder="scheduler"
18
- )
19
-
20
- # Load pipeline w/ custom scheduler
21
- pipe = DiffusionPipeline.from_pretrained(
22
- model_repo_id,
23
- scheduler=scheduler,
24
- torch_dtype=torch_dtype
25
- ).to(device)
26
-
27
- # Speed tweaks for CPU / weak GPU
28
- pipe.enable_attention_slicing()
29
- pipe.enable_sequential_cpu_offload()
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
  MAX_IMAGE_SIZE = 1024
33
 
 
 
34
  def infer(
35
  prompt,
36
  negative_prompt,
@@ -50,10 +41,10 @@ def infer(
50
  image = pipe(
51
  prompt=prompt,
52
  negative_prompt=negative_prompt,
53
- width=width,
54
- height=height,
55
  guidance_scale=guidance_scale,
56
  num_inference_steps=num_inference_steps,
 
 
57
  generator=generator,
58
  ).images[0]
59
 
@@ -61,9 +52,9 @@ def infer(
61
 
62
 
63
  examples = [
64
- "Astronaut in the jungle, hyper-detailed",
65
- "A green horse in space",
66
- "Cyberpunk ramen bowl with neon lights",
67
  ]
68
 
69
  css = """
@@ -75,7 +66,7 @@ css = """
75
 
76
  with gr.Blocks(css=css) as demo:
77
  with gr.Column(elem_id="col-container"):
78
- gr.Markdown(" # Flux Schnell Inference 🎨🔥")
79
 
80
  with gr.Row():
81
  prompt = gr.Text(
@@ -85,6 +76,7 @@ with gr.Blocks(css=css) as demo:
85
  placeholder="Enter your prompt",
86
  container=False,
87
  )
 
88
  run_button = gr.Button("Run", scale=0, variant="primary")
89
 
90
  result = gr.Image(label="Result", show_label=False)
@@ -93,7 +85,7 @@ with gr.Blocks(css=css) as demo:
93
  negative_prompt = gr.Text(
94
  label="Negative prompt",
95
  max_lines=1,
96
- placeholder="bad quality, blurry",
97
  visible=True,
98
  )
99
 
@@ -110,17 +102,18 @@ with gr.Blocks(css=css) as demo:
110
  with gr.Row():
111
  width = gr.Slider(
112
  label="Width",
113
- minimum=256,
114
  maximum=MAX_IMAGE_SIZE,
115
  step=32,
116
- value=512,
117
  )
 
118
  height = gr.Slider(
119
  label="Height",
120
- minimum=256,
121
  maximum=MAX_IMAGE_SIZE,
122
  step=32,
123
- value=512,
124
  )
125
 
126
  with gr.Row():
@@ -129,19 +122,18 @@ with gr.Blocks(css=css) as demo:
129
  minimum=0.0,
130
  maximum=10.0,
131
  step=0.1,
132
- value=0.0,
133
  )
134
 
135
  num_inference_steps = gr.Slider(
136
  label="Number of inference steps",
137
  minimum=1,
138
- maximum=20,
139
  step=1,
140
- value=4,
141
  )
142
 
143
  gr.Examples(examples=examples, inputs=[prompt])
144
-
145
  gr.on(
146
  triggers=[run_button.click, prompt.submit],
147
  fn=infer,
@@ -159,4 +151,4 @@ with gr.Blocks(css=css) as demo:
159
  )
160
 
161
  if __name__ == "__main__":
162
- demo.launch()
 
2
  import numpy as np
3
  import random
4
 
5
+ # import spaces #[uncomment to use ZeroGPU]
6
+ from diffusers import DiffusionPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
+ if torch.cuda.is_available():
13
+ torch_dtype = torch.float16
14
+ else:
15
+ torch_dtype = torch.float32
16
 
17
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
+ pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
+
24
+ # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
  prompt,
27
  negative_prompt,
 
41
  image = pipe(
42
  prompt=prompt,
43
  negative_prompt=negative_prompt,
 
 
44
  guidance_scale=guidance_scale,
45
  num_inference_steps=num_inference_steps,
46
+ width=width,
47
+ height=height,
48
  generator=generator,
49
  ).images[0]
50
 
 
52
 
53
 
54
  examples = [
55
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
+ "An astronaut riding a green horse",
57
+ "A delicious ceviche cheesecake slice",
58
  ]
59
 
60
  css = """
 
66
 
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
+ gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
  with gr.Row():
72
  prompt = gr.Text(
 
76
  placeholder="Enter your prompt",
77
  container=False,
78
  )
79
+
80
  run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
  result = gr.Image(label="Result", show_label=False)
 
85
  negative_prompt = gr.Text(
86
  label="Negative prompt",
87
  max_lines=1,
88
+ placeholder="Enter a negative prompt",
89
  visible=True,
90
  )
91
 
 
102
  with gr.Row():
103
  width = gr.Slider(
104
  label="Width",
105
+ minimum=64,
106
  maximum=MAX_IMAGE_SIZE,
107
  step=32,
108
+ value=1024, # Replace with defaults that work for your model
109
  )
110
+
111
  height = gr.Slider(
112
  label="Height",
113
+ minimum=64,
114
  maximum=MAX_IMAGE_SIZE,
115
  step=32,
116
+ value=1024, # Replace with defaults that work for your model
117
  )
118
 
119
  with gr.Row():
 
122
  minimum=0.0,
123
  maximum=10.0,
124
  step=0.1,
125
+ value=1, # Replace with defaults that work for your model
126
  )
127
 
128
  num_inference_steps = gr.Slider(
129
  label="Number of inference steps",
130
  minimum=1,
131
+ maximum=50,
132
  step=1,
133
+ value=7, # Replace with defaults that work for your model
134
  )
135
 
136
  gr.Examples(examples=examples, inputs=[prompt])
 
137
  gr.on(
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
 
151
  )
152
 
153
  if __name__ == "__main__":
154
+ demo.launch()