Skip to content

Commit 257c5fc

Browse files
committed
fix(neuralnetworks): restore Predict override for models with custom Forward
PR #1155 replaced `public override Tensor<T> Predict(Tensor<T> input) => Forward(input)` with `protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input)` in 11 network classes. That reroutes Predict through `NeuralNetworkBase.PredictCompiled` → `CompiledModelHost.Predict` → the compiled-replay plan. The compiled-replay plan silently truncates the forward pass for the affected models: on a ResNet-18 with 3×32×32 input the plan returns a shape-[64] tensor (output of the first conv block's 64 channels) instead of the expected shape-[10] logits. The net effect is wrong output for every `Predict` call on these 11 model classes. Verified with a per-call diagnostic: Forward direct shape: [10] Predict compile-off shape: [10] Predict compile-on shape: [64] ← the regression Root cause is in the tracer / compiled-replay machinery (the tracer does not capture the shape-conditional control flow in Forward — rank-3 → rank-4 promotion, final Reshape, etc.). That's a larger infrastructure fix; this commit restores master's previous behavior so `Predict` calls Forward directly, matching the pre-#1155 contract. Affected models (all had master's `public override Predict` changed to `protected override PredictEager` by #1155): - ResNetNetwork, VGGNetwork, EfficientNetNetwork, MobileNetV2Network, ConvolutionalNeuralNetwork, UNet3D, VoxelCNN - FastText, GloVe, Word2Vec, SiameseNeuralNetwork Unblocks CI on every open PR that merges master (PR #1163, PR #1165). Once the compiled-plan tracer is hardened to preserve shape-conditional control flow, these overrides can be removed again.
1 parent 4fe35b0 commit 257c5fc

11 files changed

Lines changed: 70 additions & 38 deletions

src/NeuralNetworks/ConvolutionalNeuralNetwork.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,13 @@ public override void UpdateParameters(Vector<T> parameters)
235235
/// </para>
236236
/// </remarks>
237237
/// <summary>
238-
/// Routes inference through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> for
239-
/// compiled-plan replay; <see cref="Forward"/> remains the eager fallback.
238+
/// Runs <see cref="Forward"/> directly rather than routing through the
239+
/// compiled-replay path. <see cref="Forward"/> contains shape-conditional
240+
/// control flow that the tracer does not capture faithfully — routing
241+
/// through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> was observed
242+
/// to return an intermediate feature-map shape instead of the final logits.
240243
/// </summary>
241-
protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input);
244+
public override Tensor<T> Predict(Tensor<T> input) => Forward(input);
242245

243246
/// <summary>
244247
/// Trains the convolutional neural network using the provided input and expected output.

src/NeuralNetworks/EfficientNetNetwork.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,13 @@ public Tensor<T> Forward(Tensor<T> input)
294294

295295
/// <inheritdoc />
296296
/// <summary>
297-
/// Routes inference through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> for
298-
/// compiled-plan replay; <see cref="Forward"/> remains the eager fallback.
297+
/// Runs <see cref="Forward"/> directly rather than routing through the
298+
/// compiled-replay path. <see cref="Forward"/> contains shape-conditional
299+
/// control flow that the tracer does not capture faithfully — routing
300+
/// through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> was observed
301+
/// to return an intermediate feature-map shape instead of the final logits.
299302
/// </summary>
300-
protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input);
303+
public override Tensor<T> Predict(Tensor<T> input) => Forward(input);
301304

302305
/// <inheritdoc />
303306
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

src/NeuralNetworks/FastText.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,13 @@ public override void UpdateParameters(Vector<T> parameters)
245245
}
246246

247247
/// <summary>
248-
/// Routes inference through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> for
249-
/// compiled-plan replay; <see cref="Forward"/> remains the eager fallback.
248+
/// Runs <see cref="Forward"/> directly rather than routing through the
249+
/// compiled-replay path. <see cref="Forward"/> contains shape-conditional
250+
/// control flow that the tracer does not capture faithfully — routing
251+
/// through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> was observed
252+
/// to return an intermediate feature-map shape instead of the final output.
250253
/// </summary>
251-
protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input);
254+
public override Tensor<T> Predict(Tensor<T> input) => Forward(input);
252255

253256
/// <summary>
254257
/// Trains the model on a single step of data using standard backpropagation.

src/NeuralNetworks/GloVe.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,13 @@ public override void UpdateParameters(Vector<T> parameters)
294294
/// returns their addresses (embeddings).
295295
/// </remarks>
296296
/// <summary>
297-
/// Routes inference through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> for
298-
/// compiled-plan replay; <see cref="Forward"/> remains the eager fallback.
297+
/// Runs <see cref="Forward"/> directly rather than routing through the
298+
/// compiled-replay path. <see cref="Forward"/> contains shape-conditional
299+
/// control flow that the tracer does not capture faithfully — routing
300+
/// through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> was observed
301+
/// to return an intermediate shape instead of the final output.
299302
/// </summary>
300-
protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input);
303+
public override Tensor<T> Predict(Tensor<T> input) => Forward(input);
301304

302305
/// <summary>
303306
/// Trains the model on a batch of word pairs and their co-occurrence counts.

src/NeuralNetworks/MobileNetV2Network.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,13 @@ public override Dictionary<string, Tensor<T>> GetNamedLayerActivations(Tensor<T>
289289

290290
/// <inheritdoc />
291291
/// <summary>
292-
/// Routes inference through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> for
293-
/// compiled-plan replay; <see cref="Forward"/> remains the eager fallback.
292+
/// Runs <see cref="Forward"/> directly rather than routing through the
293+
/// compiled-replay path. <see cref="Forward"/> contains shape-conditional
294+
/// control flow that the tracer does not capture faithfully — routing
295+
/// through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> was observed
296+
/// to return an intermediate feature-map shape instead of the final logits.
294297
/// </summary>
295-
protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input);
298+
public override Tensor<T> Predict(Tensor<T> input) => Forward(input);
296299

297300
/// <inheritdoc />
298301
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

src/NeuralNetworks/ResNetNetwork.cs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -493,17 +493,18 @@ public override void UpdateParameters(Vector<T> parameters)
493493
}
494494

495495
/// <summary>
496-
/// Makes a prediction using the ResNet network for the given input.
496+
/// Makes a prediction using the ResNet network for the given input. Runs
497+
/// <see cref="Forward"/> directly rather than routing through the
498+
/// compiled-replay path, because <see cref="Forward"/> contains
499+
/// shape-conditional control flow (rank-3 → rank-4 batch promotion and the
500+
/// final <c>Reshape</c> that strips the synthetic batch dim) that the
501+
/// tracer does not capture faithfully — routing through
502+
/// <see cref="NeuralNetworkBase{T}.PredictCompiled"/> was observed to
503+
/// return an intermediate feature-map shape instead of the final logits.
497504
/// </summary>
498505
/// <param name="input">The input tensor to make a prediction for.</param>
499506
/// <returns>The predicted output tensor containing class probabilities.</returns>
500-
/// <summary>
501-
/// Routes inference through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> so the
502-
/// forward pass gets traced and replayed by <c>CompiledModelHost</c> after warmup —
503-
/// matching PyTorch's <c>torch.compile</c> default. The eager forward is
504-
/// <see cref="Forward"/>, which retains the GPU-resident optimization path.
505-
/// </summary>
506-
protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input);
507+
public override Tensor<T> Predict(Tensor<T> input) => Forward(input);
507508

508509
/// <summary>
509510
/// Trains the ResNet network using the provided input and expected output.

src/NeuralNetworks/SiameseNeuralNetwork.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,13 @@ public override void UpdateParameters(Vector<T> parameters)
302302
}
303303

304304
/// <summary>
305-
/// Routes inference through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> for
306-
/// compiled-plan replay; <see cref="Forward"/> remains the eager fallback.
305+
/// Runs <see cref="Forward"/> directly rather than routing through the
306+
/// compiled-replay path. <see cref="Forward"/> contains shape-conditional
307+
/// control flow that the tracer does not capture faithfully — routing
308+
/// through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> was observed
309+
/// to return an intermediate shape instead of the final output.
307310
/// </summary>
308-
protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input);
311+
public override Tensor<T> Predict(Tensor<T> input) => Forward(input);
309312

310313
/// <summary>
311314
/// Trains the model on pairs of inputs using a similarity learning objective.

src/NeuralNetworks/UNet3D.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,13 @@ public Tensor<T> Forward(Tensor<T> input)
237237
/// <returns>The predicted segmentation map.</returns>
238238
/// <inheritdoc />
239239
/// <summary>
240-
/// Routes inference through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> for
241-
/// compiled-plan replay; <see cref="Forward"/> remains the eager fallback.
240+
/// Runs <see cref="Forward"/> directly rather than routing through the
241+
/// compiled-replay path. <see cref="Forward"/> contains shape-conditional
242+
/// control flow that the tracer does not capture faithfully — routing
243+
/// through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> was observed
244+
/// to return an intermediate feature-map shape instead of the final output.
242245
/// </summary>
243-
protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input);
246+
public override Tensor<T> Predict(Tensor<T> input) => Forward(input);
244247

245248
/// <summary>
246249
/// Trains the network on a single batch of input-output pairs.

src/NeuralNetworks/VGGNetwork.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,14 @@ public override void UpdateParameters(Vector<T> parameters)
340340
/// </para>
341341
/// </remarks>
342342
/// <summary>
343-
/// Routes inference through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> for
344-
/// compiled-plan replay; <see cref="Forward"/> remains the eager fallback.
343+
/// Runs <see cref="Forward"/> directly rather than routing through the
344+
/// compiled-replay path. <see cref="Forward"/> contains shape-conditional
345+
/// control flow (batch-dim promotion, post-pipeline reshape) that the
346+
/// tracer does not capture faithfully — routing through
347+
/// <see cref="NeuralNetworkBase{T}.PredictCompiled"/> was observed to
348+
/// return an intermediate feature-map shape instead of the final logits.
345349
/// </summary>
346-
protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input);
350+
public override Tensor<T> Predict(Tensor<T> input) => Forward(input);
347351

348352
/// <summary>
349353
/// Trains the VGG network using the provided input and expected output.

src/NeuralNetworks/VoxelCNN.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,13 @@ public Tensor<T> Forward(Tensor<T> input)
230230
/// <returns>The predicted class probabilities or scores.</returns>
231231
/// <inheritdoc />
232232
/// <summary>
233-
/// Routes inference through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> for
234-
/// compiled-plan replay; <see cref="Forward"/> remains the eager fallback.
233+
/// Runs <see cref="Forward"/> directly rather than routing through the
234+
/// compiled-replay path. <see cref="Forward"/> contains shape-conditional
235+
/// control flow that the tracer does not capture faithfully — routing
236+
/// through <see cref="NeuralNetworkBase{T}.PredictCompiled"/> was observed
237+
/// to return an intermediate feature-map shape instead of the final output.
235238
/// </summary>
236-
protected override Tensor<T> PredictEager(Tensor<T> input) => Forward(input);
239+
public override Tensor<T> Predict(Tensor<T> input) => Forward(input);
237240

238241
/// <summary>
239242
/// Trains the network on a single batch of input-output pairs.

0 commit comments

Comments
 (0)