-
Notifications
You must be signed in to change notification settings - Fork 377
Expand file tree
/
Copy pathstate.py
More file actions
104 lines (87 loc) · 4.17 KB
/
state.py
File metadata and controls
104 lines (87 loc) · 4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# ============================================================================ #
# Copyright (c) 2022 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #
from cudaq.mlir._mlir_libs._quakeDialects import cudaq_runtime
from cudaq.kernel.kernel_decorator import (mk_decorator, isa_kernel_decorator)
from cudaq.handlers import get_target_handler
def get_state(kernel, *args):
"""
Return the :class:`State` of the system after execution of the provided
`kernel`.
Args:
kernel (:class:`Kernel`): The :class:`Kernel` to execute on the QPU.
*arguments (Optional[Any]): The concrete values to evaluate the kernel
function at. Leave empty if the kernel doesn't accept any arguments.
# Example:
`import numpy as np`
# Define a kernel that will produced the all `|11...1>` state.
`qubits = kernel.qalloc(3)`
# Prepare qubits in the 1-state.
kernel.x(qubits)
Get the state of the system. This will execute the provided kernel
# and, depending on the selected target, will return the state as a
# vector or matrix.
state = cudaq.get_state(kernel)
print(state)
"""
handler = get_target_handler()
if handler.skip_compilation():
return cudaq_runtime.get_state_library_mode(kernel, *args)
if isa_kernel_decorator(kernel):
decorator = kernel
else:
decorator = mk_decorator(kernel)
processedArgs, module = decorator.prepare_call(*args)
return cudaq_runtime.get_state_impl(decorator.uniqName, module,
*processedArgs)
def get_state_async(kernel, *args, qpu_id=0):
"""
Asynchronously retrieve the state generated by the given quantum kernel.
When targeting a quantum platform with more than one QPU, the optional
`qpu_id` allows for control over which QPU to enable. Will return a future
whose results can be retrieved via `future.get()`.
Args:
kernel (:class:`Kernel`): The :class:`Kernel` to execute on the QPU.
*arguments (Optional[Any]): The concrete values to evaluate the kernel
function at. Leave empty if the kernel doesn't accept any arguments.
`qpu_id` (Optional[int]): The optional identification for which QPU
on the platform to target. Defaults to zero. Key-word only.
Returns:
:class:`AsyncStateResult`: Quantum state data. (state vector or density
matrix)
"""
if isa_kernel_decorator(kernel):
decorator = kernel
else:
decorator = mk_decorator(kernel)
processedArgs, module = decorator.prepare_call(*args)
return cudaq_runtime.get_state_async_impl(decorator.uniqName, module,
qpu_id, *processedArgs)
def to_cupy(state, dtype=None):
"""
A CUDA Quantum state is composed of a list of tensors (e.g. state-vector
state is composed of a single rank-1 tensor). Map all tensors
"""
try:
import cupy as cp
except ImportError:
print('to_cupy not supported, CuPy not available. Please install CuPy.')
if dtype is None:
# Determine the correct data type based on the cudaq target's precision
target = cudaq_runtime.get_target()
precision = target.get_precision()
dtype = cp.complex128 if precision == cudaq_runtime.SimulationPrecision.fp64 else cp.complex64
if not state.is_on_gpu():
raise RuntimeError(
"cudaq.to_cupy invoked but the state is not on the GPU.")
arrays = []
for tensor in state.getTensors():
total_bytes = tensor.get_num_elements() * tensor.get_element_size()
mem = cp.cuda.UnownedMemory(tensor.data(), total_bytes, owner=None)
memptr = cp.cuda.MemoryPointer(mem, offset=0)
arrays.append(cp.ndarray(tensor.extents, dtype=dtype, memptr=memptr))
return arrays