Spaces:
Runtime error
Runtime error
| import torch | |
| import time | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from datasets import load_dataset | |
| # Конфигурация моделей | |
| MODEL_CONFIGS = { | |
| "GigaChat-like": "ai-forever/rugpt2large", | |
| "ChatGPT-like": "ai-forever/rugpt3large_based_on_gpt2", | |
| "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2" | |
| } | |
| # Устройство | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Загрузка моделей | |
| models = {} | |
| for label, name in MODEL_CONFIGS.items(): | |
| tokenizer = AutoTokenizer.from_pretrained(name) | |
| model = AutoModelForCausalLM.from_pretrained(name) | |
| model.to(device) | |
| model.eval() | |
| models[label] = (tokenizer, model) | |
| # Загрузка датасета (не используется напрямую, но может быть полезен) | |
| dataset = load_dataset("ZhenDOS/alpha_bank_data", split="train") | |
| # CoT-промпты | |
| def cot_prompt_1(text): | |
| return f"Клиент задал вопрос: {text}\nПодумай шаг за шагом и объясни, как бы ты ответил на это обращение от лица банка." | |
| def cot_prompt_2(text): | |
| return f"Вопрос клиента: {text}\nРазложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями." | |
| # Генерация | |
| def generate_all_responses(question): | |
| results = {} | |
| for model_name, (tokenizer, model) in models.items(): | |
| results[model_name] = {} | |
| for i, prompt_func in enumerate([cot_prompt_1, cot_prompt_2], start=1): | |
| prompt = prompt_func(question) | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| end_time = time.time() | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response.replace(prompt, "").strip() | |
| duration = round(end_time - start_time, 2) | |
| results[model_name][f"CoT Промпт {i}"] = { | |
| "response": response, | |
| "time": f"{duration} сек." | |
| } | |
| return results | |
| # Отображение | |
| def display_responses(question): | |
| all_responses = generate_all_responses(question) | |
| output = "" | |
| for model_name, prompts in all_responses.items(): | |
| output += f"\n### Модель: {model_name}\n" | |
| for prompt_label, content in prompts.items(): | |
| output += f"\n**{prompt_label}** ({content['time']}):\n{content['response']}\n" | |
| return output.strip() | |
| # Интерфейс | |
| demo = gr.Interface( | |
| fn=display_responses, | |
| inputs=gr.Textbox(lines=4, label="Введите клиентский вопрос"), | |
| outputs=gr.Markdown(label="Ответы от разных моделей"), | |
| title="Alpha Bank Assistant — сравнение моделей", | |
| description="Сравнение CoT-ответов от GigaChat, ChatGPT и DeepSeek-подобных моделей на обращение клиента.", | |
| examples=[ | |
| "Как восстановить доступ в мобильный банк?", | |
| "Почему с меня списали комиссию за обслуживание карты?", | |
| "Какие условия по потребительскому кредиту?", | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |