diff --git a/src/braket/parametric/free_parameter_expression.py b/src/braket/parametric/free_parameter_expression.py index 53fc13305..5c53796a3 100644 --- a/src/braket/parametric/free_parameter_expression.py +++ b/src/braket/parametric/free_parameter_expression.py @@ -54,6 +54,7 @@ def __init__(self, expression: FreeParameterExpression | Number | sympy.Expr | s ast.Add: self.__add__, ast.Sub: self.__sub__, ast.Mult: self.__mul__, + ast.Div: self.__truediv__, ast.Pow: self.__pow__, ast.USub: self.__neg__, } diff --git a/test/unit_tests/braket/parametric/test_free_parameter_expression.py b/test/unit_tests/braket/parametric/test_free_parameter_expression.py index da0ed39d8..b9ffa44aa 100644 --- a/test/unit_tests/braket/parametric/test_free_parameter_expression.py +++ b/test/unit_tests/braket/parametric/test_free_parameter_expression.py @@ -52,9 +52,16 @@ def test_equality_str(): assert hasattr(expr_1.expression, "free_symbols") and hasattr(expr_2.expression, "free_symbols") +def test_truediv_str(): + FreeParameterExpression("theta/1") + expr_1 = FreeParameterExpression("theta/alpha") + expr_2 = FreeParameterExpression(FreeParameter("theta") / FreeParameter("alpha")) + assert expr_1 == expr_2 + + @pytest.mark.xfail(raises=ValueError) def test_unsupported_bin_op_str(): - FreeParameterExpression("theta/1") + FreeParameterExpression("theta//1") @pytest.mark.xfail(raises=ValueError)