Spaces:
Runtime error
Runtime error
| import seaborn as sns | |
| from PIL import Image, ImageDraw, ImageFont | |
| import matplotlib.font_manager | |
| import spacy | |
| import re | |
| nlp = spacy.load("en_core_web_sm-3.6.0") | |
| def draw_boxes(image, boxes, texts, output_fn='output.png'): | |
| box_width = 5 | |
| color_palette = sns.color_palette("husl", len(boxes)) | |
| colors = [(int(r*255), int(g*255), int(b*255)) for r, g, b in color_palette] | |
| width, height = image.size | |
| absolute_boxes = [[(int(box[0] * width), int(box[1] * height), int(box[2] * width), int(box[3] * height)) for box in b] for b in boxes] | |
| overlay = Image.new('RGBA', image.size, (255, 255, 255, 0)) | |
| draw = ImageDraw.Draw(overlay) | |
| font_path = sorted(matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf'))[0] | |
| font = ImageFont.truetype(font_path, size=26) | |
| for box, text, color in zip(absolute_boxes, texts, colors): | |
| for b in box: | |
| draw.rectangle(b, outline=color, width=box_width) | |
| if not text: | |
| continue | |
| splited_text = text.split('\n') | |
| num_lines = len(splited_text) | |
| text_width, text_height = font.getbbox(splited_text[0])[-2:] | |
| y_start = b[3] - text_height * num_lines - box_width | |
| if b[2] - b[0] < 100 or b[3] - b[1] < 100: | |
| y_start = b[3] | |
| for i, line in enumerate(splited_text): | |
| text_width, text_height = font.getbbox(line)[-2:] | |
| x = b[0] + box_width | |
| y = y_start + text_height * i | |
| draw.rectangle([x, y, x+text_width, y+text_height], fill=(128, 128, 128, 160)) | |
| draw.text((x, y), line, font=font, fill=(255, 255, 255)) | |
| img_with_overlay = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB') | |
| img_with_overlay.save(output_fn) | |
| def boxstr_to_boxes(box_str): | |
| boxes = [[int(y)/1000 for y in x.split(',')] for x in box_str.split(';') if x.replace(',', '').isdigit()] | |
| return boxes | |
| def text_to_dict(text): | |
| doc = nlp(text) | |
| box_matches = list(re.finditer(r'\[\[([^\]]+)\]\]', text)) | |
| box_positions = [match.start() for match in box_matches] | |
| noun_phrases = [] | |
| boxes = [] | |
| for match, box_position in zip(box_matches, box_positions): | |
| nearest_np_start = max([0] + [chunk.start_char for chunk in doc.noun_chunks if chunk.end_char <= box_position]) | |
| noun_phrase = text[nearest_np_start:box_position].strip() | |
| if noun_phrase and noun_phrase[-1] == '?': | |
| noun_phrase = text[:box_position].strip() | |
| box_string = match.group(1) | |
| noun_phrases.append(noun_phrase) | |
| boxes.append(boxstr_to_boxes(box_string)) | |
| pairs = [] | |
| for noun_phrase, box_string in zip(noun_phrases, boxes): | |
| pairs.append((noun_phrase.lower(), box_string)) | |
| return dict(pairs) | |
| def parse_response(img, response, output_fn='output.png'): | |
| img = img.convert('RGB') | |
| width, height = img.size | |
| ratio = min(1920 / width, 1080 / height) | |
| new_width = int(width * ratio) | |
| new_height = int(height * ratio) | |
| new_img = img.resize((new_width, new_height), Image.LANCZOS) | |
| pattern = r"\[\[(.*?)\]\]" | |
| positions = re.findall(pattern, response) | |
| boxes = [[[int(y) for y in x.split(',')] for x in pos.split(';') if x.replace(',', '').isdigit()] for pos in positions] | |
| dic = text_to_dict(response) | |
| if not dic: | |
| texts = [] | |
| boxes = [] | |
| else: | |
| texts, boxes = zip(*dic.items()) | |
| draw_boxes(new_img, boxes, texts, output_fn=output_fn) |