@@ -425,6 +425,13 @@ def test_deterministic(self, device):
425425 False ,
426426 2 * 2 * 6 ,
427427 ), # non-defaults
428+ (
429+ {"in_dim" : 3 , "freqs" : torch .tensor ([1.0 , 2.0 , 4.0 ])},
430+ 3 ,
431+ 3 ,
432+ True ,
433+ 3 + 2 * 3 * 3 ,
434+ ), # explicit freqs (num_bands inferred from the schedule length)
428435 ],
429436)
430437def test_fourier_positional_embedding_constructor_attrs (
@@ -435,79 +442,60 @@ def test_fourier_positional_embedding_constructor_attrs(
435442 assert emb .num_bands == exp_num_bands
436443 assert emb .include_input == exp_include_input
437444 assert emb .out_dim == exp_out_dim
438-
439-
440- def test_fourier_positional_embedding_out_dim_and_shape (device ):
441- emb = FourierPositionalEmbedding (in_dim = 3 , num_bands = 4 ).to (device )
442- assert emb .num_bands == 4
443- assert emb .out_dim == 3 + 2 * 3 * 4 # 27
444- y = emb (torch .randn (5 , 3 , device = device ))
445- assert y .shape == (5 , 27 )
446445 # No learnable parameters.
447446 assert sum (p .numel () for p in emb .parameters ()) == 0
448447
449448
450- def test_fourier_positional_embedding_no_include_input (device ):
451- emb = FourierPositionalEmbedding (in_dim = 2 , num_bands = 3 , include_input = False ).to (
452- device
453- )
454- assert emb .out_dim == 2 * 2 * 3 # 12
455- assert emb (torch .zeros (4 , 2 , device = device )).shape == (4 , 12 )
456-
457-
458- def test_fourier_positional_embedding_leading_dims (device ):
459- emb = FourierPositionalEmbedding (in_dim = 3 , num_bands = 2 ).to (device )
460- for shape in [(3 ,), (7 , 3 ), (2 , 7 , 3 )]:
461- out = emb (torch .randn (* shape , device = device ))
462- assert out .shape == (* shape [:- 1 ], emb .out_dim )
463-
464-
465- def test_fourier_positional_embedding_values (device ):
466- # Single coord, single band at base_freq=pi -> [sin(pi*x), cos(pi*x)].
467- emb = FourierPositionalEmbedding (in_dim = 1 , num_bands = 1 , include_input = False ).to (
468- device
469- )
470- x = torch .tensor ([[0.5 ]], device = device )
471- f = math .pi
472- torch .testing .assert_close (
473- emb (x ),
474- torch .tensor ([[math .sin (f * 0.5 ), math .cos (f * 0.5 )]], device = device ),
475- )
476-
477-
478- def test_fourier_positional_embedding_explicit_freqs (device ):
479- emb = FourierPositionalEmbedding (
480- in_dim = 3 , freqs = torch .tensor ([1.0 , 2.0 , 4.0 ]), include_input = True
481- ).to (device )
482- assert emb .num_bands == 3
483- assert emb .out_dim == 3 + 2 * 3 * 3
484- assert emb (torch .randn (8 , 3 , device = device )).shape == (8 , emb .out_dim )
485-
486-
487- def test_fourier_positional_embedding_axis_major_layout (device ):
488- # With include_input=False the layout is axis-major: for each axis, the
489- # num_bands sines followed by the num_bands cosines.
449+ @pytest .mark .parametrize (
450+ "in_dim, freqs, include_input, x, expected" ,
451+ [
452+ # include_input=False, axis-major layout: per axis, sines then cosines.
453+ (
454+ 2 ,
455+ [1.0 , 2.0 ],
456+ False ,
457+ [[0.3 , 0.7 ]],
458+ [
459+ [
460+ math .sin (1.0 * 0.3 ),
461+ math .sin (2.0 * 0.3 ),
462+ math .cos (1.0 * 0.3 ),
463+ math .cos (2.0 * 0.3 ),
464+ math .sin (1.0 * 0.7 ),
465+ math .sin (2.0 * 0.7 ),
466+ math .cos (1.0 * 0.7 ),
467+ math .cos (2.0 * 0.7 ),
468+ ]
469+ ],
470+ ),
471+ # Single coordinate and band.
472+ (
473+ 1 ,
474+ [math .pi ],
475+ False ,
476+ [[0.5 ]],
477+ [[math .sin (math .pi * 0.5 ), math .cos (math .pi * 0.5 )]],
478+ ),
479+ # include_input=True prepends the raw coordinate.
480+ (
481+ 1 ,
482+ [1.0 ],
483+ True ,
484+ [[0.5 ]],
485+ [[0.5 , math .sin (0.5 ), math .cos (0.5 )]],
486+ ),
487+ ],
488+ )
489+ def test_fourier_positional_embedding_forward_values (
490+ device , in_dim , freqs , include_input , x , expected
491+ ):
492+ # Known-reference forward values across configs (layout, single band,
493+ # and include_input prepend).
490494 emb = FourierPositionalEmbedding (
491- in_dim = 2 , num_bands = 2 , include_input = False , freqs = torch .tensor ([ 1.0 , 2.0 ])
495+ in_dim = in_dim , freqs = torch .tensor (freqs ), include_input = include_input
492496 ).to (device )
493- x = torch .tensor ([[0.3 , 0.7 ]], device = device )
494- out = emb (x )
495- expected = torch .tensor (
496- [
497- [
498- math .sin (1.0 * 0.3 ),
499- math .sin (2.0 * 0.3 ),
500- math .cos (1.0 * 0.3 ),
501- math .cos (2.0 * 0.3 ),
502- math .sin (1.0 * 0.7 ),
503- math .sin (2.0 * 0.7 ),
504- math .cos (1.0 * 0.7 ),
505- math .cos (2.0 * 0.7 ),
506- ]
507- ],
508- device = device ,
509- )
510- torch .testing .assert_close (out , expected )
497+ out = emb (torch .tensor (x , device = device ))
498+ torch .testing .assert_close (out , torch .tensor (expected , device = device ))
511499
512500
513501def test_fourier_positional_embedding_validation (device ):
@@ -518,6 +506,9 @@ def test_fourier_positional_embedding_validation(device):
518506 FourierPositionalEmbedding (in_dim = 0 )
519507 with pytest .raises (ValueError ):
520508 FourierPositionalEmbedding (in_dim = 3 , num_bands = 0 )
509+ # Explicit freqs must be 1-D of shape (F,).
510+ with pytest .raises (ValueError ):
511+ FourierPositionalEmbedding (in_dim = 3 , freqs = torch .ones (2 , 3 ))
521512
522513
523514def test_fourier_positional_embedding_state_dict_roundtrip (device ):
@@ -540,12 +531,13 @@ def test_fourier_positional_embedding_forward_accuracy(device):
540531 # MOD-008b: compare the forward output against committed reference data.
541532 model = FourierPositionalEmbedding (in_dim = 3 , num_bands = 4 ).to (device )
542533 model .eval ()
543- # Deterministic, reproducible input (the layer has no random parameters).
544- x = torch .linspace (- 1.0 , 1.0 , steps = 24 , device = device ).reshape (8 , 3 )
534+ # Deterministic, reproducible input; a 3-D shape also exercises arbitrary
535+ # leading (batch) dimensions against the reference.
536+ x = torch .linspace (- 1.0 , 1.0 , steps = 2 * 4 * 3 , device = device ).reshape (2 , 4 , 3 )
545537 assert validate_forward_accuracy (
546538 model ,
547539 (x ,),
548- file_name = "nn/module/data/fourier_positional_embedding_in3_nb4_bs8 .pth" ,
540+ file_name = "nn/module/data/fourier_positional_embedding_in3_nb4_b2x4 .pth" ,
549541 rtol = 1e-4 ,
550542 atol = 1e-4 ,
551543 )
0 commit comments