This repository was archived by the owner on Sep 10, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 30
Expand file tree
/
Copy pathsrnn.py
More file actions
127 lines (112 loc) · 5.28 KB
/
Copy pathsrnn.py
File metadata and controls
127 lines (112 loc) · 5.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# -*- coding: utf-8 -*-
# Copyright 2018 Google, Inc.,
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r""""SRNN as described in:
Marco Fraccaro, Søren Kaae Sønderby, Ulrich Paquet, Ole Winther.
Sequential Neural Models with Stochastic Layers.
https://arxiv.org/abs/1605.07571
Notation:
- z_0:T are hidden states, random variables
- d_1:T and e_1:T are deterministic RNN outputs
- x_1:T are the observed states
- c_1:T are the per-timestep inputs
Generative model Inference model
===================== =====================
z_0 -> z_1 -----> z_t z_0 -> z_1 ---------> z_t
| ^ | ^ ^ ^
v | v | | |
x_1 <-. | x_t <-. | | |
\| \| e_1 <--------- e_t
* * / ^ / ^
| | x_1 | x_t |
d_1 -----> d_t d_1 ---------> d_t
^ ^ ^ ^
| | | |
c_1 c_t c_1 c_t
"""
import sonnet as snt
import tensorflow as tf
from .. import latent as latent_mod
from .. import util
from .. import vae_module
class SRNN(vae_module.VAECore):
"""Implementation of SRNN (see module description)."""
def __init__(self, hparams, obs_encoder, obs_decoder, name=None):
super(SRNN, self).__init__(hparams, obs_encoder, obs_decoder, name)
with self._enter_variable_scope():
self._d_core = util.make_rnn(hparams, name="d_core")
self._e_core = util.make_rnn(hparams, name="e_core")
self._latent_p = latent_mod.LatentDecoder(hparams, name="latent_p")
self._latent_q = latent_mod.LatentDecoder(hparams, name="latent_q")
@property
def state_size(self):
return (self._d_core.state_size, self._latent_p.event_size)
def _build(self, input_, state):
d_state, latent = state
d_out, d_state = self._d_core(util.concat_features(input_), d_state)
latent_params = self._latent_p(d_out, latent)
return (self._obs_decoder(util.concat_features((d_out, latent))),
(d_state, latent_params))
def _next_state(self, state_arg, event=None):
del event # Not used.
d_state, latent_params = state_arg
return d_state, self._latent_p.dist(latent_params, name="latent")
def _initial_state(self, batch_size):
d_state = self._d_core.initial_state(batch_size)
latent_input_sizes = (self._d_core.output_size,
self._latent_p.event_size)
latent_inputs = snt.nest.map(
lambda size: tf.zeros(
[batch_size] + tf.TensorShape(size).as_list(),
name="latent_input"),
latent_input_sizes)
latent_params = self._latent_p(latent_inputs)
return self._next_state((d_state, latent_params), event=None)
def _infer_latents(self, inputs, observed):
hparams = self._hparams
batch_size = util.batch_size_from_nested_tensors(observed)
d_initial, z_initial = self.initial_state(batch_size)
(d_outs, d_states), _ = tf.nn.dynamic_rnn(
util.state_recording_rnn(self._d_core),
util.concat_features(inputs),
initial_state=d_initial)
enc_observed = snt.BatchApply(self._obs_encoder, n_dims=2)(observed)
e_outs, _ = util.reverse_dynamic_rnn(
self._e_core,
util.concat_features((enc_observed, inputs)),
initial_state=self._e_core.initial_state(batch_size))
def _inf_step(d_e_outputs, prev_latent):
"""Iterate over d_1:T and e_1:T to produce z_1:T."""
d_out, e_out = d_e_outputs
p_z_params = self._latent_p(d_out, prev_latent)
p_z = self._latent_p.dist(p_z_params)
q_loc, q_scale = self._latent_q(e_out, prev_latent)
if hparams.srnn_use_res_q:
q_loc += p_z.loc
q_z = self._latent_q.dist((q_loc, q_scale), name="q_z_dist")
latent = q_z.sample()
divergence = util.calc_kl(hparams, latent, q_z, p_z)
return (latent, divergence), latent
inf_core = util.WrapRNNCore(
_inf_step,
state_size=tf.TensorShape(hparams.latent_size), # prev_latent
output_size=(tf.TensorShape(hparams.latent_size), # latent
tf.TensorShape([]),), # divergence
name="inf_z_core")
(latents, kls), _ = util.heterogeneous_dynamic_rnn(
inf_core,
(d_outs, e_outs),
initial_state=z_initial,
output_dtypes=(self._latent_q.event_dtype, tf.float32))
return (d_states, latents), kls