Skip to content

why is ' self.lstm_cell(X, state) ' ? #20

@fanyuzeng

Description

@fanyuzeng

I successfully run the program.However, I found threre is something seen abnormal.

class RecurrentController(BaseController):
def network_vars(self):
self.lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(256)
self.state = self.lstm_cell.zero_state(self.batch_size, tf.float32)

def network_op(self, X, state):
    X = tf.convert_to_tensor(X)
    return self.lstm_cell(X, state)

def get_state(self):
    return self.state

def update_state(self, new_state):
    return tf.no_op()

In the above, tf.nn.rnn_cell.BasicLSTMCell(256) make 256 lstm cells, but the code directly use return self.lstm_cell(X, state) without using tf.static_rnn or tf.nn.dynamic_rnn.
tf.static_rnn or tf.nn.dynamic_rnn can output state, but self.lstm_cell(X, state) can't.
I wonder whether it's wrong and needed to add tf.static_rnn or tf.nn.dynamic_rnn as following:

def network_op(self, X, state):
X = tf.convert_to_tensor(X)
return return tf.nn.dynamic_rnn(self.lstm_cell, X, initial_state=state, time_major=False)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions