import json
import cv2
import torch
import torchvision.transforms as transforms
import torch.multiprocessing as mp
import numpy as np
import glob
from tqdm import tqdm
from multiprocessing import Process, Manager
import argparse
import matplotlib
import sys, os
import time
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.utils import save_json_entry, convert_json_line_to_general, find_optimal_thread_count, convert_to_mp4, load_json_any, save_json_any
from Depth_Anything_V2.depth_anything_v2.dpt import DepthAnythingV2

def parse_config():
    """
    Returns:
        dict: A dictionary containing the following keys:
            - input_video_root (str)
            - input_json_path (str)
            - output_video_root (str)
            - output_json_path (str)
            - max_threads (int)
            - thread_threshold (int)
    """
    parser = argparse.ArgumentParser(
        description="Multi-process Canny video processing script"
    )
    parser.add_argument(
        "--input_video_root",
        type=str,
        default="HDV_dataset/HDV_original",
        help="Root directory of input videos (default: HDV_dataset/HDV_original)",
    )
    parser.add_argument(
        "--input_json_path",
        type=str,
        default="HDV_dataset/info_input.json",
        help="Path to input JSON file (default: HDV_dataset/info_input.json)",
    )
    parser.add_argument(
        "--output_video_root",
        type=str,
        default="HDV_dataset/HDV_depth",
        help="Root directory for output videos (default: HDV_dataset/HDV_depth)",
    )
    parser.add_argument(
        "--output_json_path",
        type=str,
        default="HDV_dataset/info_output.json",
        help="Path to output JSON file (default: HDV_dataset/info_output.json)",
    )
    parser.add_argument(
        "--max_threads",
        type=int,
        default=8,
        help="Maximum number of processes (default: 8)",
    )
    parser.add_argument(
        "--thread_threshold",
        type=int,
        default=8,
        help="Maximum number of single process (default: 8)",
    )

    args = parser.parse_args()

    config = {
        "input_video_root": args.input_video_root,
        "input_json_path": args.input_json_path,
        "output_video_root": args.output_video_root,
        "output_json_path": args.output_json_path,
        "max_threads": args.max_threads,
        "thread_threshold": args.thread_threshold
    }
    return config


def depth_process_video_rgb(video_path, outdir, device, input_size=518, encoder='vits', pred_only=True, grayscale=False):
    """
    Process a single video and generate depth visualization.

    Parameters:
    - video_path (str): Path to the input video file.
    - input_size (int): Size of the input image for depth inference.
    - outdir (str): Directory to save the output video.
    - encoder (str): Type of encoder to use ('vits', 'vitb', 'vitl', 'vitg').
    - pred_only (bool): If True, only display the prediction without raw frames.
    - grayscale (bool): If True, use grayscale for depth visualization.
    """
    # device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"

    model_configs = {
        'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
        'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
        'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
        'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
    }

    depth_anything = DepthAnythingV2(**model_configs[encoder])
    depth_anything.load_state_dict(torch.load(f'video2depth/Depth_Anything_V2/checkpoints/depth_anything_v2_{encoder}.pth', map_location='cpu'))
    depth_anything = depth_anything.to(device).eval()

    filenames = [video_path] if os.path.isfile(video_path) else glob.glob(os.path.join(video_path, '**/*'), recursive=True)
    os.makedirs(os.path.dirname(outdir), exist_ok=True)

    margin_width = 50
    cmap = matplotlib.colormaps.get_cmap('Spectral_r')

    for k, filename in enumerate(filenames):
        # print(f'Progress {k+1}/{len(filenames)}: {filename}')

        raw_video = cv2.VideoCapture(filename)
        frame_width, frame_height = int(raw_video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(raw_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
        frame_rate = int(raw_video.get(cv2.CAP_PROP_FPS))

        output_width = frame_width if pred_only else frame_width * 2 + margin_width

        out = cv2.VideoWriter(outdir, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (output_width, frame_height))

        frame_count = int(raw_video.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_index = 0

        while raw_video.isOpened():
            ret, raw_frame = raw_video.read()
            if not ret:
                break

            depth = depth_anything.infer_image(raw_frame, device, input_size)
            depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
            depth = depth.astype(np.uint8)

            if grayscale:
                depth = np.repeat(depth[..., np.newaxis], 3, axis=-1)
            else:
                depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)

            if pred_only:
                out.write(depth)
            else:
                split_region = np.ones((frame_height, margin_width, 3), dtype=np.uint8) * 255
                combined_frame = cv2.hconcat([raw_frame, split_region, depth])
                out.write(combined_frame)

            frame_index += 1

        raw_video.release()
        out.release()
        return True



def process_videos_in_batch(batch, input_video_root, output_video_root, output_data, gpu_id):
    """
    处理一批视频，并显式指定 GPU。
    """
    device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"
    for item in tqdm(batch, desc=f"video to depth on GPU {gpu_id}"):
        rel_path = item.get("video_clip_path", "")

        input_path = os.path.join(input_video_root, rel_path)
        base_name = os.path.splitext(os.path.basename(rel_path))[0]
        out_name = f"{base_name}_depth.mp4"
        output_path = os.path.join(output_video_root, out_name)

        try:
            success = depth_process_video_rgb(input_path, output_path, device)
            if success:
                item["video_depth_path"] = output_path
                save_json_entry(item, output_data)
            else:
                print(f"Failed to process video: {input_path}")
        except Exception as e:
            print(f"Error processing video {input_path} on GPU {gpu_id}: {e}")



def main():
    # 1. Parse configuration
    config = parse_config()
    input_video_root = config["input_video_root"]
    input_json_path = config["input_json_path"]
    output_video_root = config["output_video_root"]
    output_json_path = config["output_json_path"]
    max_threads = config["max_threads"]
    thread_threshold = config["thread_threshold"]

    # 检测可用 GPU 数量
    gpu_count = torch.cuda.device_count()
    if gpu_count == 0:
        raise RuntimeError("No GPUs available!")
    print(f"Detected {gpu_count} GPUs.")

    # 2. Load input JSON, output JSON, and detect each one's format
    input_data, input_is_line_based = load_json_any(input_json_path)
    output_data, output_is_line_based = load_json_any(output_json_path)

    processed_ids = {item["id"] for item in output_data if "id" in item}
    to_process_list = [item for item in input_data if item.get("id") not in processed_ids]

    total_input = len(input_data)
    total_output = len(output_data)
    need_to_process = len(to_process_list)

    print(f"Total samples in input JSON: {total_input}")
    print(f"Total samples in output JSON: {total_output}")
    print(f"Number of samples to process: {need_to_process}")

    if need_to_process == 0:
        print("No new video to process. Exiting.")
        return

    threads_to_use = find_optimal_thread_count(need_to_process, max_threads, thread_threshold)
    print(f"Number of processes to use: {threads_to_use}")

    # os.makedirs(output_video_root, exist_ok=True)

    # 3. Split the tasks and assign to GPUs
    batch_size = (need_to_process + threads_to_use - 1) // threads_to_use
    batches = [to_process_list[i:i + batch_size] for i in range(0, need_to_process, batch_size)]

    processes = []
    start_time = time.time()

    # 分配任务到每张 GPU
    for i, batch in enumerate(batches):
        gpu_id = i % gpu_count  # 循环分配 GPU
        p = Process(target=process_videos_in_batch, args=(batch, input_video_root, output_video_root, output_json_path, gpu_id))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

    # Save the updated output data
    convert_json_line_to_general(output_json_path)
    end_time = time.time()
    print("video to depth time:", end_time - start_time)



if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)

    # 如果需要，可以取消下面的硬编码参数，使用命令行传参
    sys.argv = ['/home/yexin/data_processing/HDV_Data_Processing/video2depth/v2depth.py',
                '--input_video_root', '/home/yexin/data_processing/HDV_dataset/HDV_clip',
                '--input_json_path', '/home/yexin/data_processing/HDV_dataset/data_json/test.json',
                '--output_video_root', '/home/yexin/data_processing/HDV_dataset/HDV_depth',
                '--output_json_path', '/home/yexin/data_processing/HDV_dataset/data_json/test_depth.json',
                '--max_threads', '4',
                '--thread_threshold', '2']
    # main()

    main()

'''
示例命令行调用：

python /home/yexin/data_processing/HDV_Data_Processing/video2canny/v2canny.py \
    --input_video_root /home/yexin/data_processing/HDV_dataset/HDV_clip \
    --input_json_path /home/yexin/data_processing/HDV_dataset/data_json/test.json \
    --output_video_root /home/yexin/data_processing/HDV_dataset/HDV_depth \
    --output_json_path /home/yexin/data_processing/HDV_dataset/data_json/test_depth.json \
    --max_threads 4 \
    --thread_threshold 2
'''
