Skip to content

Commit 1732b82

Browse files
committed
Corrected peepholeWeight shape and fixed pseudo sample code error for lstm op
1 parent c22a3ac commit 1732b82

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

index.bs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,7 +2224,7 @@ partial interface MLGraphBuilder {
22242224
- *options*: an optional {{MLGruOptions}}. The optional parameters of the operation.
22252225
- *bias*: an {{MLOperand}}. The 2-D input bias tensor of shape [num_directions, 4 * hidden_size]. The ordering of the bias vectors in the second dimension of the tensor shape is specified according to the *options.layout* argument.
22262226
- *recurrentBias*: an {{MLOperand}}. The 2-D recurrent bias tensor of shape [num_directions, 4 * hidden_size]. The ordering of the bias vectors in the second dimension of the tensor shape is specified according to the *options.layout* argument.
2227-
- *peepholeWeight*: an {{MLOperand}}. The 2-D weight tensor for peepholes of shape [num_directions, 4 * hidden_size]. The pack ordering of the weight vectors is for the *input (i)*, *output (o)*, and *forget (f)* gate respectively.
2227+
- *peepholeWeight*: an {{MLOperand}}. The 2-D weight tensor for peepholes of shape [num_directions, 3 * hidden_size]. The pack ordering of the weight vectors is for the *input (i)*, *output (o)*, and *forget (f)* gate respectively.
22282228
- *initialHiddenState*: an {{MLOperand}}. The 3-D initial hidden state tensor of shape [num_directions, batch_size, hidden_size]. When not specified, it's assumed to be a tensor filled with zero.
22292229
- *initialCellState*: an {{MLOperand}}. The 3-D initial hidden state tensor of shape [num_directions, batch_size, hidden_size]. When not specified, it's assumed to be a tensor filled with zero.
22302230
- *returnSequence*: a {{boolean}} indicating whether to also return the entire sequence with every output from each time step in it in addition to the output of the last time step. Default to false.
@@ -2267,7 +2267,7 @@ partial interface MLGraphBuilder {
22672267
currentRecurrentBias.push(options.recurrentBias ?
22682268
(builder.squeeze(builder.slice(options.recurrentBias, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null);
22692269
currentPeepholeWeight.push(options.peepholeWeight ?
2270-
(builder.squeeze(builder.slice(options.peepholeWeight, [dir, 0], [1, 4 * hidden_size]), { axes: [0] })) : null);
2270+
(builder.squeeze(builder.slice(options.peepholeWeight, [dir, 0], [1, 3 * hidden_size]), { axes: [0] })) : null);
22712271
}
22722272

22732273
for (let step = 0; step < steps; ++step) {
@@ -2294,8 +2294,8 @@ partial interface MLGraphBuilder {
22942294
let output = builder.reshape(results[0], [1, null, hiddenSize]);
22952295
let cell = builder.reshape(results[1], [1, null, hiddenSize]);
22962296

2297-
nextHidden = (nextHidden ? builder.concat([nextHidden, result], 0) : output);
2298-
nextCell = (nextCell ? builder.concat([nextCell, result], 0) : cell);
2297+
nextHidden = (nextHidden ? builder.concat([nextHidden, output], 0) : output);
2298+
nextCell = (nextCell ? builder.concat([nextCell, cell], 0) : cell);
22992299
}
23002300

23012301
hiddenState = nextHidden;

0 commit comments

Comments
 (0)