|
38 | 38 | @test first(d1).x ≈ first(d2).x |
39 | 39 | end |
40 | 40 |
|
| 41 | +@testset "Parametric Anticipative Solver - ContextualStochasticArgmax" begin |
| 42 | + using DecisionFocusedLearningBenchmarks |
| 43 | + |
| 44 | + b = ContextualStochasticArgmaxBenchmark(; n=5, d=3, seed=0) |
| 45 | + dataset = generate_dataset(b, 2; contexts_per_instance=1, nb_scenarios=1) |
| 46 | + sample = first(dataset) |
| 47 | + scenario = generate_scenario(b, StableRNG(0); sample.context...) |
| 48 | + |
| 49 | + solver = generate_anticipative_solver(b) |
| 50 | + parametric_solver = generate_parametric_anticipative_solver(b) |
| 51 | + |
| 52 | + # 1. Zero perturbation equivalence |
| 53 | + θ_zero = zeros(eltype(scenario), size(scenario)) |
| 54 | + @test parametric_solver(θ_zero, scenario; sample.context...) == |
| 55 | + solver(scenario; sample.context...) |
| 56 | + |
| 57 | + # 2. Extreme perturbation |
| 58 | + θ_extreme = zeros(eltype(scenario), size(scenario)) |
| 59 | + θ_extreme[1] = 1000.0 # Force dimension 1 |
| 60 | + y_extreme = parametric_solver(θ_extreme, scenario; sample.context...) |
| 61 | + |
| 62 | + @test y_extreme[1] == 1.0 # Only dimension 1 should be active |
| 63 | + @test sum(y_extreme) ≈ 1.0 # One-hot preserved |
| 64 | +end |
| 65 | + |
41 | 66 | @testset "csa_saa_policy" begin |
42 | 67 | using DecisionFocusedLearningBenchmarks |
43 | 68 |
|
|
0 commit comments