from datetime import datetime, timedelta import glob import os from fire import Fire import h5py from matplotlib import pyplot as plt import numpy as np import pathlib import cartopy import wradlib as wrl import dask.array as da import pandas as pd import xarray as xr import re from datetime import datetime, timedelta from ldcast import forecast from ldcast.visualization import plots import torch MAP_PROJECTION = cartopy.crs.PlateCarree() def get_date_time(name): year=int(name[0:4]) month=int(name[4:6]) day=int(name[6:8]) hour=int(name[8:10]) minutes=int(name[10:12]) return datetime(year,month,day,hour,minutes) def demo( ldm_weights_fn="/data/data_WF/ablation/ablation_time/genforecast-radaronly-256x256-20step.pt", autoenc_weights_fn="/home/mmhk20/weather_forecast/ldcast/models/autoenc/autoenc-32-0.01.pt", num_diffusion_iters=50, out_dir="/data/AI102024/test/pred_rad", data_dir="/data/AI102024/test/rad", t0=datetime(2020,12,31,21,40), interval=timedelta(minutes=10), past_timesteps=4, crop_box=((0,640), (0,640)), draw_border=False, ensemble_members=1, ): filtered_files = [] for filename in os.listdir(data_dir): if filename.endswith("00.npz"): filtered_files.append(filename) if filename.endswith("03.npz"): filtered_files.append(filename) sorted_files = sorted(filtered_files) fc1 = forecast.Forecast( ldm_weights_fn=ldm_weights_fn, autoenc_weights_fn=autoenc_weights_fn, future_timesteps=20, gpu=3 ) index = 635 print(len(sorted_files)) for i in range(index, len(sorted_files)): print(i) print(sorted_files[i]) name_datetime = get_date_time(sorted_files[i]) temp_radar = [] for i in range(0,4): temp_time = name_datetime + timedelta(minutes = i*10 - 30) temp_path = os.path.join(data_dir,temp_time.strftime('%Y%m%d%H%M')+'.npz') if(os.path.exists(temp_path)): temp_radar.append(temp_path) if(len(temp_radar) == 4): with torch.no_grad(): R_past = torch.stack([torch.tensor(np.load(temp_radar[i])['precipitation']) for i in range(0, len(temp_radar))],dim=0) R_past.to('cuda:3') R_pred = fc1(R_past,num_diffusion_iters=num_diffusion_iters) npz_data = {'precipitation':R_pred[-3]} pred_time = name_datetime + timedelta(minutes = 180) path_save = os.path.join(out_dir,pred_time.strftime('%Y%m%d%H%M')+'.npz') print(path_save) np.savez(path_save,**npz_data) if __name__ == "__main__": Fire(demo) #31779