File size: 3,109 Bytes
d2f661a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gc
import gzip
import os
import pickle

from fire import Fire
import numpy as np
from omegaconf import OmegaConf

from ldcast.features import batch, patches, split, transform
from ldcast.models.autoenc import encoder, training

#file_dir = os.path.dirname(os.path.abspath(__file__))
file_dir = os.path.dirname("/data/data_WF/ldcast_precipitation/ldcast/")

def setup_data(
    var="RZC", 
    batch_size=64,
    sampler_file=None,
    num_timesteps=8,
    chunks_file="./preprocess_data/split_chunks.pkl.gz"
):
    variables = {
        var: {
            "sources": [var],
            "timesteps": np.arange(num_timesteps),
        }
    }
    predictors = [var] # autoencoder: predictors == targets
    target = var
    raw_vars = [var]
    raw = {
        var: patches.load_all_patches(
            #os.path.join(file_dir, f"../data/{var}/"), var
            os.path.join(file_dir, f"./preprocess_data/{var}/"), var
        )
        for var in raw_vars
    }

    # Load pregenerated train/valid/test split data.
    # These can be generated with features.split.get_chunks()
    with gzip.open(os.path.join(file_dir, chunks_file), 'rb') as f:
        chunks = pickle.load(f)
    (raw, _) = split.train_valid_test_split(raw, var, chunks=chunks)
    
    variables[var]["transform"] = transform.default_rainrate_transform(
        raw["train"][var]["scale"]
    )
    
    if sampler_file is None:
        sampler_file = {
            "train": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_autoenc_train.pkl",
            "valid": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_autoenc_valid.pkl",
            "test": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_autoenc_test.pkl",
        }
    bins = np.exp(np.linspace(np.log(0.2), np.log(50), 10))
    datamodule = split.DataModule(
        variables, raw, predictors, target, var,
        sampling_bins=bins, batch_size=batch_size,
        sampler_file=sampler_file,
        valid_seed=1234, test_seed=2345
    )
    
    gc.collect()
    return datamodule


def setup_model(
    model_dir=None,
):
    enc = encoder.SimpleConvEncoder()
    dec = encoder.SimpleConvDecoder()
    (autoencoder, trainer) = training.setup_autoenc_training(
        encoder=enc,
        decoder=dec,
        model_dir=model_dir,
    )
    gc.collect()
    return (autoencoder, trainer)


def train(
    var="RZC", 
    batch_size=64,
    sampler_file=None,
    num_timesteps=8,
    chunks_file="./preprocess_data/split_chunks.pkl.gz",
    model_dir=None,
    ckpt_path=None
):
    print("Loading data...")
    datamodule = setup_data(
        var=var, batch_size=batch_size, sampler_file=sampler_file,
        num_timesteps=num_timesteps, chunks_file=chunks_file
    )

    print("Setting up model...")
    (model, trainer) = setup_model(model_dir=model_dir)

    print("Starting training...")
    trainer.fit(model, datamodule=datamodule)


def main(config=None, **kwargs):
    config = OmegaConf.load(config) if (config is not None) else {}
    config.update(kwargs)
    train(**config)


if __name__ == "__main__":
    Fire(main)