Skip to content

Commit 34e6682

Browse files
committed
address review comments
1 parent da24d7e commit 34e6682

4 files changed

Lines changed: 76 additions & 78 deletions

File tree

physicsnemo/nn/module/embedding_layers.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,10 @@ class FourierPositionalEmbedding(Module):
299299
freq_scale : float, optional, default=2.0
300300
Geometric ratio between consecutive band frequencies for the
301301
generated schedule.
302-
freqs : torch.Tensor or sequence of float, optional, default=None
303-
Explicit 1-D frequency schedule. Overrides ``num_bands``,
304-
``base_freq`` and ``freq_scale`` when provided.
302+
freqs : torch.Tensor, optional, default=None
303+
Explicit 1-D frequency schedule of shape :math:`(F,)` with
304+
:math:`F \geq 1`. Overrides ``num_bands``, ``base_freq`` and
305+
``freq_scale`` when provided. A non-1-D ``freqs`` raises ``ValueError``.
305306
306307
Forward
307308
-------
@@ -334,7 +335,7 @@ def __init__(
334335
include_input: bool = True,
335336
base_freq: float = math.pi,
336337
freq_scale: float = 2.0,
337-
freqs: Tensor | None = None,
338+
freqs: Float[Tensor, "num_freqs"] | None = None, # noqa: F821
338339
):
339340
super().__init__()
340341
if in_dim < 1:
@@ -346,9 +347,12 @@ def __init__(
346347
num_bands, dtype=torch.float32
347348
)
348349
else:
349-
freqs = torch.as_tensor(freqs, dtype=torch.float32).flatten()
350-
if freqs.numel() < 1:
351-
raise ValueError("freqs must contain at least one frequency.")
350+
freqs = freqs.to(torch.float32)
351+
if freqs.ndim != 1 or freqs.numel() < 1:
352+
raise ValueError(
353+
"freqs must be a 1-D tensor of shape (F,) with F >= 1, "
354+
f"got shape {tuple(freqs.shape)}."
355+
)
352356
self.in_dim = int(in_dim)
353357
self.include_input = bool(include_input)
354358
# Persistent so an explicitly supplied ``freqs`` schedule survives a
@@ -367,7 +371,9 @@ def out_dim(self) -> int:
367371
base = self.in_dim if self.include_input else 0
368372
return base + 2 * self.in_dim * self.num_bands
369373

370-
def forward(self, x: Float[Tensor, "... in_dim"]) -> Float[Tensor, "... out_dim"]:
374+
def forward(
375+
self, x: Float[Tensor, "*dims in_dim"]
376+
) -> Float[Tensor, "*dims out_dim"]:
371377
r"""Encode coordinates ``x``; see the class docstring for shapes."""
372378
# Skip validation when running under torch.compile (MOD-005).
373379
if not torch.compiler.is_compiling():
Binary file not shown.
-2.77 KB
Binary file not shown.

test/nn/module/test_embedding_layers.py

Lines changed: 62 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,13 @@ def test_deterministic(self, device):
425425
False,
426426
2 * 2 * 6,
427427
), # non-defaults
428+
(
429+
{"in_dim": 3, "freqs": torch.tensor([1.0, 2.0, 4.0])},
430+
3,
431+
3,
432+
True,
433+
3 + 2 * 3 * 3,
434+
), # explicit freqs (num_bands inferred from the schedule length)
428435
],
429436
)
430437
def test_fourier_positional_embedding_constructor_attrs(
@@ -435,79 +442,60 @@ def test_fourier_positional_embedding_constructor_attrs(
435442
assert emb.num_bands == exp_num_bands
436443
assert emb.include_input == exp_include_input
437444
assert emb.out_dim == exp_out_dim
438-
439-
440-
def test_fourier_positional_embedding_out_dim_and_shape(device):
441-
emb = FourierPositionalEmbedding(in_dim=3, num_bands=4).to(device)
442-
assert emb.num_bands == 4
443-
assert emb.out_dim == 3 + 2 * 3 * 4 # 27
444-
y = emb(torch.randn(5, 3, device=device))
445-
assert y.shape == (5, 27)
446445
# No learnable parameters.
447446
assert sum(p.numel() for p in emb.parameters()) == 0
448447

449448

450-
def test_fourier_positional_embedding_no_include_input(device):
451-
emb = FourierPositionalEmbedding(in_dim=2, num_bands=3, include_input=False).to(
452-
device
453-
)
454-
assert emb.out_dim == 2 * 2 * 3 # 12
455-
assert emb(torch.zeros(4, 2, device=device)).shape == (4, 12)
456-
457-
458-
def test_fourier_positional_embedding_leading_dims(device):
459-
emb = FourierPositionalEmbedding(in_dim=3, num_bands=2).to(device)
460-
for shape in [(3,), (7, 3), (2, 7, 3)]:
461-
out = emb(torch.randn(*shape, device=device))
462-
assert out.shape == (*shape[:-1], emb.out_dim)
463-
464-
465-
def test_fourier_positional_embedding_values(device):
466-
# Single coord, single band at base_freq=pi -> [sin(pi*x), cos(pi*x)].
467-
emb = FourierPositionalEmbedding(in_dim=1, num_bands=1, include_input=False).to(
468-
device
469-
)
470-
x = torch.tensor([[0.5]], device=device)
471-
f = math.pi
472-
torch.testing.assert_close(
473-
emb(x),
474-
torch.tensor([[math.sin(f * 0.5), math.cos(f * 0.5)]], device=device),
475-
)
476-
477-
478-
def test_fourier_positional_embedding_explicit_freqs(device):
479-
emb = FourierPositionalEmbedding(
480-
in_dim=3, freqs=torch.tensor([1.0, 2.0, 4.0]), include_input=True
481-
).to(device)
482-
assert emb.num_bands == 3
483-
assert emb.out_dim == 3 + 2 * 3 * 3
484-
assert emb(torch.randn(8, 3, device=device)).shape == (8, emb.out_dim)
485-
486-
487-
def test_fourier_positional_embedding_axis_major_layout(device):
488-
# With include_input=False the layout is axis-major: for each axis, the
489-
# num_bands sines followed by the num_bands cosines.
449+
@pytest.mark.parametrize(
450+
"in_dim, freqs, include_input, x, expected",
451+
[
452+
# include_input=False, axis-major layout: per axis, sines then cosines.
453+
(
454+
2,
455+
[1.0, 2.0],
456+
False,
457+
[[0.3, 0.7]],
458+
[
459+
[
460+
math.sin(1.0 * 0.3),
461+
math.sin(2.0 * 0.3),
462+
math.cos(1.0 * 0.3),
463+
math.cos(2.0 * 0.3),
464+
math.sin(1.0 * 0.7),
465+
math.sin(2.0 * 0.7),
466+
math.cos(1.0 * 0.7),
467+
math.cos(2.0 * 0.7),
468+
]
469+
],
470+
),
471+
# Single coordinate and band.
472+
(
473+
1,
474+
[math.pi],
475+
False,
476+
[[0.5]],
477+
[[math.sin(math.pi * 0.5), math.cos(math.pi * 0.5)]],
478+
),
479+
# include_input=True prepends the raw coordinate.
480+
(
481+
1,
482+
[1.0],
483+
True,
484+
[[0.5]],
485+
[[0.5, math.sin(0.5), math.cos(0.5)]],
486+
),
487+
],
488+
)
489+
def test_fourier_positional_embedding_forward_values(
490+
device, in_dim, freqs, include_input, x, expected
491+
):
492+
# Known-reference forward values across configs (layout, single band,
493+
# and include_input prepend).
490494
emb = FourierPositionalEmbedding(
491-
in_dim=2, num_bands=2, include_input=False, freqs=torch.tensor([1.0, 2.0])
495+
in_dim=in_dim, freqs=torch.tensor(freqs), include_input=include_input
492496
).to(device)
493-
x = torch.tensor([[0.3, 0.7]], device=device)
494-
out = emb(x)
495-
expected = torch.tensor(
496-
[
497-
[
498-
math.sin(1.0 * 0.3),
499-
math.sin(2.0 * 0.3),
500-
math.cos(1.0 * 0.3),
501-
math.cos(2.0 * 0.3),
502-
math.sin(1.0 * 0.7),
503-
math.sin(2.0 * 0.7),
504-
math.cos(1.0 * 0.7),
505-
math.cos(2.0 * 0.7),
506-
]
507-
],
508-
device=device,
509-
)
510-
torch.testing.assert_close(out, expected)
497+
out = emb(torch.tensor(x, device=device))
498+
torch.testing.assert_close(out, torch.tensor(expected, device=device))
511499

512500

513501
def test_fourier_positional_embedding_validation(device):
@@ -518,6 +506,9 @@ def test_fourier_positional_embedding_validation(device):
518506
FourierPositionalEmbedding(in_dim=0)
519507
with pytest.raises(ValueError):
520508
FourierPositionalEmbedding(in_dim=3, num_bands=0)
509+
# Explicit freqs must be 1-D of shape (F,).
510+
with pytest.raises(ValueError):
511+
FourierPositionalEmbedding(in_dim=3, freqs=torch.ones(2, 3))
521512

522513

523514
def test_fourier_positional_embedding_state_dict_roundtrip(device):
@@ -540,12 +531,13 @@ def test_fourier_positional_embedding_forward_accuracy(device):
540531
# MOD-008b: compare the forward output against committed reference data.
541532
model = FourierPositionalEmbedding(in_dim=3, num_bands=4).to(device)
542533
model.eval()
543-
# Deterministic, reproducible input (the layer has no random parameters).
544-
x = torch.linspace(-1.0, 1.0, steps=24, device=device).reshape(8, 3)
534+
# Deterministic, reproducible input; a 3-D shape also exercises arbitrary
535+
# leading (batch) dimensions against the reference.
536+
x = torch.linspace(-1.0, 1.0, steps=2 * 4 * 3, device=device).reshape(2, 4, 3)
545537
assert validate_forward_accuracy(
546538
model,
547539
(x,),
548-
file_name="nn/module/data/fourier_positional_embedding_in3_nb4_bs8.pth",
540+
file_name="nn/module/data/fourier_positional_embedding_in3_nb4_b2x4.pth",
549541
rtol=1e-4,
550542
atol=1e-4,
551543
)

0 commit comments

Comments
 (0)