Skip to content

Commit 65cac79

Browse files
committed
⏪ unrelated changes
1 parent 4418db3 commit 65cac79

6 files changed

Lines changed: 83 additions & 156 deletions

File tree

src/mqt/predictor/reward.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
if TYPE_CHECKING:
2525
from qiskit import QuantumCircuit
26+
from qiskit.circuit import QuantumRegister, Qubit
2627
from qiskit.transpiler import Target
2728
from sklearn.ensemble import RandomForestRegressor
2829

@@ -61,22 +62,44 @@ def expected_fidelity(qc: QuantumCircuit, device: Target, precision: int = 10) -
6162

6263
if gate_type != "barrier":
6364
assert len(qargs) in [1, 2]
64-
first_qubit_idx = qc.find_bit(qargs[0]).index
65+
first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0)
6566

6667
if len(qargs) == 1:
6768
specific_fidelity = 1 - device[gate_type][first_qubit_idx,].error
6869
else:
69-
second_qubit_idx = qc.find_bit(qargs[1]).index
70-
try:
70+
second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1)
7171
specific_fidelity = 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error
72-
except KeyError:
73-
msg = f"Error rate for gate {gate_type} on qubits {first_qubit_idx} and {second_qubit_idx} not found in device properties."
74-
raise KeyError(msg) from None
72+
7573
res *= specific_fidelity
7674

7775
return float(np.round(res, precision).item())
7876

7977

78+
def calc_qubit_index(qargs: list[Qubit], qregs: list[QuantumRegister], index: int) -> int:
79+
"""Calculates the global qubit index for a given quantum circuit and qubit index.
80+
81+
Arguments:
82+
qargs: The qubits of the quantum circuit.
83+
qregs: The quantum registers of the quantum circuit.
84+
index: The index of the qubit in the qargs list.
85+
86+
Returns:
87+
The global qubit index of the given qubit in the quantum circuit.
88+
89+
Raises:
90+
ValueError: If the qubit index is not found in the quantum registers.
91+
"""
92+
offset = 0
93+
for reg in qregs:
94+
if qargs[index] not in reg:
95+
offset += reg.size
96+
else:
97+
qubit_index: int = offset + reg.index(qargs[index])
98+
return qubit_index
99+
error_msg = f"Global qubit index for local qubit {index} index not found."
100+
raise ValueError(error_msg)
101+
102+
80103
def estimated_success_probability(qc: QuantumCircuit, device: Target, precision: int = 10) -> float:
81104
"""Calculates the estimated success probability of a given quantum circuit on a given device.
82105
@@ -102,7 +125,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
102125
if gate_type == "barrier" or gate_type == "id":
103126
continue
104127
assert len(qargs) in (1, 2)
105-
first_qubit_idx = qc.find_bit(qargs[0]).index
128+
first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0)
106129
active_qubits.add(first_qubit_idx)
107130

108131
if len(qargs) == 1: # single-qubit gate
@@ -117,7 +140,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
117140
))
118141
exec_time_per_qubit[first_qubit_idx] += duration
119142
else: # multi-qubit gate
120-
second_qubit_idx = qc.find_bit(qargs[1]).index
143+
second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1)
121144
active_qubits.add(second_qubit_idx)
122145
duration = device[gate_type][first_qubit_idx, second_qubit_idx].duration
123146
op_times.append((gate_type, [first_qubit_idx, second_qubit_idx], duration, "s"))
@@ -168,7 +191,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
168191
continue
169192

170193
assert len(qargs) in (1, 2)
171-
first_qubit_idx = scheduled_circ.find_bit(qargs[0]).index
194+
first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0)
172195

173196
if len(qargs) == 1:
174197
if gate_type == "measure":
@@ -190,7 +213,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
190213
continue
191214
res *= 1 - device[gate_type][first_qubit_idx,].error
192215
else:
193-
second_qubit_idx = scheduled_circ.find_bit(qargs[1]).index
216+
second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1)
194217
res *= 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error
195218

196219
if qiskit_version >= "2.0.0":

src/mqt/predictor/rl/actions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from __future__ import annotations
1212

1313
import os
14+
import sys
1415
from collections import defaultdict
1516
from dataclasses import dataclass
1617
from enum import Enum
@@ -89,8 +90,7 @@
8990
from qiskit.passmanager.base_tasks import Task
9091

9192
TaskList = list[Task | TketBasePass | PreProcessTKETRoutingAfterQiskitLayout]
92-
from qiskit.passmanager import PropertySet
93-
93+
9494

9595
class CompilationOrigin(str, Enum):
9696
"""Enumeration of the origin of the compilation action."""
@@ -146,7 +146,7 @@ class DeviceDependentAction(Action):
146146
Callable[..., tuple[Any, ...] | Circuit],
147147
]
148148
)
149-
do_while: Callable[[PropertySet], bool] | None = None
149+
do_while: Callable[[dict[str, Circuit]], bool] | None = None
150150

151151

152152
# Registry of actions
@@ -332,7 +332,7 @@ def remove_action(name: str) -> None:
332332
circuit,
333333
optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2,
334334
synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8,
335-
max_synthesis_size=3,
335+
max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" else 3,
336336
seed=10,
337337
num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1,
338338
),
@@ -431,7 +431,7 @@ def remove_action(name: str) -> None:
431431
with_mapping=True,
432432
optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2,
433433
synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8,
434-
max_synthesis_size=3,
434+
max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" and sys.platform != "linux" else 3,
435435
seed=10,
436436
num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1,
437437
)
@@ -461,7 +461,7 @@ def remove_action(name: str) -> None:
461461
model=MachineModel(bqskit_circuit.num_qudits, gate_set=get_bqskit_native_gates(device)),
462462
optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2,
463463
synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8,
464-
max_synthesis_size=3,
464+
max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" and sys.platform != "linux" else 3,
465465
seed=10,
466466
num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1,
467467
)

src/mqt/predictor/rl/predictor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,11 @@ def train_model(
9797
verbose: The verbosity level. Defaults to 2.
9898
test: Whether to train the model for testing purposes. Defaults to False.
9999
"""
100-
set_random_seed(0) # for reproducibility
101-
if test:
102-
n_steps = 512
100+
if test:
101+
set_random_seed(0) # for reproducibility
102+
n_steps = 10
103103
n_epochs = 1
104-
batch_size = 16
104+
batch_size = 10
105105
progress_bar = False
106106
else:
107107
# default PPO values

0 commit comments

Comments
 (0)