File size: 2,716 Bytes
d2f661a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from datetime import datetime, timedelta
import glob
import os
from fire import Fire
import h5py
from matplotlib import pyplot as plt
import numpy as np
import pathlib
import cartopy
import wradlib as wrl
import dask.array as da
import pandas as pd
import xarray as xr
import re
from datetime import datetime, timedelta
from ldcast import forecast
from ldcast.visualization import plots
import torch
MAP_PROJECTION = cartopy.crs.PlateCarree()
def get_date_time(name):
year=int(name[0:4])
month=int(name[4:6])
day=int(name[6:8])
hour=int(name[8:10])
minutes=int(name[10:12])
return datetime(year,month,day,hour,minutes)
def demo(
ldm_weights_fn="/data/data_WF/ablation/ablation_time/genforecast-radaronly-256x256-20step.pt",
autoenc_weights_fn="/home/mmhk20/weather_forecast/ldcast/models/autoenc/autoenc-32-0.01.pt",
num_diffusion_iters=50,
out_dir="/data/data_WF/ablation/ablation_time/train/predict_radar",
data_dir="/data/data_WF/ablation/ablation_time/train/GT_radar",
t0=datetime(2020,12,31,21,40),
interval=timedelta(minutes=10),
past_timesteps=4,
crop_box=((0,640), (0,640)),
draw_border=False,
ensemble_members=6,
):
filtered_files = []
for filename in os.listdir(data_dir):
if filename.endswith("00.npy") and (not filename.startswith("2020")):
filtered_files.append(filename)
if filename.endswith("03.npy") and (not filename.startswith("2020")):
filtered_files.append(filename)
sorted_files = sorted(filtered_files)
fc1 = forecast.Forecast(
ldm_weights_fn=ldm_weights_fn,
autoenc_weights_fn=autoenc_weights_fn,
future_timesteps=20,
gpu=0
)
index = 27600
for i in range(index, 27750):
print(i)
print(sorted_files[i])
name_datetime = get_date_time(sorted_files[i])
temp_radar = []
for i in range(0,4):
temp_time = name_datetime + timedelta(minutes = i*10 - 30)
temp_path = os.path.join(data_dir,temp_time.strftime('%Y%m%d%H%M')+'.npy')
if(os.path.exists(temp_path)):
temp_radar.append(temp_path)
if(len(temp_radar) == 4):
with torch.no_grad():
R_past = torch.stack([torch.tensor(np.load(temp_radar[i])) for i in range(0, len(temp_radar))],dim=0)
R_past.to('cuda:0')
R_pred = fc1(R_past,num_diffusion_iters=num_diffusion_iters)
pred_time = name_datetime + timedelta(minutes = 180)
path_save = os.path.join(out_dir,pred_time.strftime('%Y%m%d%H%M')+'.npy')
np.save(path_save,R_pred[-3])
if __name__ == "__main__":
Fire(demo)
#18080
|