Ntdeseb commited on
Commit
635fc93
·
1 Parent(s): 7998696

Agregando más modelos: chat, traducción y optimizaciones de velocidad

Browse files
Files changed (1) hide show
  1. app.py +125 -39
app.py CHANGED
@@ -11,13 +11,42 @@ import base64
11
  MODELS = {
12
  "text": {
13
  "microsoft/DialoGPT-medium": "Chat conversacional",
 
 
14
  "gpt2": "Generación de texto",
 
 
15
  "distilgpt2": "GPT-2 optimizado",
16
- "EleutherAI/gpt-neo-125M": "GPT-Neo pequeño"
 
 
 
 
 
 
 
 
 
17
  },
18
  "image": {
19
  "runwayml/stable-diffusion-v1-5": "Stable Diffusion v1.5",
20
- "CompVis/stable-diffusion-v1-4": "Stable Diffusion v1.4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  }
22
  }
23
 
@@ -25,16 +54,25 @@ MODELS = {
25
  model_cache = {}
26
 
27
  def load_text_model(model_name):
28
- """Cargar modelo de texto"""
29
  if model_name not in model_cache:
30
  print(f"Cargando modelo de texto: {model_name}")
31
- tokenizer = AutoTokenizer.from_pretrained(model_name)
32
- model = AutoModelForCausalLM.from_pretrained(model_name)
33
 
34
- # Configurar para chat si es DialoGPT
35
- if "dialogpt" in model_name.lower():
36
- tokenizer.pad_token = tokenizer.eos_token
37
- model.config.pad_token_id = model.config.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  model_cache[model_name] = {
40
  "tokenizer": tokenizer,
@@ -45,16 +83,21 @@ def load_text_model(model_name):
45
  return model_cache[model_name]
46
 
47
  def load_image_model(model_name):
48
- """Cargar modelo de imagen"""
49
  if model_name not in model_cache:
50
  print(f"Cargando modelo de imagen: {model_name}")
 
 
51
  pipe = StableDiffusionPipeline.from_pretrained(
52
  model_name,
53
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
 
54
  )
55
 
56
- if torch.cuda.is_available():
57
- pipe = pipe.to("cuda")
 
58
 
59
  model_cache[model_name] = {
60
  "pipeline": pipe,
@@ -64,32 +107,40 @@ def load_image_model(model_name):
64
  return model_cache[model_name]
65
 
66
  def generate_text(prompt, model_name, max_length=100):
67
- """Generar texto con el modelo seleccionado"""
68
  try:
69
  model_data = load_text_model(model_name)
70
  tokenizer = model_data["tokenizer"]
71
  model = model_data["model"]
72
 
73
- # Preparar input
74
- inputs = tokenizer.encode(prompt, return_tensors="pt")
75
-
76
- # Generar
77
- with torch.no_grad():
78
- outputs = model.generate(
79
- inputs,
80
- max_length=max_length,
81
- num_return_sequences=1,
82
- temperature=0.7,
83
- do_sample=True,
84
- pad_token_id=tokenizer.eos_token_id
85
- )
86
-
87
- # Decodificar respuesta
88
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
-
90
- # Para DialoGPT, extraer solo la respuesta del asistente
91
- if "dialogpt" in model_name.lower():
92
- response = response.replace(prompt, "").strip()
 
 
 
 
 
 
 
 
93
 
94
  return response
95
 
@@ -97,16 +148,22 @@ def generate_text(prompt, model_name, max_length=100):
97
  return f"Error generando texto: {str(e)}"
98
 
99
  def generate_image(prompt, model_name, num_inference_steps=20):
100
- """Generar imagen con el modelo seleccionado"""
101
  try:
102
  model_data = load_image_model(model_name)
103
  pipeline = model_data["pipeline"]
104
 
105
- # Generar imagen
 
 
 
 
106
  image = pipeline(
107
  prompt,
108
  num_inference_steps=num_inference_steps,
109
- guidance_scale=7.5
 
 
110
  ).images[0]
111
 
112
  return image
@@ -207,7 +264,7 @@ with gr.Blocks(title="Modelos Libres de IA", theme=gr.themes.Soft()) as demo:
207
  with gr.Row():
208
  with gr.Column():
209
  chat_model = gr.Dropdown(
210
- choices=["microsoft/DialoGPT-medium"],
211
  value="microsoft/DialoGPT-medium",
212
  label="Modelo de Chat"
213
  )
@@ -237,6 +294,35 @@ with gr.Blocks(title="Modelos Libres de IA", theme=gr.themes.Soft()) as demo:
237
  outputs=[chatbot]
238
  )
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  # Tab de Generación de Imágenes
241
  with gr.TabItem("🎨 Generación de Imágenes"):
242
  with gr.Row():
@@ -254,7 +340,7 @@ with gr.Blocks(title="Modelos Libres de IA", theme=gr.themes.Soft()) as demo:
254
  steps = gr.Slider(
255
  minimum=10,
256
  maximum=50,
257
- value=20,
258
  step=5,
259
  label="Pasos de inferencia"
260
  )
 
11
  MODELS = {
12
  "text": {
13
  "microsoft/DialoGPT-medium": "Chat conversacional",
14
+ "microsoft/DialoGPT-large": "Chat conversacional avanzado",
15
+ "microsoft/DialoGPT-small": "Chat conversacional rápido",
16
  "gpt2": "Generación de texto",
17
+ "gpt2-medium": "GPT-2 mediano",
18
+ "gpt2-large": "GPT-2 grande",
19
  "distilgpt2": "GPT-2 optimizado",
20
+ "EleutherAI/gpt-neo-125M": "GPT-Neo pequeño",
21
+ "EleutherAI/gpt-neo-1.3B": "GPT-Neo mediano",
22
+ "microsoft/DialoGPT-medium": "Chat conversacional",
23
+ "facebook/opt-125m": "OPT pequeño",
24
+ "facebook/opt-350m": "OPT mediano",
25
+ "bigscience/bloom-560m": "BLOOM multilingüe",
26
+ "bigscience/bloom-1b1": "BLOOM grande",
27
+ "microsoft/DialoGPT-medium": "Chat conversacional",
28
+ "Helsinki-NLP/opus-mt-es-en": "Traductor español-inglés",
29
+ "Helsinki-NLP/opus-mt-en-es": "Traductor inglés-español"
30
  },
31
  "image": {
32
  "runwayml/stable-diffusion-v1-5": "Stable Diffusion v1.5",
33
+ "CompVis/stable-diffusion-v1-4": "Stable Diffusion v1.4",
34
+ "stabilityai/stable-diffusion-2-1": "Stable Diffusion 2.1",
35
+ "stabilityai/stable-diffusion-xl-base-1.0": "SDXL Base",
36
+ "stabilityai/stable-diffusion-xl-refiner-1.0": "SDXL Refiner",
37
+ "prompthero/openjourney": "Midjourney style",
38
+ "dreamlike-art/dreamlike-photoreal-2.0": "Fotorealista",
39
+ "nitrosocke/Ghibli-Diffusion": "Estilo Studio Ghibli",
40
+ "nitrosocke/mo-di-diffusion": "Estilo moderno",
41
+ "CompVis/stable-diffusion-v1-4": "Stable Diffusion v1.4",
42
+ "runwayml/stable-diffusion-v1-5": "Stable Diffusion v1.5"
43
+ },
44
+ "chat": {
45
+ "microsoft/DialoGPT-medium": "Chat conversacional",
46
+ "microsoft/DialoGPT-large": "Chat conversacional avanzado",
47
+ "microsoft/DialoGPT-small": "Chat conversacional rápido",
48
+ "facebook/opt-350m": "OPT conversacional",
49
+ "bigscience/bloom-560m": "BLOOM multilingüe"
50
  }
51
  }
52
 
 
54
  model_cache = {}
55
 
56
  def load_text_model(model_name):
57
+ """Cargar modelo de texto con soporte para diferentes tipos"""
58
  if model_name not in model_cache:
59
  print(f"Cargando modelo de texto: {model_name}")
 
 
60
 
61
+ # Detectar tipo de modelo
62
+ if "opus-mt" in model_name.lower():
63
+ # Modelo de traducción
64
+ from transformers import MarianMTModel, MarianTokenizer
65
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
66
+ model = MarianMTModel.from_pretrained(model_name)
67
+ else:
68
+ # Modelo de generación de texto
69
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
70
+ model = AutoModelForCausalLM.from_pretrained(model_name)
71
+
72
+ # Configurar para chat si es DialoGPT
73
+ if "dialogpt" in model_name.lower():
74
+ tokenizer.pad_token = tokenizer.eos_token
75
+ model.config.pad_token_id = model.config.eos_token_id
76
 
77
  model_cache[model_name] = {
78
  "tokenizer": tokenizer,
 
83
  return model_cache[model_name]
84
 
85
  def load_image_model(model_name):
86
+ """Cargar modelo de imagen - optimizado para velocidad"""
87
  if model_name not in model_cache:
88
  print(f"Cargando modelo de imagen: {model_name}")
89
+
90
+ # Optimizaciones para CPU y velocidad
91
  pipe = StableDiffusionPipeline.from_pretrained(
92
  model_name,
93
+ torch_dtype=torch.float32, # Usar float32 para CPU
94
+ safety_checker=None, # Desactivar safety checker para velocidad
95
+ requires_safety_checker=False
96
  )
97
 
98
+ # Optimizaciones adicionales
99
+ pipe.enable_attention_slicing() # Reducir uso de memoria
100
+ pipe.enable_sequential_cpu_offload() # Optimizar para CPU
101
 
102
  model_cache[model_name] = {
103
  "pipeline": pipe,
 
107
  return model_cache[model_name]
108
 
109
  def generate_text(prompt, model_name, max_length=100):
110
+ """Generar texto con el modelo seleccionado - mejorado para diferentes tipos"""
111
  try:
112
  model_data = load_text_model(model_name)
113
  tokenizer = model_data["tokenizer"]
114
  model = model_data["model"]
115
 
116
+ # Detectar si es modelo de traducción
117
+ if "opus-mt" in model_name.lower():
118
+ # Traducción
119
+ inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
120
+ with torch.no_grad():
121
+ outputs = model.generate(inputs, max_length=max_length, num_beams=4, early_stopping=True)
122
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
123
+ else:
124
+ # Generación de texto
125
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
126
+
127
+ # Generar
128
+ with torch.no_grad():
129
+ outputs = model.generate(
130
+ inputs,
131
+ max_length=max_length,
132
+ num_return_sequences=1,
133
+ temperature=0.7,
134
+ do_sample=True,
135
+ pad_token_id=tokenizer.eos_token_id
136
+ )
137
+
138
+ # Decodificar respuesta
139
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
140
+
141
+ # Para DialoGPT, extraer solo la respuesta del asistente
142
+ if "dialogpt" in model_name.lower():
143
+ response = response.replace(prompt, "").strip()
144
 
145
  return response
146
 
 
148
  return f"Error generando texto: {str(e)}"
149
 
150
  def generate_image(prompt, model_name, num_inference_steps=20):
151
+ """Generar imagen con el modelo seleccionado - optimizado para velocidad"""
152
  try:
153
  model_data = load_image_model(model_name)
154
  pipeline = model_data["pipeline"]
155
 
156
+ # Optimizaciones para velocidad
157
+ if num_inference_steps > 20:
158
+ num_inference_steps = 20 # Limitar a máximo 20 pasos para velocidad
159
+
160
+ # Generar imagen con configuración optimizada
161
  image = pipeline(
162
  prompt,
163
  num_inference_steps=num_inference_steps,
164
+ guidance_scale=7.0, # Reducido de 7.5 para velocidad
165
+ height=512, # Tamaño fijo para consistencia
166
+ width=512
167
  ).images[0]
168
 
169
  return image
 
264
  with gr.Row():
265
  with gr.Column():
266
  chat_model = gr.Dropdown(
267
+ choices=list(MODELS["chat"].keys()),
268
  value="microsoft/DialoGPT-medium",
269
  label="Modelo de Chat"
270
  )
 
294
  outputs=[chatbot]
295
  )
296
 
297
+ # Tab de Traducción
298
+ with gr.TabItem("🌐 Traducción"):
299
+ with gr.Row():
300
+ with gr.Column():
301
+ translate_model = gr.Dropdown(
302
+ choices=["Helsinki-NLP/opus-mt-es-en", "Helsinki-NLP/opus-mt-en-es"],
303
+ value="Helsinki-NLP/opus-mt-es-en",
304
+ label="Modelo de Traducción"
305
+ )
306
+ translate_text = gr.Textbox(
307
+ label="Texto a traducir",
308
+ placeholder="Escribe el texto que quieres traducir...",
309
+ lines=3
310
+ )
311
+ translate_btn = gr.Button("Traducir", variant="primary")
312
+
313
+ with gr.Column():
314
+ translate_output = gr.Textbox(
315
+ label="Traducción",
316
+ lines=3,
317
+ interactive=False
318
+ )
319
+
320
+ translate_btn.click(
321
+ generate_text,
322
+ inputs=[translate_text, translate_model, gr.Slider(value=100, visible=False)],
323
+ outputs=translate_output
324
+ )
325
+
326
  # Tab de Generación de Imágenes
327
  with gr.TabItem("🎨 Generación de Imágenes"):
328
  with gr.Row():
 
340
  steps = gr.Slider(
341
  minimum=10,
342
  maximum=50,
343
+ value=15,
344
  step=5,
345
  label="Pasos de inferencia"
346
  )