Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| # Model configuration | |
| MODEL_PATH = "jzhang533/PaddleOCR-VL-For-Manga" | |
| # Load model and processor | |
| print(f"Loading model from {MODEL_PATH}...") | |
| model = None | |
| processor = None | |
| def load_model(): | |
| global model, processor | |
| print("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ).eval() | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_PATH, trust_remote_code=True, use_fast=True | |
| ) | |
| # Set pad_token_id to avoid warning during generation | |
| if model.generation_config.pad_token_id is None: | |
| model.generation_config.pad_token_id = processor.tokenizer.eos_token_id | |
| print("Model loaded successfully!") | |
| # Load model on startup | |
| load_model() | |
| def perform_ocr(image): | |
| """ | |
| Perform OCR on the provided manga image. | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| str: Recognized text from the image | |
| """ | |
| if image is None: | |
| return "Please upload an image first." | |
| # Ensure model is on GPU | |
| if model.device.type == "cpu" and torch.cuda.is_available(): | |
| print("Moving model to GPU...") | |
| model.to("cuda") | |
| # Convert to PIL Image if needed | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| # Ensure RGB format | |
| image = image.convert("RGB") | |
| # Prepare the prompt | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": "OCR:"}, | |
| ], | |
| } | |
| ] | |
| # Process inputs | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor(text=[text], images=[image], return_tensors="pt") | |
| inputs = { | |
| k: (v.to(model.device) if isinstance(v, torch.Tensor) else v) | |
| for k, v in inputs.items() | |
| } | |
| # Generate text | |
| with torch.inference_mode(): | |
| generated = model.generate( | |
| **inputs, | |
| max_new_tokens=2048, | |
| do_sample=False, | |
| use_cache=True, | |
| ) | |
| input_length = inputs["input_ids"].shape[1] | |
| generated_tokens = generated[:, input_length:] | |
| answer = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| return answer | |
| # Create Gradio interface | |
| with gr.Blocks(title="PaddleOCR-VL For Manga") as demo: | |
| gr.Markdown( | |
| """ | |
| # PaddleOCR-VL-For-Manga Demo | |
| **Model**: [jzhang533/PaddleOCR-VL-For-Manga](https://huggingface.co/jzhang533/PaddleOCR-VL-For-Manga) | |
| ### Features: | |
| - Fine-tuned from [PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL) | |
| - Trained on Manga109-s dataset + 1.5M synthetic samples | |
| - Achieves 70% full-sentence accuracy on manga crops | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Upload Manga Image", height=400) | |
| submit_btn = gr.Button("Recognize Text π", variant="primary") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/01.png"], | |
| ["examples/02.png"], | |
| ["examples/03.png"], | |
| ["examples/04.png"], | |
| ["examples/05.png"], | |
| ], | |
| inputs=input_image, | |
| label="Example Images", | |
| ) | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Recognized Text", | |
| placeholder="The recognized Japanese text will appear here...", | |
| lines=15, | |
| max_lines=20, | |
| ) | |
| # Connect the button to the function | |
| submit_btn.click(fn=perform_ocr, inputs=input_image, outputs=output_text) | |
| # Also trigger on image upload | |
| input_image.change(fn=perform_ocr, inputs=input_image, outputs=output_text) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |