Skip to content

add support for complex tensor types in getConstantAttr#2790

Open
mofeing wants to merge 19 commits intoEnzymeAD:mainfrom
mofeing:ss/getConstantAttr-complex
Open

add support for complex tensor types in getConstantAttr#2790
mofeing wants to merge 19 commits intoEnzymeAD:mainfrom
mofeing:ss/getConstantAttr-complex

Conversation

@mofeing
Copy link
Copy Markdown
Collaborator

@mofeing mofeing commented Apr 15, 2026

diff rules that use HLOConstantFP in Enzyme-JAX are broken for complex numbers, due to mlir::enzyme::getConstantAttr not being able to handle them.

for example, the diff rule of stablehlo.rsqrt is defined as the following:

def : HLODerivative<"RsqrtOp", (Op $x), [
    (CheckedDiv (DiffeRet), (Mul (HLOConstantFP<"-2"> $x), (Mul $x, (Sqrt $x))))
]>;

which works for real numbers. but for complex numbers, like the following example,

func.func @rsqrt_complex(%x : tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
  %y = stablehlo.rsqrt %x : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
  func.return %y : tensor<2xcomplex<f32>>
}

it errors:

mofeing@hydra:~/Enzyme-JAX$ ./bazel-bin/enzymexlamlir-opt --enzyme-wrap="infn=rsqrt_complex outfn= retTys=enzyme_dup argTys=enzyme_dup mode=ForwardMode" --enzyme-hlo-opt --cse test/lit_tests/diffrules/stablehlo/rsqrt.mlir 
 unsupported eltype: <<NULL TYPE>> of type tensor<2xcomplex<f32>>

we haven't checked this because we don't test complex numbers on diff rules that do not have a specific behaviour for them (i.e. SelectIfComplex).

this pr adds support for complex numbers in getConstantAttr. the current semantics is to create a ComplexAttr or a DenseElementsAttr with ComplexType eltype with the given value as the real part. not sure if we want to write a way to set the imaginary part.

@mofeing mofeing marked this pull request as draft April 15, 2026 15:21
@mofeing mofeing marked this pull request as ready for review April 15, 2026 15:48
Copy link
Copy Markdown
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a relevant test?

@mofeing mofeing force-pushed the ss/getConstantAttr-complex branch from c45359c to 8628eee Compare May 7, 2026 12:12
@mofeing mofeing changed the title add support for complex and complex tensor types to getConstantAttr add support for complex tensor types in getConstantAttr May 7, 2026
ArrayRef<APFloat>(values));
} else if (auto CET = dyn_cast<ComplexType>(T.getElementType())) {
auto ET = cast<FloatType>(CET.getElementType());
std::complex<APFloat> values[] = {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to use mlir::Complex not std::complex

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants