|
13 | 13 | from coremltools.converters.mil.mil import Function, Program, types |
14 | 14 | from coremltools.converters.mil.mil.passes.tests.test_passes import CONSTEXPR_FUNCS |
15 | 15 | from coremltools.converters.mil.mil.scope import ScopeInfo, ScopeSource, add_graph_pass_scope |
| 16 | +from coremltools.converters.mil.mil.var import ComplexVar |
16 | 17 |
|
17 | 18 | np.random.seed(0) |
18 | 19 |
|
@@ -427,6 +428,36 @@ def test_type_domain_validation(): |
427 | 428 | def prog(x): |
428 | 429 | res = mb.rsqrt(x=x, epsilon=1) |
429 | 430 | return res |
| 431 | + |
| 432 | + @staticmethod |
| 433 | + def test_const_complex_var_initialization(): |
| 434 | + """ |
| 435 | + Test that a const operation with a complex value correctly initializes |
| 436 | + the real and imag parts of the output ComplexVar. |
| 437 | + """ |
| 438 | + complex_val_np = np.array([[1+2j, 3-4j], [-5+6j, 7+8j]], dtype=np.complex64) |
| 439 | + @mb.program(input_specs=[]) |
| 440 | + def prog(): |
| 441 | + complex_const_var = mb.const(val=complex_val_np, name="my_complex_const") |
| 442 | + return complex_const_var |
| 443 | + |
| 444 | + main_func = prog.functions["main"] |
| 445 | + output_var = main_func.outputs[0] |
| 446 | + assert isinstance(output_var, ComplexVar), \ |
| 447 | + f"Output var should be ComplexVar, got {type(output_var)}" |
| 448 | + assert output_var.op.op_type == "const", \ |
| 449 | + f"Expected op_type const, got {output_var.op.op_type}" |
| 450 | + expected_real_part = np.real(complex_val_np) |
| 451 | + expected_imag_part = np.imag(complex_val_np) |
| 452 | + assert output_var.real is not None, "ComplexVar.real should not be None" |
| 453 | + assert output_var.imag is not None, "ComplexVar.imag should not be None" |
| 454 | + np.testing.assert_array_equal(output_var.real, expected_real_part, |
| 455 | + err_msg="Real part of ComplexVar does not match expected.") |
| 456 | + np.testing.assert_array_equal(output_var.imag, expected_imag_part, |
| 457 | + err_msg="Imaginary part of ComplexVar does not match expected.") |
| 458 | + const_op = output_var.op |
| 459 | + np.testing.assert_array_equal(const_op.val.val, complex_val_np, |
| 460 | + err_msg="Value of const op does not match original complex numpy array.") |
430 | 461 |
|
431 | 462 | @staticmethod |
432 | 463 | def test_get_dialect_namespaces(): |
|
0 commit comments