Spaces:
Runtime error
Runtime error
| import asyncio | |
| import json | |
| import logging | |
| import click | |
| import requests | |
| from dotenv import load_dotenv | |
| from hugginggpt import generate_response, infer, plan_tasks | |
| from hugginggpt.history import ConversationHistory | |
| from hugginggpt.llm_factory import LLMs, create_llms | |
| from hugginggpt.log import setup_logging | |
| from hugginggpt.model_inference import TaskSummary | |
| from hugginggpt.model_selection import select_hf_models | |
| from hugginggpt.response_generation import format_response | |
| load_dotenv() | |
| setup_logging() | |
| logger = logging.getLogger(__name__) | |
| def main(prompt): | |
| _print_banner() | |
| llms = create_llms() | |
| if prompt: | |
| standalone_mode(user_input=prompt, llms=llms) | |
| else: | |
| interactive_mode(llms=llms) | |
| def standalone_mode(user_input: str, llms: LLMs) -> str: | |
| try: | |
| response, task_summaries = compute( | |
| user_input=user_input, | |
| history=ConversationHistory(), | |
| llms=llms, | |
| ) | |
| pretty_response = format_response(response) | |
| print(pretty_response) | |
| return pretty_response | |
| except Exception as e: | |
| logger.exception("") | |
| print( | |
| f"Sorry, encountered error: {e}. Please try again. Check logs if problem persists." | |
| ) | |
| def interactive_mode(llms: LLMs): | |
| print("Please enter your request. End the conversation with 'exit'") | |
| history = ConversationHistory() | |
| while True: | |
| try: | |
| user_input = click.prompt("User") | |
| if user_input.lower() == "exit": | |
| break | |
| logger.info(f"User input: {user_input}") | |
| response, task_summaries = compute( | |
| user_input=user_input, | |
| history=history, | |
| llms=llms, | |
| ) | |
| pretty_response = format_response(response) | |
| print(f"Assistant:{pretty_response}") | |
| history.add(role="user", content=user_input) | |
| history.add(role="assistant", content=response) | |
| except Exception as e: | |
| logger.exception("") | |
| print( | |
| f"Sorry, encountered error: {e}. Please try again. Check logs if problem persists." | |
| ) | |
| def compute( | |
| user_input: str, | |
| history: ConversationHistory, | |
| llms: LLMs, | |
| ) -> (str, list[TaskSummary]): | |
| tasks = plan_tasks( | |
| user_input=user_input, history=history, llm=llms.task_planning_llm | |
| ) | |
| sorted(tasks, key=lambda t: max(t.dep)) | |
| logger.info(f"Sorted tasks: {tasks}") | |
| hf_models = asyncio.run( | |
| select_hf_models( | |
| user_input=user_input, | |
| tasks=tasks, | |
| model_selection_llm=llms.model_selection_llm, | |
| output_fixing_llm=llms.output_fixing_llm, | |
| ) | |
| ) | |
| task_summaries = [] | |
| with requests.Session() as session: | |
| for task in tasks: | |
| logger.info(f"Starting task: {task}") | |
| if task.depends_on_generated_resources(): | |
| task = task.replace_generated_resources(task_summaries=task_summaries) | |
| model = hf_models[task.id] | |
| inference_result = infer( | |
| task=task, | |
| model_id=model.id, | |
| llm=llms.model_inference_llm, | |
| session=session, | |
| ) | |
| task_summaries.append( | |
| TaskSummary( | |
| task=task, | |
| model=model, | |
| inference_result=json.dumps(inference_result), | |
| ) | |
| ) | |
| logger.info(f"Finished task: {task}") | |
| logger.info("Finished all tasks") | |
| logger.debug(f"Task summaries: {task_summaries}") | |
| response = generate_response( | |
| user_input=user_input, | |
| task_summaries=task_summaries, | |
| llm=llms.response_generation_llm, | |
| ) | |
| return response, task_summaries | |
| def _print_banner(): | |
| with open("resources/banner.txt", "r") as f: | |
| banner = f.read() | |
| logger.info("\n" + banner) | |
| if __name__ == "__main__": | |
| main() | |