1+ import pytest
2+ from sympy import Symbol , sqrt , sin as sympy_sin
3+
4+ from ..utility .expression_utilities import (
5+ compute_relative_tolerance_from_significant_decimals ,
6+ convert_absolute_notation ,
7+ convert_unicode_dashes ,
8+ create_expression_set ,
9+ extract_latex ,
10+ find_matching_parenthesis ,
11+ latex_symbols ,
12+ preprocess_expression ,
13+ protect_elementary_functions_substitutions ,
14+ substitute ,
15+ substitute_input_symbols ,
16+ substitutions_sort_key ,
17+ sympy_symbols ,
18+ sympy_to_latex ,
19+ transform_unicode_greek_symbols ,
20+ )
21+
22+
23+ class TestConvertUnicodeDashes :
24+
25+ @pytest .mark .parametrize (
26+ "expr, expected" ,
27+ [
28+ # No dashes at all
29+ ("x+y" , []),
30+ # ASCII hyphen-minus is not a unicode dash — no substitution
31+ ("x-y" , []),
32+ # Empty string
33+ ("" , []),
34+ # HYPHEN (U+2010)
35+ ("x‐y" , [("‐" , "-" )]),
36+ # NON-BREAKING HYPHEN (U+2011)
37+ ("x‑y" , [("‑" , "-" )]),
38+ # FIGURE DASH (U+2012)
39+ ("x‒y" , [("‒" , "-" )]),
40+ # EN DASH (U+2013)
41+ ("x–y" , [("–" , "-" )]),
42+ # EM DASH (U+2014)
43+ ("x—y" , [("—" , "-" )]),
44+ # MINUS SIGN (U+2212)
45+ ("x−y" , [("−" , "-" )]),
46+ # SMALL HYPHEN-MINUS (U+FE63)
47+ ("x﹣y" , [("﹣" , "-" )]),
48+ # FULLWIDTH HYPHEN-MINUS (U+FF0D)
49+ ("x-y" , [("-" , "-" )]),
50+ # Multiple different unicode dashes in one expression
51+ ("x–y−z" , [("–" , "-" ), ("−" , "-" )]),
52+ # Repeated occurrences of the same dash — still one substitution pair
53+ ("x−y−z" , [("−" , "-" )]),
54+ ]
55+ )
56+ def test_convert_unicode_dashes (self , expr , expected ):
57+ result = convert_unicode_dashes (expr )
58+ assert result == expected
59+
60+
61+ class TestConvertAbsoluteNotation :
62+
63+ @pytest .mark .parametrize (
64+ "expr, expected_expr, has_feedback" ,
65+ [
66+ # No pipes — unchanged
67+ ("x+y" , "x+y" , False ),
68+ # Exactly two pipes — simple conversion
69+ ("|x|" , "Abs(x)" , False ),
70+ ("|x+y|" , "Abs(x+y)" , False ),
71+ # Two non-adjacent absolute values
72+ ("|x|+|y|" , "Abs(x)+Abs(y)" , False ),
73+ ]
74+ )
75+ def test_convert_absolute_notation (self , expr , expected_expr , has_feedback ):
76+ result_expr , feedback = convert_absolute_notation (expr , "response" )
77+ assert result_expr == expected_expr
78+ assert (feedback is not None ) == has_feedback
79+
80+ def test_ambiguous_pipes_produce_feedback (self ):
81+ # More than 2 pipes with ambiguous positions produces feedback
82+ expr , feedback = convert_absolute_notation ("|x|y|z|" , "response" )
83+ assert feedback is not None
84+ assert feedback [0 ] == "ABSOLUTE_VALUE_NOTATION_AMBIGUITY"
85+
86+
87+ class TestTransformUnicodeGreekSymbols :
88+
89+ def test_no_greek_symbols_returns_empty (self ):
90+ assert transform_unicode_greek_symbols ("x + y" ) == []
91+
92+ def test_named_greek_symbol_returns_self_substitution (self ):
93+ result = transform_unicode_greek_symbols ("alpha + 1" )
94+ assert ("alpha" , " alpha " ) in result
95+
96+ def test_unicode_greek_alias_maps_to_name (self ):
97+ # α is an alias for "alpha"
98+ result = transform_unicode_greek_symbols ("α" )
99+ assert ("α" , " alpha " ) in result
100+
101+ def test_multiple_greek_symbols (self ):
102+ result = transform_unicode_greek_symbols ("alpha + beta" )
103+ assert ("alpha" , " alpha " ) in result
104+ assert ("beta" , " beta " ) in result
105+
106+ def test_unicode_beta_alias (self ):
107+ result = transform_unicode_greek_symbols ("β" )
108+ assert ("β" , " beta " ) in result
109+
110+
111+ class TestProtectElementaryFunctionsSubstitutions :
112+
113+ def test_no_functions_returns_empty (self ):
114+ assert protect_elementary_functions_substitutions ("x + y" ) == []
115+
116+ def test_sin_generates_self_substitution (self ):
117+ result = protect_elementary_functions_substitutions ("sin(x)" )
118+ assert ("sin" , " sin " ) in result
119+
120+ def test_alias_maps_to_canonical_name (self ):
121+ # arctan is an alias for atan
122+ result = protect_elementary_functions_substitutions ("arctan(x)" )
123+ assert ("arctan" , " atan " ) in result
124+
125+ def test_multiple_functions (self ):
126+ result = protect_elementary_functions_substitutions ("sin(x) + cos(x)" )
127+ assert ("sin" , " sin " ) in result
128+ assert ("cos" , " cos " ) in result
129+
130+
131+ class TestSubstituteInputSymbols :
132+
133+ def test_plain_expression_unchanged (self ):
134+ result = substitute_input_symbols ("x+y" , {})
135+ assert result == ["x+y" ]
136+
137+ def test_lambda_replaced_with_lamda (self ):
138+ result = substitute_input_symbols ("lambda" , {})
139+ assert result == ["lamda" ]
140+
141+ def test_alias_replaced_with_symbol_code (self ):
142+ params = {"symbols" : {"x" : {"latex" : r"\(x\)" , "aliases" : ["x_var" ]}}}
143+ result = substitute_input_symbols (["x_var" ], params )
144+ assert result == ["x" ]
145+
146+ def test_symbol_code_preserved (self ):
147+ params = {"symbols" : {"x" : {"latex" : r"\(x\)" , "aliases" : ["x_var" ]}}}
148+ result = substitute_input_symbols (["x" ], params )
149+ assert result == ["x" ]
150+
151+ def test_list_input_accepted (self ):
152+ result = substitute_input_symbols (["x" , "y" ], {})
153+ assert result == ["x" , "y" ]
154+
155+
156+ class TestFindMatchingParenthesis :
157+
158+ @pytest .mark .parametrize (
159+ "string, index, delimiters, expected" ,
160+ [
161+ # Simple pair
162+ ("(x)" , 0 , None , 2 ),
163+ # Longer content
164+ ("(x+y)" , 0 , None , 4 ),
165+ # Nested — outer pair
166+ ("((x+y))" , 0 , None , 6 ),
167+ # Nested — inner pair
168+ ("((x+y))" , 1 , None , 5 ),
169+ # No closing delimiter
170+ ("(x" , 0 , None , - 1 ),
171+ # Custom square-bracket delimiters
172+ ("[x+y]" , 0 , ("[" , "]" ), 4 ),
173+ # Starts mid-string
174+ ("a(b+c)d" , 1 , None , 5 ),
175+ ]
176+ )
177+ def test_find_matching_parenthesis (self , string , index , delimiters , expected ):
178+ if delimiters is None :
179+ result = find_matching_parenthesis (string , index )
180+ else :
181+ result = find_matching_parenthesis (string , index , delimiters )
182+ assert result == expected
183+
184+
185+ class TestSubstitute :
186+
187+ @pytest .mark .parametrize (
188+ "string, substitutions, expected" ,
189+ [
190+ # Empty substitutions — unchanged
191+ ("hello" , [], "hello" ),
192+ # Single substitution
193+ ("hello world" , [("world" , "earth" )], "hello earth" ),
194+ # Whole string replaced
195+ ("abc" , [("abc" , "xyz" )], "xyz" ),
196+ # No match — unchanged
197+ ("hello" , [("world" , "earth" )], "hello" ),
198+ # Earlier substitution in list wins at the same position
199+ ("abc" , [("a" , "p" ), ("ab" , "q" )], "pbc" ),
200+ # Longer substitution listed first wins
201+ ("abc" , [("ab" , "q" ), ("a" , "p" )], "qc" ),
202+ # Multiple non-overlapping substitutions
203+ ("a b c" , [("a" , "x" ), ("b" , "y" ), ("c" , "z" )], "x y z" ),
204+ # List input is joined into a single result
205+ (["hello" , " " , "world" ], [("world" , "earth" )], "hello earth" ),
206+ ]
207+ )
208+ def test_substitute (self , string , substitutions , expected ):
209+ assert substitute (string , substitutions ) == expected
210+
211+ def test_lookahead_tuple_matches_with_following_context (self ):
212+ # (("sin", ["("]), " sin ") matches "sin(" but not "sin " or "sinx"
213+ subs = [(("sin" , ["(" ]), " sin " )]
214+ assert substitute ("sin(x)" , subs ) == " sin (x)"
215+
216+ def test_lookahead_tuple_does_not_match_without_context (self ):
217+ subs = [(("sin" , ["(" ]), " sin " )]
218+ assert substitute ("sinx" , subs ) == "sinx"
219+
220+
221+ class TestComputeRelativeTolerance :
222+
223+ @pytest .mark .parametrize (
224+ "string, expected" ,
225+ [
226+ # Non-numeric → 0
227+ ("not_a_number" , 0 ),
228+ # 1 sig char, below DEFAULT_SIGNIFICANT_FIGURES=2 floor → 5e-2
229+ ("1" , 5e-2 ),
230+ # "0.5" → chars "05" → lstrip "5" → len 1 → floor applies → 5e-2
231+ ("0.5" , 5e-2 ),
232+ # "1.5" → chars "15" → len 2 → max(2,2)=2 → 5e-2
233+ ("1.5" , 5e-2 ),
234+ # "100" → chars "100" (lstrip removes nothing, '1' stops it) → len 3 → 5e-3
235+ ("100" , 5e-3 ),
236+ # "1.23" → chars "123" (decimal removed) → len 3 → 5e-3
237+ ("1.23" , 5e-3 ),
238+ # Scientific notation: mantissa "1.23" → len 3 → 5e-3
239+ ("1.23e5" , 5e-3 ),
240+ # Negative: "-1.23" → lstrip removes "-" → "123" → len 3 → 5e-3
241+ ("-1.23" , 5e-3 ),
242+ ]
243+ )
244+ def test_relative_tolerance (self , string , expected ):
245+ result = compute_relative_tolerance_from_significant_decimals (string )
246+ assert result == pytest .approx (expected )
247+
248+
249+ class TestSympySymbols :
250+
251+ def test_returns_symbol_objects (self ):
252+ result = sympy_symbols ({"x" : {}, "y" : {}})
253+ assert result == {"x" : Symbol ("x" ), "y" : Symbol ("y" )}
254+
255+ def test_empty_dict (self ):
256+ assert sympy_symbols ({}) == {}
257+
258+ def test_symbol_names_preserved (self ):
259+ result = sympy_symbols ({"alpha" : {}})
260+ assert result ["alpha" ] == Symbol ("alpha" )
261+
262+
263+ class TestExtractLatex :
264+
265+ @pytest .mark .parametrize (
266+ "symbol, expected" ,
267+ [
268+ # LaTeX delimiters removed
269+ (r"\(x^2\)" , "x^2" ),
270+ ("$x^2$" , "x^2" ),
271+ ("$$x^2$$" , "x^2" ),
272+ (r"\(\alpha\)" , r"\alpha" ),
273+ # No delimiters — returned as-is
274+ ("x^2" , "x^2" ),
275+ ("plain" , "plain" ),
276+ ]
277+ )
278+ def test_extract_latex (self , symbol , expected ):
279+ assert extract_latex (symbol ) == expected
280+
281+
282+ class TestLatexSymbols :
283+
284+ def test_maps_symbol_to_latex_string (self ):
285+ syms = {"x" : {"latex" : r"\(x\)" , "aliases" : []}}
286+ result = latex_symbols (syms )
287+ assert result == {Symbol ("x" ): "x" }
288+
289+ def test_greek_latex_preserved (self ):
290+ syms = {"alpha" : {"latex" : r"\(\alpha\)" , "aliases" : []}}
291+ result = latex_symbols (syms )
292+ assert result == {Symbol ("alpha" ): r"\alpha" }
293+
294+ def test_empty_dict (self ):
295+ assert latex_symbols ({}) == {}
296+
297+
298+ class TestSympyToLatex :
299+
300+ def test_simple_power (self ):
301+ expr = Symbol ("x" ) ** 2
302+ syms = {"x" : {"latex" : r"\(x\)" , "aliases" : []}}
303+ result = sympy_to_latex (expr , syms )
304+ assert result == "x^{2}"
305+
306+ def test_custom_latex_name_used (self ):
307+ expr = Symbol ("alpha" )
308+ syms = {"alpha" : {"latex" : r"\(\alpha\)" , "aliases" : []}}
309+ result = sympy_to_latex (expr , syms )
310+ assert result == r"\alpha"
311+
312+ def test_sqrt (self ):
313+ expr = sqrt (Symbol ("x" ))
314+ syms = {"x" : {"latex" : r"\(x\)" , "aliases" : []}}
315+ result = sympy_to_latex (expr , syms )
316+ assert result == r"\sqrt{x}"
317+
318+
319+ class TestSubstitutionsSortKey :
320+
321+ def test_longer_left_element_sorts_first (self ):
322+ long_sub = ("abc" , "p" )
323+ short_sub = ("ab" , "p" )
324+ assert substitutions_sort_key (long_sub ) < substitutions_sort_key (short_sub )
325+
326+ def test_equal_left_length_longer_right_sorts_first (self ):
327+ long_right = ("ab" , "pqr" )
328+ short_right = ("ab" , "p" )
329+ assert substitutions_sort_key (long_right ) < substitutions_sort_key (short_right )
330+
331+ def test_sort_orders_longer_substitutions_first (self ):
332+ subs = [("a" , "x" ), ("abc" , "y" ), ("ab" , "z" )]
333+ subs .sort (key = substitutions_sort_key )
334+ assert subs [0 ][0 ] == "abc"
335+ assert subs [1 ][0 ] == "ab"
336+ assert subs [2 ][0 ] == "a"
337+
338+
339+ class TestCreateExpressionSet :
340+
341+ def test_plain_string_wrapped_in_list (self ):
342+ result = create_expression_set ("x+y" , {})
343+ assert result == ["x+y" ]
344+
345+ def test_set_notation_split_into_list (self ):
346+ result = create_expression_set ("{x, y}" , {})
347+ assert sorted (result ) == ["x" , "y" ]
348+
349+ def test_list_input_accepted (self ):
350+ result = create_expression_set (["x" , "y" ], {})
351+ assert sorted (result ) == ["x" , "y" ]
352+
353+ def test_plus_minus_expands_to_two_expressions (self ):
354+ params = {"plus_minus" : "±" }
355+ result = create_expression_set ("±x" , params )
356+ assert sorted (result ) == sorted (["+x" , "-x" ]) or sorted (result ) == sorted (["x" , "-x" ])
357+ assert len (result ) == 2
358+
359+
360+ class TestPreprocessExpression :
361+
362+ def test_plain_expression_succeeds (self ):
363+ success , expr , feedback = preprocess_expression ("response" , "x+y" , {})
364+ assert success is True
365+ assert expr == "x+y"
366+ assert feedback is None
367+
368+ def test_absolute_value_notation_converted (self ):
369+ success , expr , feedback = preprocess_expression ("response" , "|x|" , {})
370+ assert success is True
371+ assert expr == "Abs(x)"
372+ assert feedback is None
373+
374+ def test_ambiguous_pipes_returns_failure (self ):
375+ success , expr , feedback = preprocess_expression ("response" , "|x|y|z|" , {})
376+ assert success is False
377+ assert feedback is not None
378+ assert feedback [0 ] == "ABSOLUTE_VALUE_NOTATION_AMBIGUITY"
0 commit comments