tyrwh commited on
Commit
228a7ef
·
1 Parent(s): 37dd438

Big overhaul to app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -116
app.py CHANGED
@@ -1,28 +1,27 @@
1
- from flask import Flask, render_template, request, jsonify, send_from_directory, send_file, Response
2
- from multiprocessing import Pool, cpu_count
3
- from threading import Thread
4
- from pathlib import Path
5
- from PIL import Image
6
- from datetime import datetime
7
  import os
8
- import tempfile
9
  import uuid
10
- import pandas as pd
11
- from werkzeug.utils import secure_filename
12
  import traceback
13
  import sys
14
  import io
15
  import zipfile
16
  import cv2
17
  import csv
18
- import numpy as np
19
- import redis
20
  import json
21
  import shutil
22
-
 
 
 
 
 
 
 
 
23
  from yolo_utils import load_model, detect_image
24
 
25
  app = Flask(__name__)
 
26
 
27
  APP_ROOT = Path(__file__).parent
28
  UPLOAD_FOLDER = APP_ROOT / 'uploads'
@@ -35,9 +34,6 @@ app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg', 'tif', 'tiff'}
35
  UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True)
36
  RESULT_FOLDER.mkdir(parents=True, exist_ok=True)
37
 
38
- # Redis client (localhost:6379, db=0, no password)
39
- redis_client = redis.Redis(host='localhost', port=6379, db=0, decode_responses=True)
40
-
41
  @app.errorhandler(Exception)
42
  def handle_exception(e):
43
  print(f"Unhandled exception: {str(e)}")
@@ -51,73 +47,90 @@ def allowed_file(filename):
51
  def index():
52
  return render_template('index.html')
53
 
54
- # Global model for each process
 
55
  _model = None
56
  def get_model():
57
  global _model
58
  if _model is None:
59
  _model = load_model(WEIGHTS_FILE)
 
 
60
  return _model
61
 
62
- def cleanup_job(job_id):
63
- # Remove files
64
- upload_dir = os.path.join(app.config['UPLOAD_FOLDER'], job_id)
65
- if os.path.exists(upload_dir):
 
 
 
 
 
 
 
 
 
 
 
66
  shutil.rmtree(upload_dir)
67
- # Remove Redis state
68
- redis_client.delete(f"job:{job_id}")
69
-
70
- @app.route('/cleanup/<job_id>', methods=['POST'])
71
- def cleanup_job_endpoint(job_id):
72
- cleanup_job(job_id)
73
- return jsonify({'status': 'cleaned'})
74
-
75
- def get_job_state(job_id):
76
- data = redis_client.get(f"job:{job_id}")
77
- return json.loads(data) if data else None
78
-
79
- def set_job_state(job_id, state):
80
- redis_client.set(f"job:{job_id}", json.dumps(state))
81
 
82
- all_detections = {}
83
 
84
  def process_image(args):
85
  orig_name, unique_name, image_bytes = args
86
  model = get_model()
87
  detections = detect_image(model, image_bytes, conf=0.05)
88
- # Save original image to uploads for later annotation (already saved)
89
  return {'orig_name': orig_name, 'unique_name': unique_name, 'detections': detections}
90
 
91
- def async_process_images(job_id, file_data):
92
  try:
93
- job_state = get_job_state(job_id)
94
- job_state['status'] = 'running'
95
- job_state['progress'] = 0
96
- set_job_state(job_id, job_state)
97
  total = len(file_data)
98
  results = []
99
  detections = {}
100
- with Pool(processes=min(cpu_count(), total)) as pool:
101
- for idx, result in enumerate(pool.imap(process_image, file_data)):
 
 
 
102
  results.append({
103
  'filename': result['orig_name'],
104
  'num_eggs': sum(1 for d in result['detections'] if d.get('class') == 'egg'),
105
  })
106
  detections[result['orig_name']] = result['detections']
107
- # Update progress
108
- job_state['progress'] = int((idx + 1) / total * 100)
109
- set_job_state(job_id, job_state)
110
- job_state['status'] = 'success'
111
- job_state['results'] = results
112
- job_state['detections'] = detections
113
- job_state['progress'] = 100
114
- set_job_state(job_id, job_state)
 
 
 
 
 
 
115
  except Exception as e:
116
- job_state = get_job_state(job_id) or {}
117
- job_state['status'] = 'error'
118
- job_state['error'] = str(e)
119
- job_state['progress'] = 100
120
- set_job_state(job_id, job_state)
121
 
122
  @app.route('/process', methods=['POST'])
123
  def process_images():
@@ -125,80 +138,62 @@ def process_images():
125
  files = request.files.getlist('files')
126
  if not files or files[0].filename == '':
127
  return jsonify({'error': 'No files uploaded'}), 400
128
- job_id = str(uuid.uuid4())
129
- # Clean up any previous state for this job
130
- cleanup_job(job_id)
131
- filename_map, file_data = save_uploaded_files(files, job_id)
132
- # Store initial job state in Redis
133
- job_state = {
 
 
 
134
  'status': 'starting',
135
  'progress': 0,
136
  'results': [],
137
  'filename_map': filename_map,
138
  'detections': {},
139
  }
140
- set_job_state(job_id, job_state)
141
- thread = Thread(target=async_process_images, args=(job_id, file_data))
142
  thread.daemon = True
143
  thread.start()
144
- return jsonify({'jobId': job_id})
145
  except Exception as e:
146
  print(f"Error in /process: {e}")
147
  print(traceback.format_exc())
148
  return jsonify({'error': str(e)}), 500
149
 
150
- def save_uploaded_files(files, job_id):
151
- upload_dir = os.path.join(app.config['UPLOAD_FOLDER'], job_id)
152
- if os.path.exists(upload_dir):
153
- shutil.rmtree(upload_dir)
154
- os.makedirs(upload_dir, exist_ok=True)
155
- filename_map = {}
156
- file_data = []
157
- for f in files:
158
- orig_name = secure_filename(f.filename)
159
- ext = os.path.splitext(orig_name)[1]
160
- unique_name = f"{uuid.uuid4().hex}{ext}"
161
- file_path = os.path.join(upload_dir, unique_name)
162
- f.save(file_path)
163
- filename_map[orig_name] = unique_name
164
- with open(file_path, 'rb') as imgf:
165
- file_data.append((orig_name, unique_name, imgf.read()))
166
- return filename_map, file_data
167
-
168
- @app.route('/progress/<job_id>')
169
- def get_progress(job_id):
170
- job_state = get_job_state(job_id)
171
  if not job_state:
172
- return jsonify({"status": "error", "error": "Job ID not found"}), 404
173
- # Add a mapping from filename to detections for frontend plotting
174
  if 'detections' in job_state:
175
  job_state['detections_by_filename'] = job_state['detections']
176
  return jsonify(job_state)
177
 
178
- @app.route('/results/<job_id>/<path:filename>')
179
- def download_file(job_id, filename):
180
  try:
181
- try:
182
- uuid.UUID(job_id, version=4)
183
- except ValueError:
184
- return jsonify({"error": "Invalid job ID format"}), 400
185
 
 
 
186
  if '..' in filename or filename.startswith('/'):
187
  return jsonify({"error": "Invalid filename"}), 400
188
-
189
  safe_filename = secure_filename(filename)
190
- file_dir = Path(app.config['RESULT_FOLDER']) / job_id
191
  file_path = (file_dir / safe_filename).resolve()
192
-
193
  if not str(file_path).startswith(str(file_dir.resolve())):
194
- print(f"Attempted path traversal: {job_id}/{filename}")
195
  return jsonify({"error": "Invalid file path"}), 400
196
-
197
  if not file_path.is_file():
198
  if not file_dir.exists():
199
- return jsonify({"error": f"Job directory {job_id} not found"}), 404
200
  files_in_dir = list(file_dir.iterdir())
201
- return jsonify({"error": f"File '{filename}' not found in job '{job_id}'. Available: {[f.name for f in files_in_dir]}"}), 404
202
 
203
  if filename.lower().endswith(('.tif', '.tiff')):
204
  try:
@@ -240,17 +235,15 @@ def download_file(job_id, filename):
240
  print(error_message)
241
  return jsonify({"error": "Server error", "log": error_message}), 500
242
 
243
- @app.route('/export_images/<job_id>')
244
- def export_images(job_id):
245
  try:
246
- try:
247
- uuid.UUID(job_id, version=4)
248
- except ValueError:
249
- return jsonify({"error": "Invalid job ID format"}), 400
250
 
251
- job_dir = Path(app.config['RESULT_FOLDER']) / job_id
 
 
252
  if not job_dir.exists():
253
- return jsonify({"error": f"Job directory {job_id} not found"}), 404
254
 
255
  annotated_files = list(job_dir.glob('*_annotated.*'))
256
  if not annotated_files:
@@ -280,9 +273,9 @@ def export_images(job_id):
280
  def export_csv():
281
  try:
282
  data = request.json
283
- job_id = data['jobId']
284
  threshold = float(data.get('confidence', 0.5))
285
- job_state = get_job_state(job_id)
286
  if not job_state:
287
  return jsonify({'error': 'Job not found'}), 404
288
  rows = []
@@ -311,17 +304,17 @@ def export_csv():
311
  def export_images_post():
312
  try:
313
  data = request.json
314
- job_id = data['jobId']
315
  threshold = float(data.get('confidence', 0.5))
316
- job_state = get_job_state(job_id)
317
  if not job_state:
318
  return jsonify({'error': 'Job not found'}), 404
319
  memory_file = io.BytesIO()
320
  with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
321
  for orig_name, detections in job_state['detections'].items():
322
  unique_name = job_state['filename_map'][orig_name]
323
- img_path = os.path.join(app.config['UPLOAD_FOLDER'], job_id, unique_name)
324
- img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
325
  filtered = [d for d in detections if d['score'] >= threshold]
326
  for det in filtered:
327
  x1, y1, x2, y2 = map(int, det['bbox'])
@@ -356,7 +349,7 @@ def print_startup_info():
356
  except Exception as e:
357
  print(f"Could not get weights file size: {e}")
358
 
359
- is_container = os.path.exists('/.dockerenv') or 'DOCKER_HOST' in os.environ
360
  print(f"Running in container: {is_container}")
361
 
362
  if is_container:
@@ -386,6 +379,17 @@ def print_startup_info():
386
  except Exception as e:
387
  print(f"Could not get NemaQuant script details: {e}")
388
 
 
 
 
 
 
 
 
 
 
 
 
389
  if __name__ == '__main__':
390
  print_startup_info()
391
  app.run(host='0.0.0.0', port=7860, debug=True)
 
 
 
 
 
 
 
1
  import os
 
2
  import uuid
 
 
3
  import traceback
4
  import sys
5
  import io
6
  import zipfile
7
  import cv2
8
  import csv
9
+ import torch
 
10
  import json
11
  import shutil
12
+ import numpy as np
13
+ import pandas as pd
14
+ from flask import Flask, Response, render_template, request, jsonify, send_from_directory, send_file, session, redirect, url_for
15
+ from multiprocessing.pool import ThreadPool
16
+ from threading import Thread
17
+ from pathlib import Path
18
+ from PIL import Image
19
+ from datetime import datetime
20
+ from werkzeug.utils import secure_filename
21
  from yolo_utils import load_model, detect_image
22
 
23
  app = Flask(__name__)
24
+ app.secret_key = os.environ.get('FLASK_SECRET_KEY', str(uuid.uuid4())) # For session security
25
 
26
  APP_ROOT = Path(__file__).parent
27
  UPLOAD_FOLDER = APP_ROOT / 'uploads'
 
34
  UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True)
35
  RESULT_FOLDER.mkdir(parents=True, exist_ok=True)
36
 
 
 
 
37
  @app.errorhandler(Exception)
38
  def handle_exception(e):
39
  print(f"Unhandled exception: {str(e)}")
 
47
  def index():
48
  return render_template('index.html')
49
 
50
+ # Load model once at startup, use CUDA if available
51
+ MODEL_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
52
  _model = None
53
  def get_model():
54
  global _model
55
  if _model is None:
56
  _model = load_model(WEIGHTS_FILE)
57
+ if MODEL_DEVICE == 'cuda':
58
+ _model.to('cuda')
59
  return _model
60
 
61
+ def cleanup_session(session_id):
62
+ # Remove files for this session
63
+ upload_dir = Path(app.config['UPLOAD_FOLDER']) / session_id
64
+ result_dir = Path(app.config['RESULT_FOLDER']) / session_id
65
+ for d in [upload_dir, result_dir]:
66
+ if d.exists():
67
+ shutil.rmtree(d)
68
+
69
+ # save the uploaded files
70
+ @app.route('/uploads/<session_id>', methods=['POST'])
71
+ def upload_files(session_id):
72
+ files = request.files.getlist('files')
73
+ upload_dir = Path(app.config['UPLOAD_FOLDER']) / session_id
74
+ # clear out any existing files for the session
75
+ if upload_dir.exists():
76
  shutil.rmtree(upload_dir)
77
+ upload_dir.mkdir(parents=True, exist_ok=True)
78
+ filename_map = {}
79
+ file_data = []
80
+ for f in files:
81
+ orig_name = secure_filename(f.filename)
82
+ ext = Path(orig_name).suffix
83
+ unique_name = f"{uuid.uuid4().hex}{ext}"
84
+ file_path = upload_dir / unique_name
85
+ f.save(str(file_path))
86
+ filename_map[orig_name] = unique_name
87
+ with open(file_path, 'rb') as imgf:
88
+ file_data.append((orig_name, unique_name, imgf.read()))
89
+ return filename_map, file_data
 
90
 
 
91
 
92
  def process_image(args):
93
  orig_name, unique_name, image_bytes = args
94
  model = get_model()
95
  detections = detect_image(model, image_bytes, conf=0.05)
 
96
  return {'orig_name': orig_name, 'unique_name': unique_name, 'detections': detections}
97
 
98
+ def async_process_images(session_id, file_data, state):
99
  try:
100
+ state['status'] = 'running'
101
+ state['progress'] = 0
 
 
102
  total = len(file_data)
103
  results = []
104
  detections = {}
105
+ # Use ThreadPool for CPU, else single-threaded for CUDA
106
+ if MODEL_DEVICE == 'cuda':
107
+ pool = None
108
+ for idx, args in enumerate(file_data):
109
+ result = process_image(args)
110
  results.append({
111
  'filename': result['orig_name'],
112
  'num_eggs': sum(1 for d in result['detections'] if d.get('class') == 'egg'),
113
  })
114
  detections[result['orig_name']] = result['detections']
115
+ state['progress'] = int((idx + 1) / total * 100)
116
+ else:
117
+ with ThreadPool() as pool:
118
+ for idx, result in enumerate(pool.imap(process_image, file_data)):
119
+ results.append({
120
+ 'filename': result['orig_name'],
121
+ 'num_eggs': sum(1 for d in result['detections'] if d.get('class') == 'egg'),
122
+ })
123
+ detections[result['orig_name']] = result['detections']
124
+ state['progress'] = int((idx + 1) / total * 100)
125
+ state['status'] = 'success'
126
+ state['results'] = results
127
+ state['detections'] = detections
128
+ state['progress'] = 100
129
  except Exception as e:
130
+ state['status'] = 'error'
131
+ state['error'] = str(e)
132
+ state['progress'] = 100
133
+
 
134
 
135
  @app.route('/process', methods=['POST'])
136
  def process_images():
 
138
  files = request.files.getlist('files')
139
  if not files or files[0].filename == '':
140
  return jsonify({'error': 'No files uploaded'}), 400
141
+ # Assign a session ID if not present
142
+ if 'id' not in session:
143
+ session['id'] = str(uuid.uuid4())
144
+ session_id = session['id']
145
+ # Clean up any previous state for this session
146
+ cleanup_session(session_id)
147
+ filename_map, file_data = upload_files(files, session_id)
148
+ # Store job state in session
149
+ state = {
150
  'status': 'starting',
151
  'progress': 0,
152
  'results': [],
153
  'filename_map': filename_map,
154
  'detections': {},
155
  }
156
+ session['job_state'] = state
157
+ thread = Thread(target=async_process_images, args=(session_id, file_data, state))
158
  thread.daemon = True
159
  thread.start()
160
+ return jsonify({'jobId': session_id})
161
  except Exception as e:
162
  print(f"Error in /process: {e}")
163
  print(traceback.format_exc())
164
  return jsonify({'error': str(e)}), 500
165
 
166
+ @app.route('/progress/<session_id>')
167
+ def get_progress(session_id):
168
+ # Only allow access to own session
169
+ if 'id' not in session or session['id'] != session_id:
170
+ return jsonify({"status": "error", "error": "Session not found or expired"}), 404
171
+ job_state = session.get('job_state')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  if not job_state:
173
+ return jsonify({"status": "error", "error": "No job state"}), 404
 
174
  if 'detections' in job_state:
175
  job_state['detections_by_filename'] = job_state['detections']
176
  return jsonify(job_state)
177
 
178
+ @app.route('/results/<session_id>/<path:filename>')
179
+ def download_file(session_id, filename):
180
  try:
 
 
 
 
181
 
182
+ if 'id' not in session or session['id'] != session_id:
183
+ return jsonify({"error": "Session not found or expired"}), 404
184
  if '..' in filename or filename.startswith('/'):
185
  return jsonify({"error": "Invalid filename"}), 400
 
186
  safe_filename = secure_filename(filename)
187
+ file_dir = Path(app.config['RESULT_FOLDER']) / session_id
188
  file_path = (file_dir / safe_filename).resolve()
 
189
  if not str(file_path).startswith(str(file_dir.resolve())):
190
+ print(f"Attempted path traversal: {session_id}/{filename}")
191
  return jsonify({"error": "Invalid file path"}), 400
 
192
  if not file_path.is_file():
193
  if not file_dir.exists():
194
+ return jsonify({"error": f"Session directory {session_id} not found"}), 404
195
  files_in_dir = list(file_dir.iterdir())
196
+ return jsonify({"error": f"File '{filename}' not found in session '{session_id}'. Available: {[f.name for f in files_in_dir]}"}), 404
197
 
198
  if filename.lower().endswith(('.tif', '.tiff')):
199
  try:
 
235
  print(error_message)
236
  return jsonify({"error": "Server error", "log": error_message}), 500
237
 
238
+ @app.route('/export_images/<session_id>')
239
+ def export_images(session_id):
240
  try:
 
 
 
 
241
 
242
+ if 'id' not in session or session['id'] != session_id:
243
+ return jsonify({"error": "Session not found or expired"}), 404
244
+ job_dir = Path(app.config['RESULT_FOLDER']) / session_id
245
  if not job_dir.exists():
246
+ return jsonify({"error": f"Session directory {session_id} not found"}), 404
247
 
248
  annotated_files = list(job_dir.glob('*_annotated.*'))
249
  if not annotated_files:
 
273
  def export_csv():
274
  try:
275
  data = request.json
276
+ session_id = session.get('id')
277
  threshold = float(data.get('confidence', 0.5))
278
+ job_state = session.get('job_state')
279
  if not job_state:
280
  return jsonify({'error': 'Job not found'}), 404
281
  rows = []
 
304
  def export_images_post():
305
  try:
306
  data = request.json
307
+ session_id = session.get('id')
308
  threshold = float(data.get('confidence', 0.5))
309
+ job_state = session.get('job_state')
310
  if not job_state:
311
  return jsonify({'error': 'Job not found'}), 404
312
  memory_file = io.BytesIO()
313
  with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
314
  for orig_name, detections in job_state['detections'].items():
315
  unique_name = job_state['filename_map'][orig_name]
316
+ img_path = Path(app.config['UPLOAD_FOLDER']) / session_id / unique_name
317
+ img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED)
318
  filtered = [d for d in detections if d['score'] >= threshold]
319
  for det in filtered:
320
  x1, y1, x2, y2 = map(int, det['bbox'])
 
349
  except Exception as e:
350
  print(f"Could not get weights file size: {e}")
351
 
352
+ is_container = Path('/.dockerenv').exists() or 'DOCKER_HOST' in os.environ
353
  print(f"Running in container: {is_container}")
354
 
355
  if is_container:
 
379
  except Exception as e:
380
  print(f"Could not get NemaQuant script details: {e}")
381
 
382
+ @app.before_request
383
+ def ensure_session_id():
384
+ if 'id' not in session:
385
+ session['id'] = str(uuid.uuid4())
386
+
387
+ @app.teardown_appcontext
388
+ def cleanup_on_teardown(exception):
389
+ # If session is gone, clean up files
390
+ if 'id' in session and not session.modified:
391
+ cleanup_session(session['id'])
392
+
393
  if __name__ == '__main__':
394
  print_startup_info()
395
  app.run(host='0.0.0.0', port=7860, debug=True)