|
|
from datetime import datetime, timedelta |
|
|
import glob |
|
|
import os |
|
|
|
|
|
from fire import Fire |
|
|
import h5py |
|
|
from matplotlib import pyplot as plt |
|
|
import numpy as np |
|
|
import pathlib |
|
|
import cartopy |
|
|
import wradlib as wrl |
|
|
import dask.array as da |
|
|
import pandas as pd |
|
|
import xarray as xr |
|
|
import re |
|
|
from datetime import datetime, timedelta |
|
|
from ldcast import forecast |
|
|
from ldcast.visualization import plots |
|
|
import torch |
|
|
MAP_PROJECTION = cartopy.crs.PlateCarree() |
|
|
|
|
|
|
|
|
def get_date_time(name): |
|
|
year=int(name[0:4]) |
|
|
month=int(name[4:6]) |
|
|
day=int(name[6:8]) |
|
|
hour=int(name[8:10]) |
|
|
minutes=int(name[10:12]) |
|
|
return datetime(year,month,day,hour,minutes) |
|
|
|
|
|
def demo( |
|
|
ldm_weights_fn="/data/data_WF/ablation/ablation_time/genforecast-radaronly-256x256-20step.pt", |
|
|
autoenc_weights_fn="/home/mmhk20/weather_forecast/ldcast/models/autoenc/autoenc-32-0.01.pt", |
|
|
num_diffusion_iters=50, |
|
|
out_dir="/data/AI102024/test/pred_rad", |
|
|
data_dir="/data/AI102024/test/rad", |
|
|
t0=datetime(2020,12,31,21,40), |
|
|
interval=timedelta(minutes=10), |
|
|
past_timesteps=4, |
|
|
crop_box=((0,640), (0,640)), |
|
|
draw_border=False, |
|
|
ensemble_members=1, |
|
|
): |
|
|
filtered_files = [] |
|
|
for filename in os.listdir(data_dir): |
|
|
if filename.endswith("00.npz"): |
|
|
filtered_files.append(filename) |
|
|
if filename.endswith("03.npz"): |
|
|
filtered_files.append(filename) |
|
|
sorted_files = sorted(filtered_files) |
|
|
fc1 = forecast.Forecast( |
|
|
ldm_weights_fn=ldm_weights_fn, |
|
|
autoenc_weights_fn=autoenc_weights_fn, |
|
|
future_timesteps=20, |
|
|
gpu=3 |
|
|
) |
|
|
index = 635 |
|
|
print(len(sorted_files)) |
|
|
for i in range(index, len(sorted_files)): |
|
|
print(i) |
|
|
print(sorted_files[i]) |
|
|
name_datetime = get_date_time(sorted_files[i]) |
|
|
temp_radar = [] |
|
|
for i in range(0,4): |
|
|
temp_time = name_datetime + timedelta(minutes = i*10 - 30) |
|
|
temp_path = os.path.join(data_dir,temp_time.strftime('%Y%m%d%H%M')+'.npz') |
|
|
if(os.path.exists(temp_path)): |
|
|
temp_radar.append(temp_path) |
|
|
if(len(temp_radar) == 4): |
|
|
with torch.no_grad(): |
|
|
R_past = torch.stack([torch.tensor(np.load(temp_radar[i])['precipitation']) for i in range(0, len(temp_radar))],dim=0) |
|
|
R_past.to('cuda:3') |
|
|
R_pred = fc1(R_past,num_diffusion_iters=num_diffusion_iters) |
|
|
npz_data = {'precipitation':R_pred[-3]} |
|
|
pred_time = name_datetime + timedelta(minutes = 180) |
|
|
path_save = os.path.join(out_dir,pred_time.strftime('%Y%m%d%H%M')+'.npz') |
|
|
print(path_save) |
|
|
np.savez(path_save,**npz_data) |
|
|
if __name__ == "__main__": |
|
|
Fire(demo) |
|
|
|
|
|
|