Skip to content

Commit dd2325d

Browse files
authored
test(autojac): Add missing TensorDict tests (#271)
* Add test_immutability * Add test_empty_tensor_dict
1 parent 5706a93 commit dd2325d

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/unit/autojac/_transform/test_tensor_dict.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,20 @@ def _assert_class_checks_properly(
110110

111111
def _make_tensor_dict(value_shapes: list[list[int]]) -> dict[Tensor, Tensor]:
112112
return {torch.zeros(key): torch.zeros(value) for key, value in zip(_key_shapes, value_shapes)}
113+
114+
115+
def test_immutability():
116+
"""Tests that it's impossible to modify an existing TensorDict."""
117+
118+
t = Gradients({})
119+
with raises(TypeError):
120+
t[torch.ones(1)] = torch.ones(1)
121+
122+
assert t == Gradients({})
123+
124+
125+
def test_empty_tensor_dict():
126+
"""Tests that it's impossible to instantiate a non-empty EmptyTensorDict."""
127+
128+
with raises(ValueError):
129+
_ = EmptyTensorDict({torch.ones(1): torch.ones(1)})

0 commit comments

Comments
 (0)