Skip to content

Commit f1321f9

Browse files
authored
Support phase with expression (#66)
* ready for testing gadget extraction * support plotting circuit during extraction * support extract ZX-diagrams with phase gadgets * add full reduction * fix function name typos * update for Expr * fix test * fix conversion from ZXDiagram to QCircuit * rework for `Phase` * put type args into field * update simplification * update show
1 parent d3033c0 commit f1321f9

12 files changed

Lines changed: 395 additions & 96 deletions

src/ZXCalculus.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ZXCalculus
22

3+
include("phase.jl")
34
include("abstract_zx_diagram.jl")
45
include("zx_layout.jl")
56
include("zx_diagram.jl")

src/circuit_extraction.jl

Lines changed: 108 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,42 @@ Extract circuit from a graph-like ZX-diagram.
77
"""
88
function circuit_extraction(zxg::ZXGraph{T, P}) where {T, P}
99
nzxg = copy(zxg)
10-
nbits = nqubits(zxg)
10+
nbits = nqubits(nzxg)
11+
gads = Set{T}()
12+
for v in spiders(nzxg)
13+
if spider_type(nzxg, v) == SpiderType.Z && degree(nzxg, v) == 1
14+
push!(gads, v, neighbors(zxg, v)[1])
15+
end
16+
end
1117

12-
cir = ZXDiagram(nbits)
18+
# TODO: extract a QCircuit instead
19+
# cir = QCircuit(nbits)
1320
Outs = get_outputs(nzxg)
1421
Ins = get_inputs(nzxg)
15-
if length(Outs) != length(Ins)
16-
return cir
17-
end
1822
if nbits == 0
1923
nbits = length(Outs)
2024
end
25+
cir = ZXDiagram(nbits)
26+
if length(Outs) != length(Ins)
27+
return cir
28+
end
2129
for v1 in Ins
2230
@inbounds v2 = neighbors(nzxg, v1)[1]
2331
if !is_hadamard(nzxg, v1, v2)
2432
insert_spider!(nzxg, v1, v2)
2533
end
2634
end
2735
@inbounds frontier = [neighbors(nzxg, v)[1] for v in Outs]
28-
29-
extracted = copy(Outs)
36+
qubit_map = Dict(zip(frontier, 1:nbits))
3037

3138
for i = 1:nbits
3239
@inbounds w = neighbors(nzxg, Outs[i])[1]
3340
@inbounds if is_hadamard(nzxg, w, Outs[i])
3441
pushfirst_gate!(cir, Val{:H}(), i)
3542
end
36-
pushfirst_gate!(cir, Val{:Z}(), i, phase(nzxg, w))
37-
set_phase!(nzxg, w, zero(P))
43+
if phase(nzxg, w) != 0
44+
pushfirst_gate!(cir, Val{:Z}(), i, phase(nzxg, w))
45+
set_phase!(nzxg, w, zero(P)) end
3846
@inbounds rem_edge!(nzxg, w, Outs[i])
3947
end
4048
for i = 1:nbits
@@ -47,11 +55,21 @@ function circuit_extraction(zxg::ZXGraph{T, P}) where {T, P}
4755
end
4856
end
4957
end
50-
extracted = [extracted; frontier]
5158

52-
while !isempty(setdiff(spiders(nzxg), extracted))
53-
frontier = update_frontier!(nzxg, frontier, cir)
54-
extracted = union!(extracted, frontier)
59+
old_frontier = copy(frontier)
60+
max_iter = 1000
61+
current_iter = 1
62+
while !isempty(frontier)
63+
update_frontier!(nzxg, gads, frontier, qubit_map, cir)
64+
if frontier != old_frontier
65+
old_frontier = copy(frontier)
66+
current_iter = 1
67+
else
68+
current_iter += 1
69+
if current_iter > max_iter
70+
error("Circuit extraction failed!")
71+
end
72+
end
5573
end
5674

5775
frontier = T[]
@@ -61,8 +79,8 @@ function circuit_extraction(zxg::ZXGraph{T, P}) where {T, P}
6179
push!(frontier, nb[])
6280
end
6381
end
64-
sort!(frontier, by = (v->qubit_loc(nzxg, v)))
65-
M = biadjancency(nzxg, frontier, Ins)
82+
sort!(frontier, by = (v->qubit_map[v]))
83+
M = biadjacency(nzxg, frontier, Ins)
6684
M, steps = gaussian_elimination(M)
6785
for step in steps
6886
if step.op == :addto
@@ -84,28 +102,65 @@ function circuit_extraction(zxg::ZXGraph{T, P}) where {T, P}
84102
end
85103

86104
"""
87-
update_frontier!(zxg, frontier, cir)
105+
update_frontier!(zxg, frontier, qubit_map, cir)
88106
89107
Update frontier. This is a important step in the circuit extraction algorithm.
90108
For more detail, please check the paper [arXiv:1902.03178](https://arxiv.org/abs/1902.03178).
91109
"""
92-
function update_frontier!(zxg::ZXGraph{T, P}, frontier::Vector{T}, cir::ZXDiagram{T, P}) where {T, P}
93-
frontier = frontier[[spider_type(zxg, f) == SpiderType.Z && (degree(zxg, f)) > 0 for f in frontier]]
110+
function update_frontier!(zxg::ZXGraph{T, P}, gads::Set{T}, frontier::Vector{T}, qubit_map::Dict{T, Int}, cir::ZXDiagram{T, P}) where {T, P}
111+
# TODO: use inplace methods
112+
deleteat!(frontier, [spider_type(zxg, f) != SpiderType.Z || (degree(zxg, f)) == 0 for f in frontier])
113+
114+
for i = 1:length(frontier)
115+
v = frontier[i]
116+
nb_v = neighbors(zxg, v)
117+
u = findfirst([u in gads for u in nb_v])
118+
if u !== nothing
119+
u = nb_v[u]
120+
gad_u = zero(T)
121+
for w in neighbors(zxg, u)
122+
if w in gads
123+
gad_u = w
124+
break
125+
end
126+
end
127+
rewrite!(Rule{:pivot}(), zxg, [u, gad_u, v])
128+
pop!(gads, u)
129+
pop!(gads, gad_u)
130+
frontier[i] = u
131+
qubit_map[u] = qubit_map[v]
132+
pushfirst_gate!(cir, Val(:H), qubit_map[u])
133+
delete!(qubit_map, v)
134+
for j = 1:length(frontier)
135+
for k = j+1:length(frontier)
136+
if is_hadamard(zxg, frontier[j], frontier[k])
137+
pushfirst_gate!(cir, Val(:CZ), qubit_map[frontier[j]], qubit_map[frontier[k]])
138+
rem_edge!(zxg, frontier[j], frontier[k])
139+
end
140+
end
141+
end
142+
143+
return frontier
144+
end
145+
end
146+
94147
SetN = Set{T}()
95148
for f in frontier
96149
union!(SetN, neighbors(zxg, f))
97150
end
98151
N = collect(SetN)
99-
sort!(N, by = v -> qubit_loc(zxg, v))
100-
M = biadjancency(zxg, frontier, N)
152+
153+
# TODO: qubit_loc is not necessary here
154+
# sort!(N, by = v -> qubit_loc(zxg, v))
155+
M = biadjacency(zxg, frontier, N)
101156
M0, steps = gaussian_elimination(M)
102157
ws = T[]
103158
@inbounds for i = 1:length(frontier)
104159
if sum(M0[i,:]) == 1
105160
push!(ws, N[findfirst(isone, M0[i,:])])
106161
end
107162
end
108-
M1 = biadjancency(zxg, frontier, ws)
163+
# M1 = biadjacency(zxg, frontier, ws)
109164
@inbounds for e in findall(M .== 1)
110165
if has_edge(zxg, frontier[e[1]], N[e[2]])
111166
rem_edge!(zxg, frontier[e[1]], N[e[2]])
@@ -117,65 +172,60 @@ function update_frontier!(zxg::ZXGraph{T, P}, frontier::Vector{T}, cir::ZXDiagra
117172

118173
@inbounds for step in steps
119174
if step.op == :addto
120-
ctrl = qubit_loc(zxg, frontier[step.r2])
121-
loc = qubit_loc(zxg, frontier[step.r1])
175+
ctrl = qubit_map[frontier[step.r2]]
176+
loc = qubit_map[frontier[step.r1]]
122177
pushfirst_gate!(cir, Val{:CNOT}(), loc, ctrl)
123178
else
124-
q1 = qubit_loc(zxg, frontier[step.r1])
125-
q2 = qubit_loc(zxg, frontier[step.r2])
179+
q1 = qubit_map[frontier[step.r1]]
180+
q2 = qubit_map[frontier[step.r2]]
126181

127182
pushfirst_gate!(cir, Val{:CNOT}(), q2, q1)
128183
pushfirst_gate!(cir, Val{:CNOT}(), q1, q2)
129184
pushfirst_gate!(cir, Val{:CNOT}(), q2, q1)
130185
end
131186
end
132-
old_frontier = copy(frontier)
133-
@inbounds for w in ws
134-
nb_w = neighbors(zxg, w)
135-
v = intersect(nb_w, old_frontier)[1]
136-
if (degree(zxg, v)) == 1
137-
qubit_v = qubit_loc(zxg, v)
138-
qubit_w = qubit_loc(zxg, w)
139-
pushfirst_gate!(cir, Val{:H}(), qubit_v)
140-
if spider_type(zxg, w) == SpiderType.Z
141-
pushfirst_gate!(cir, Val{:Z}(), qubit_v, phase(zxg, w))
187+
188+
for i in 1:length(frontier)
189+
v = frontier[i]
190+
if degree(zxg, v) > 1
191+
continue
192+
end
193+
w = neighbors(zxg, v)[1]
194+
if is_hadamard(zxg, v, w)
195+
pushfirst_gate!(cir, Val(:H), qubit_map[v])
196+
end
197+
if spider_type(zxg, w) == SpiderType.Z
198+
qubit_map[w] = qubit_map[v]
199+
if phase(zxg, w) != 0
200+
pushfirst_gate!(cir, Val{:Z}(), qubit_map[w], phase(zxg, w))
142201
set_phase!(zxg, w, zero(P))
143202
end
144-
if qubit_v != qubit_w && spider_type(zxg, w) == SpiderType.Z
145-
loc_v = column_loc(zxg, v)
146-
loc_w = column_loc(zxg, w)
147-
set_loc!(zxg.layout, w, qubit_v, loc_v)
148-
set_column!(zxg.layout, v, loc_v+1//2)
149-
# deleteat!(zxg.layout.spider_seq[qubit_w], loc_w)
150-
# insert!(zxg.layout.spider_seq[qubit_v], loc_v, w)
151-
end
152203
rem_edge!(zxg, v, w)
153-
if spider_type(zxg, w) == SpiderType.In
154-
add_edge!(zxg, w, v, EdgeType.SIM)
155-
end
156-
deleteat!(frontier, frontier .== v)
157-
push!(frontier, w)
204+
else
205+
rem_edge!(zxg, v, w)
206+
add_edge!(zxg, v, w, EdgeType.SIM)
158207
end
208+
frontier[i] = w
159209
end
160-
@inbounds for i1 = 1:length(ws)
161-
for i2 = i1+1:length(ws)
162-
if has_edge(zxg, ws[i1], ws[i2])
163-
pushfirst_gate!(cir, Val{:CZ}(), qubit_loc(zxg, ws[i1]),
164-
qubit_loc(zxg, ws[i2]))
165-
rem_edge!(zxg, ws[i1], ws[i2])
210+
211+
@inbounds for i1 = 1:length(frontier)
212+
for i2 = i1+1:length(frontier)
213+
if has_edge(zxg, frontier[i1], frontier[i2])
214+
pushfirst_gate!(cir, Val{:CZ}(), qubit_map[frontier[i1]],
215+
qubit_map[frontier[i2]])
216+
rem_edge!(zxg, frontier[i1], frontier[i2])
166217
end
167218
end
168219
end
169-
sort!(frontier, by = v -> qubit_loc(zxg, v))
170220
return frontier
171221
end
172222

173223
"""
174-
biadjancency(zxg, F, N)
224+
biadjacency(zxg, F, N)
175225
176-
Return the biadjancency matrix of `zxg` from vertices in `F` to vertices in `N`.
226+
Return the biadjacency matrix of `zxg` from vertices in `F` to vertices in `N`.
177227
"""
178-
function biadjancency(zxg::ZXGraph{T, P}, F::Vector{T}, N::Vector{T}) where {T, P}
228+
function biadjacency(zxg::ZXGraph{T, P}, F::Vector{T}, N::Vector{T}) where {T, P}
179229
M = zeros(Int, length(F), length(N))
180230

181231
for i = 1:length(F)

src/phase.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import Base: +, -, *, /, ==, isless, rem, convert, zero, one, iseven
2+
import Base: show
3+
4+
"""
5+
Phase
6+
The type supports manipulating phases as expressions.
7+
"""
8+
struct Phase
9+
ex
10+
type
11+
end
12+
13+
Phase(p::T) where {T} = Phase(p, T)
14+
15+
Phase(p::Phase) = p
16+
17+
function show(io::IO, p::Phase)
18+
if p.ex isa Number
19+
print(io, "Phase($(p.ex))")
20+
else
21+
print(io, "Phase($(p.ex)::$(p.type))")
22+
end
23+
end
24+
25+
function +(p1::Phase, p2::Phase)
26+
T1 = p1.type
27+
T2 = p2.type
28+
if p1.ex isa Number && p2.ex isa Number
29+
return Phase(p1.ex + p2.ex)
30+
end
31+
32+
T = Base.promote_op(+, T1, T2)
33+
return Phase(Expr(:call, :+, p1.ex, p2.ex), T)
34+
end
35+
+(p1::Phase, p2::Number) = p1 + Phase(p2)
36+
+(p1::Number, p2::Phase) = Phase(p1) + p2
37+
38+
function -(p1::Phase, p2::Phase)
39+
T1 = p1.type
40+
T2 = p2.type
41+
if p1.ex isa Number && p2.ex isa Number
42+
return Phase(p1.ex - p2.ex)
43+
end
44+
45+
T = Base.promote_op(-, T1, T2)
46+
return Phase(Expr(:call, :-, p1.ex, p2.ex), T)
47+
end
48+
-(p1::Phase, p2::Number) = p1 - Phase(p2)
49+
-(p1::Number, p2::Phase) = Phase(p1) - p2
50+
51+
function *(p1::Phase, p2::Phase)
52+
T1 = p1.type
53+
T2 = p2.type
54+
if p1.ex isa Number && p2.ex isa Number
55+
return Phase(p1.ex * p2.ex)
56+
end
57+
58+
T = Base.promote_op(*, T1, T2)
59+
return Phase(Expr(:call, :*, p1.ex, p2.ex), T)
60+
end
61+
*(p1::Phase, p2::Number) = p1 * Phase(p2)
62+
*(p1::Number, p2::Phase) = Phase(p1) * p2
63+
64+
function /(p1::Phase, p2::Phase)
65+
T1 = p1.type
66+
T2 = p2.type
67+
if p1.ex isa Number && p2.ex isa Number
68+
return Phase(p1.ex / p2.ex)
69+
end
70+
71+
T = Base.promote_op(/, T1, T2)
72+
return Phase(Expr(:call, :/, p1.ex, p2.ex), T)
73+
end
74+
/(p1::Phase, p2::Number) = p1 / Phase(p2)
75+
/(p1::Number, p2::Phase) = Phase(p1) / p2
76+
77+
78+
function -(p::Phase)
79+
T0 = p.type
80+
if p.ex isa Number
81+
return Phase(-p.ex)
82+
end
83+
84+
T = Base.promote_op(-, T0)
85+
return Phase(Expr(:call, :-, p.ex), T)
86+
end
87+
88+
==(p1::Phase, p2::Phase) = p1.ex == p2.ex
89+
==(p1::Phase, p2::Number) = (p1 == Phase(p2))
90+
==(p1::Number, p2::Phase) = (Phase(p1) == p2)
91+
92+
isless(p1::Phase, p2::Number) = (p1.ex isa Number) && p1.ex < p2
93+
function rem(p::Phase, d::Number)
94+
if p.ex isa Number
95+
return Phase(rem(p.ex, d))
96+
end
97+
return p
98+
end
99+
100+
convert(::Type{Phase}, p) = Phase(p)
101+
convert(::Type{Phase}, p::Phase) = p
102+
103+
zero(::Phase) = Phase(0//1)
104+
zero(::Type{Phase}) = Phase(0//1)
105+
one(::Phase) = Phase(1//1)
106+
one(::Type{Phase}) = Phase(1//1)
107+
108+
iseven(p::Phase) = (p.ex isa Number) && (-1)^p.ex > 0

0 commit comments

Comments
 (0)