Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f9e8ebf
Adding baseic model
97harsh Nov 15, 2020
6bf9fff
Initial commit for DeepAR
97harsh Nov 15, 2020
e8dd08f
Initial commit for DeepAR
97harsh Nov 15, 2020
5bf8088
config file updation
97harsh Nov 15, 2020
f9dc299
trainer_test updated
97harsh Nov 15, 2020
03aef74
trainer_test updated1
97harsh Nov 15, 2020
7163617
removed device argument
97harsh Nov 15, 2020
09ace45
changed number of layers to include dropout functionality
97harsh Nov 15, 2020
1c91826
added back inference params which is required in test
97harsh Nov 15, 2020
a8c4cb7
added back inference params which is required in test
97harsh Nov 15, 2020
7db3cb0
removed scaling in dataset
97harsh Nov 15, 2020
7378272
interpolate param added in test file
97harsh Nov 15, 2020
b719e23
added interpolate parameter
97harsh Nov 15, 2020
90d3b5b
removed tag from wandb in test
97harsh Nov 15, 2020
6d5b860
trying things in test to pass wandb issue
97harsh Nov 15, 2020
b4ef9bf
changed order in circle config file
97harsh Nov 15, 2020
006a5d7
testing changes1 added wandb api key
97harsh Nov 16, 2020
54f8057
added forward params in test json
97harsh Nov 16, 2020
17bbd68
added blank optim params in test
97harsh Nov 16, 2020
bfd9c03
added gaussian loss in model params, changed forward component
97harsh Nov 17, 2020
bf04546
changes to test
97harsh Nov 19, 2020
cfcb836
Gaussian loss initializaton
97harsh Nov 19, 2020
d2ecbea
added target in trainer
97harsh Nov 19, 2020
53b6cc5
Changes to trainer, to make model run work, problems in evaluation
97harsh Nov 19, 2020
d927488
changed pytorch_training to accomodate multi error
97harsh Nov 19, 2020
47d5635
added DeepAR based conditions in pytorchtrain, flake8 C901 added to .…
97harsh Nov 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ jobs:
command: |
echo -e 'running da-meta data unit test'
coverage run flood_forecast/trainer.py -p tests/da_meta.json
echo -e 'running Deep_AR_test \n'
coverage run flood_forecast/trainer.py -p tests/DeepAR_test.json
echo -e 'running transformer bottleneck'
coverage run flood_forecast/trainer.py -p tests/transformer_bottleneck.json
echo -e 'running da_rnn probabilistic test'
Expand Down
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[flake8]
max_line_length=120
ignore=E305,W504
ignore=E305,W504,C901
max-complexity=15
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,11 @@ tests/output/
data
mypy
.mypy_cache
*.png
*.png
.vscode
.vscode/*
.idea
.idea/
wandb
model_save
checkpoint.pth
2 changes: 1 addition & 1 deletion flood_forecast/custom/custom_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def forward(self, target: torch.Tensor, output: torch.Tensor):

# Add custom loss function
class GaussianLoss(torch.nn.Module):
def __init__(self, mu, sigma):
def __init__(self, mu=0, sigma=0):
"""Compute the negative log likelihood of Gaussian Distribution
From https://arxiv.org/abs/1907.00235
"""
Expand Down
18 changes: 18 additions & 0 deletions flood_forecast/deep_ar/config/lstm_kwargs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"batch_size": 64,
"cov_dim": 4,
"embedding_dim": 20,
"learning_rate": 1e-3,
"lstm_dropout": 0.1,
"lstm_hidden_dim": 40,
"lstm_layers": 3,
"num_class": 370,
"num_epochs": 20,
"predict_batch": 256,
"predict_start": 168,
"predict_steps": 24,
"sample_times": 200,
"test_predict_start": 168,
"test_window": 192,
"train_window": 192
}
176 changes: 176 additions & 0 deletions flood_forecast/deep_ar/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@

import torch
import torch.nn as nn
from torch.autograd import Variable


class DeepAR(nn.Module):
def __init__(self,
num_class: int,
cov_dim: int,
lstm_dropout: float,
embedding_dim: int,
lstm_hidden_dim: int,
lstm_layers: int,
sample_times: int,
predict_steps: int,
predict_start: int
):
"""Initialize the DeepAR model.

:param num_class: Number of classes
:param cov_dim: Number of covariates
:param lstm_dropout: drop out rate
:param embedding_dim: dimension of embedding layer
:param lstm_hidden_dim: hidden dimension of LSTM
:param lstm_layers: Number of LSTM layers
:param sample_times: sample time steps
:param predict_steps: Number of steps to predict
:param predict_start: Step to start prediction at
"""
super(DeepAR, self).__init__()
self.params = {}

self.params["num_class"] = num_class
self.params["cov_dim"] = cov_dim
self.params["lstm_dropout"] = lstm_dropout
self.params["embedding_dim"] = embedding_dim
self.params["lstm_hidden_dim"] = lstm_hidden_dim
self.params["lstm_layers"] = lstm_layers
self.params["sample_times"] = sample_times
self.params["predict_steps"] = predict_steps
self.params["predict_start"] = predict_start
# self.params = params
self.embedding = nn.Embedding(self.params["num_class"], self.params["embedding_dim"])

# self.lstm = nn.LSTM(input_size=1 + self.params["cov_dim"] + self.params["embedding_dim"],
# hidden_size=self.params["lstm_hidden_dim"],
# num_layers=self.params["lstm_layers"],
# bias=True,
# batch_first=False,
# dropout=self.params["lstm_dropout"])
self.lstm = nn.LSTM(input_size=1 + self.params["cov_dim"],
hidden_size=self.params["lstm_hidden_dim"],
num_layers=self.params["lstm_layers"],
bias=True,
batch_first=True,
dropout=self.params["lstm_dropout"])
# initialize LSTM forget gate bias to be 1 as recommanded by http://proceedings.mlr.press/v37/jozefowicz15.pdf
for names in self.lstm._all_weights:
for name in filter(lambda n: "bias" in n, names):
bias = getattr(self.lstm, name)
n = bias.size(0)
start, end = n // 4, n // 2
bias.data[start:end].fill_(1.)

self.relu = nn.ReLU()
self.distribution_mu = nn.Linear(self.params["lstm_hidden_dim"] * self.params["lstm_layers"], 1)
self.distribution_presigma = nn.Linear(self.params["lstm_hidden_dim"] * self.params["lstm_layers"], 1)
self.distribution_sigma = nn.Softplus()

def forward(self, input_data: torch.Tensor, t: torch.Tensor):
'''
Predict mu and sigma of the distribution for z_t.
Args:
x: ([batch_size, predict_start,1+cov_dim]): z_{t} + x{t}, have to make z_{t-1}+x_{t} to be able to use
t: ([batch_size,predict_steps,1+cov_dim]): z_{t} + x{t}, have to make z_{t-1}+x_{t} to be able to use
Returns:
mu: ([batch_size, predict_steps,1]): only predict future steps
sigma: ([batch_size,predict_steps,1]): only predict future
'''
hidden = self.init_hidden(input_data.shape[0]) # input batch size
cell = self.init_cell(input_data.shape[0]) # input batch size
z_0 = self.init_z0(input_data.shape[0])
mu_concat = torch.Tensor([])
mu = torch.Tensor([])
sigma_concat = torch.Tensor([])
target = input_data[:, :, 0:1]
covariate = input_data[:, :, 1:]
future_target = t[:, :, 0:1]
future_covariate = t[:, :, 1:]
X = torch.cat((covariate, future_covariate), dim=1)
y = torch.cat((target, future_target), axis=1)
for idx in range(X.shape[1]):
# onehot_embed = self.embedding(idx)
# # TODO: is it possible to do this only once per window instead of per step?
# lstm_input = torch.cat((x, onehot_embed), dim=2)
i = X[:, idx:idx + 1, :]
if idx == 0: # initial step
z = z_0
elif idx < self.params["predict_start"]: # training period
z = y[:, idx - 1:idx, :] # (batch,idx,1)
else: # prediction period
z = mu.unsqueeze(1)
# print(mu.shape)
lstm_input = torch.cat((i, z), dim=2)
# print(idx,lstm_input.shape)
output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
# use h from all three layers to calculate mu and sigma
hidden_permute = hidden.permute(1, 2, 0).contiguous().view(hidden.shape[1], -1)
pre_sigma = self.distribution_presigma(hidden_permute)
mu = self.distribution_mu(hidden_permute)
sigma = self.distribution_sigma(pre_sigma) # softplus to make sure standard deviation is positive
mu_concat = torch.cat((mu_concat, mu.unsqueeze(0)), dim=1)
sigma_concat = torch.cat((sigma_concat, sigma.unsqueeze(0)), dim=1)
return (mu_concat[:, self.params["predict_start"]:, :],
sigma_concat[:, self.params["predict_start"]:, :]) # (batch_size,predict_steps,1)

def init_hidden(self, input_size):
return torch.zeros(self.params["lstm_layers"], input_size, self.params["lstm_hidden_dim"])

def init_cell(self, input_size):
return torch.zeros(self.params["lstm_layers"], input_size, self.params["lstm_hidden_dim"])

def init_z0(self, batch_size):
return torch.zeros(batch_size, 1, 1)

def test(self, x, v_batch, id_batch, hidden, cell, sampling=False):
batch_size = x.shape[1]
if sampling:
samples = torch.zeros(self.params["sample_times"], batch_size, self.params["predict_steps"])
for j in range(self.params["sample_times"]):
decoder_hidden = hidden
decoder_cell = cell
for t in range(self.params["predict_steps"]):
mu_de, sigma_de, decoder_hidden, decoder_cell = self(
x[self.params["predict_start"] + t].unsqueeze(0),
id_batch, decoder_hidden, decoder_cell)
gaussian = torch.distributions.normal.Normal(mu_de, sigma_de)
pred = gaussian.sample() # not scaled
samples[j, :, t] = pred * v_batch[:, 0] + v_batch[:, 1]
if t < (self.params["predict_steps"] - 1):
x[self.params["predict_start"] + t + 1, :, 0] = pred

sample_mu = torch.median(samples, dim=0)[0]
sample_sigma = samples.std(dim=0)
return samples, sample_mu, sample_sigma

else:
decoder_hidden = hidden
decoder_cell = cell
sample_mu = torch.zeros(batch_size, self.params["predict_steps"])
sample_sigma = torch.zeros(batch_size, self.params["predict_steps"])
for t in range(self.params["predict_steps"]):
mu_de, sigma_de, decoder_hidden, decoder_cell = self(x[self.params["predict_start"] + t].unsqueeze(0),
id_batch, decoder_hidden, decoder_cell)
sample_mu[:, t] = mu_de * v_batch[:, 0] + v_batch[:, 1]
sample_sigma[:, t] = sigma_de * v_batch[:, 0]
if t < (self.params["predict_steps"] - 1):
x[self.params["predict_start"] + t + 1, :, 0] = mu_de
return sample_mu, sample_sigma


def loss_fn(mu: Variable, sigma: Variable, labels: Variable):
'''
Compute using gaussian the log-likehood which needs to be maximized. Ignore time steps where labels are missing.
Args:
mu: (Variable) dimension [batch_size] - estimated mean at time step t
sigma: (Variable) dimension [batch_size] - estimated standard deviation at time step t
labels: (Variable) dimension [batch_size] z_t
Returns:
loss: (Variable) average log-likelihood loss across the batch
'''
zero_index = (labels != 0)
distribution = torch.distributions.normal.Normal(mu[zero_index], sigma[zero_index])
likelihood = distribution.log_prob(labels[zero_index])
return -torch.mean(likelihood)
9 changes: 6 additions & 3 deletions flood_forecast/model_dict_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from flood_forecast.basic.linear_regression import simple_decode
from flood_forecast.transformer_xl.transformer_basic import greedy_decode
from flood_forecast.da_rnn.model import DARNN
from flood_forecast.custom.custom_opt import RMSELoss, MAPELoss, PenalizedMSELoss, NegativeLogLikelihood
from flood_forecast.custom.custom_opt import GaussianLoss, RMSELoss, MAPELoss, PenalizedMSELoss, NegativeLogLikelihood
from flood_forecast.transformer_xl.transformer_bottleneck import DecoderTransformer
from flood_forecast.custom.dilate_loss import DilateLoss
from flood_forecast.meta_models.basic_ae import AE
from flood_forecast.deep_ar.model import DeepAR
import torch

"""
Expand All @@ -29,7 +30,8 @@
"CustomTransformerDecoder": CustomTransformerDecoder,
"DARNN": DARNN,
"DecoderTransformer": DecoderTransformer,
"BasicAE": AE
"BasicAE": AE,
"DeepAR": DeepAR

}

Expand All @@ -42,7 +44,8 @@
"DilateLoss": DilateLoss,
"L1": L1Loss,
"PenalizedMSELoss": PenalizedMSELoss,
"NegativeLogLikelihood": NegativeLogLikelihood}
"NegativeLogLikelihood": NegativeLogLikelihood,
"GaussianLoss": GaussianLoss}

evaluation_functions_dict = {"NSE": "", "MSE": ""}
decoding_functions = {"greedy_decode": greedy_decode, "simple_decode": simple_decode}
Expand Down
12 changes: 8 additions & 4 deletions flood_forecast/pytorch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def torch_single_train(model: PyTorchForecast,
labels = trg[:, :, 0]
if isinstance(criterion, GaussianLoss):
g_loss = GaussianLoss(output[0], output[1])
loss = g_loss(labels)
loss = g_loss(labels.unsqueeze(1))
else:
loss = criterion(output, labels.float())
# TODO fix Guassian loss
Expand Down Expand Up @@ -223,7 +223,9 @@ def compute_validation(validation_loader: DataLoader, # s lint
targ = targ.to(device)
i += 1
if decoder_structure:
if type(model).__name__ == "SimpleTransformer":
if type(model).__name__ == "DeepAR":
output = model(src.float(), targ.float())
elif type(model).__name__ == "SimpleTransformer":
targ_clone = targ.detach().clone()
output = greedy_decode(
model,
Expand Down Expand Up @@ -252,7 +254,9 @@ def compute_validation(validation_loader: DataLoader, # s lint
output_len=1,
probabilistic=probabilistic)[:, :, 0]
else:
if probabilistic:
if type(model).__name__ == "DeepAR":
output = model(src.float(), targ.float())
elif probabilistic:
output_dist = model(src.float())
output = output_dist.mean.detach().numpy()
output_std = output_dist.stddev.detach().numpy()
Expand Down Expand Up @@ -284,7 +288,7 @@ def compute_validation(validation_loader: DataLoader, # s lint
loss = loss.numpy()
elif isinstance(criterion, GaussianLoss):
g_loss = GaussianLoss(output[0], output[1])
loss = g_loss(labels)
loss = g_loss(labels.unsqueeze(1))
else:
loss = criterion(output, labels.float())
loop_loss += len(labels.float()) * loss.item()
Expand Down
8 changes: 7 additions & 1 deletion flood_forecast/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ def train_function(model_type: str, params: Dict):
dataset_params["validation_path"],
dataset_params["test_path"],
params)
train_transformer_style(trained_model, params["training_params"], params["forward_params"])
takes_target = False
if "takes_target" in trained_model.params:
takes_target = trained_model.params["takes_target"]
train_transformer_style(model=trained_model,
training_params=params["training_params"],
takes_target=takes_target,
forward_params=params["forward_params"])
params["inference_params"]["dataset_params"]["scaling"] = scaler_dict[dataset_params["scaler"]]
test_acc = evaluate_model(
trained_model,
Expand Down
Loading