Skip to content

Commit 26f5f2d

Browse files
committed
fix parallel reach plot bug
1 parent 74a9ab9 commit 26f5f2d

5 files changed

Lines changed: 55 additions & 27 deletions

File tree

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
import numpy as np
22
from pybdr.geometry import Interval, Zonotope, Geometry
33
from pybdr.geometry.operation import cvt2, partition, boundary
4-
from pybdr.algorithm import ASB2008CDC
4+
from pybdr.algorithm import ASB2008CDCParallel
55
from pybdr.dynamic_system import NonLinSys
66
from pybdr.model import *
77
from pybdr.util.visualization import plot, plot_cmp
88
from pybdr.util.functional import performance_counter, performance_counter_start
99

1010
if __name__ == '__main__':
11+
# tuples_list = [(1, 'a', 'x', 'u'), (2, 'b', 'y', 'v'), (3, 'c', 'z', 'w')]
12+
#
13+
# lists_list = [list(t) for t in zip(*tuples_list)]
14+
#
15+
# print(lists_list)
16+
#
17+
# exit(False)
1118
time_start = performance_counter_start()
1219
# init dynamic system
13-
system = NonLinSys(Model(lotka_volterra_2d, [2, 1]))
20+
# system = NonLinSys(Model(lotka_volterra_2d, [2, 1]))
1421

1522
# settings for the computation
16-
options = ASB2008CDC.Options()
23+
options = ASB2008CDCParallel.Options()
1724
options.t_end = 2.2
1825
options.step = 0.005
1926
options.tensor_order = 3
@@ -26,29 +33,37 @@
2633
Zonotope.REDUCE_METHOD = Zonotope.REDUCE_METHOD.GIRARD
2734
Zonotope.ORDER = 50
2835

29-
z = Interval.identity(2) * 0.5 + 3
36+
init_set = Interval.identity(2) * 0.5 + 3
3037

31-
options.r0 = [cvt2(z, Geometry.TYPE.ZONOTOPE)]
32-
_, tp_whole, _, _ = ASB2008CDC.reach(system, options)
38+
_, tp_whole = ASB2008CDCParallel.reach(lotka_volterra_2d, [2, 1], options,
39+
cvt2(init_set, Geometry.TYPE.ZONOTOPE))
3340

41+
# NAIVE PARTITION
3442
# --------------------------------------------------------
3543
# options.r0 = partition(z, 1, Geometry.TYPE.ZONOTOPE)
44+
# xs = partition(init_set, 1, Geometry.TYPE.ZONOTOPE)
3645
# 4
37-
# ASB2008CDC cost: 43.344963666000005s
46+
# ASB2008CDCParallel.reach_parallel cost: 20.126129667s
3847
# --------------------------------------------------------
39-
# options.r0 = partition(z, 0.5, Geometry.TYPE.ZONOTOPE)
48+
# xs = partition(init_set, 0.5, Geometry.TYPE.ZONOTOPE)
4049
# 9
41-
# ASB2008CDC cost: 3868912500001s
50+
# ASB2008CDCParallel.reach_parallel cost: 23.938516459000002s
4251
# --------------------------------------------------------
43-
options.r0 = partition(z, 0.2, Geometry.TYPE.ZONOTOPE)
52+
xs = partition(init_set, 0.2, Geometry.TYPE.ZONOTOPE)
4453
# 36
45-
# ASB2008CDC cost: 317.59988937500003s
54+
# ASB2008CDCParallel.reach_parallel cost: 65.447113125s
4655
# --------------------------------------------------------
47-
# options.r0 = partition(z, 0.1, Geometry.TYPE.ZONOTOPE)
4856

49-
print(len(options.r0))
57+
# BOUNDAYR ANALYSIS
58+
# --------------------------------------------------------
59+
xs = boundary(init_set, 1, Geometry.TYPE.ZONOTOPE)
60+
# 8
61+
# ASB2008CDCParallel.reach_parallel cost: 22.185758250000003s
62+
63+
print(len(xs))
64+
65+
_, tp_part_00 = ASB2008CDCParallel.reach_parallel(lotka_volterra_2d, [2, 1], options, xs)
5066

51-
_, tp_part_00, _, _ = ASB2008CDC.reach(system, options)
52-
performance_counter(time_start, "ASB2008CDC")
67+
performance_counter(time_start, "ASB2008CDCParallel.reach_parallel")
5368

5469
plot_cmp([tp_whole, tp_part_00], [0, 1], cs=["#FF5722", "#303F9F"])

pybdr/algorithm/asb2008cdc_parallel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ def reach(cls, dyn: Callable, dims, opts: Options, x: Zonotope):
189189

190190
@classmethod
191191
def reach_parallel(cls, dyn: Callable, dims, opts: Options, xs: [Zonotope]):
192+
193+
def ll_decompose(ll):
194+
return [list(group) for group in zip(*ll)]
195+
192196
# init containers for storing the results
193197
ri = []
194198

@@ -203,6 +207,6 @@ def reach_parallel(cls, dyn: Callable, dims, opts: Options, xs: [Zonotope]):
203207
except Exception as e:
204208
raise e
205209

206-
ri = [list(group) for group in zip(*ri)]
210+
ri = ll_decompose(ri)
207211

208-
return ri
212+
return ll_decompose(ri[0]), ll_decompose(ri[1])

pybdr/algorithm/reach_linear_zono_algo3_parallel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ def reach(cls, lin_sys: LinearSystemSimple, opts: Settings, x: Zonotope):
161161
@classmethod
162162
def reach_parallel(cls, lin_sys: LinearSystemSimple, opts: Settings, xs: [Zonotope]):
163163

164+
def ll_decompose(ll):
165+
return [list(group) for group in zip(*ll)]
166+
164167
with ProcessPoolExecutor() as executor:
165168
partial_reach = partial(cls.reach, lin_sys, opts)
166169

@@ -174,6 +177,6 @@ def reach_parallel(cls, lin_sys: LinearSystemSimple, opts: Settings, xs: [Zonoto
174177
except Exception as exc:
175178
raise exc
176179

177-
ri = [list(group) for group in zip(*ri)]
180+
ri = ll_decompose(ri)
178181

179-
return ri
182+
return ri[0], ll_decompose(ri[1]), ri[2], ri[3]

pybdr/util/visualization/plot.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,13 @@ def __2d_plot_cmp(collections, dims, width, height, xlim, ylim, cs, filled):
139139
for i in range(len(collections)):
140140
this_color = plt.cm.turbo(i / len(collections)) if cs is None else cs[i]
141141

142-
geos = (
143-
[collections[i]]
144-
if not isinstance(collections[i], list)
145-
else list(itertools.chain.from_iterable(collections[i]))
146-
)
142+
if isinstance(collections[i][0], list):
143+
geos = list(itertools.chain.from_iterable(collections[i]))
144+
else:
145+
geos = collections[i]
146+
147+
# geos = [collections[i]] if not isinstance(collections[i], list) else list(
148+
# itertools.chain.from_iterable(collections[i]))
147149

148150
for geo in geos:
149151
if isinstance(geo, np.ndarray):

test/algorithm/test_reach_linear_zono_algo3_parallel.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,22 @@ def test_reach_linear_zono_algo3_case_00():
1616
lin_sys = LinearSystemSimple(xa, ub)
1717
opts = ReachLinearZonoAlgo3Parallel.Settings()
1818
opts.t_end = 5
19-
opts.step = 0.04
19+
opts.step = 0.01
2020
opts.eta = 4
2121
# opts.x0 = cvt2(Interval([0.9, 0.9], [1.1, 1.1]), Geometry.TYPE.ZONOTOPE)
2222
x0 = Interval([0.9, 0.9], [1.1, 1.1])
2323
opts.u = cvt2(ub @ Interval(0.1, 0.3), Geometry.TYPE.ZONOTOPE) # u must contain origin
24-
x0_bounds = boundary(x0, 0.0001, Geometry.TYPE.ZONOTOPE)
24+
x0_bounds = boundary(x0, 0.1, Geometry.TYPE.ZONOTOPE)
2525

2626
print(len(x0_bounds))
2727

2828
_, ri, _, _ = ReachLinearZonoAlgo3Parallel.reach_parallel(lin_sys, opts, x0_bounds)
2929

3030
performance_counter(time_tag, "reach_linear_zono_algo3_parallel")
3131

32-
# plot(ri, [0, 1])
32+
print(len(ri))
33+
34+
plot(ri, [0, 1])
3335

3436

3537
def test_reach_linear_zono_algo3_case_01():
@@ -49,6 +51,8 @@ def test_reach_linear_zono_algo3_case_01():
4951

5052
x0_bounds = boundary(x0, 0.1, Geometry.TYPE.ZONOTOPE)
5153

54+
performance_counter(time_tag, "boundary")
55+
5256
print(len(x0_bounds))
5357

5458
_, ri, _, _ = ReachLinearZonoAlgo3Parallel.reach_parallel(lin_sys, opts, x0_bounds)

0 commit comments

Comments
 (0)