Spaces:
Runtime error
Runtime error
| import socketio | |
| import requests | |
| import json | |
| import time | |
| import random | |
| import base64 | |
| import io | |
| import PIL | |
| from PIL import Image | |
| from io import BytesIO | |
| import gradio as gr | |
| from requests_toolbelt.multipart.encoder import MultipartEncoder | |
| from constant import * | |
| def login(email, password): | |
| payload = {'password': password} | |
| if email: | |
| payload['email'] = email | |
| response = requests.post(f"{BASE_URL}/user/login", json=payload) | |
| try: | |
| response_data = response.json() | |
| except json.JSONDecodeError as e: | |
| log("ERROR", f"Error in login: {response}") | |
| raise e | |
| if 'error' in response_data and response_data['error']: | |
| raise Exception(response_data['error']) | |
| log("INFO", f"Logged successfully") | |
| user_uuid = response_data['user_uuid'] | |
| token = response_data['token'] | |
| return user_uuid, token | |
| def rodin_history(task_uuid, token): | |
| headers = { | |
| 'Authorization': f'Bearer {token}' | |
| } | |
| response = requests.post(f"{BASE_URL}/task/rodin_history", data={"uuid": task_uuid}, headers=headers) | |
| return response.json() | |
| def rodin_preprocess_image(generate_prompt, image, name, token): | |
| m = MultipartEncoder( | |
| fields={ | |
| 'generate_prompt': "true" if generate_prompt else "false", | |
| 'images': (name, image, 'image/jpeg') | |
| } | |
| ) | |
| headers = { | |
| 'Content-Type': m.content_type, | |
| 'Authorization': f'Bearer {token}' | |
| } | |
| response = requests.post(f"{BASE_URL}/task/rodin_mesh_image_process", data=m, headers=headers) | |
| return response | |
| def crop_image(image, type): | |
| if image == None: | |
| raise gr.Error("Please generate the object first") | |
| new_image_width = 360 * (11520 // 720) # 每隔720像素裁切一次,每次裁切宽度为360 | |
| new_image_height = 360 # 新图片的高度 | |
| new_image = Image.new('RGB', (new_image_width, new_image_height)) | |
| for i in range(11520 // 720): | |
| left = i * 720 + type[1] | |
| upper = type[0] | |
| right = left + 360 | |
| lower = upper + 360 | |
| cropped_image = image.crop((left, upper, right, lower)) | |
| new_image.paste(cropped_image, (i * 360, 0)) | |
| return new_image | |
| # Perform Rodin mesh operation | |
| def rodin_mesh(prompt, group_uuid, settings, images, name, token): | |
| images = [convert_base64_to_binary(img) for img in images] | |
| m = MultipartEncoder( | |
| fields={ | |
| 'prompt': prompt, | |
| 'group_uuid': group_uuid, | |
| 'settings': json.dumps(settings), # Convert settings dictionary to JSON string | |
| **{f'images': (name, image, 'image/jpeg') for i, image in enumerate(images)} | |
| } | |
| ) | |
| headers = { | |
| 'Content-Type': m.content_type, | |
| 'Authorization': f'Bearer {token}' | |
| } | |
| response = requests.post(f"{BASE_URL}/task/rodin_mesh", data=m, headers=headers) | |
| return response | |
| # Convert base64 to binary since the result from `rodin_preprocess_image` is encoded with base64 | |
| def convert_base64_to_binary(base64_string): | |
| if ',' in base64_string: | |
| base64_string = base64_string.split(',')[1] | |
| image_data = base64.b64decode(base64_string) | |
| image_buffer = io.BytesIO(image_data) | |
| return image_buffer | |
| def rodin_update(prompt, task_uuid, token, settings): | |
| headers = { | |
| 'Authorization': f'Bearer {token}' | |
| } | |
| response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers) | |
| return response | |
| def load_image(img_path): | |
| try: | |
| image = Image.open(img_path) | |
| except PIL.UnidentifiedImageError as e: | |
| raise gr.Error("Unsupported Image Format") | |
| # 按比例缩小图像到长度为1024 | |
| width, height = image.size | |
| if width > height: | |
| scale = 512 / width | |
| else: | |
| scale = 512 / height | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| resized_image = image.resize((new_width, new_height)) | |
| # 将 PIL.Image 对象转换为字节流 | |
| byte_io = BytesIO() | |
| resized_image.save(byte_io, format='PNG') | |
| image_bytes = byte_io.getvalue() | |
| return image_bytes | |
| def log(level, info_text): | |
| print(f"[ {level} ] - {time.strftime('%Y%m%d_%H:%M:%S', time.localtime())} - {info_text}") | |
| class Generator: | |
| def __init__(self, user_id, password, token) -> None: | |
| # _, self.token = login(user_id, password) | |
| self.token = token | |
| self.user_id = user_id | |
| self.password = password | |
| self.task_uuid = None | |
| self.processed_image = None | |
| def preprocess(self, prompt, image_path, processed_image , task_uuid=""): | |
| if image_path == None: | |
| raise gr.Error("Please upload an image first") | |
| if processed_image and prompt and (not task_uuid): | |
| log("INFO", "Using cached image and prompt...") | |
| return prompt, processed_image | |
| log("INFO", "Preprocessing image...") | |
| success = False | |
| try_times = 0 | |
| while not success: | |
| if try_times > 3: | |
| raise gr.Error("Failed to preprocess image") | |
| try_times += 1 | |
| image_file = load_image(image_path) | |
| log("INFO", "Image loaded, processing...") | |
| try: | |
| if prompt and task_uuid: | |
| res = rodin_preprocess_image(generate_prompt=False, image=image_file, name=os.path.basename(image_path), token=self.token) | |
| else: | |
| res = rodin_preprocess_image(generate_prompt=True, image=image_file, name=os.path.basename(image_path), token=self.token) | |
| preprocess_response = res.json() | |
| log("INFO", f"Image preprocessed: {preprocess_response.get('statusCode')}") | |
| except Exception as e: | |
| log("ERROR", f"Error in image preprocessing: {res}") | |
| raise gr.Error("Error in image preprocessing, please try again.") | |
| if 'error' in preprocess_response: | |
| log("ERROR", f"Error in image preprocessing: {preprocess_response}") | |
| raise gr.Error("Error in image preprocessing, please try again.") | |
| elif preprocess_response.get("statusCode") == 400: | |
| if "InvalidFile.Content" in preprocess_response.get("message"): | |
| raise gr.Error("Unsupported Image Format") | |
| else: | |
| log("ERROR", f"Error in image preprocessing: {preprocess_response}") | |
| raise gr.Error("Busy connection, please try again later.") | |
| elif preprocess_response.get("statusCode") == 401: | |
| log("WARNING", "Token expired. Logging in again...") | |
| _, self.token = login(self.user_id, self.password) | |
| continue | |
| else: | |
| try: | |
| if not (prompt and task_uuid): | |
| prompt = preprocess_response.get('prompt', None) | |
| processed_image = "data:image/png;base64," + preprocess_response.get('processed_image', None) | |
| success = True | |
| except Exception as e: | |
| log("ERROR", f"Error in image preprocessing: {preprocess_response}") | |
| raise gr.Error("Busy connection, please try again later.") | |
| return prompt, processed_image | |
| def generate_mesh(self, prompt, processed_image, task_uuid=""): | |
| log("INFO", "Generating mesh...") | |
| if task_uuid == "": | |
| settings = {'view_weights': [1]} # Define weights as per your requirements, for multiple images, use multiple values, e,g [0.5, 0.5] | |
| images = [processed_image] # List of images, all the images should be processed first | |
| res = rodin_mesh(prompt=prompt, group_uuid=None, settings=settings, images=images, name="images.jpeg", token=self.token) | |
| try: | |
| mesh_response = res.json() | |
| progress_checker = JobStatusChecker(BASE_URL, mesh_response['job']['subscription_key']) | |
| progress_checker.start() | |
| except Exception as e: | |
| log("ERROR", f"Error in generating mesh: {e} and response: {res}") | |
| raise gr.Error("Error in generating mesh, please try again later.") | |
| task_uuid = mesh_response['uuid'] # The task_uuid should be same during whole generation process | |
| else: | |
| new_prompt = prompt | |
| settings = { | |
| "view_weights": [1], | |
| "seed": random.randint(0, 10000), # Customize your seed here | |
| "escore": 5.5, # Temprature | |
| } | |
| res = rodin_update(new_prompt, task_uuid, self.token, settings) | |
| try: | |
| update_response = res.json() | |
| subscription_key = update_response['job']['subscription_key'] | |
| checker = JobStatusChecker(BASE_URL, subscription_key) | |
| checker.start() | |
| except Exception as e: | |
| log("ERROR", f"Error in updating mesh: {e}") | |
| raise gr.Error("Error in generating mesh, please try again later.") | |
| try: | |
| history = rodin_history(task_uuid, self.token) | |
| preview_image = next(reversed(history.items()))[1]["preview_image"] | |
| except Exception as e: | |
| log("ERROR", f"Error in generating mesh: {history}") | |
| raise gr.Error("Busy connection, please try again later.") | |
| response = requests.get(preview_image, stream=True) | |
| if response.status_code == 200: | |
| # 创建一个PIL Image对象 | |
| image = Image.open(response.raw) | |
| # 在这里对image对象进行处理,如显示、保存等 | |
| else: | |
| log("ERROR", f"Error in generating mesh: {response}") | |
| raise RuntimeError | |
| response.close() | |
| return image, task_uuid, crop_image(image, DEFAULT) | |
| class JobStatusChecker: | |
| def __init__(self, base_url, subscription_key): | |
| self.base_url = base_url | |
| self.subscription_key = subscription_key | |
| self.sio = socketio.Client(logger=True, engineio_logger=True) | |
| def connect(): | |
| print("Connected to the server.") | |
| def disconnect(): | |
| print("Disconnected from server.") | |
| def message(*args, **kwargs): | |
| if len(args) > 2: | |
| data = args[2] | |
| if data.get('jobStatus') == 'Succeeded': | |
| print("Job Succeeded! Please find the SDF image in history") | |
| self.sio.disconnect() | |
| else: | |
| print("Received event with insufficient arguments.") | |
| def start(self): | |
| self.sio.connect(f"{self.base_url}/scheduler_socket?subscription={self.subscription_key}", | |
| namespaces=['/api/scheduler_socket'], transports='websocket') | |
| self.sio.wait() |