diff --git a/tests/unit/autojac/_transform/test_tensor_dict.py b/tests/unit/autojac/_transform/test_tensor_dict.py index 10d49d3a1..70707075c 100644 --- a/tests/unit/autojac/_transform/test_tensor_dict.py +++ b/tests/unit/autojac/_transform/test_tensor_dict.py @@ -110,3 +110,20 @@ def _assert_class_checks_properly( def _make_tensor_dict(value_shapes: list[list[int]]) -> dict[Tensor, Tensor]: return {torch.zeros(key): torch.zeros(value) for key, value in zip(_key_shapes, value_shapes)} + + +def test_immutability(): + """Tests that it's impossible to modify an existing TensorDict.""" + + t = Gradients({}) + with raises(TypeError): + t[torch.ones(1)] = torch.ones(1) + + assert t == Gradients({}) + + +def test_empty_tensor_dict(): + """Tests that it's impossible to instantiate a non-empty EmptyTensorDict.""" + + with raises(ValueError): + _ = EmptyTensorDict({torch.ones(1): torch.ones(1)})