Skip to content

Commit 54d57b4

Browse files
BruceDaianssiko
authored andcommitted
Corrected peepholeWeight shape and fixed pseudo sample code error for lstm op
1 parent 1a308dd commit 54d57b4

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

index.bs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2227,7 +2227,7 @@ partial interface MLGraphBuilder {
22272227
- *options*: an optional {{MLGruOptions}}. The optional parameters of the operation.
22282228
- *bias*: an {{MLOperand}}. The 2-D input bias tensor of shape [num_directions, 4 * hidden_size]. The ordering of the bias vectors in the second dimension of the tensor shape is specified according to the *options.layout* argument.
22292229
- *recurrentBias*: an {{MLOperand}}. The 2-D recurrent bias tensor of shape [num_directions, 4 * hidden_size]. The ordering of the bias vectors in the second dimension of the tensor shape is specified according to the *options.layout* argument.
2230-
- *peepholeWeight*: an {{MLOperand}}. The 2-D weight tensor for peepholes of shape [num_directions, 4 * hidden_size]. The pack ordering of the weight vectors is for the *input (i)*, *output (o)*, and *forget (f)* gate respectively.
2230+
- *peepholeWeight*: an {{MLOperand}}. The 2-D weight tensor for peepholes of shape [num_directions, 3 * hidden_size]. The pack ordering of the weight vectors is for the *input (i)*, *output (o)*, and *forget (f)* gate respectively.
22312231
- *initialHiddenState*: an {{MLOperand}}. The 3-D initial hidden state tensor of shape [num_directions, batch_size, hidden_size]. When not specified, it's assumed to be a tensor filled with zero.
22322232
- *initialCellState*: an {{MLOperand}}. The 3-D initial hidden state tensor of shape [num_directions, batch_size, hidden_size]. When not specified, it's assumed to be a tensor filled with zero.
22332233
- *returnSequence*: a {{boolean}} indicating whether to also return the entire sequence with every output from each time step in it in addition to the output of the last time step. Default to false.
@@ -2270,7 +2270,7 @@ partial interface MLGraphBuilder {
22702270
currentRecurrentBias.push(options.recurrentBias ?
22712271
(builder.squeeze(builder.slice(options.recurrentBias, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null);
22722272
currentPeepholeWeight.push(options.peepholeWeight ?
2273-
(builder.squeeze(builder.slice(options.peepholeWeight, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null);
2273+
(builder.squeeze(builder.slice(options.peepholeWeight, [dir, 0], [1, 3 * hidden_size]), { axes: [0] })) : null);
22742274
}
22752275

22762276
for (let step = 0; step < steps; ++step) {
@@ -2297,8 +2297,8 @@ partial interface MLGraphBuilder {
22972297
let output = builder.reshape(results[0], [1, null, hiddenSize]);
22982298
let cell = builder.reshape(results[1], [1, null, hiddenSize]);
22992299

2300-
nextHidden = (nextHidden ? builder.concat([nextHidden, result], 0) : output);
2301-
nextCell = (nextCell ? builder.concat([nextCell, result], 0) : cell);
2300+
nextHidden = (nextHidden ? builder.concat([nextHidden, output], 0) : output);
2301+
nextCell = (nextCell ? builder.concat([nextCell, cell], 0) : cell);
23022302
}
23032303

23042304
hiddenState = nextHidden;

0 commit comments

Comments
 (0)