import sys
sys.path.append('core')

import argparse
import os
os.environ['visible_devices'] = '1'
import cv2
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder

DEVICE = 'cuda'

def load_image_from_frame(frame):
    img = torch.from_numpy(frame).permute(2, 0, 1).float()
    return img[None].to(DEVICE)

def viz_and_write(writer, img, flo):
    img = img[0].permute(1, 2, 0).cpu().numpy()
    flo = flo[0].permute(1, 2, 0).cpu().numpy()

    # Map flow to RGB image
    flo = flow_viz.flow_to_image(flo)
    # img_flo = np.concatenate([img, flo], axis=0)
    img_flo = flo

    # Convert to BGR for OpenCV
    img_flo_bgr = img_flo[:, :, [2, 1, 0]].astype(np.uint8)

    # Write the concatenated image to the video writer
    writer.write(img_flo_bgr)

def process_video(args):
    # Load RAFT model
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to(DEVICE)
    model.eval()

    # Open the input video
    cap = cv2.VideoCapture(args.path)
    if not cap.isOpened():
        print(f"Error: Unable to open video file {args.path}")
        return

    # Get video properties
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # Define the codec and create a VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out_path = args.output_path
    out_name = os.path.join(out_path, 'output_optical_flow.mp4')
    os.makedirs(out_path, exist_ok=True)
    
    # Read the first frame
    ret, prev_frame = cap.read()
    if not ret:
        print("Error: Unable to read the first frame from the video.")
        cap.release()
        return

    prev_frame_tensor = load_image_from_frame(prev_frame)
    padder = InputPadder(prev_frame_tensor.shape)
    sample_pad_frame_tensor = padder.pad(prev_frame_tensor)
    width = int(sample_pad_frame_tensor[0].shape[-1])
    height = int(sample_pad_frame_tensor[0].shape[-2])
    out = cv2.VideoWriter(out_name, fourcc, fps, (width, height))

    with torch.no_grad():
        # Create tqdm progress bar
        with tqdm(total=total_frames - 1, desc="Processing frames", unit="frame") as pbar:
            while True:
                # Read the next frame
                ret, curr_frame = cap.read()
                if not ret:
                    break

                curr_frame_tensor = load_image_from_frame(curr_frame)
                curr_frame_tensor_m = curr_frame_tensor.clone()

                # Pad frames
                prev_frame_tensor, curr_frame_tensor = padder.pad(prev_frame_tensor, curr_frame_tensor)

                # Compute optical flow
                flow_low, flow_up = model(prev_frame_tensor, curr_frame_tensor, iters=20, test_mode=True)

                # Visualize and write to video
                viz_and_write(out, prev_frame_tensor, flow_up)

                # Update the previous frame
                prev_frame_tensor = curr_frame_tensor_m

                # Update tqdm progress bar
                pbar.update(1)

    # Release video objects
    cap.release()
    out.release()
    print("Optical flow video has been saved as {}".format(out_name))

if __name__ == '__main__':
    sys.argv = ['demo.py', '--model', 'models/raft-small.pth', '--path', 'dataset/videos/test_video_6s/201368-915360232_small_clip_14.mp4',
                '--output_path', 'outputs', '--small']
    # sys.argv = ['demo.py', '--model', 'models/raft-sintel.pth', '--path', 'dataset/videos/test_video_6s/201368-915360232_small_clip_14.mp4',
    #             '--output_path', 'outputs']
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', help="restore checkpoint")
    parser.add_argument('--path', help="video file for processing")
    parser.add_argument('--output_path', help="the path to save the output video")
    parser.add_argument('--small', action='store_true', help='use small model')
    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
    parser.add_argument('--alternate_corr', action='store_true', help='use efficient correlation implementation')
    args = parser.parse_args()

    process_video(args)
