ldcast_code / scripts /genpredict_2.py
weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
from datetime import datetime, timedelta
import glob
import os
from fire import Fire
import h5py
from matplotlib import pyplot as plt
import numpy as np
import pathlib
import cartopy
import wradlib as wrl
import dask.array as da
import pandas as pd
import xarray as xr
import re
from datetime import datetime, timedelta
from ldcast import forecast
from ldcast.visualization import plots
import torch
MAP_PROJECTION = cartopy.crs.PlateCarree()
def get_date_time(name):
year=int(name[0:4])
month=int(name[4:6])
day=int(name[6:8])
hour=int(name[8:10])
minutes=int(name[10:12])
return datetime(year,month,day,hour,minutes)
def demo(
ldm_weights_fn="/data/data_WF/ablation/ablation_time/genforecast-radaronly-256x256-20step.pt",
autoenc_weights_fn="/home/mmhk20/weather_forecast/ldcast/models/autoenc/autoenc-32-0.01.pt",
num_diffusion_iters=50,
out_dir="/data/AI102024/test/pred_rad",
data_dir="/data/AI102024/test/rad",
t0=datetime(2020,12,31,21,40),
interval=timedelta(minutes=10),
past_timesteps=4,
crop_box=((0,640), (0,640)),
draw_border=False,
ensemble_members=1,
):
filtered_files = []
for filename in os.listdir(data_dir):
if filename.endswith("00.npz"):
filtered_files.append(filename)
if filename.endswith("03.npz"):
filtered_files.append(filename)
sorted_files = sorted(filtered_files)
fc1 = forecast.Forecast(
ldm_weights_fn=ldm_weights_fn,
autoenc_weights_fn=autoenc_weights_fn,
future_timesteps=20,
gpu=3
)
index = 635
print(len(sorted_files))
for i in range(index, len(sorted_files)):
print(i)
print(sorted_files[i])
name_datetime = get_date_time(sorted_files[i])
temp_radar = []
for i in range(0,4):
temp_time = name_datetime + timedelta(minutes = i*10 - 30)
temp_path = os.path.join(data_dir,temp_time.strftime('%Y%m%d%H%M')+'.npz')
if(os.path.exists(temp_path)):
temp_radar.append(temp_path)
if(len(temp_radar) == 4):
with torch.no_grad():
R_past = torch.stack([torch.tensor(np.load(temp_radar[i])['precipitation']) for i in range(0, len(temp_radar))],dim=0)
R_past.to('cuda:3')
R_pred = fc1(R_past,num_diffusion_iters=num_diffusion_iters)
npz_data = {'precipitation':R_pred[-3]}
pred_time = name_datetime + timedelta(minutes = 180)
path_save = os.path.join(out_dir,pred_time.strftime('%Y%m%d%H%M')+'.npz')
print(path_save)
np.savez(path_save,**npz_data)
if __name__ == "__main__":
Fire(demo)
#31779