jzhang533's picture
minor
f073f41
import spaces
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
# Model configuration
MODEL_PATH = "jzhang533/PaddleOCR-VL-For-Manga"
# Load model and processor
print(f"Loading model from {MODEL_PATH}...")
model = None
processor = None
def load_model():
global model, processor
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()
processor = AutoProcessor.from_pretrained(
MODEL_PATH, trust_remote_code=True, use_fast=True
)
# Set pad_token_id to avoid warning during generation
if model.generation_config.pad_token_id is None:
model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
print("Model loaded successfully!")
# Load model on startup
load_model()
@spaces.GPU(duration=120)
def perform_ocr(image):
"""
Perform OCR on the provided manga image.
Args:
image: PIL Image or numpy array
Returns:
str: Recognized text from the image
"""
if image is None:
return "Please upload an image first."
# Ensure model is on GPU
if model.device.type == "cpu" and torch.cuda.is_available():
print("Moving model to GPU...")
model.to("cuda")
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Ensure RGB format
image = image.convert("RGB")
# Prepare the prompt
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "OCR:"},
],
}
]
# Process inputs
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(text=[text], images=[image], return_tensors="pt")
inputs = {
k: (v.to(model.device) if isinstance(v, torch.Tensor) else v)
for k, v in inputs.items()
}
# Generate text
with torch.inference_mode():
generated = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False,
use_cache=True,
)
input_length = inputs["input_ids"].shape[1]
generated_tokens = generated[:, input_length:]
answer = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return answer
# Create Gradio interface
with gr.Blocks(title="PaddleOCR-VL For Manga") as demo:
gr.Markdown(
"""
# PaddleOCR-VL-For-Manga Demo
**Model**: [jzhang533/PaddleOCR-VL-For-Manga](https://huggingface.co/jzhang533/PaddleOCR-VL-For-Manga)
### Features:
- Fine-tuned from [PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)
- Trained on Manga109-s dataset + 1.5M synthetic samples
- Achieves 70% full-sentence accuracy on manga crops
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Manga Image", height=400)
submit_btn = gr.Button("Recognize Text πŸ“–", variant="primary")
gr.Examples(
examples=[
["examples/01.png"],
["examples/02.png"],
["examples/03.png"],
["examples/04.png"],
["examples/05.png"],
],
inputs=input_image,
label="Example Images",
)
with gr.Column():
output_text = gr.Textbox(
label="Recognized Text",
placeholder="The recognized Japanese text will appear here...",
lines=15,
max_lines=20,
)
# Connect the button to the function
submit_btn.click(fn=perform_ocr, inputs=input_image, outputs=output_text)
# Also trigger on image upload
input_image.change(fn=perform_ocr, inputs=input_image, outputs=output_text)
# Launch the app
if __name__ == "__main__":
demo.launch()