Skip to content

Commit ed171e6

Browse files
committed
Add test_immutability
1 parent 5706a93 commit ed171e6

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

tests/unit/autojac/_transform/test_tensor_dict.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,13 @@ def _assert_class_checks_properly(
110110

111111
def _make_tensor_dict(value_shapes: list[list[int]]) -> dict[Tensor, Tensor]:
112112
return {torch.zeros(key): torch.zeros(value) for key, value in zip(_key_shapes, value_shapes)}
113+
114+
115+
def test_immutability():
116+
"""Tests that it's impossible to modify an existing TensorDict."""
117+
118+
t = Gradients({})
119+
with raises(TypeError):
120+
t[torch.ones(1)] = torch.ones(1)
121+
122+
assert t == Gradients({})

0 commit comments

Comments
 (0)