Skip to content

Commit d59e029

Browse files
franklinicclaude
andcommitted
fix: address 5 CodeRabbit review comments on PR #1155
- NBEATSBlock ctor: extract CreateInputShape / CreateOutputShape static factories that validate lookbackWindow / forecastHorizon BEFORE the base(...) call. Invalid values now surface as ArgumentException with the right nameof(...) tag instead of a downstream LayerBase<T> shape error. - InterfaceGuard: class visibility reduced from public to internal to match the AiDotNet facade pattern. InternalsVisibleTo on src/AiDotNet.csproj already grants access to AiDotNetTests / AiDotNetTestConsole / AiDotNet.Serving / AiDotNetBenchmarkTests, so the 58 existing test call sites still compile. Doc remark added explaining the visibility choice. - PretrainedTeacherModel + TransformerTeacherModel: reworded "auto-compiles via Tensors' AutoTracer" remarks. The wrapper only invokes the delegate; whether auto-compile actually happens depends entirely on what's inside the delegate. Removed the unconditional guarantee and added a note that external paths (ONNX, REST, etc.) won't pick up engine optimizations. - SelfTeacherModel.GetLogits: rewrote XML-doc so summary/returns/exception match the throw-only behavior (method has no underlying model to run and always throws InvalidOperationException). Previous summary said "Gets logits from the underlying model" which was misleading. Verify: dotnet build net10.0 + net471 — 0 errors. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1f53769 commit d59e029

5 files changed

Lines changed: 74 additions & 22 deletions

File tree

src/Helpers/InterfaceGuard.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,15 @@ namespace AiDotNet.Helpers;
66
/// callers must validate capabilities before use. These methods provide clear error
77
/// messages when a model doesn't support the requested capability.
88
/// </summary>
9-
public static class InterfaceGuard
9+
/// <remarks>
10+
/// <para>
11+
/// <b>Visibility:</b> <c>internal</c> to match the facade pattern — users interact with
12+
/// <c>AiModelBuilder</c> / <c>AiModelResult</c>, and the InternalsVisibleTo attribute on
13+
/// AiDotNet.csproj exposes this helper to the test/console/serving assemblies that need
14+
/// capability checks from outside the main assembly.
15+
/// </para>
16+
/// </remarks>
17+
internal static class InterfaceGuard
1018
{
1119
/// <summary>
1220
/// Returns the model as IParameterizable or throws with a clear message.

src/KnowledgeDistillation/Teachers/PretrainedTeacherModel.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,14 @@ namespace AiDotNet.KnowledgeDistillation.Teachers;
1313
/// </summary>
1414
/// <remarks>
1515
/// <para>
16-
/// Construction uses a <c>Func&lt;&gt;</c> forward-pass delegate. Inference goes
17-
/// through the standard model path, which auto-compiles via Tensors' AutoTracer
18-
/// once the input-shape pattern repeats.
16+
/// This wrapper takes a <c>Func&lt;Vector&lt;T&gt;, Vector&lt;T&gt;&gt;</c> forward-pass
17+
/// delegate and invokes it directly on every <see cref="GetLogits"/> call.
18+
/// The wrapper itself performs no caching or graph compilation — any
19+
/// optimizations (including Tensors' AutoTracer auto-compile) depend entirely
20+
/// on what happens inside the supplied delegate. A delegate that wraps a
21+
/// standard neural-network model's <c>Predict</c> path will pick up those
22+
/// engine-level optimizations; a delegate that invokes external code
23+
/// (pre-converted ONNX, a REST call, etc.) will not.
1924
/// </para>
2025
/// </remarks>
2126
[ModelDomain(ModelDomain.MachineLearning)]

src/KnowledgeDistillation/Teachers/SelfTeacherModel.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,19 @@ public void CachePredictions(Vector<T>[] predictions)
7575
}
7676

7777
/// <summary>
78-
/// Gets logits from the underlying model.
78+
/// Not supported for <see cref="SelfTeacherModel{T}"/> — always throws.
7979
/// </summary>
80-
/// <param name="input">Input to the model.</param>
81-
/// <returns>The logits from the underlying model.</returns>
82-
/// <exception cref="InvalidOperationException">Thrown when no underlying model is configured.</exception>
80+
/// <param name="input">Ignored.</param>
81+
/// <returns>This method does not return; it always throws.</returns>
82+
/// <exception cref="InvalidOperationException">
83+
/// Always thrown. <see cref="SelfTeacherModel{T}"/> serves pre-computed
84+
/// predictions by index via <see cref="GetCachedPrediction"/> and cannot
85+
/// evaluate a fresh input vector — it has no underlying model to run.
86+
/// </exception>
8387
/// <remarks>
84-
/// <para>Not supported for <see cref="SelfTeacherModel{T}"/> — callers must use
85-
/// <see cref="GetCachedPrediction"/>, which returns a pre-computed prediction by index.</para>
88+
/// <para>Callers must use <see cref="GetCachedPrediction"/> instead, which
89+
/// returns a prediction from the cache populated via
90+
/// <see cref="CachePredictions"/>.</para>
8691
/// </remarks>
8792
public override Vector<T> GetLogits(Vector<T> input)
8893
{

src/KnowledgeDistillation/Teachers/TransformerTeacherModel.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ namespace AiDotNet.KnowledgeDistillation.Teachers;
1414
/// <typeparam name="T">The numeric type for calculations (e.g., double, float).</typeparam>
1515
/// <remarks>
1616
/// <para>
17-
/// Construction uses a <c>Func&lt;&gt;</c> forward-pass delegate. Inference goes
18-
/// through the standard model path, which auto-compiles via Tensors' AutoTracer
19-
/// once the input-shape pattern repeats.
17+
/// This wrapper takes a <c>Func&lt;Vector&lt;T&gt;, Vector&lt;T&gt;&gt;</c> forward-pass
18+
/// delegate and invokes it directly on every <see cref="GetLogits"/> call.
19+
/// The wrapper performs no caching or graph compilation itself — any
20+
/// optimizations (including Tensors' AutoTracer auto-compile) depend on what
21+
/// the supplied delegate does internally.
2022
/// </para>
2123
/// <para>For attention-based distillation strategies that need attention weights, implement
2224
/// a custom IDistillationStrategy that can extract attention from the underlying model.</para>

src/TimeSeries/NBEATSBlock.cs

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,38 @@ public override int ParameterCount
119119
/// - useInterpretableBasis: Whether to use human-understandable basis functions
120120
/// </para>
121121
/// </remarks>
122+
/// <summary>
123+
/// Validates <paramref name="lookbackWindow"/> and returns the corresponding
124+
/// LayerBase input shape. Runs BEFORE the base ctor so invalid values surface
125+
/// as <see cref="ArgumentException"/> with the argument name instead of a
126+
/// downstream shape error.
127+
/// </summary>
128+
private static int[] CreateInputShape(int lookbackWindow)
129+
{
130+
if (lookbackWindow <= 0)
131+
{
132+
throw new ArgumentException("Lookback window must be positive.", nameof(lookbackWindow));
133+
}
134+
return new[] { lookbackWindow };
135+
}
136+
137+
/// <summary>
138+
/// Validates <paramref name="forecastHorizon"/> (and re-checks lookback for
139+
/// consistency) and returns the corresponding LayerBase output shape.
140+
/// </summary>
141+
private static int[] CreateOutputShape(int lookbackWindow, int forecastHorizon)
142+
{
143+
if (lookbackWindow <= 0)
144+
{
145+
throw new ArgumentException("Lookback window must be positive.", nameof(lookbackWindow));
146+
}
147+
if (forecastHorizon <= 0)
148+
{
149+
throw new ArgumentException("Forecast horizon must be positive.", nameof(forecastHorizon));
150+
}
151+
return new[] { lookbackWindow + forecastHorizon };
152+
}
153+
122154
public NBEATSBlock(
123155
int lookbackWindow,
124156
int forecastHorizon,
@@ -128,16 +160,16 @@ public NBEATSBlock(
128160
int thetaSizeForecast,
129161
bool useInterpretableBasis,
130162
int polynomialDegree = 3)
131-
: base(new[] { lookbackWindow }, new[] { lookbackWindow + forecastHorizon })
163+
: base(
164+
CreateInputShape(lookbackWindow),
165+
CreateOutputShape(lookbackWindow, forecastHorizon))
132166
{
133-
if (lookbackWindow <= 0)
134-
{
135-
throw new ArgumentException("Lookback window must be positive.", nameof(lookbackWindow));
136-
}
137-
if (forecastHorizon <= 0)
138-
{
139-
throw new ArgumentException("Forecast horizon must be positive.", nameof(forecastHorizon));
140-
}
167+
// Primary-argument validation happens inside the static shape factories
168+
// above so `lookbackWindow` / `forecastHorizon` are rejected BEFORE
169+
// LayerBase<T> consumes them — users see the nameof(...)-tagged
170+
// ArgumentException instead of a downstream shape error from the base.
171+
// (The two blocks that previously validated those here are now in
172+
// CreateInputShape / CreateOutputShape below.)
141173
if (hiddenLayerSize <= 0)
142174
{
143175
throw new ArgumentException("Hidden layer size must be positive.", nameof(hiddenLayerSize));

0 commit comments

Comments
 (0)