Emeritus-21 commited on
Commit
f263567
Β·
verified Β·
1 Parent(s): 810b4a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from threading import Thread
4
+ import gradio as gr
5
+ import spaces
6
+ from PIL import Image
7
+ import torch
8
+ from transformers import (
9
+ AutoProcessor,
10
+ AutoModelForImageTextToText,
11
+ Qwen2_5_VLForConditionalGeneration,
12
+ TextIteratorStreamer,
13
+ )
14
+ MODEL_PATHS = {
15
+ "Model 3 (structured handwritting)": (
16
+ "Emeritus-21/Finetuned-full-HTR-model",
17
+ AutoModelForImageTextToText,
18
+ ),
19
+ }
20
+
21
+ MAX_NEW_TOKENS_DEFAULT = 512
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ # ---------------------------
25
+ # Preload models at startup
26
+ # ---------------------------
27
+ _loaded_processors = {}
28
+ _loaded_models = {}
29
+
30
+ print("πŸš€ Preloading models into GPU/CPU memory...")
31
+
32
+ for name, (repo_id, cls) in MODEL_PATHS.items():
33
+ try:
34
+ print(f"Loading {name} ...")
35
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
36
+ model = cls.from_pretrained(
37
+ repo_id,
38
+ trust_remote_code=True,
39
+ torch_dtype=torch.float16
40
+ ).to(device).eval()
41
+ _loaded_processors[name] = processor
42
+ _loaded_models[name] = model
43
+ print(f"βœ… {name} ready.")
44
+ except Exception as e:
45
+ print(f"⚠️ Failed to load {name}: {e}")
46
+
47
+ # ---------------------------
48
+ # Warmup (GPU)
49
+ # ---------------------------
50
+ #@spaces.GPU
51
+ def warmup():
52
+ try:
53
+ default_model_choice = list(MODEL_PATHS.keys())[0]
54
+ processor = _loaded_processors[default_model_choice]
55
+ model = _loaded_models[default_model_choice]
56
+
57
+ messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
58
+ chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
59
+ inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
60
+
61
+ with torch.inference_mode():
62
+ _ = model.generate(**inputs, max_new_tokens=1)
63
+
64
+ return f"GPU warm and {default_model_choice} ready."
65
+ except Exception as e:
66
+ return f"Warmup skipped: {e}"
67
+
68
+ # ---------------------------
69
+ # OCR Function (RAW ONLY)
70
+ # ---------------------------
71
+ #@spaces.GPU
72
+ def ocr_image(image: Image.Image, model_choice: str, query: str = None,
73
+ max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
74
+ temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0):
75
+
76
+ if image is None:
77
+ yield "Please upload an image."
78
+ return
79
+
80
+ if model_choice not in _loaded_models:
81
+ yield f"Invalid model: {model_choice}"
82
+ return
83
+
84
+ processor = _loaded_processors[model_choice]
85
+ model = _loaded_models[model_choice]
86
+
87
+ if query and query.strip():
88
+ prompt = query.strip()
89
+ else:
90
+ prompt = (
91
+ "You are a professional Handwritten OCR system.\n"
92
+ "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
93
+ "- Preserve original structure and line breaks.\n"
94
+ "- Keep spacing, bullet points, numbering, and indentation.\n"
95
+ "- Render tables as Markdown tables if present.\n"
96
+ "- Do NOT autocorrect spelling or grammar.\n"
97
+ "- Do NOT merge lines.\n"
98
+ "Return RAW transcription only."
99
+ )
100
+
101
+ messages = [{
102
+ "role": "user",
103
+ "content": [
104
+ {"type": "image", "image": image},
105
+ {"type": "text", "text": prompt}
106
+ ]
107
+ }]
108
+
109
+ chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110
+ inputs = processor(text=[chat_prompt], images=[image], return_tensors="pt").to(device)
111
+
112
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
113
+
114
+ generation_kwargs = dict(
115
+ **inputs,
116
+ streamer=streamer,
117
+ max_new_tokens=max_new_tokens,
118
+ do_sample=False,
119
+ temperature=temperature,
120
+ top_p=top_p,
121
+ top_k=top_k,
122
+ repetition_penalty=repetition_penalty
123
+ )
124
+
125
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
126
+ thread.start()
127
+
128
+ buffer = ""
129
+ for new_text in streamer:
130
+ new_text = new_text.replace("<|im_end|>", "")
131
+ buffer += new_text
132
+ time.sleep(0.01)
133
+ yield buffer
134
+
135
+ # ---------------------------
136
+ # Gradio Interface
137
+ # ---------------------------
138
+ with gr.Blocks() as demo:
139
+ gr.Markdown("## wilson Handwritten OCR ")
140
+
141
+ model_choice = gr.Radio(
142
+ choices=list(MODEL_PATHS.keys()),
143
+ value=list(MODEL_PATHS.keys())[0],
144
+ label="Select OCR Model"
145
+ )
146
+
147
+ with gr.Tab("πŸ–Ό Image Inference"):
148
+ query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
149
+ image_input = gr.Image(type="pil", label="Upload Handwritten Image")
150
+
151
+ with gr.Accordion("βš™οΈ Advanced Options", open=False):
152
+ max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
153
+ temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
154
+ top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
155
+ top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
156
+ repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
157
+
158
+ with gr.Row():
159
+ extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
160
+ clear_btn = gr.Button("🧹 Clear")
161
+
162
+ raw_output = gr.Textbox(label="πŸ“œ RAW Structured Output (exact as written)", lines=18, show_copy_button=True)
163
+
164
+ extract_btn.click(
165
+ fn=ocr_image,
166
+ inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
167
+ outputs=[raw_output],
168
+ api_name="ocr_image" # <--- THIS IS THE CRUCIAL FIX
169
+ )
170
+
171
+ clear_btn.click(
172
+ fn=lambda: ("", None, ""),
173
+ outputs=[raw_output, image_input, query_input]
174
+ )
175
+
176
+ if __name__ == "__main__":
177
+ demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)