Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 214 additions & 10 deletions frontend/catalyst/from_plxpr/qref_jax_primitives.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2026 Xanadu Quantum Technologies Inc.

Check notice on line 1 in frontend/catalyst/from_plxpr/qref_jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr/qref_jax_primitives.py#L1

Too many lines in module (1102/1000) (too-many-lines)

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -811,6 +811,14 @@
return []


def _is_custom_op(op_cls, params):
if op_cls.static_argnames or op_cls.hybrid_argnames or op_cls.compilable_argnames:
return False
if op_cls.wire_argnames != ("wires",):
return False
return all(p.shape == () and "float" in p.dtype.name for p in params)


def _qref_operator_p_lowering(
jax_ctx: mlir.LoweringRuleContext,
*args,
Expand All @@ -820,6 +828,18 @@
wire_lens,
**static_data,
):
ctx = jax_ctx.module_context.context
ctx.allow_unregistered_dialects = True
if op_cls.__name__ in _SPECIAL_LOWERINGS:
return _SPECIAL_LOWERINGS[op_cls.__name__](
jax_ctx,
*args,
op_cls=op_cls,
hybrid_lens=hybrid_lens,
hybrid_trees=hybrid_trees,
wire_lens=wire_lens,
**static_data,
)
params = args[: len(op_cls.dynamic_argnames)]
qubits = args[len(op_cls.dynamic_argnames) :]

Expand All @@ -841,24 +861,208 @@

processed_qubit_map = get_mlir_attribute_from_pyval(qubit_map)

OperatorOp(
op_name=name_attr,
params=params,
ctrl_qubits = []
ctrl_values = []
adjoint = False

if _is_custom_op(op_cls, params):
params = [extract_scalar(safe_cast_to_f64(p, op_cls), op_cls) for p in params]
CustomOp(
params=params,
qubits=qubits,
gate_name=name_attr,
ctrl_qubits=ctrl_qubits,
ctrl_values=ctrl_values,
adjoint=adjoint,
)
else:
OperatorOp(
op_name=name_attr,
params=params,
qubits=qubits,
qreg=None,
forward_args=[],
ctrl_qubits=[],
ctrl_values=[],
adjoint=False,
UID=None,
arr_qubit_indices=[],
param_map=processed_param_map,
static_data=processed_static_data,
qubit_map=processed_qubit_map,
)
return []


_SPECIAL_LOWERINGS = {}


def _register_special_lowering(op_name):
def decorator(f):
_SPECIAL_LOWERINGS[op_name] = f
return f

return decorator


@_register_special_lowering("MultiRZ")
def _multirz_lowering(
jax_ctx: mlir.LoweringRuleContext,
*args,
op_cls,
hybrid_lens,
hybrid_trees,
wire_lens
):
theta = (extract_scalar(safe_cast_to_f64(args[0], "MultiRZ"), "MultiRZ"),)
qubits = args[1:]
MultiRZOp(
theta=theta,
qubits=qubits,
qreg=None,
forward_args=[],
ctrl_qubits=[],
ctrl_values=[],
adjoint=False,
UID=None,
arr_qubit_indices=[],
param_map=processed_param_map,
static_data=processed_static_data,
qubit_map=processed_qubit_map,
)
return []


@_register_special_lowering("PCPhase")
def _pcphase_lowering(
jax_ctx: mlir.LoweringRuleContext,
*args,
op_cls,
hybrid_lens,
hybrid_trees,
wire_lens,
):
qubits = args[2:]
PCPhaseOp(
theta=extract_scalar(safe_cast_to_f64(args[0], "PCPhase"), "PCPhase"),
dim=extract_scalar(safe_cast_to_f64(args[0], "PCPhase"), "PCPhase"),
qubits=qubits,
ctrl_qubits=[],
ctrl_values=[],
adjoint=False,
)
return ()


@_register_special_lowering("GlobalPhase")
def _special_gphase_lowering(
jax_ctx: mlir.LoweringRuleContext,
*args,
op_cls,
hybrid_lens,
hybrid_trees,
wire_lens,
):
GlobalPhaseOp(
angle=extract_scalar(safe_cast_to_f64(args[0], "GlobalPhase"), "GlobalPhase"),
ctrl_qubits=[],
ctrl_values=[],
adjoint=False,
)
return ()


@_register_special_lowering("QubitUnitary")
def _special_unitary_lowering(
jax_ctx: mlir.LoweringRuleContext,
matrix,
*qubits,
op_cls,
hybrid_lens,
hybrid_trees,
wire_lens,
):
ctrl_qubits = []
ctrl_values = []

for q in qubits:
assert ir.OpaqueType.isinstance(q.type)
assert ir.OpaqueType(q.type).dialect_namespace == "qref"
assert ir.OpaqueType(q.type).data == "bit"

matrix_type = matrix.type
is_tensor = ir.RankedTensorType.isinstance(matrix_type)
shape = ir.RankedTensorType(matrix_type).shape if is_tensor else None
is_2d_tensor = len(shape) == 2 if is_tensor else False
if not is_2d_tensor:
raise TypeError("QubitUnitary must be a 2 dimensional tensor.")

possibly_complex_type = ir.RankedTensorType(matrix_type).element_type
is_complex = ir.ComplexType.isinstance(possibly_complex_type)
is_f64_type = False

if is_complex:
complex_type = ir.ComplexType(possibly_complex_type)
possibly_f64_type = complex_type.element_type
is_f64_type = ir.F64Type.isinstance(possibly_f64_type)

is_complex_f64_type = is_complex and is_f64_type
if not is_complex_f64_type:
f64_type = ir.F64Type.get()
complex_f64_type = ir.ComplexType.get(f64_type)
tensor_complex_f64_type = ir.RankedTensorType.get(shape, complex_f64_type)
matrix = StableHLOConvertOp(tensor_complex_f64_type, matrix).result

ctrl_values_i1 = [
TensorExtractOp(ir.IntegerType.get_signless(1), v, []).result for v in ctrl_values
]

QubitUnitaryOp(
matrix=matrix,
qubits=qubits,
ctrl_qubits=ctrl_qubits,
ctrl_values=ctrl_values_i1,
adjoint=False,
)

return ()


@_register_special_lowering("PauliRot")
def _special_paulirot_lowering(
jax_ctx: mlir.LoweringRuleContext,
angle,
*qubits,
op_cls,
hybrid_lens,
hybrid_trees,
wire_lens,
pauli_word,
):
pauli_word = unflatten(*pauli_word)
ctrl_qubits = []
ctrl_values = []

for q in qubits:
assert ir.OpaqueType.isinstance(q.type)
assert ir.OpaqueType(q.type).dialect_namespace == "qref"
assert ir.OpaqueType(q.type).data == "bit"

angle = safe_cast_to_f64(angle, "PauliRot")
angle = extract_scalar(angle, "PauliRot")
assert ir.F64Type.isinstance(angle.type)

pauli_word = ir.ArrayAttr.get([ir.StringAttr.get(p) for p in pauli_word])

ctrl_values_i1 = [
TensorExtractOp(ir.IntegerType.get_signless(1), v, []).result for v in ctrl_values
]

PauliRotOp(
angle=angle,
pauli_product=pauli_word,
qubits=qubits,
ctrl_qubits=ctrl_qubits,
ctrl_values=ctrl_values_i1,
adjoint=False,
)

return ()


#
# hermitian observable
#
Expand Down
Loading