File size: 3,109 Bytes
d2f661a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gc
import gzip
import os
import pickle
from fire import Fire
import numpy as np
from omegaconf import OmegaConf
from ldcast.features import batch, patches, split, transform
from ldcast.models.autoenc import encoder, training
#file_dir = os.path.dirname(os.path.abspath(__file__))
file_dir = os.path.dirname("/data/data_WF/ldcast_precipitation/ldcast/")
def setup_data(
var="RZC",
batch_size=64,
sampler_file=None,
num_timesteps=8,
chunks_file="./preprocess_data/split_chunks.pkl.gz"
):
variables = {
var: {
"sources": [var],
"timesteps": np.arange(num_timesteps),
}
}
predictors = [var] # autoencoder: predictors == targets
target = var
raw_vars = [var]
raw = {
var: patches.load_all_patches(
#os.path.join(file_dir, f"../data/{var}/"), var
os.path.join(file_dir, f"./preprocess_data/{var}/"), var
)
for var in raw_vars
}
# Load pregenerated train/valid/test split data.
# These can be generated with features.split.get_chunks()
with gzip.open(os.path.join(file_dir, chunks_file), 'rb') as f:
chunks = pickle.load(f)
(raw, _) = split.train_valid_test_split(raw, var, chunks=chunks)
variables[var]["transform"] = transform.default_rainrate_transform(
raw["train"][var]["scale"]
)
if sampler_file is None:
sampler_file = {
"train": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_autoenc_train.pkl",
"valid": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_autoenc_valid.pkl",
"test": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_autoenc_test.pkl",
}
bins = np.exp(np.linspace(np.log(0.2), np.log(50), 10))
datamodule = split.DataModule(
variables, raw, predictors, target, var,
sampling_bins=bins, batch_size=batch_size,
sampler_file=sampler_file,
valid_seed=1234, test_seed=2345
)
gc.collect()
return datamodule
def setup_model(
model_dir=None,
):
enc = encoder.SimpleConvEncoder()
dec = encoder.SimpleConvDecoder()
(autoencoder, trainer) = training.setup_autoenc_training(
encoder=enc,
decoder=dec,
model_dir=model_dir,
)
gc.collect()
return (autoencoder, trainer)
def train(
var="RZC",
batch_size=64,
sampler_file=None,
num_timesteps=8,
chunks_file="./preprocess_data/split_chunks.pkl.gz",
model_dir=None,
ckpt_path=None
):
print("Loading data...")
datamodule = setup_data(
var=var, batch_size=batch_size, sampler_file=sampler_file,
num_timesteps=num_timesteps, chunks_file=chunks_file
)
print("Setting up model...")
(model, trainer) = setup_model(model_dir=model_dir)
print("Starting training...")
trainer.fit(model, datamodule=datamodule)
def main(config=None, **kwargs):
config = OmegaConf.load(config) if (config is not None) else {}
config.update(kwargs)
train(**config)
if __name__ == "__main__":
Fire(main)
|