Spaces:
Running
Running
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| from inference import generate_image | |
| TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3} | |
| TASK_OPTIMAL_COORDS = {0: (325, 326), 1: (59, 1126), 2: (47, 102), 3: (497, 933)} | |
| def create_marker_overlay(image_path: str, x: int, y: int) -> Image.Image: | |
| """Creates an image with a marker at the specified coordinates""" | |
| base_image = Image.open(image_path) | |
| marked_image = base_image.copy() | |
| draw = ImageDraw.Draw(marked_image) | |
| marker_size = 10 | |
| marker_color = "red" | |
| draw.line([x - marker_size, y, x + marker_size, y], fill=marker_color, width=2) | |
| draw.line([x, y - marker_size, x, y + marker_size], fill=marker_color, width=2) | |
| return marked_image | |
| def update_reference_image(choice: int) -> tuple[str, int, str]: | |
| image_path = f"imgs/pattern_{choice}.png" | |
| heatmap_path = f"imgs/heatmap_{choice}.png" | |
| return image_path, choice, heatmap_path | |
| def update_marker(image_idx: int, evt: gr.SelectData) -> tuple[Image.Image, tuple[int, int]]: | |
| x, y = evt.index[0], evt.index[1] | |
| heatmap_path = f"imgs/heatmap_{image_idx}.png" | |
| return create_marker_overlay(heatmap_path, x, y), (x, y) | |
| def generate_output_image(image_idx: int, coords: tuple[int, int]) -> Image.Image: | |
| x, y = coords | |
| x_norm, y_norm = x / 1155, y / 1155 | |
| return generate_image(image_idx, x_norm, y_norm) | |
| def find_optimal_latent(image_idx: int) -> tuple[Image.Image, tuple[int, int], Image.Image]: | |
| x, y = TASK_OPTIMAL_COORDS[image_idx] | |
| heatmap_path = f"imgs/heatmap_{image_idx}.png" | |
| marked_heatmap = create_marker_overlay(heatmap_path, x, y) | |
| output_img = generate_output_image(image_idx, (x, y)) | |
| return marked_heatmap, (x, y), output_img | |
| with gr.Blocks( | |
| css=""" | |
| .container { | |
| max-width: 1200px !important; | |
| width: 100% !important; | |
| margin-left: auto !important; | |
| margin-right: auto !important; | |
| padding: 0 1rem !important; | |
| } | |
| .diagram-container { | |
| width: 100% !important; | |
| max-width: 1000px !important; | |
| margin: 2rem auto !important; | |
| } | |
| .diagram-container img { | |
| width: 100% !important; | |
| height: auto !important; | |
| display: block !important; | |
| margin: 0 auto !important; | |
| cursor: default !important; | |
| } | |
| .radio-container { | |
| width: 100% !important; | |
| max-width: 450px !important; | |
| margin-bottom: 1rem !important; | |
| } | |
| .image-preview-container { | |
| width: 100% !important; | |
| max-width: 450px !important; | |
| } | |
| .image-preview-container img { | |
| width: 100% !important; | |
| height: 100% !important; | |
| object-fit: contain !important; | |
| cursor: default !important; | |
| } | |
| .coordinate-container { | |
| width: 100% !important; | |
| aspect-ratio: 1 !important; | |
| position: relative !important; | |
| max-width: 550px !important; | |
| } | |
| .coordinate-container img { | |
| width: 100% !important; | |
| height: 100% !important; | |
| object-fit: contain !important; | |
| } | |
| .button-container { | |
| width: 100% !important; | |
| max-width: 450px !important; | |
| display: flex !important; | |
| justify-content: center !important; | |
| margin-bottom: 1rem !important; | |
| } | |
| # .documentation { | |
| # margin-top: 2rem !important; | |
| # padding: 1rem !important; | |
| # background-color: #f8f9fa !important; | |
| # border-radius: 8px !important; | |
| # } | |
| .optimal-button { | |
| width: 200px !important; | |
| } | |
| """ | |
| ) as demo: | |
| with gr.Column(elem_classes="container"): | |
| gr.Markdown( | |
| """ | |
| # Interactive Visualization of a Latent Program Network (LPN) | |
| ## Introduction | |
| The LPN is an architecture for inductive program synthesis that builds in test-time adaption | |
| by learning a latent space that can be used for search. | |
| This interactive demo showcases a latent traversal of the LPN in the latent program space. | |
| More specifically, the decoder of the LPN is conditioned on a latent vector representing | |
| an abstract program, which is then used to generate an output. | |
| """ | |
| ) | |
| with gr.Column(elem_classes="diagram-container"): | |
| gr.Image( | |
| value="imgs/lpn_diagram.png", | |
| show_label=False, | |
| interactive=False, | |
| show_download_button=False, | |
| show_fullscreen_button=False, | |
| show_share_button=False, | |
| container=False, | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### How to Use | |
| 1. Choose a pattern task using the radio buttons | |
| 2. View the input-output pairs for your selected task | |
| 3. The goal is to find the latent that will generate the right third image for the given input | |
| 4. Click anywhere on the latent space to specify coordinates for the latent | |
| 5. See the generated image based on your selected latent | |
| Use the "Find Optimal Latent" button to find the latent that maximizes likelihood of generating the other input-output pairs. | |
| """ | |
| ) | |
| with gr.Row(): | |
| # Left column for controls | |
| with gr.Column(scale=1): | |
| selected_idx = gr.State(value=0) | |
| coords = gr.State() | |
| with gr.Column(elem_classes="radio-container"): | |
| task_select = gr.Radio( | |
| choices=["Task 1", "Task 2", "Task 3", "Task 4"], | |
| value="Task 1", | |
| label="Select Task", | |
| interactive=True, | |
| ) | |
| gr.Markdown("### Latent Space Search") | |
| gr.Markdown( | |
| "Click anywhere in the 2D latent space below to condition the decoder on a specific latent vector. " | |
| "The heatmap shows the decoder log-likelihood of generating the first two input-output pairs conditioning on any point in the latent space. " | |
| "The goal is to find the latent that generates the third image for the given input." | |
| ) | |
| with gr.Column(elem_classes="coordinate-container"): | |
| coord_selector = gr.Image( | |
| value="imgs/heatmap_0.png", | |
| show_label=False, | |
| interactive=False, | |
| sources=[], | |
| container=True, | |
| show_download_button=False, | |
| show_fullscreen_button=False, | |
| show_share_button=False, | |
| ) | |
| with gr.Column(elem_classes="button-container"): | |
| optimal_button = gr.Button("Find Optimal Latent", elem_classes="optimal-button") | |
| # Right column for images | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input-Output Pairs") | |
| with gr.Column(elem_classes="image-preview-container"): | |
| reference_image = gr.Image( | |
| value="imgs/pattern_0.png", | |
| show_label=False, | |
| interactive=False, | |
| show_download_button=False, | |
| show_fullscreen_button=False, | |
| show_share_button=False, | |
| ) | |
| gr.Markdown("### Generated Output") | |
| with gr.Column(elem_classes="image-preview-container"): | |
| output_image = gr.Image( | |
| show_label=False, | |
| interactive=False, | |
| show_download_button=False, | |
| show_fullscreen_button=False, | |
| show_share_button=False, | |
| ) | |
| with gr.Column(elem_classes="container"): | |
| gr.Markdown( | |
| """ | |
| ### Technical Details | |
| For more information, please refer to our [paper](https://arxiv.org/pdf/2411.08706) or GitHub [repository](https://github.com/clement-bonnet/lpn). | |
| """ | |
| ) | |
| # Event handlers | |
| task_select.change( | |
| fn=lambda x: update_reference_image(TASK_TO_INDEX[x]), | |
| inputs=[task_select], | |
| outputs=[reference_image, selected_idx, coord_selector], | |
| ) | |
| coord_selector.select( | |
| fn=update_marker, | |
| inputs=[selected_idx], | |
| outputs=[coord_selector, coords], | |
| trigger_mode="multiple", | |
| ).then( | |
| fn=generate_output_image, | |
| inputs=[selected_idx, coords], | |
| outputs=output_image, | |
| ) | |
| optimal_button.click( | |
| fn=find_optimal_latent, | |
| inputs=[selected_idx], | |
| outputs=[coord_selector, coords, output_image], | |
| ) | |
| demo.launch() | |