import json
import cv2
import sys, os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import torch
import torchvision.transforms as transforms
import torch.multiprocessing as mp
import numpy as np
import glob
from tqdm import tqdm
from multiprocessing import Process, Manager
import argparse
import matplotlib
import time
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.utils import save_json_entry, convert_json_line_to_general, find_optimal_thread_count, convert_to_mp4, load_json_any, save_json_any
from argparse import Namespace
from RAFT.core.raft import RAFT
from RAFT.core.utils.utils import InputPadder
from RAFT.core.utils import flow_viz

def parse_config():
    """
    Returns:
        dict: A dictionary containing the following keys:
            - input_video_root (str)
            - input_json_path (str)
            - output_video_root (str)
            - output_json_path (str)
            - max_threads (int)
            - thread_threshold (int)
            - gpu_max_threads (int)
    """
    parser = argparse.ArgumentParser(
        description="Multi-process Canny video processing script"
    )
    parser.add_argument(
        "--input_video_root",
        type=str,
        default="HDV_dataset/HDV_original",
        help="Root directory of input videos (default: HDV_dataset/HDV_original)",
    )
    parser.add_argument(
        "--input_json_path",
        type=str,
        default="HDV_dataset/info_input.json",
        help="Path to input JSON file (default: HDV_dataset/info_input.json)",
    )
    parser.add_argument(
        "--output_video_root",
        type=str,
        default="HDV_dataset/HDV_optical_flow",
        help="Root directory for output videos (default: HDV_dataset/HDV_optical_flow)",
    )
    parser.add_argument(
        "--output_json_path",
        type=str,
        default="HDV_dataset/info_output.json",
        help="Path to output JSON file (default: HDV_dataset/info_output.json)",
    )
    parser.add_argument(
        "--max_threads",
        type=int,
        default=8,
        help="Maximum number of processes (default: 8)",
    )
    parser.add_argument(
        "--thread_threshold",
        type=int,
        default=8,
        help="Maximum number of single process (default: 8)",
    )
    parser.add_argument(
        "--gpu_max_threads",
        type=int,
        default=4,
        help="Maximum number of threads per GPU (default: 4)",
    )

    args = parser.parse_args()

    config = {
        "input_video_root": args.input_video_root,
        "input_json_path": args.input_json_path,
        "output_video_root": args.output_video_root,
        "output_json_path": args.output_json_path,
        "max_threads": args.max_threads,
        "thread_threshold": args.thread_threshold,
        "gpu_max_threads": args.gpu_max_threads
    }
    return config

def optical_flow_process_video(video_path, outdir, device, model_path='video2optical_flow/RAFT/models/raft-small.pth', use_small_model=True):
    """
    Process a single video and generate optical flow visualization.

    Parameters:
    - video_path (str): Path to the input video file.
    - outdir (str): Directory to save the output video.
    - device (str): Device to run the model on ('cuda' or 'cpu').
    - model_path (str): Path to the RAFT model checkpoint.
    - use_small_model (bool): Whether to use the small model.
    """
    def load_image_from_frame(frame):
        img = torch.from_numpy(frame).permute(2, 0, 1).float()
        return img[None].to(device)

    # Load RAFT model
    args = Namespace(
        small=use_small_model,
        mixed_precision=False,
        alternate_corr=False
    )

    # Load RAFT model
    model = torch.nn.DataParallel(RAFT(args), device_ids=[int(device.split(':')[1])])
    model.load_state_dict(torch.load(model_path))
    model = model.to(device).eval()

    # Load video frames
    raw_video = cv2.VideoCapture(video_path)
    if not raw_video.isOpened():
        raise ValueError(f"Error: Unable to open video file {video_path}")

    # frame_width, frame_height = int(raw_video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(raw_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_rate = int(raw_video.get(cv2.CAP_PROP_FPS))
    frame_count = int(raw_video.get(cv2.CAP_PROP_FRAME_COUNT))

    ret, prev_frame = raw_video.read()
    prev_frame_tensor = load_image_from_frame(prev_frame)
    padder = InputPadder(prev_frame_tensor.shape)
    sample_pad_frame_tensor = padder.pad(prev_frame_tensor)
    frame_width = int(sample_pad_frame_tensor[0].shape[-1])
    frame_height = int(sample_pad_frame_tensor[0].shape[-2])
    os.makedirs(os.path.dirname(outdir), exist_ok=True)
    out = cv2.VideoWriter(outdir, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (frame_width, frame_height))


    if not ret:
        raw_video.release()
        raise ValueError("Error: Unable to read the first frame of the video.")

    prev_frame_tensor = load_image_from_frame(prev_frame)
    padder = InputPadder(prev_frame_tensor.shape)
    prev_frame_tensor = padder.pad(prev_frame_tensor)[0]

    with torch.no_grad():
        # for _ in tqdm(range(1, frame_count), desc="Processing frames", unit="frame"):
        while raw_video.isOpened():
            ret, curr_frame = raw_video.read()
            if not ret:
                break

            curr_frame_tensor = load_image_from_frame(curr_frame)
            curr_frame_tensor = padder.pad(curr_frame_tensor)[0]

            # Compute optical flow
            flow_low, flow_up = model(prev_frame_tensor, curr_frame_tensor, iters=20, test_mode=True)

            # Visualize optical flow
            flo = flow_up[0].permute(1, 2, 0).cpu().numpy()
            flo = flow_viz.flow_to_image(flo)

            # Convert to BGR for OpenCV
            flo_bgr = flo[:, :, [2, 1, 0]].astype(np.uint8)
            out.write(flo_bgr)

            # Update previous frame
            prev_frame_tensor = curr_frame_tensor

    raw_video.release()
    out.release()
    return True
    # print(f"Optical flow visualization saved to {output_path}")


from concurrent.futures import ThreadPoolExecutor

def process_videos_in_batch(batch, input_video_root, output_video_root, output_data, gpu_id, gpu_max_threads):
    """
    处理一批视频，并显式指定 GPU，支持同一 GPU 上多个视频并行处理。
    """
    device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"

    def process_single_video(item):
        """
        处理单个视频的函数。
        """
        rel_path = item.get("video_clip_path", "")
        input_path = os.path.join(input_video_root, rel_path)
        base_name = os.path.splitext(os.path.basename(rel_path))[0]
        out_name = f"{base_name}_optical_flow.mp4"
        output_path = os.path.join(output_video_root, out_name)

        try:
            success = optical_flow_process_video(input_path, output_path, device)
            if success:
                item["video_optical_flow_path"] = output_path
                save_json_entry(item, output_data)
            else:
                print(f"Failed to process video: {input_path}")
        except Exception as e:
            print(f"Error processing video {input_path} on GPU {gpu_id}: {e}")
        

    # 使用线程池并行处理同一批次中的多个视频
    max_workers = min(len(batch), gpu_max_threads)
    print(f'Allocate {max_workers} threads for GPU {gpu_id} to process {len(batch)} videos')
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        list(tqdm(executor.map(process_single_video, batch), desc=f"Batch on GPU {gpu_id}"))


def main():
    # 1. Parse configuration
    config = parse_config()
    input_video_root = config["input_video_root"]
    input_json_path = config["input_json_path"]
    output_video_root = config["output_video_root"]
    output_json_path = config["output_json_path"]
    max_threads = config["max_threads"]
    thread_threshold = config["thread_threshold"]
    gpu_max_threads = config["gpu_max_threads"]

    # 检测可用 GPU 数量
    gpu_count = torch.cuda.device_count()
    if gpu_count == 0:
        raise RuntimeError("No GPUs available!")
    print(f"Detected {gpu_count} GPUs.")

    # 2. Load input and output JSON
    input_data, input_is_line_based = load_json_any(input_json_path)
    output_data, output_is_line_based = load_json_any(output_json_path)

    processed_ids = {item["id"] for item in output_data if "id" in item}
    to_process_list = [item for item in input_data if item.get("id") not in processed_ids]

    total_input = len(input_data)
    total_output = len(output_data)
    need_to_process = len(to_process_list)

    print(f"Total samples in input JSON: {total_input}")
    print(f"Total samples in output JSON: {total_output}")
    print(f"Number of samples to process: {need_to_process}")

    if need_to_process == 0:
        print("No new video to process. Exiting.")
        return

    threads_to_use = find_optimal_thread_count(need_to_process, max_threads, thread_threshold)
    print(f"Number of processes to use: {threads_to_use}")

    batch_size = (need_to_process + gpu_count - 1) // gpu_count
    batches = [to_process_list[i:i + batch_size] for i in range(0, need_to_process, batch_size)]

    processes = []
    start_time = time.time()

    # 3. 分配任务到不同 GPU
    for gpu_id, batch in enumerate(batches):
        p = torch.multiprocessing.Process(
            target=process_videos_in_batch,
            args=(batch, input_video_root, output_video_root, output_json_path, gpu_id, gpu_max_threads)
        )
        p.start()
        processes.append(p)

    # 等待所有进程完成
    for p in processes:
        p.join()

    end_time = time.time()
    print(f"Processing completed in {end_time - start_time:.2f} seconds.")



if __name__ == "__main__":
    
    mp.set_start_method('spawn', force=True)

    # 如果需要，可以取消下面的硬编码参数，使用命令行传参
    sys.argv = ['/home/yexin/data_processing/HDV_Data_Processing/video2optical_flow/v2optical_flow.py',
                '--input_video_root', '/home/yexin/data_processing/HDV_dataset/HDV_clip',
                '--input_json_path', '/home/yexin/data_processing/HDV_dataset/data_json/test.json',
                '--output_video_root', '/home/yexin/data_processing/HDV_dataset/HDV_optical_flow',
                '--output_json_path', '/home/yexin/data_processing/HDV_dataset/data_json/test_optical_flow.json',
                '--max_threads', '1',
                '--thread_threshold', '1',
                '--gpu_max_threads', '2']
    # main()

    main()

'''
示例命令行调用：

python /home/yexin/data_processing/HDV_Data_Processing/video2canny/v2canny.py \
    --input_video_root /home/yexin/data_processing/HDV_dataset/HDV_clip \
    --input_json_path /home/yexin/data_processing/HDV_dataset/data_json/test.json \
    --output_video_root /home/yexin/data_processing/HDV_dataset/HDV_optical_flow \
    --output_json_path /home/yexin/data_processing/HDV_dataset/data_json/test_optical_flow.json \
    --max_threads 4 \
    --thread_threshold 2
'''
