tyrwh commited on
Commit
bae5c31
·
1 Parent(s): 228a7ef

WorkerPool threading and GPU support

Browse files
Files changed (3) hide show
  1. app.py +168 -104
  2. static/script.js +188 -133
  3. yolo_utils.py +3 -10
app.py CHANGED
@@ -6,33 +6,35 @@ import io
6
  import zipfile
7
  import cv2
8
  import csv
9
- import torch
10
  import json
11
  import shutil
 
 
12
  import numpy as np
13
  import pandas as pd
 
14
  from flask import Flask, Response, render_template, request, jsonify, send_from_directory, send_file, session, redirect, url_for
15
  from multiprocessing.pool import ThreadPool
16
- from threading import Thread
17
  from pathlib import Path
18
  from PIL import Image
19
  from datetime import datetime
20
  from werkzeug.utils import secure_filename
21
- from yolo_utils import load_model, detect_image
22
 
23
  app = Flask(__name__)
24
  app.secret_key = os.environ.get('FLASK_SECRET_KEY', str(uuid.uuid4())) # For session security
25
 
26
  APP_ROOT = Path(__file__).parent
27
  UPLOAD_FOLDER = APP_ROOT / 'uploads'
28
- RESULT_FOLDER = APP_ROOT / 'results'
29
  WEIGHTS_FILE = APP_ROOT / 'weights.pt'
30
  app.config['UPLOAD_FOLDER'] = str(UPLOAD_FOLDER)
31
- app.config['RESULT_FOLDER'] = str(RESULT_FOLDER)
32
  app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg', 'tif', 'tiff'}
33
 
34
  UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True)
35
- RESULT_FOLDER.mkdir(parents=True, exist_ok=True)
36
 
37
  @app.errorhandler(Exception)
38
  def handle_exception(e):
@@ -47,13 +49,14 @@ def allowed_file(filename):
47
  def index():
48
  return render_template('index.html')
49
 
 
50
  # Load model once at startup, use CUDA if available
51
- MODEL_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
52
  _model = None
53
  def get_model():
54
  global _model
55
  if _model is None:
56
- _model = load_model(WEIGHTS_FILE)
57
  if MODEL_DEVICE == 'cuda':
58
  _model.to('cuda')
59
  return _model
@@ -61,22 +64,23 @@ def get_model():
61
  def cleanup_session(session_id):
62
  # Remove files for this session
63
  upload_dir = Path(app.config['UPLOAD_FOLDER']) / session_id
64
- result_dir = Path(app.config['RESULT_FOLDER']) / session_id
65
- for d in [upload_dir, result_dir]:
66
  if d.exists():
67
  shutil.rmtree(d)
68
 
69
  # save the uploaded files
70
- @app.route('/uploads/<session_id>', methods=['POST'])
71
- def upload_files(session_id):
 
72
  files = request.files.getlist('files')
73
  upload_dir = Path(app.config['UPLOAD_FOLDER']) / session_id
74
  # clear out any existing files for the session
75
  if upload_dir.exists():
76
  shutil.rmtree(upload_dir)
77
  upload_dir.mkdir(parents=True, exist_ok=True)
 
78
  filename_map = {}
79
- file_data = []
80
  for f in files:
81
  orig_name = secure_filename(f.filename)
82
  ext = Path(orig_name).suffix
@@ -84,107 +88,164 @@ def upload_files(session_id):
84
  file_path = upload_dir / unique_name
85
  f.save(str(file_path))
86
  filename_map[orig_name] = unique_name
87
- with open(file_path, 'rb') as imgf:
88
- file_data.append((orig_name, unique_name, imgf.read()))
89
- return filename_map, file_data
90
-
 
 
 
 
 
 
 
91
 
92
- def process_image(args):
93
- orig_name, unique_name, image_bytes = args
94
  model = get_model()
95
- detections = detect_image(model, image_bytes, conf=0.05)
96
- return {'orig_name': orig_name, 'unique_name': unique_name, 'detections': detections}
97
-
98
- def async_process_images(session_id, file_data, state):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  try:
100
- state['status'] = 'running'
101
- state['progress'] = 0
102
- total = len(file_data)
103
- results = []
104
- detections = {}
105
- # Use ThreadPool for CPU, else single-threaded for CUDA
106
  if MODEL_DEVICE == 'cuda':
107
  pool = None
108
- for idx, args in enumerate(file_data):
109
- result = process_image(args)
110
- results.append({
111
- 'filename': result['orig_name'],
112
- 'num_eggs': sum(1 for d in result['detections'] if d.get('class') == 'egg'),
113
- })
114
- detections[result['orig_name']] = result['detections']
115
- state['progress'] = int((idx + 1) / total * 100)
116
  else:
117
  with ThreadPool() as pool:
118
- for idx, result in enumerate(pool.imap(process_image, file_data)):
119
- results.append({
120
- 'filename': result['orig_name'],
121
- 'num_eggs': sum(1 for d in result['detections'] if d.get('class') == 'egg'),
122
- })
123
- detections[result['orig_name']] = result['detections']
124
- state['progress'] = int((idx + 1) / total * 100)
125
- state['status'] = 'success'
126
- state['results'] = results
127
- state['detections'] = detections
128
  state['progress'] = 100
 
129
  except Exception as e:
 
 
130
  state['status'] = 'error'
131
  state['error'] = str(e)
132
  state['progress'] = 100
133
-
134
-
135
- @app.route('/process', methods=['POST'])
136
- def process_images():
 
 
 
 
 
 
 
 
137
  try:
138
- files = request.files.getlist('files')
139
- if not files or files[0].filename == '':
140
- return jsonify({'error': 'No files uploaded'}), 400
141
- # Assign a session ID if not present
142
- if 'id' not in session:
143
- session['id'] = str(uuid.uuid4())
144
- session_id = session['id']
145
- # Clean up any previous state for this session
146
- cleanup_session(session_id)
147
- filename_map, file_data = upload_files(files, session_id)
148
- # Store job state in session
149
- state = {
150
- 'status': 'starting',
151
- 'progress': 0,
152
- 'results': [],
153
- 'filename_map': filename_map,
154
- 'detections': {},
155
  }
156
- session['job_state'] = state
157
- thread = Thread(target=async_process_images, args=(session_id, file_data, state))
158
- thread.daemon = True
159
- thread.start()
160
- return jsonify({'jobId': session_id})
 
 
 
 
161
  except Exception as e:
162
- print(f"Error in /process: {e}")
163
  print(traceback.format_exc())
164
- return jsonify({'error': str(e)}), 500
165
 
166
- @app.route('/progress/<session_id>')
167
- def get_progress(session_id):
168
- # Only allow access to own session
169
- if 'id' not in session or session['id'] != session_id:
170
- return jsonify({"status": "error", "error": "Session not found or expired"}), 404
171
- job_state = session.get('job_state')
172
- if not job_state:
173
- return jsonify({"status": "error", "error": "No job state"}), 404
174
- if 'detections' in job_state:
175
- job_state['detections_by_filename'] = job_state['detections']
176
- return jsonify(job_state)
177
-
178
- @app.route('/results/<session_id>/<path:filename>')
179
- def download_file(session_id, filename):
180
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- if 'id' not in session or session['id'] != session_id:
183
- return jsonify({"error": "Session not found or expired"}), 404
 
 
184
  if '..' in filename or filename.startswith('/'):
185
  return jsonify({"error": "Invalid filename"}), 400
186
  safe_filename = secure_filename(filename)
187
- file_dir = Path(app.config['RESULT_FOLDER']) / session_id
188
  file_path = (file_dir / safe_filename).resolve()
189
  if not str(file_path).startswith(str(file_dir.resolve())):
190
  print(f"Attempted path traversal: {session_id}/{filename}")
@@ -235,13 +296,11 @@ def download_file(session_id, filename):
235
  print(error_message)
236
  return jsonify({"error": "Server error", "log": error_message}), 500
237
 
238
- @app.route('/export_images/<session_id>')
239
- def export_images(session_id):
240
  try:
241
-
242
- if 'id' not in session or session['id'] != session_id:
243
- return jsonify({"error": "Session not found or expired"}), 404
244
- job_dir = Path(app.config['RESULT_FOLDER']) / session_id
245
  if not job_dir.exists():
246
  return jsonify({"error": f"Session directory {session_id} not found"}), 404
247
 
@@ -273,7 +332,7 @@ def export_images(session_id):
273
  def export_csv():
274
  try:
275
  data = request.json
276
- session_id = session.get('id')
277
  threshold = float(data.get('confidence', 0.5))
278
  job_state = session.get('job_state')
279
  if not job_state:
@@ -304,7 +363,7 @@ def export_csv():
304
  def export_images_post():
305
  try:
306
  data = request.json
307
- session_id = session.get('id')
308
  threshold = float(data.get('confidence', 0.5))
309
  job_state = session.get('job_state')
310
  if not job_state:
@@ -335,6 +394,8 @@ def export_images_post():
335
  print(error_message)
336
  return jsonify({"error": "Server error", "log": error_message}), 500
337
 
 
 
338
  def print_startup_info():
339
  print("----- NemaQuant Flask App Starting -----")
340
  print(f"Working directory: {os.getcwd()}")
@@ -384,11 +445,14 @@ def ensure_session_id():
384
  if 'id' not in session:
385
  session['id'] = str(uuid.uuid4())
386
 
387
- @app.teardown_appcontext
388
- def cleanup_on_teardown(exception):
389
- # If session is gone, clean up files
390
- if 'id' in session and not session.modified:
391
  cleanup_session(session['id'])
 
 
 
392
 
393
  if __name__ == '__main__':
394
  print_startup_info()
 
6
  import zipfile
7
  import cv2
8
  import csv
9
+ import pickle
10
  import json
11
  import shutil
12
+ from ultralytics import YOLO
13
+ from ultralytics.utils import ThreadingLocked
14
  import numpy as np
15
  import pandas as pd
16
+ from torch import cuda
17
  from flask import Flask, Response, render_template, request, jsonify, send_from_directory, send_file, session, redirect, url_for
18
  from multiprocessing.pool import ThreadPool
 
19
  from pathlib import Path
20
  from PIL import Image
21
  from datetime import datetime
22
  from werkzeug.utils import secure_filename
23
+ from yolo_utils import detect_in_image
24
 
25
  app = Flask(__name__)
26
  app.secret_key = os.environ.get('FLASK_SECRET_KEY', str(uuid.uuid4())) # For session security
27
 
28
  APP_ROOT = Path(__file__).parent
29
  UPLOAD_FOLDER = APP_ROOT / 'uploads'
30
+ RESULTS_FOLDER = APP_ROOT / 'results'
31
  WEIGHTS_FILE = APP_ROOT / 'weights.pt'
32
  app.config['UPLOAD_FOLDER'] = str(UPLOAD_FOLDER)
33
+ app.config['RESULTS_FOLDER'] = str(RESULTS_FOLDER)
34
  app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg', 'tif', 'tiff'}
35
 
36
  UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True)
37
+ RESULTS_FOLDER.mkdir(parents=True, exist_ok=True)
38
 
39
  @app.errorhandler(Exception)
40
  def handle_exception(e):
 
49
  def index():
50
  return render_template('index.html')
51
 
52
+
53
  # Load model once at startup, use CUDA if available
54
+ MODEL_DEVICE = 'cuda' if cuda.is_available() else 'cpu'
55
  _model = None
56
  def get_model():
57
  global _model
58
  if _model is None:
59
+ _model = YOLO(WEIGHTS_FILE)
60
  if MODEL_DEVICE == 'cuda':
61
  _model.to('cuda')
62
  return _model
 
64
  def cleanup_session(session_id):
65
  # Remove files for this session
66
  upload_dir = Path(app.config['UPLOAD_FOLDER']) / session_id
67
+ results_dir = Path(app.config['RESULTS_FOLDER']) / session_id
68
+ for d in [upload_dir, results_dir]:
69
  if d.exists():
70
  shutil.rmtree(d)
71
 
72
  # save the uploaded files
73
+ @app.route('/uploads', methods=['POST'])
74
+ def upload_files():
75
+ session_id = session['id']
76
  files = request.files.getlist('files')
77
  upload_dir = Path(app.config['UPLOAD_FOLDER']) / session_id
78
  # clear out any existing files for the session
79
  if upload_dir.exists():
80
  shutil.rmtree(upload_dir)
81
  upload_dir.mkdir(parents=True, exist_ok=True)
82
+ # generate new unique filenames via uuid, save the mapping dict of old:new to session
83
  filename_map = {}
 
84
  for f in files:
85
  orig_name = secure_filename(f.filename)
86
  ext = Path(orig_name).suffix
 
88
  file_path = upload_dir / unique_name
89
  f.save(str(file_path))
90
  filename_map[orig_name] = unique_name
91
+ session['filename_map'] = filename_map
92
+ return jsonify({'filename_map': filename_map, 'status': 'uploaded'})
93
+
94
+ # helper function to simplify args for pool.imap
95
+ @ThreadingLocked()
96
+ def process_single_image(args):
97
+ orig_name, img_path, pickle_path, model = args
98
+ img_results = detect_in_image(model, str(img_path))
99
+ with open(pickle_path, 'wb') as pf:
100
+ pickle.dump(img_results, pf)
101
+ return (orig_name, img_results)
102
 
103
+ @app.route('/process', methods=['POST'])
104
+ def process_images():
105
  model = get_model()
106
+ session_id = session['id']
107
+ filename_map = session.get('filename_map', {})
108
+ upload_dir = Path(app.config['UPLOAD_FOLDER']) / session_id
109
+ state = {}
110
+ state['status'] = 'starting'
111
+ state['progress'] = 0
112
+ state['filename_map'] = filename_map
113
+ state['jobId'] = session['id']
114
+ session['job_state'] = state
115
+
116
+ # create a results_dir, clean out old one if needed
117
+ results_dir = Path(app.config['RESULTS_FOLDER']) / session_id
118
+ if results_dir.exists():
119
+ shutil.rmtree(results_dir)
120
+ results_dir.mkdir(parents=True)
121
+
122
+ # set up args list for imap
123
+ n_img = len(filename_map)
124
+ arg_list = [(orig_name,
125
+ upload_dir / filename_map[orig_name],
126
+ results_dir / f"{Path(orig_name).stem}_results.pkl",
127
+ model) for orig_name in filename_map.keys()]
128
  try:
129
+ all_detections = {}
130
+ state['status'] = 'processing'
131
+ session['job_state'] = state
 
 
 
132
  if MODEL_DEVICE == 'cuda':
133
  pool = None
134
+ for idx, args in enumerate(arg_list):
135
+ orig_name, img_results = process_single_image(args)
136
+ all_detections[orig_name] = img_results
137
+ state['progress'] = int((idx + 1) / n_img * 100)
138
+ session['job_state'] = state
 
 
 
139
  else:
140
  with ThreadPool() as pool:
141
+ for idx, result in enumerate(pool.imap(process_single_image, arg_list)):
142
+ state['progress'] = int((idx + 1) / n_img * 100)
143
+ orig_name, img_results = result
144
+ all_detections[orig_name] = img_results
145
+ session['job_state'] = state
146
+ # Save all detections to a pickled file
147
+ detections_path = results_dir / 'all_detections.pkl'
148
+ with open(detections_path, 'wb') as f:
149
+ pickle.dump(all_detections, f)
150
+ state['status'] = 'completed'
151
  state['progress'] = 100
152
+ session['job_state'] = state
153
  except Exception as e:
154
+ print(f"Error in /process: {e}")
155
+ print(traceback.format_exc())
156
  state['status'] = 'error'
157
  state['error'] = str(e)
158
  state['progress'] = 100
159
+ session['job_state'] = state
160
+ resp = {
161
+ 'status': state.get('status', 'unknown'),
162
+ 'progress': state.get('progress', 0),
163
+ 'jobId': state.get('jobId'),
164
+ 'error': state.get('error'),
165
+ }
166
+ return jsonify(resp)
167
+
168
+ # Support /progress/<jobId> for frontend polling
169
+ @app.route('/progress/<jobId>')
170
+ def get_progress_with_id(jobId):
171
  try:
172
+ job_state = session.get('job_state')
173
+ if not job_state:
174
+ print(f"/progress/{jobId}: No job_state found in session.")
175
+ return jsonify({"status": "error", "error": "No job state"}), 404
176
+ resp = {
177
+ 'status': job_state.get('status', 'unknown'),
178
+ 'progress': job_state.get('progress', 0),
179
+ 'jobId': session.get('id'),
180
+ 'error': job_state.get('error'),
 
 
 
 
 
 
 
 
181
  }
182
+ # If completed, load and return all_detections.pkl as JSON
183
+ if job_state.get('status') == 'completed':
184
+ session_id = session['id']
185
+ detections_path = Path(app.config['RESULTS_FOLDER']) / session_id / 'all_detections.pkl'
186
+ if detections_path.exists():
187
+ with open(detections_path, 'rb') as f:
188
+ all_detections = pickle.load(f)
189
+ resp['results'] = all_detections
190
+ return jsonify(resp)
191
  except Exception as e:
192
+ print(f"Error in /progress/{jobId}: {e}")
193
  print(traceback.format_exc())
194
+ return jsonify({"status": "error", "error": str(e)}), 500
195
 
196
+ # /annotate route for dynamic annotation
197
+ @app.route('/annotate', methods=['POST'])
198
+ def annotate_image():
 
 
 
 
 
 
 
 
 
 
 
199
  try:
200
+ data = request.get_json()
201
+ filename = data.get('filename')
202
+ confidence = float(data.get('confidence', 0.5))
203
+ session_id = session['id']
204
+ filename_map = session.get('filename_map', {})
205
+ unique_name = filename_map.get(filename)
206
+ if not unique_name:
207
+ return jsonify({'error': 'File not found'}), 404
208
+ # Load detections from pickle
209
+ result_path = Path(app.config['RESULTS_FOLDER']) / session_id / f"{Path(filename).stem}_results.pkl"
210
+ if not result_path.exists():
211
+ return jsonify({'error': 'Results not found'}), 404
212
+ with open(result_path, 'rb') as pf:
213
+ detections = pickle.load(pf)
214
+ # Load image
215
+ img_path = Path(app.config['UPLOAD_FOLDER']) / session_id / unique_name
216
+ img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED)
217
+ # Filter detections
218
+ filtered = [d for d in detections if d.get('score', 0) >= confidence]
219
+ # Draw boxes
220
+ for det in filtered:
221
+ x1, y1, x2, y2 = map(int, det['bbox'])
222
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0,0,255), 3)
223
+ # Save annotated image to temp
224
+ annotated_path = Path(app.config['RESULTS_FOLDER']) / 'annotated'
225
+ annotated_path.mkdir(parents=True, exist_ok=True)
226
+ out_name = f"{Path(filename).stem}_annotated.png"
227
+ out_file = annotated_path / out_name
228
+ cv2.imwrite(str(out_file), img)
229
+ # Serve image
230
+ with open(out_file, 'rb') as f:
231
+ return send_file(
232
+ io.BytesIO(f.read()),
233
+ mimetype='image/png',
234
+ as_attachment=False,
235
+ download_name=out_name
236
+ )
237
+ except Exception as e:
238
+ print(f"Error in /annotate: {e}")
239
+ return jsonify({'error': str(e)}), 500
240
 
241
+ @app.route('/results/<path:filename>')
242
+ def download_file(filename):
243
+ try:
244
+ session_id = session['id']
245
  if '..' in filename or filename.startswith('/'):
246
  return jsonify({"error": "Invalid filename"}), 400
247
  safe_filename = secure_filename(filename)
248
+ file_dir = Path(app.config['RESULTS_FOLDER']) / session_id
249
  file_path = (file_dir / safe_filename).resolve()
250
  if not str(file_path).startswith(str(file_dir.resolve())):
251
  print(f"Attempted path traversal: {session_id}/{filename}")
 
296
  print(error_message)
297
  return jsonify({"error": "Server error", "log": error_message}), 500
298
 
299
+ @app.route('/export_images')
300
+ def export_images():
301
  try:
302
+ session_id = session['id']
303
+ job_dir = Path(app.config['RESULTS_FOLDER']) / session_id
 
 
304
  if not job_dir.exists():
305
  return jsonify({"error": f"Session directory {session_id} not found"}), 404
306
 
 
332
  def export_csv():
333
  try:
334
  data = request.json
335
+ session_id = session['id']
336
  threshold = float(data.get('confidence', 0.5))
337
  job_state = session.get('job_state')
338
  if not job_state:
 
363
  def export_images_post():
364
  try:
365
  data = request.json
366
+ session_id = session['id']
367
  threshold = float(data.get('confidence', 0.5))
368
  job_state = session.get('job_state')
369
  if not job_state:
 
394
  print(error_message)
395
  return jsonify({"error": "Server error", "log": error_message}), 500
396
 
397
+
398
+
399
  def print_startup_info():
400
  print("----- NemaQuant Flask App Starting -----")
401
  print(f"Working directory: {os.getcwd()}")
 
445
  if 'id' not in session:
446
  session['id'] = str(uuid.uuid4())
447
 
448
+ # Explicit endpoint for safe session cleanup
449
+ @app.route('/cleanup', methods=['POST'])
450
+ def cleanup_endpoint():
451
+ if 'id' in session:
452
  cleanup_session(session['id'])
453
+ session.clear()
454
+ return jsonify({'status': 'cleaned up'})
455
+ return jsonify({'error': 'No session to clean up'}), 400
456
 
457
  if __name__ == '__main__':
458
  print_startup_info()
static/script.js CHANGED
@@ -35,6 +35,7 @@ document.addEventListener('DOMContentLoaded', () => {
35
  const MAX_ZOOM = 3;
36
  const MIN_ZOOM = 0.5;
37
  let progressInterval = null; // Interval timer for polling
 
38
 
39
  // Panning variables
40
  let isPanning = false;
@@ -256,8 +257,34 @@ document.addEventListener('DOMContentLoaded', () => {
256
  fileInput.click();
257
  });
258
 
259
- fileInput.addEventListener('change', () => {
260
  handleFiles(fileInput.files);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  });
262
 
263
  // Input mode change
@@ -279,6 +306,24 @@ document.addEventListener('DOMContentLoaded', () => {
279
  } else {
280
  uploadText.textContent = `${validFileCount} image${validFileCount === 1 ? '' : 's'} selected`;
281
  startProcessingBtn.disabled = validFileCount === 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  }
283
  }
284
 
@@ -345,6 +390,7 @@ document.addEventListener('DOMContentLoaded', () => {
345
  }
346
  // --- ASYNC JOB: Start polling for progress ---
347
  if (data.jobId) {
 
348
  currentJobId = data.jobId;
349
  pollProgress(currentJobId);
350
  } else {
@@ -365,27 +411,22 @@ document.addEventListener('DOMContentLoaded', () => {
365
  // --- Filtering and Table Update ---
366
  function updateResultsTable() {
367
  const threshold = parseFloat(confidenceSlider.value);
368
- // Use allDetections (now an array of {filename, detections})
369
- const grouped = {};
370
- allDetections.forEach(imgResult => {
371
  const filtered = imgResult.detections.filter(det => det.score >= threshold);
372
- grouped[imgResult.filename] = filtered;
 
 
 
 
373
  });
374
- // Build results for table
375
- const prevFilename = (currentImageIndex >= 0 && currentResults[currentImageIndex]) ? currentResults[currentImageIndex].filename : null;
376
- currentResults = Object.keys(grouped).map(filename => ({
377
- filename,
378
- num_eggs: grouped[filename].length,
379
- detections: grouped[filename]
380
- }));
381
  resultsTableBody.innerHTML = '';
382
  currentSortField = null;
383
  currentSortDirection = 'asc';
384
  totalPages = Math.ceil(currentResults.length / RESULTS_PER_PAGE);
385
  currentPage = 1;
386
  displayResultsPage(currentPage);
387
- exportCsvBtn.disabled = true;
388
- exportImagesBtn.disabled = true;
389
  // Try to restore previous image if it still exists
390
  let newIndex = 0;
391
  if (prevFilename) {
@@ -406,18 +447,39 @@ document.addEventListener('DOMContentLoaded', () => {
406
  }
407
 
408
  confidenceSlider.addEventListener('input', () => {
409
- confidenceValue.textContent = confidenceSlider.value;
410
- updateResultsTable();
411
- renderConfidencePlot();
412
- if (currentImageIndex >= 0) displayImage(currentImageIndex);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  });
414
 
415
  // --- Replace displayImage to use backend-annotated PNG ---
416
  async function displayImage(index) {
417
- if (!currentResults[index]) return;
 
 
 
 
 
 
 
 
418
  currentImageIndex = index;
419
- const result = currentResults[index];
420
- const filename = result.filename;
421
  const confidence = parseFloat(confidenceSlider.value);
422
  try {
423
  const response = await fetch('/annotate', {
@@ -454,7 +516,7 @@ document.addEventListener('DOMContentLoaded', () => {
454
  clearPreview();
455
  }
456
  prevBtn.disabled = index <= 0;
457
- nextBtn.disabled = index >= currentResults.length - 1;
458
  }
459
 
460
  // --- New Polling Function ---
@@ -463,65 +525,65 @@ document.addEventListener('DOMContentLoaded', () => {
463
  clearInterval(progressInterval); // Clear any existing timer
464
  }
465
 
466
- const totalImages = Array.from(fileInput.files).filter(file => {
467
- const allowedTypes = ['image/png', 'image/jpeg', 'image/jpg', 'image/tiff', 'image/tif'];
468
- if (inputMode.value === 'folder' || inputMode.value === 'keyence') {
469
- return allowedTypes.includes(file.type) &&
470
- file.webkitRelativePath &&
471
- !file.webkitRelativePath.startsWith('.');
472
- }
473
- return allowedTypes.includes(file.type);
474
- }).length;
475
-
476
  progressInterval = setInterval(async () => {
477
  try {
478
  const response = await fetch(`/progress/${jobId}`);
479
  if (!response.ok) {
480
- // Handle cases where the progress endpoint itself fails
481
- let errorText = `Progress check failed: ${response.status}`;
482
- try {
483
- const errorData = await response.json();
484
- errorText += `: ${errorData.error || 'Unknown progress error'}`;
485
- } catch(e) { errorText += ` - ${response.statusText}`; }
486
  throw new Error(errorText);
487
  }
488
 
489
  const data = await response.json();
490
-
491
- // Update UI based on status
492
- updateProgress(data.progress || 0, data.status.charAt(0).toUpperCase() + data.status.slice(1));
493
-
494
- // Update image counter based on progress percentage
495
- if (data.status === 'Running' || data.status === 'Starting') {
496
- const processedImages = Math.floor((data.progress / 90) * totalImages); // 90 is the max progress before completion
497
- imageCounter.textContent = `${processedImages} of ${totalImages} images`;
498
- }
499
-
500
- if (data.status === 'success') {
501
- clearInterval(progressInterval);
502
- progressInterval = null;
503
- updateProgress(100, 'Processing complete');
504
- logStatus("Processing finished successfully.");
505
- displayResults(data.results || [], data.detections_by_filename || null);
506
- renderConfidencePlot();
507
- setLoading(false);
508
- } else if (data.status === 'error') {
509
- clearInterval(progressInterval);
510
- progressInterval = null;
511
- logStatus(`Error during processing: ${data.error || 'Unknown error'}`);
512
- updateProgress(data.progress || 100, 'Error'); // Show error state on progress bar
513
- setLoading(false);
514
- } else if (data.status === 'running' || data.status === 'starting') {
515
- // Continue polling
516
- // logStatus(`Processing status: ${data.status} (${data.progress}%)`); // Removed for cleaner log
517
- // You could add snippets from data.log here if the backend provides useful intermediate logs
518
- } else {
519
- // Unknown status - stop polling to prevent infinite loops
520
- clearInterval(progressInterval);
521
- progressInterval = null;
522
- logStatus(`Warning: Unknown job status received: ${data.status}. Stopping progress updates.`);
523
- updateProgress(data.progress || 0, `Unknown (${data.status})`);
524
- setLoading(false);
 
 
 
 
 
 
 
 
 
 
 
525
  }
526
 
527
  } catch (error) {
@@ -531,7 +593,7 @@ document.addEventListener('DOMContentLoaded', () => {
531
  updateProgress(0, 'Polling Error');
532
  setLoading(false);
533
  }
534
- }, 2000); // Poll every 2 seconds
535
  }
536
 
537
  // --- UI Update Functions ---
@@ -614,49 +676,40 @@ document.addEventListener('DOMContentLoaded', () => {
614
  }
615
 
616
  // Results Display
617
- function displayResults(results, detectionsByFilename) {
618
- // If detectionsByFilename is provided, update allDetections for plot logic
619
- if (detectionsByFilename) {
620
- allDetections = Object.keys(detectionsByFilename).map(filename => ({
621
- filename,
622
- detections: detectionsByFilename[filename]
623
- }));
624
- } else {
625
- // Fallback: use results array
626
- allDetections = results.map(r => ({ filename: r.filename, detections: r.detections }));
627
- }
628
- // Now update the table
629
- currentResults = results;
630
  resultsTableBody.innerHTML = '';
631
  currentImageIndex = -1;
632
  currentSortField = null;
633
  currentSortDirection = 'asc';
634
 
635
- if (!results || results.length === 0) {
636
- logStatus("No results to display");
637
- clearPreview();
 
 
 
 
 
 
638
  return;
639
  }
640
 
641
- // Reset sort indicators
642
- document.querySelectorAll('.results-table th[data-sort]').forEach(h => {
643
- h.classList.remove('sort-asc', 'sort-desc');
644
- });
645
-
646
- // Calculate pagination
647
- totalPages = Math.ceil(results.length / RESULTS_PER_PAGE);
648
- currentPage = 1;
649
-
650
- // Display current page
651
- displayResultsPage(currentPage);
652
-
653
- // Enable export buttons
654
- exportCsvBtn.disabled = false;
655
- exportImagesBtn.disabled = false;
656
-
657
- // Show first image
658
- displayImage(0);
659
- logStatus(`Displayed ${results.length} results`);
660
  }
661
 
662
  // Display a specific page of results
@@ -671,31 +724,33 @@ document.addEventListener('DOMContentLoaded', () => {
671
  for (let i = startIndex; i < endIndex; i++) {
672
  const result = currentResults[i];
673
  const row = resultsTableBody.insertRow();
674
- row.innerHTML = `
675
- <td>
676
- <i class="ri-image-line"></i>
677
- ${result.filename}
678
- </td>
679
- <td>${result.num_eggs}</td>
680
- <td class="text-right">
681
- <button class="view-button" title="Click to view image">
682
- <i class="ri-eye-line"></i>
683
- View
684
- </button>
685
- </td>
686
- `;
687
-
688
- // Store the original index to maintain image preview relationship
689
- row.dataset.originalIndex = currentResults.indexOf(result);
690
-
691
- row.addEventListener('click', () => {
692
- const originalIndex = parseInt(row.dataset.originalIndex);
693
- displayImage(originalIndex);
694
- document.querySelectorAll('.results-table tr').forEach(r => r.classList.remove('selected'));
695
- row.classList.add('selected');
 
 
 
696
  });
697
- }
698
-
699
  // Update pagination UI
700
  updatePaginationControls();
701
  }
 
35
  const MAX_ZOOM = 3;
36
  const MIN_ZOOM = 0.5;
37
  let progressInterval = null; // Interval timer for polling
38
+ let filteredValidFiles = [];
39
 
40
  // Panning variables
41
  let isPanning = false;
 
257
  fileInput.click();
258
  });
259
 
260
+ fileInput.addEventListener('change', async () => {
261
  handleFiles(fileInput.files);
262
+ if (filteredValidFiles && filteredValidFiles.length > 0) {
263
+ // Prepare FormData for upload
264
+ const formData = new FormData();
265
+ filteredValidFiles.forEach(f => formData.append('files', f));
266
+ try {
267
+ const response = await fetch('/uploads', {
268
+ method: 'POST',
269
+ body: formData
270
+ });
271
+ if (response.ok) {
272
+ const data = await response.json();
273
+ logStatus('Files uploaded successfully.');
274
+ // Update results table with filenames
275
+ resultsTableBody.innerHTML = '';
276
+ Object.keys(data.filename_map).forEach((filename, idx) => {
277
+ const row = resultsTableBody.insertRow();
278
+ row.dataset.originalIndex = idx;
279
+ row.innerHTML = `<td>${filename}</td><td>Pending</td>`;
280
+ });
281
+ } else {
282
+ logStatus('File upload failed.');
283
+ }
284
+ } catch (err) {
285
+ logStatus('Error uploading files: ' + err);
286
+ }
287
+ }
288
  });
289
 
290
  // Input mode change
 
306
  } else {
307
  uploadText.textContent = `${validFileCount} image${validFileCount === 1 ? '' : 's'} selected`;
308
  startProcessingBtn.disabled = validFileCount === 0;
309
+ // Populate results table with filenames after upload
310
+ resultsTableBody.innerHTML = '';
311
+ filteredValidFiles.forEach((file, idx) => {
312
+ const row = resultsTableBody.insertRow();
313
+ row.dataset.originalIndex = idx;
314
+ row.innerHTML = `
315
+ <td>${file.name}</td>
316
+ <td>Pending</td>
317
+ <td><button class="view-button" data-index="${idx}">View</button></td>
318
+ `;
319
+ });
320
+ // Add click event for View buttons
321
+ resultsTableBody.querySelectorAll('.view-button').forEach(btn => {
322
+ btn.addEventListener('click', (e) => {
323
+ const idx = parseInt(btn.dataset.index, 10);
324
+ displayImage(idx);
325
+ });
326
+ });
327
  }
328
  }
329
 
 
390
  }
391
  // --- ASYNC JOB: Start polling for progress ---
392
  if (data.jobId) {
393
+ logStatus(`Processing started. Job ID: ${data.jobId}`);
394
  currentJobId = data.jobId;
395
  pollProgress(currentJobId);
396
  } else {
 
411
  // --- Filtering and Table Update ---
412
  function updateResultsTable() {
413
  const threshold = parseFloat(confidenceSlider.value);
414
+ // Use allDetections (array of {filename, detections}) for filtering
415
+ const prevFilename = (currentImageIndex >= 0 && currentResults[currentImageIndex]) ? currentResults[currentImageIndex].filename : null;
416
+ currentResults = allDetections.map(imgResult => {
417
  const filtered = imgResult.detections.filter(det => det.score >= threshold);
418
+ return {
419
+ filename: imgResult.filename,
420
+ num_eggs: filtered.length,
421
+ detections: filtered
422
+ };
423
  });
 
 
 
 
 
 
 
424
  resultsTableBody.innerHTML = '';
425
  currentSortField = null;
426
  currentSortDirection = 'asc';
427
  totalPages = Math.ceil(currentResults.length / RESULTS_PER_PAGE);
428
  currentPage = 1;
429
  displayResultsPage(currentPage);
 
 
430
  // Try to restore previous image if it still exists
431
  let newIndex = 0;
432
  if (prevFilename) {
 
447
  }
448
 
449
  confidenceSlider.addEventListener('input', () => {
450
+ confidenceValue.textContent = confidenceSlider.value;
451
+ // 1. Update total eggs detected
452
+ const totalEggsElem = document.getElementById('total-eggs-count');
453
+ if (totalEggsElem && allDetections.length > 0) {
454
+ const threshold = parseFloat(confidenceSlider.value);
455
+ const totalEggs = allDetections.reduce((sum, imgResult) => sum + imgResult.detections.filter(det => det.score >= threshold).length, 0);
456
+ totalEggsElem.textContent = totalEggs;
457
+ }
458
+
459
+ // 2. Redraw confidence plot
460
+ renderConfidencePlot();
461
+
462
+ // 3. Update results table
463
+ updateResultsTable();
464
+
465
+ // 4. Update eggs detected under image preview
466
+ if (currentImageIndex >= 0) {
467
+ displayImage(currentImageIndex);
468
+ }
469
  });
470
 
471
  // --- Replace displayImage to use backend-annotated PNG ---
472
  async function displayImage(index) {
473
+ // Use either currentResults or filteredValidFiles for filename
474
+ let filename;
475
+ if (currentResults && currentResults[index] && currentResults[index].filename) {
476
+ filename = currentResults[index].filename;
477
+ } else if (filteredValidFiles && filteredValidFiles[index]) {
478
+ filename = filteredValidFiles[index].name;
479
+ } else {
480
+ return;
481
+ }
482
  currentImageIndex = index;
 
 
483
  const confidence = parseFloat(confidenceSlider.value);
484
  try {
485
  const response = await fetch('/annotate', {
 
516
  clearPreview();
517
  }
518
  prevBtn.disabled = index <= 0;
519
+ nextBtn.disabled = index >= (currentResults.length > 0 ? currentResults.length : filteredValidFiles.length) - 1;
520
  }
521
 
522
  // --- New Polling Function ---
 
525
  clearInterval(progressInterval); // Clear any existing timer
526
  }
527
 
 
 
 
 
 
 
 
 
 
 
528
  progressInterval = setInterval(async () => {
529
  try {
530
  const response = await fetch(`/progress/${jobId}`);
531
  if (!response.ok) {
532
+ let errorText = `Progress check failed: ${response.status}`;
533
+ try {
534
+ const errorData = await response.json();
535
+ errorText += `: ${errorData.error || 'Unknown progress error'}`;
536
+ } catch(e) { errorText += ` - ${response.statusText}`; }
 
537
  throw new Error(errorText);
538
  }
539
 
540
  const data = await response.json();
541
+ const status = (data.status || '').toLowerCase();
542
+
543
+ switch (status) {
544
+ case 'starting':
545
+ updateProgress(data.progress || 0, 'Starting...');
546
+ logStatus('Job is starting.');
547
+ break;
548
+ case 'processing':
549
+ updateProgress(data.progress || 0, `Processing (${data.progress || 0}%)`);
550
+ logStatus(`Processing images... (${data.progress || 0}%)`);
551
+ // If results are present, update detections and table
552
+ if (data.results) {
553
+ allDetections = Object.entries(data.results).map(([filename, detections]) => ({ filename, detections }));
554
+ updateResultsTable();
555
+ }
556
+ break;
557
+ case 'completed':
558
+ clearInterval(progressInterval);
559
+ progressInterval = null;
560
+ updateProgress(100, 'Processing complete');
561
+ logStatus('Processing finished successfully.');
562
+ if (data.results) {
563
+ allDetections = Object.entries(data.results).map(([filename, detections]) => ({ filename, detections }));
564
+ updateResultsTable();
565
+ }
566
+ renderConfidencePlot();
567
+ setLoading(false);
568
+ break;
569
+ case 'error':
570
+ clearInterval(progressInterval);
571
+ progressInterval = null;
572
+ logStatus(`Error during processing: ${data.error || 'Unknown error'}`);
573
+ updateProgress(data.progress || 100, 'Error');
574
+ setLoading(false);
575
+ break;
576
+ case 'unknown':
577
+ clearInterval(progressInterval);
578
+ progressInterval = null;
579
+ logStatus('Unknown job status. Stopping progress updates.');
580
+ updateProgress(data.progress || 0, 'Unknown status');
581
+ setLoading(false);
582
+ break;
583
+ default:
584
+ logStatus(`Status: ${data.status}`);
585
+ updateProgress(data.progress || 0, data.status ? data.status.charAt(0).toUpperCase() + data.status.slice(1) : '');
586
+ break;
587
  }
588
 
589
  } catch (error) {
 
593
  updateProgress(0, 'Polling Error');
594
  setLoading(false);
595
  }
596
+ }, 1000); // Poll every 1 second
597
  }
598
 
599
  // --- UI Update Functions ---
 
676
  }
677
 
678
  // Results Display
679
+ function displayResults(jobStatus, filenameMap, resultsObj) {
 
 
 
 
 
 
 
 
 
 
 
 
680
  resultsTableBody.innerHTML = '';
681
  currentImageIndex = -1;
682
  currentSortField = null;
683
  currentSortDirection = 'asc';
684
 
685
+ // If job is not completed, show filenames only
686
+ if (jobStatus !== 'completed') {
687
+ Object.keys(filenameMap).forEach((filename, idx) => {
688
+ const row = resultsTableBody.insertRow();
689
+ row.innerHTML = `<td>${filename}</td><td>Pending</td>`;
690
+ });
691
+ exportCsvBtn.disabled = true;
692
+ exportImagesBtn.disabled = true;
693
+ logStatus('Waiting for job to complete...');
694
  return;
695
  }
696
 
697
+ // If job is completed, show filtered detection counts
698
+ if (resultsObj) {
699
+ Object.keys(resultsObj).forEach((filename, idx) => {
700
+ const detections = resultsObj[filename] || [];
701
+ // Filter by confidence threshold
702
+ const threshold = parseFloat(confidenceSlider.value);
703
+ const filtered = detections.filter(d => d.score >= threshold);
704
+ const row = resultsTableBody.insertRow();
705
+ row.innerHTML = `<td>${filename}</td><td>${filtered.length}</td>`;
706
+ });
707
+ exportCsvBtn.disabled = false;
708
+ exportImagesBtn.disabled = false;
709
+ logStatus('Job completed. Results displayed.');
710
+ } else {
711
+ logStatus('No results found.');
712
+ }
 
 
 
713
  }
714
 
715
  // Display a specific page of results
 
724
  for (let i = startIndex; i < endIndex; i++) {
725
  const result = currentResults[i];
726
  const row = resultsTableBody.insertRow();
727
+ row.innerHTML = `
728
+ <td>
729
+ <i class="ri-image-line"></i>
730
+ ${result.filename}
731
+ </td>
732
+ <td>${result.num_eggs}</td>
733
+ <td class="text-right">
734
+ <button class="view-button" data-index="${i}" title="Click to view image">
735
+ <i class="ri-eye-line"></i>
736
+ View
737
+ </button>
738
+ </td>
739
+ `;
740
+ // Store the original index to maintain image preview relationship
741
+ row.dataset.originalIndex = i;
742
+ }
743
+ // Wire up View buttons after rows are created
744
+ resultsTableBody.querySelectorAll('.view-button').forEach(btn => {
745
+ btn.addEventListener('click', (e) => {
746
+ e.stopPropagation();
747
+ const idx = parseInt(btn.getAttribute('data-index'));
748
+ displayImage(idx);
749
+ // Highlight selected row
750
+ resultsTableBody.querySelectorAll('tr').forEach(r => r.classList.remove('selected'));
751
+ btn.closest('tr').classList.add('selected');
752
  });
753
+ });
 
754
  // Update pagination UI
755
  updatePaginationControls();
756
  }
yolo_utils.py CHANGED
@@ -1,14 +1,7 @@
1
- from ultralytics import YOLO
2
- import cv2
3
- import numpy as np
4
 
5
- def load_model(weights_path):
6
- return YOLO(weights_path)
7
-
8
- def detect_image(model, image_bytes, conf=0.05):
9
- # image_bytes: bytes-like object (from Flask upload)
10
- arr = np.frombuffer(image_bytes, np.uint8)
11
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
12
  results = model.predict(img, imgsz=1440, max_det=1000, verbose=False, conf=conf)
13
  result = results[0]
14
  detections = []
 
1
+ from PIL import Image
 
 
2
 
3
+ def detect_in_image(model, im_path, conf=0.05):
4
+ img = Image.open(im_path)
 
 
 
 
 
5
  results = model.predict(img, imgsz=1440, max_det=1000, verbose=False, conf=conf)
6
  result = results[0]
7
  detections = []