Spaces:
Paused
Paused
| import gradio as gr | |
| import base64 | |
| import json | |
| import os | |
| import shutil | |
| import uuid | |
| import glob | |
| from huggingface_hub import CommitScheduler, HfApi, snapshot_download | |
| from pathlib import Path | |
| import git | |
| from datasets import Dataset, Features, Value, Sequence, Image as ImageFeature | |
| import threading | |
| import time | |
| from utils import process_and_push_dataset | |
| from datasets import load_dataset | |
| api = HfApi(token=os.environ["HF_TOKEN"]) | |
| VALID_DATASET = load_dataset("taesiri/IERv2-Subset", split="train") | |
| VALID_DATASET_POST_IDS = ( | |
| load_dataset("taesiri/IERv2-Subset", split="train", columns=["post_id"]) | |
| .to_pandas()["post_id"] | |
| .tolist() | |
| ) | |
| POST_ID_TO_ID_MAP = {post_id: idx for idx, post_id in enumerate(VALID_DATASET_POST_IDS)} | |
| DATASET_REPO = "taesiri/AIImageEditingResults_Intemediate" | |
| FINAL_DATASET_REPO = "taesiri/AIImageEditingResults" | |
| # Download existing data from hub | |
| def sync_with_hub(): | |
| """ | |
| Synchronize local data with the hub by cloning the dataset repo | |
| """ | |
| print("Starting sync with hub...") | |
| data_dir = Path("./data") | |
| if data_dir.exists(): | |
| # Backup existing data | |
| backup_dir = Path("./data_backup") | |
| if backup_dir.exists(): | |
| shutil.rmtree(backup_dir) | |
| shutil.copytree(data_dir, backup_dir) | |
| # Clone/pull latest data from hub | |
| # Use token in the URL for authentication following HF's new format | |
| token = os.environ["HF_TOKEN"] | |
| username = "taesiri" # Extract from DATASET_REPO | |
| repo_url = f"https://{username}:{token}@huggingface.co/datasets/{DATASET_REPO}" | |
| hub_data_dir = Path("hub_data") | |
| if hub_data_dir.exists(): | |
| # If repo exists, do a git pull | |
| print("Pulling latest changes...") | |
| repo = git.Repo(hub_data_dir) | |
| origin = repo.remotes.origin | |
| # Set the new URL with token | |
| if "https://" in origin.url: | |
| origin.set_url(repo_url) | |
| origin.pull() | |
| else: | |
| # Clone the repo with token | |
| print("Cloning repository...") | |
| git.Repo.clone_from(repo_url, hub_data_dir) | |
| # Merge hub data with local data | |
| hub_data_source = hub_data_dir / "data" | |
| if hub_data_source.exists(): | |
| # Create data dir if it doesn't exist | |
| data_dir.mkdir(exist_ok=True) | |
| # Copy files from hub | |
| for item in hub_data_source.glob("*"): | |
| if item.is_dir(): | |
| dest = data_dir / item.name | |
| if not dest.exists(): # Only copy if doesn't exist locally | |
| shutil.copytree(item, dest) | |
| # Clean up cloned repo | |
| if hub_data_dir.exists(): | |
| shutil.rmtree(hub_data_dir) | |
| print("Finished syncing with hub!") | |
| scheduler = CommitScheduler( | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| folder_path="./data", | |
| path_in_repo="data", | |
| every=1, | |
| ) | |
| def load_question_data(question_id): | |
| """ | |
| Load a specific question's data | |
| Returns a tuple of all form fields | |
| """ | |
| if not question_id: | |
| return [None] * 11 # Reduced number of fields | |
| # Extract the ID part before the colon from the dropdown selection | |
| question_id = ( | |
| question_id.split(":")[0].strip() if ":" in question_id else question_id | |
| ) | |
| json_path = os.path.join("./data", question_id, "question.json") | |
| if not os.path.exists(json_path): | |
| print(f"Question file not found: {json_path}") | |
| return [None] * 11 | |
| try: | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| data = json.loads(f.read().strip()) | |
| # Load images | |
| def load_image(image_path): | |
| if not image_path: | |
| return None | |
| full_path = os.path.join( | |
| "./data", question_id, os.path.basename(image_path) | |
| ) | |
| return full_path if os.path.exists(full_path) else None | |
| question_images = data.get("question_images", []) | |
| rationale_images = data.get("rationale_images", []) | |
| return [ | |
| ( | |
| ",".join(data["question_categories"]) | |
| if isinstance(data["question_categories"], list) | |
| else data["question_categories"] | |
| ), | |
| data["question"], | |
| data["final_answer"], | |
| data.get("rationale_text", ""), | |
| load_image(question_images[0] if question_images else None), | |
| load_image(question_images[1] if len(question_images) > 1 else None), | |
| load_image(question_images[2] if len(question_images) > 2 else None), | |
| load_image(question_images[3] if len(question_images) > 3 else None), | |
| load_image(rationale_images[0] if rationale_images else None), | |
| load_image(rationale_images[1] if len(rationale_images) > 1 else None), | |
| question_id, | |
| ] | |
| except Exception as e: | |
| print(f"Error loading question {question_id}: {str(e)}") | |
| return [None] * 11 | |
| def load_post_image(post_id): | |
| if not post_id: | |
| return [ | |
| None | |
| ] * 33 # source image + instruction + simplified_instruction + 10 triplets | |
| idx = POST_ID_TO_ID_MAP[post_id] | |
| source_image = VALID_DATASET[idx]["image"] | |
| instruction = VALID_DATASET[idx]["instruction"] | |
| simplified_instruction = VALID_DATASET[idx]["simplified_instruction"] | |
| # Load existing responses if any | |
| post_folder = os.path.join("./data", str(post_id)) | |
| metadata_path = os.path.join(post_folder, "metadata.json") | |
| if os.path.exists(metadata_path): | |
| with open(metadata_path, "r") as f: | |
| metadata = json.load(f) | |
| # Initialize response data | |
| responses = [(None, "", "")] * 10 # Initialize with empty notes | |
| # Fill in existing responses | |
| for response in metadata["responses"]: | |
| idx = response["response_id"] | |
| if idx < 10: # Ensure we don't exceed our UI limit | |
| image_path = os.path.join(post_folder, response["image_path"]) | |
| responses[idx] = ( | |
| image_path, | |
| response["answer_text"], | |
| response.get("notes", ""), | |
| ) | |
| # Flatten responses for output | |
| flat_responses = [item for triplet in responses for item in triplet] | |
| return [source_image, instruction, simplified_instruction] + flat_responses | |
| # If no existing responses, return source image, instructions and empty responses | |
| return [source_image, instruction, simplified_instruction] + [None] * 30 | |
| def generate_json_files(source_image, responses, post_id): | |
| """ | |
| Save the source image and multiple responses to the data directory | |
| Args: | |
| source_image: Path to the source image | |
| responses: List of (image, answer, notes) tuples | |
| post_id: The post ID from the dataset | |
| """ | |
| # Create parent data folder if it doesn't exist | |
| parent_data_folder = "./data" | |
| os.makedirs(parent_data_folder, exist_ok=True) | |
| # Create/clear post_id folder | |
| post_folder = os.path.join(parent_data_folder, str(post_id)) | |
| if os.path.exists(post_folder): | |
| shutil.rmtree(post_folder) | |
| os.makedirs(post_folder) | |
| # Save source image | |
| source_image_path = os.path.join(post_folder, "source_image.png") | |
| if isinstance(source_image, str): | |
| shutil.copy2(source_image, source_image_path) | |
| else: | |
| gr.processing_utils.save_image(source_image, source_image_path) | |
| # Create responses data | |
| responses_data = [] | |
| for idx, (response_image, answer_text, notes) in enumerate(responses): | |
| if response_image and answer_text: # Only process if both image and text exist | |
| response_folder = os.path.join(post_folder, f"response_{idx}") | |
| os.makedirs(response_folder) | |
| # Save response image | |
| response_image_path = os.path.join(response_folder, "response_image.png") | |
| if isinstance(response_image, str): | |
| shutil.copy2(response_image, response_image_path) | |
| else: | |
| gr.processing_utils.save_image(response_image, response_image_path) | |
| # Add to responses data | |
| responses_data.append( | |
| { | |
| "response_id": idx, | |
| "answer_text": answer_text, | |
| "notes": notes, | |
| "image_path": f"response_{idx}/response_image.png", | |
| } | |
| ) | |
| # Create metadata JSON | |
| metadata = { | |
| "post_id": post_id, | |
| "source_image": "source_image.png", | |
| "responses": responses_data, | |
| } | |
| # Save metadata | |
| with open(os.path.join(post_folder, "metadata.json"), "w", encoding="utf-8") as f: | |
| json.dump(metadata, f, ensure_ascii=False, indent=2) | |
| return post_folder | |
| def get_statistics(): | |
| """ | |
| Scan the data folder and return statistics about the responses | |
| """ | |
| data_dir = Path("./data") | |
| if not data_dir.exists(): | |
| return "No data directory found" | |
| total_expected_posts = len(VALID_DATASET_POST_IDS) | |
| processed_post_ids = set() | |
| posts_with_responses = 0 | |
| total_responses = 0 | |
| responses_per_post = [] # List to track number of responses for each post | |
| for metadata_file in data_dir.glob("*/metadata.json"): | |
| post_id = metadata_file.parent.name | |
| if post_id in VALID_DATASET_POST_IDS: # Only count valid posts | |
| processed_post_ids.add(post_id) | |
| try: | |
| with open(metadata_file, "r") as f: | |
| metadata = json.load(f) | |
| num_responses = len(metadata.get("responses", [])) | |
| responses_per_post.append(num_responses) | |
| if num_responses > 0: | |
| posts_with_responses += 1 | |
| total_responses += num_responses | |
| except: | |
| continue | |
| missing_posts = set(map(str, VALID_DATASET_POST_IDS)) - processed_post_ids | |
| total_processed = len(processed_post_ids) | |
| # Calculate additional statistics | |
| if responses_per_post: | |
| responses_per_post.sort() | |
| median_responses = responses_per_post[len(responses_per_post) // 2] | |
| max_responses = max(responses_per_post) | |
| avg_responses = ( | |
| total_responses / posts_with_responses if posts_with_responses > 0 else 0 | |
| ) | |
| else: | |
| median_responses = max_responses = avg_responses = 0 | |
| stats = f""" | |
| 📊 Collection Statistics: | |
| Dataset Coverage: | |
| - Total Expected Posts: {total_expected_posts} | |
| - Posts Processed: {total_processed} | |
| - Missing Posts: {len(missing_posts)} ({', '.join(list(missing_posts)[:5])}{'...' if len(missing_posts) > 5 else ''}) | |
| - Coverage Rate: {(total_processed/total_expected_posts*100):.2f}% | |
| Response Statistics: | |
| - Posts with Responses: {posts_with_responses} | |
| - Posts without Responses: {total_processed - posts_with_responses} | |
| - Total Individual Responses: {total_responses} | |
| Response Distribution: | |
| - Median Responses per Post: {median_responses} | |
| - Average Responses per Post: {avg_responses:.2f} | |
| - Maximum Responses for a Post: {max_responses} | |
| """ | |
| return stats | |
| # Build the Gradio app | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Image Response Collector") | |
| # Source image selection at the top | |
| with gr.Row(): | |
| with gr.Column(): | |
| post_id_dropdown = gr.Dropdown( | |
| label="Select Post ID to Load Image", | |
| choices=VALID_DATASET_POST_IDS, | |
| type="value", | |
| allow_custom_value=False, | |
| ) | |
| instruction_text = gr.Textbox(label="Instruction", interactive=False) | |
| simplified_instruction_text = gr.Textbox( | |
| label="Simplified Instruction", interactive=False | |
| ) | |
| source_image = gr.Image(label="Source Image", type="filepath", height=300) | |
| # Responses in tabs | |
| with gr.Tabs() as response_tabs: | |
| responses = [] | |
| for i in range(10): | |
| with gr.Tab(f"Response {i+1}"): | |
| img = gr.Image( | |
| label=f"Response Image {i+1}", type="filepath", height=300 | |
| ) | |
| txt = gr.Textbox(label=f"Model Name {i+1}", lines=2) | |
| notes = gr.Textbox(label=f"Miscellaneous Notes {i+1}", lines=3) | |
| responses.append((img, txt, notes)) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit All Responses") | |
| clear_btn = gr.Button("Clear Form") | |
| # Add statistics accordion | |
| with gr.Accordion("Collection Statistics", open=False): | |
| stats_text = gr.Markdown("Loading statistics...") | |
| refresh_stats_btn = gr.Button("Refresh Statistics") | |
| def update_stats(): | |
| return get_statistics() | |
| refresh_stats_btn.click(fn=update_stats, outputs=[stats_text]) | |
| # Move the load event inside the Blocks context | |
| demo.load( | |
| fn=get_statistics, | |
| outputs=[stats_text], | |
| ) | |
| def submit_responses( | |
| source_img, post_id, instruction, simplified_instruction, *response_data | |
| ): | |
| if not source_img: | |
| gr.Warning("Please select a source image first!") | |
| return | |
| if not post_id: | |
| gr.Warning("Please select a post ID first!") | |
| return | |
| # Convert flat response_data into triplets of (image, text, notes) | |
| response_triplets = list( | |
| zip(response_data[::3], response_data[1::3], response_data[2::3]) | |
| ) | |
| # Check for responses with images but no model names | |
| incomplete_responses = [ | |
| i + 1 | |
| for i, (img, txt, _) in enumerate(response_triplets) | |
| if img is not None and not txt.strip() | |
| ] | |
| if incomplete_responses: | |
| gr.Warning( | |
| f"Please provide model names for responses: {', '.join(map(str, incomplete_responses))}!" | |
| ) | |
| return | |
| # Filter out empty responses (where both image and model name are empty) | |
| valid_responses = [ | |
| (img, txt, notes) | |
| for img, txt, notes in response_triplets | |
| if img is not None and txt.strip() | |
| ] | |
| if not valid_responses: | |
| gr.Warning("Please provide at least one response (image + model name)!") | |
| return | |
| # Generate JSON files with the valid responses | |
| generate_json_files(source_img, valid_responses, post_id) | |
| gr.Info("Responses saved successfully! 🎉") | |
| def clear_form(): | |
| outputs = [None] * ( | |
| 1 + 2 + 30 | |
| ) # source image + 2 instruction fields + 10 triplets | |
| return outputs | |
| # Connect components | |
| post_id_dropdown.change( | |
| fn=load_post_image, | |
| inputs=[post_id_dropdown], | |
| outputs=[source_image, instruction_text, simplified_instruction_text] | |
| + [comp for triplet in responses for comp in triplet], | |
| ) | |
| submit_inputs = [ | |
| source_image, | |
| post_id_dropdown, | |
| instruction_text, | |
| simplified_instruction_text, | |
| ] + [comp for triplet in responses for comp in triplet] | |
| submit_btn.click(fn=submit_responses, inputs=submit_inputs) | |
| clear_outputs = [source_image, instruction_text, simplified_instruction_text] + [ | |
| comp for triplet in responses for comp in triplet | |
| ] | |
| clear_btn.click(fn=clear_form, outputs=clear_outputs) | |
| def process_thread(): | |
| while True: | |
| try: | |
| pass | |
| # process_and_push_dataset( | |
| # "./data", | |
| # FINAL_DATASET_REPO, | |
| # token=os.environ["HF_TOKEN"], | |
| # private=True, | |
| # ) | |
| except Exception as e: | |
| print(f"Error in process thread: {e}") | |
| time.sleep(120) # Sleep for 2 minutes | |
| if __name__ == "__main__": | |
| print("Initializing app...") | |
| sync_with_hub() # Sync before launching the app | |
| print("Starting Gradio interface...") | |
| # Start the processing thread when the app starts | |
| processing_thread = threading.Thread(target=process_thread, daemon=True) | |
| processing_thread.start() | |
| demo.launch() | |