Skip to content

Commit ce4c8b8

Browse files
committed
fixes #2530
1 parent ee515ab commit ce4c8b8

2 files changed

Lines changed: 42 additions & 7 deletions

File tree

coremltools/converters/mil/mil/operation.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -302,15 +302,19 @@ def type_value_inference(self, overwrite_output=False):
302302
raise ValueError(msg)
303303
else:
304304
if types.is_tensor(sym_type) and types.is_complex(sym_type.T[0]):
305-
# Only `complex` op needs to maintain the real/imag data in the ComplexVar.
305+
# Only `complex` and ``const ops need to maintain the real/imag data in the ComplexVar.
306306
# For other ops, this ComplexVar is just a placeholder here, which will be
307307
# replaced by a newly created ComplexVar during complex ops lowering pass.
308-
real_data = (
309-
self.real_data if self.op_type == "complex" else None
310-
)
311-
imag_data = (
312-
self.imag_data if self.op_type == "complex" else None
313-
)
308+
if self.op_type == "complex":
309+
real_data = self.real_data
310+
imag_data = self.imag_data
311+
elif self.op_type == "const":
312+
real_data = np.real(sym_val.val)
313+
imag_data = np.imag(sym_val.val)
314+
else:
315+
real_data = None
316+
imag_data = None
317+
314318
new_var = ComplexVar(
315319
name,
316320
sym_type,

coremltools/converters/mil/mil/tests/test_programs.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from coremltools.converters.mil.mil import Function, Program, types
1414
from coremltools.converters.mil.mil.passes.tests.test_passes import CONSTEXPR_FUNCS
1515
from coremltools.converters.mil.mil.scope import ScopeInfo, ScopeSource, add_graph_pass_scope
16+
from coremltools.converters.mil.mil.var import ComplexVar
1617

1718
np.random.seed(0)
1819

@@ -427,6 +428,36 @@ def test_type_domain_validation():
427428
def prog(x):
428429
res = mb.rsqrt(x=x, epsilon=1)
429430
return res
431+
432+
@staticmethod
433+
def test_const_complex_var_initialization():
434+
"""
435+
Test that a const operation with a complex value correctly initializes
436+
the real and imag parts of the output ComplexVar.
437+
"""
438+
complex_val_np = np.array([[1+2j, 3-4j], [-5+6j, 7+8j]], dtype=np.complex64)
439+
@mb.program(input_specs=[])
440+
def prog():
441+
complex_const_var = mb.const(val=complex_val_np, name="my_complex_const")
442+
return complex_const_var
443+
444+
main_func = prog.functions["main"]
445+
output_var = main_func.outputs[0]
446+
assert isinstance(output_var, ComplexVar), \
447+
f"Output var should be ComplexVar, got {type(output_var)}"
448+
assert output_var.op.op_type == "const", \
449+
f"Expected op_type const, got {output_var.op.op_type}"
450+
expected_real_part = np.real(complex_val_np)
451+
expected_imag_part = np.imag(complex_val_np)
452+
assert output_var.real is not None, "ComplexVar.real should not be None"
453+
assert output_var.imag is not None, "ComplexVar.imag should not be None"
454+
np.testing.assert_array_equal(output_var.real, expected_real_part,
455+
err_msg="Real part of ComplexVar does not match expected.")
456+
np.testing.assert_array_equal(output_var.imag, expected_imag_part,
457+
err_msg="Imaginary part of ComplexVar does not match expected.")
458+
const_op = output_var.op
459+
np.testing.assert_array_equal(const_op.val.val, complex_val_np,
460+
err_msg="Value of const op does not match original complex numpy array.")
430461

431462
@staticmethod
432463
def test_get_dialect_namespaces():

0 commit comments

Comments
 (0)