From ed171e6b21a4ac88e0990704f48fca1296e43999 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 25 Mar 2025 16:42:16 +0100 Subject: [PATCH 1/2] Add test_immutability --- tests/unit/autojac/_transform/test_tensor_dict.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/unit/autojac/_transform/test_tensor_dict.py b/tests/unit/autojac/_transform/test_tensor_dict.py index 10d49d3a1..59f80ce0e 100644 --- a/tests/unit/autojac/_transform/test_tensor_dict.py +++ b/tests/unit/autojac/_transform/test_tensor_dict.py @@ -110,3 +110,13 @@ def _assert_class_checks_properly( def _make_tensor_dict(value_shapes: list[list[int]]) -> dict[Tensor, Tensor]: return {torch.zeros(key): torch.zeros(value) for key, value in zip(_key_shapes, value_shapes)} + + +def test_immutability(): + """Tests that it's impossible to modify an existing TensorDict.""" + + t = Gradients({}) + with raises(TypeError): + t[torch.ones(1)] = torch.ones(1) + + assert t == Gradients({}) From a7bd0c076bdb388ecee97af062c5d6bab0ce5bdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 25 Mar 2025 17:23:39 +0100 Subject: [PATCH 2/2] Add test_empty_tensor_dict --- tests/unit/autojac/_transform/test_tensor_dict.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/autojac/_transform/test_tensor_dict.py b/tests/unit/autojac/_transform/test_tensor_dict.py index 59f80ce0e..70707075c 100644 --- a/tests/unit/autojac/_transform/test_tensor_dict.py +++ b/tests/unit/autojac/_transform/test_tensor_dict.py @@ -120,3 +120,10 @@ def test_immutability(): t[torch.ones(1)] = torch.ones(1) assert t == Gradients({}) + + +def test_empty_tensor_dict(): + """Tests that it's impossible to instantiate a non-empty EmptyTensorDict.""" + + with raises(ValueError): + _ = EmptyTensorDict({torch.ones(1): torch.ones(1)})