ldcast_code / ldcast /forecast.py
weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
import contextlib
import gc
import numpy as np
import torch
import torch.multiprocessing as mp
from .features.transform import Antialiasing
from .models.autoenc import autoenc, encoder
from .models.genforecast import analysis, unet
from .models.diffusion import diffusion, plms
class Forecast:
def __init__(
self,
*,
ldm_weights_fn,
autoenc_weights_fn,
gpu='auto',
past_timesteps=4,
future_timesteps=20,
autoenc_time_ratio=4,
autoenc_hidden_dim=32,
verbose=True,
R_min_value=0.1,
R_zero_value=0.02,
R_min_output=0.1,
R_max_output=118.428,
log_R_mean=-0.051,
log_R_std=0.528,
):
self.ldm_weights_fn = ldm_weights_fn
self.autoenc_weights_fn = autoenc_weights_fn
self.verbose = verbose
self.R_min_value = R_min_value
self.R_zero_value = R_zero_value
self.R_min_output = R_min_output
self.R_max_output = R_max_output
self.log_R_mean = log_R_mean
self.log_R_std = log_R_std
self.past_timesteps = past_timesteps
self.future_timesteps = future_timesteps
self.autoenc_time_ratio = autoenc_time_ratio
self.autoenc_hidden_dim = autoenc_hidden_dim
self.antialiasing = Antialiasing()
# setup LDM
self.ldm = self._init_model()
if gpu is not None:
if gpu == 'auto':
if torch.cuda.device_count() > 0:
self.ldm.to(device="cuda")
else:
self.ldm.to(device=f"cuda:{gpu}")
# setup sampler
self.sampler = plms.PLMSSampler(self.ldm)
print(self.ldm.device)
gc.collect()
def _init_model(self):
# setup autoencoder
enc = encoder.SimpleConvEncoder()
dec = encoder.SimpleConvDecoder()
autoencoder_obs = autoenc.AutoencoderKL(enc, dec)
#print(torch.load(self.autoenc_weights_fn)['state_dict'].keys())
# autoencoder_obs.load_state_dict(torch.load(self.autoenc_weights_fn)['state_dict'])
autoencoder_obs.load_state_dict(torch.load(self.autoenc_weights_fn))
autoencoders = [autoencoder_obs]
input_patches = [self.past_timesteps//self.autoenc_time_ratio]
input_size_ratios = [1]
embed_dim = [128]
analysis_depth = [4]
# setup forecaster
analysis_net = analysis.AFNONowcastNetCascade(
autoencoders,
input_patches=input_patches,
input_size_ratios=input_size_ratios,
train_autoenc=False,
output_patches=self.future_timesteps//self.autoenc_time_ratio,
cascade_depth=3,
embed_dim=embed_dim,
analysis_depth=analysis_depth
)
# setup denoiser
denoiser = unet.UNetModel(in_channels=autoencoder_obs.hidden_width,
model_channels=256, out_channels=autoencoder_obs.hidden_width,
num_res_blocks=2, attention_resolutions=(1,2),
dims=3, channel_mult=(1, 2, 4), num_heads=8,
num_timesteps=self.future_timesteps//self.autoenc_time_ratio,
context_ch=analysis_net.cascade_dims
)
# create LDM
ldm = diffusion.LatentDiffusion(denoiser, autoencoder_obs,
context_encoder=analysis_net)
# ldm.load_state_dict(torch.load(self.ldm_weights_fn)['state_dict'])
ldm.load_state_dict(torch.load(self.ldm_weights_fn))
return ldm
def __call__(
self,
R_past,
num_diffusion_iters=50
):
# preprocess inputs and setup correct input shape
x = self.transform_precip(R_past)
timesteps = self.input_timesteps(x)
future_patches = self.future_timesteps // self.autoenc_time_ratio
gen_shape = (self.autoenc_hidden_dim, future_patches) + \
(x.shape[-2]//4, x.shape[-1]//4)
x = [[x, timesteps]]
# run LDM sampler
with contextlib.redirect_stdout(None):
(s, intermediates) = self.sampler.sample(
num_diffusion_iters,
x[0][0].shape[0],
gen_shape,
x,
progbar=self.verbose
)
# postprocess outputs
y_pred = self.ldm.autoencoder.decode(s)
R_pred = self.inv_transform_precip(y_pred)
return R_pred[0,...]
def transform_precip(self, R):
# x = R.copy()
x = R.clone().detach()
x[~(x >= self.R_min_value)] = self.R_zero_value
x = np.log10(x)
x -= self.log_R_mean
x /= self.log_R_std
x = x.reshape((1,) + x.shape)
x = self.antialiasing(x)
x = x.reshape((1,) + x.shape)
return torch.Tensor(x).to(device=self.ldm.device)
def inv_transform_precip(self, x):
x *= self.log_R_std
x += self.log_R_mean
R = torch.pow(10, x)
if self.R_min_output:
R[R < self.R_min_output] = 0.0
if self.R_max_output is not None:
R[R > self.R_max_output] = self.R_max_output
R = R[:,0,...]
return R.to(device='cpu').numpy()
def input_timesteps(self, x):
batch_size = x.shape[0]
t0 = -x.shape[2]+1
t1 = 1
timesteps = torch.arange(t0, t1,
dtype=x.dtype, device=self.ldm.device)
return timesteps.unsqueeze(0).expand(batch_size,-1)
class ForecastDistributed:
def __init__(
self,
ldm_weights_fn,
autoenc_weights_fn,
past_timesteps=4,
future_timesteps=8,
autoenc_time_ratio=4,
autoenc_hidden_dim=32,
verbose=True,
R_min_value=0.1,
R_zero_value=0.02,
R_min_output=0.1,
R_max_output=118.428,
log_R_mean=-0.051,
log_R_std=0.528,
):
self.verbose = verbose
self.R_min_value = R_min_value
self.R_zero_value = R_zero_value
self.R_min_output = R_min_output
self.R_max_output = R_max_output
self.log_R_mean = log_R_mean
self.log_R_std = log_R_std
self.past_timesteps = past_timesteps
self.future_timesteps = future_timesteps
self.autoenc_time_ratio = autoenc_time_ratio
self.autoenc_hidden_dim = autoenc_hidden_dim
# start worker processes
context = mp.get_context('spawn')
self.input_queue = context.Queue()
self.output_queue = context.Queue()
process_kwargs = {
"past_timesteps": past_timesteps,
"future_timesteps": future_timesteps,
"ldm_weights_fn": ldm_weights_fn,
"autoenc_weights_fn": autoenc_weights_fn,
"autoenc_time_ratio": autoenc_time_ratio,
"autoenc_hidden_dim": autoenc_hidden_dim,
"R_min_value": R_min_value,
"R_zero_value": R_zero_value,
"R_min_output": R_min_output,
"R_max_output": R_max_output,
"log_R_mean": log_R_mean,
"log_R_std": log_R_std,
"verbose": True
}
self.num_procs = max(0, torch.cuda.device_count())
self.compute_procs = mp.spawn(
_compute_process,
args=(self.input_queue, self.output_queue, process_kwargs),
nprocs=self.num_procs,
join=False
)
# wait for worker processes to be ready
for _ in range(self.num_procs):
self.output_queue.get()
gc.collect()
def __call__(
self,
R_past,
ensemble_members=1,
num_diffusion_iters=50
):
# send samples to compute processes
for (i, R_past_sample) in enumerate(R_past):
for j in range(ensemble_members):
self.input_queue.put((R_past_sample, num_diffusion_iters, i, j))
# build output array
pred_shape = (R_past.shape[0], self.future_timesteps) + \
R_past.shape[2:] + (ensemble_members,)
R_pred = np.empty(pred_shape, R_past.dtype)
# gather outputs from processes
predictions_needed = R_past.shape[0] * ensemble_members
for _ in range(predictions_needed):
(R_pred_sample, i, j) = self.output_queue.get()
R_pred[i,...,j] = R_pred_sample
return R_pred
def __del__(self):
for _ in range(self.num_procs):
self.input_queue.put(None)
self.compute_procs.join()
def _compute_process(process_index, input_queue, output_queue, kwargs):
gpu = process_index if (torch.cuda.device_count() > 0) else None
fc = Forecast(gpu=gpu, **kwargs)
output_queue.put("Ready") # signal process ready to accept inputs
while (data := input_queue.get()) is not None:
(R_past, num_diffusion_iters, sample, member) = data
R_pred = fc(R_past, num_diffusion_iters=num_diffusion_iters)
output_queue.put((R_pred, sample, member))