We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 5706a93 commit ed171e6Copy full SHA for ed171e6
1 file changed
tests/unit/autojac/_transform/test_tensor_dict.py
@@ -110,3 +110,13 @@ def _assert_class_checks_properly(
110
111
def _make_tensor_dict(value_shapes: list[list[int]]) -> dict[Tensor, Tensor]:
112
return {torch.zeros(key): torch.zeros(value) for key, value in zip(_key_shapes, value_shapes)}
113
+
114
115
+def test_immutability():
116
+ """Tests that it's impossible to modify an existing TensorDict."""
117
118
+ t = Gradients({})
119
+ with raises(TypeError):
120
+ t[torch.ones(1)] = torch.ones(1)
121
122
+ assert t == Gradients({})
0 commit comments