Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import traceback | |
| import sys | |
| import io | |
| import zipfile | |
| import cv2 | |
| import csv | |
| import pickle | |
| import json | |
| import shutil | |
| from ultralytics import YOLO | |
| from ultralytics.utils import ThreadingLocked | |
| import numpy as np | |
| import pandas as pd | |
| from torch import cuda | |
| from flask import Flask, Response, render_template, request, jsonify, send_from_directory, send_file, session, redirect, url_for | |
| from multiprocessing.pool import ThreadPool | |
| from pathlib import Path | |
| from PIL import Image | |
| from datetime import datetime | |
| from werkzeug.utils import secure_filename | |
| from yolo_utils import detect_in_image | |
| app = Flask(__name__) | |
| app.secret_key = os.environ.get('FLASK_SECRET_KEY', str(uuid.uuid4())) # For session security | |
| APP_ROOT = Path(__file__).parent | |
| UPLOAD_FOLDER = APP_ROOT / 'uploads' | |
| RESULTS_FOLDER = APP_ROOT / 'results' | |
| WEIGHTS_FILE = APP_ROOT / 'weights.pt' | |
| app.config['UPLOAD_FOLDER'] = str(UPLOAD_FOLDER) | |
| app.config['RESULTS_FOLDER'] = str(RESULTS_FOLDER) | |
| app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg', 'tif', 'tiff'} | |
| UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True) | |
| RESULTS_FOLDER.mkdir(parents=True, exist_ok=True) | |
| def handle_exception(e): | |
| print(f"Unhandled exception: {str(e)}") | |
| print(traceback.format_exc()) | |
| return jsonify({"error": "Server error", "log": str(e)}), 500 | |
| def allowed_file(filename): | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] | |
| def index(): | |
| return render_template('index.html') | |
| # Load model once at startup, use CUDA if available | |
| MODEL_DEVICE = 'cuda' if cuda.is_available() else 'cpu' | |
| _model = None | |
| def get_model(): | |
| global _model | |
| if _model is None: | |
| _model = YOLO(WEIGHTS_FILE) | |
| if MODEL_DEVICE == 'cuda': | |
| _model.to('cuda') | |
| return _model | |
| def cleanup_session(session_id): | |
| # Remove files for this session | |
| upload_dir = Path(app.config['UPLOAD_FOLDER']) / session_id | |
| results_dir = Path(app.config['RESULTS_FOLDER']) / session_id | |
| for d in [upload_dir, results_dir]: | |
| if d.exists(): | |
| shutil.rmtree(d) | |
| # save the uploaded files | |
| def upload_files(): | |
| session_id = session['id'] | |
| files = request.files.getlist('files') | |
| upload_dir = Path(app.config['UPLOAD_FOLDER']) / session_id | |
| # clear out any existing files for the session | |
| if upload_dir.exists(): | |
| shutil.rmtree(upload_dir) | |
| upload_dir.mkdir(parents=True, exist_ok=True) | |
| # generate new unique filenames via uuid, save the mapping dict of old:new to session | |
| filename_map = {} | |
| for f in files: | |
| orig_name = secure_filename(f.filename) | |
| ext = Path(orig_name).suffix | |
| unique_name = f"{uuid.uuid4().hex}{ext}" | |
| file_path = upload_dir / unique_name | |
| f.save(str(file_path)) | |
| filename_map[orig_name] = unique_name | |
| session['filename_map'] = filename_map | |
| return jsonify({'filename_map': filename_map, 'status': 'uploaded'}) | |
| # helper function to simplify args for pool.imap | |
| def process_single_image(args): | |
| orig_name, img_path, pickle_path, model = args | |
| img_results = detect_in_image(model, str(img_path)) | |
| with open(pickle_path, 'wb') as pf: | |
| pickle.dump(img_results, pf) | |
| return (orig_name, img_results) | |
| def process_images(): | |
| model = get_model() | |
| session_id = session['id'] | |
| filename_map = session.get('filename_map', {}) | |
| upload_dir = Path(app.config['UPLOAD_FOLDER']) / session_id | |
| state = {} | |
| state['status'] = 'starting' | |
| state['progress'] = 0 | |
| state['filename_map'] = filename_map | |
| state['jobId'] = session['id'] | |
| session['job_state'] = state | |
| # create a results_dir, clean out old one if needed | |
| results_dir = Path(app.config['RESULTS_FOLDER']) / session_id | |
| if results_dir.exists(): | |
| shutil.rmtree(results_dir) | |
| results_dir.mkdir(parents=True) | |
| # set up args list for imap | |
| n_img = len(filename_map) | |
| arg_list = [(orig_name, | |
| upload_dir / filename_map[orig_name], | |
| results_dir / f"{Path(orig_name).stem}_results.pkl", | |
| model) for orig_name in filename_map.keys()] | |
| try: | |
| all_detections = {} | |
| state['status'] = 'processing' | |
| session['job_state'] = state | |
| if MODEL_DEVICE == 'cuda': | |
| pool = None | |
| for idx, args in enumerate(arg_list): | |
| orig_name, img_results = process_single_image(args) | |
| all_detections[orig_name] = img_results | |
| state['progress'] = int((idx + 1) / n_img * 100) | |
| session['job_state'] = state | |
| else: | |
| with ThreadPool() as pool: | |
| for idx, result in enumerate(pool.imap(process_single_image, arg_list)): | |
| state['progress'] = int((idx + 1) / n_img * 100) | |
| orig_name, img_results = result | |
| all_detections[orig_name] = img_results | |
| session['job_state'] = state | |
| # Save all detections to a pickled file | |
| detections_path = results_dir / 'all_detections.pkl' | |
| with open(detections_path, 'wb') as f: | |
| pickle.dump(all_detections, f) | |
| state['status'] = 'completed' | |
| state['progress'] = 100 | |
| session['job_state'] = state | |
| except Exception as e: | |
| print(f"Error in /process: {e}") | |
| print(traceback.format_exc()) | |
| state['status'] = 'error' | |
| state['error'] = str(e) | |
| state['progress'] = 100 | |
| session['job_state'] = state | |
| resp = { | |
| 'status': state.get('status', 'unknown'), | |
| 'progress': state.get('progress', 0), | |
| 'jobId': state.get('jobId'), | |
| 'error': state.get('error'), | |
| } | |
| return jsonify(resp) | |
| # Support /progress/<jobId> for frontend polling | |
| def get_progress_with_id(jobId): | |
| try: | |
| job_state = session.get('job_state') | |
| if not job_state: | |
| print(f"/progress/{jobId}: No job_state found in session.") | |
| return jsonify({"status": "error", "error": "No job state"}), 404 | |
| resp = { | |
| 'status': job_state.get('status', 'unknown'), | |
| 'progress': job_state.get('progress', 0), | |
| 'jobId': session.get('id'), | |
| 'error': job_state.get('error'), | |
| } | |
| # If completed, load and return all_detections.pkl as JSON | |
| if job_state.get('status') == 'completed': | |
| session_id = session['id'] | |
| detections_path = Path(app.config['RESULTS_FOLDER']) / session_id / 'all_detections.pkl' | |
| if detections_path.exists(): | |
| with open(detections_path, 'rb') as f: | |
| all_detections = pickle.load(f) | |
| resp['results'] = all_detections | |
| return jsonify(resp) | |
| except Exception as e: | |
| print(f"Error in /progress/{jobId}: {e}") | |
| print(traceback.format_exc()) | |
| return jsonify({"status": "error", "error": str(e)}), 500 | |
| # /annotate route for dynamic annotation | |
| def annotate_image(): | |
| try: | |
| data = request.get_json() | |
| filename = data.get('filename') | |
| confidence = float(data.get('confidence', 0.5)) | |
| session_id = session['id'] | |
| filename_map = session.get('filename_map', {}) | |
| unique_name = filename_map.get(filename) | |
| if not unique_name: | |
| return jsonify({'error': 'File not found'}), 404 | |
| # Load detections from pickle | |
| result_path = Path(app.config['RESULTS_FOLDER']) / session_id / f"{Path(filename).stem}_results.pkl" | |
| if not result_path.exists(): | |
| return jsonify({'error': 'Results not found'}), 404 | |
| with open(result_path, 'rb') as pf: | |
| detections = pickle.load(pf) | |
| # Load image | |
| img_path = Path(app.config['UPLOAD_FOLDER']) / session_id / unique_name | |
| img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED) | |
| # Filter detections | |
| filtered = [d for d in detections if d.get('score', 0) >= confidence] | |
| # Draw boxes | |
| for det in filtered: | |
| x1, y1, x2, y2 = map(int, det['bbox']) | |
| cv2.rectangle(img, (x1, y1), (x2, y2), (0,0,255), 3) | |
| # Save annotated image to temp | |
| annotated_path = Path(app.config['RESULTS_FOLDER']) / 'annotated' | |
| annotated_path.mkdir(parents=True, exist_ok=True) | |
| out_name = f"{Path(filename).stem}_annotated.png" | |
| out_file = annotated_path / out_name | |
| cv2.imwrite(str(out_file), img) | |
| # Serve image | |
| with open(out_file, 'rb') as f: | |
| return send_file( | |
| io.BytesIO(f.read()), | |
| mimetype='image/png', | |
| as_attachment=False, | |
| download_name=out_name | |
| ) | |
| except Exception as e: | |
| print(f"Error in /annotate: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def download_file(filename): | |
| try: | |
| session_id = session['id'] | |
| if '..' in filename or filename.startswith('/'): | |
| return jsonify({"error": "Invalid filename"}), 400 | |
| safe_filename = secure_filename(filename) | |
| file_dir = Path(app.config['RESULTS_FOLDER']) / session_id | |
| file_path = (file_dir / safe_filename).resolve() | |
| if not str(file_path).startswith(str(file_dir.resolve())): | |
| print(f"Attempted path traversal: {session_id}/{filename}") | |
| return jsonify({"error": "Invalid file path"}), 400 | |
| if not file_path.is_file(): | |
| if not file_dir.exists(): | |
| return jsonify({"error": f"Session directory {session_id} not found"}), 404 | |
| files_in_dir = list(file_dir.iterdir()) | |
| return jsonify({"error": f"File '{filename}' not found in session '{session_id}'. Available: {[f.name for f in files_in_dir]}"}), 404 | |
| if filename.lower().endswith(('.tif', '.tiff')): | |
| try: | |
| with Image.open(file_path) as img: | |
| img = img.convert('RGBA') if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info) else img.convert('RGB') | |
| img_byte_arr = io.BytesIO() | |
| img.save(img_byte_arr, format='PNG') | |
| img_byte_arr.seek(0) | |
| return send_file( | |
| img_byte_arr, | |
| mimetype='image/png', | |
| as_attachment=False, | |
| download_name=f"{Path(filename).stem}.png" | |
| ) | |
| except Exception as e: | |
| print(f"Error converting TIF to PNG: {e}") | |
| return jsonify({"error": "Could not convert TIF image"}), 500 | |
| mime_type = None | |
| if safe_filename.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| try: | |
| with Image.open(file_path) as img: | |
| mime_type = 'image/jpeg' if img.format == 'JPEG' else 'image/png' | |
| except Exception as img_err: | |
| print(f"Could not determine MIME type for {safe_filename}: {img_err}") | |
| if safe_filename.lower() == "results.csv": | |
| mime_type = 'text/csv' | |
| return send_file( | |
| str(file_path), | |
| mimetype=mime_type, | |
| as_attachment=True, | |
| download_name=safe_filename | |
| ) | |
| return send_file(str(file_path), mimetype=mime_type) | |
| except Exception as e: | |
| error_message = f"File serving error: {str(e)}" | |
| print(error_message) | |
| return jsonify({"error": "Server error", "log": error_message}), 500 | |
| def export_images(): | |
| try: | |
| session_id = session['id'] | |
| job_dir = Path(app.config['RESULTS_FOLDER']) / session_id | |
| if not job_dir.exists(): | |
| return jsonify({"error": f"Session directory {session_id} not found"}), 404 | |
| annotated_files = list(job_dir.glob('*_annotated.*')) | |
| if not annotated_files: | |
| return jsonify({"error": "No annotated images found"}), 404 | |
| memory_file = io.BytesIO() | |
| with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf: | |
| for file_path in annotated_files: | |
| zf.write(file_path, file_path.name) | |
| memory_file.seek(0) | |
| timestamp = datetime.now().strftime('%Y%m%d-%H%M%S') | |
| return send_file( | |
| memory_file, | |
| mimetype='application/zip', | |
| as_attachment=True, | |
| download_name=f'nemaquant_annotated_{timestamp}.zip' | |
| ) | |
| except Exception as e: | |
| error_message = f"Error exporting images: {str(e)}" | |
| print(error_message) | |
| return jsonify({"error": "Server error", "log": error_message}), 500 | |
| def export_csv(): | |
| try: | |
| data = request.json | |
| session_id = session['id'] | |
| threshold = float(data.get('confidence', 0.5)) | |
| job_state = session.get('job_state') | |
| if not job_state: | |
| return jsonify({'error': 'Job not found'}), 404 | |
| rows = [] | |
| for orig_name, detections in job_state['detections'].items(): | |
| count = sum(1 for d in detections if d['score'] >= threshold) | |
| rows.append({'Filename': orig_name, 'EggsDetected': count}) | |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| output = io.StringIO() | |
| writer = csv.DictWriter(output, fieldnames=['Filename', 'EggsDetected']) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| output.seek(0) | |
| return Response( | |
| output.getvalue(), | |
| mimetype='text/csv', | |
| headers={ | |
| 'Content-Disposition': f'attachment; filename=nemaquant_results_{timestamp}.csv' | |
| } | |
| ) | |
| except Exception as e: | |
| error_message = f"Error exporting CSV: {str(e)}" | |
| print(error_message) | |
| return jsonify({"error": "Server error", "log": error_message}), 500 | |
| def export_images_post(): | |
| try: | |
| data = request.json | |
| session_id = session['id'] | |
| threshold = float(data.get('confidence', 0.5)) | |
| job_state = session.get('job_state') | |
| if not job_state: | |
| return jsonify({'error': 'Job not found'}), 404 | |
| memory_file = io.BytesIO() | |
| with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf: | |
| for orig_name, detections in job_state['detections'].items(): | |
| unique_name = job_state['filename_map'][orig_name] | |
| img_path = Path(app.config['UPLOAD_FOLDER']) / session_id / unique_name | |
| img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED) | |
| filtered = [d for d in detections if d['score'] >= threshold] | |
| for det in filtered: | |
| x1, y1, x2, y2 = map(int, det['bbox']) | |
| cv2.rectangle(img, (x1, y1), (x2, y2), (0,0,255), 3) | |
| out_name = f"{Path(orig_name).stem}.png" | |
| _, img_bytes = cv2.imencode('.png', img) | |
| zf.writestr(out_name, img_bytes.tobytes()) | |
| memory_file.seek(0) | |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| return send_file( | |
| memory_file, | |
| mimetype='application/zip', | |
| as_attachment=True, | |
| download_name=f'nemaquant_annotated_{timestamp}.zip' | |
| ) | |
| except Exception as e: | |
| error_message = f"Error exporting images: {str(e)}" | |
| print(error_message) | |
| return jsonify({"error": "Server error", "log": error_message}), 500 | |
| def print_startup_info(): | |
| print("----- NemaQuant Flask App Starting -----") | |
| print(f"Working directory: {os.getcwd()}") | |
| python_version_single_line = sys.version.replace('\n', ' ') | |
| print(f"Python version: {python_version_single_line}") | |
| print(f"Weights file: {WEIGHTS_FILE}") | |
| print(f"Weights file exists: {WEIGHTS_FILE.exists()}") | |
| if WEIGHTS_FILE.exists(): | |
| try: | |
| print(f"Weights file size: {WEIGHTS_FILE.stat().st_size} bytes") | |
| except Exception as e: | |
| print(f"Could not get weights file size: {e}") | |
| is_container = Path('/.dockerenv').exists() or 'DOCKER_HOST' in os.environ | |
| print(f"Running in container: {is_container}") | |
| if is_container: | |
| try: | |
| user_info = f"{os.getuid()}:{os.getgid()}" | |
| print(f"User running process: {user_info}") | |
| except AttributeError: | |
| print("User running process: UID/GID not available on this OS") | |
| for path_str in ["/app/uploads", "/app/results"]: | |
| path_obj = Path(path_str) | |
| if path_obj.exists(): | |
| stat_info = path_obj.stat() | |
| permissions = oct(stat_info.st_mode)[-3:] | |
| owner = f"{stat_info.st_uid}:{stat_info.st_gid}" | |
| print(f"Permissions for {path_str}: {permissions}") | |
| print(f"Owner for {path_str}: {owner}") | |
| else: | |
| print(f"Directory {path_str} does not exist.") | |
| nemaquant_script = APP_ROOT / 'nemaquant.py' | |
| print(f"NemaQuant script exists: {nemaquant_script.exists()}") | |
| if nemaquant_script.exists(): | |
| try: | |
| permissions = oct(nemaquant_script.stat().st_mode)[-3:] | |
| print(f"NemaQuant script permissions: {permissions}") | |
| except Exception as e: | |
| print(f"Could not get NemaQuant script details: {e}") | |
| def ensure_session_id(): | |
| if 'id' not in session: | |
| session['id'] = str(uuid.uuid4()) | |
| # Explicit endpoint for safe session cleanup | |
| def cleanup_endpoint(): | |
| if 'id' in session: | |
| cleanup_session(session['id']) | |
| session.clear() | |
| return jsonify({'status': 'cleaned up'}) | |
| return jsonify({'error': 'No session to clean up'}), 400 | |
| if __name__ == '__main__': | |
| print_startup_info() | |
| app.run(host='0.0.0.0', port=7860, debug=True) |