|
| 1 | +from __future__ import annotations |
| 2 | +from dataclasses import dataclass |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +np.seterr(divide='ignore', invalid='ignore') |
| 6 | + |
| 7 | +from scipy.special import factorial |
| 8 | +from concurrent.futures import ProcessPoolExecutor, as_completed |
| 9 | + |
| 10 | +from pybdr.dynamic_system import LinSys, NonLinSys |
| 11 | +from pybdr.geometry import Geometry, Zonotope, Interval |
| 12 | +from pybdr.geometry.operation import cvt2 |
| 13 | +from typing import Callable |
| 14 | +from pybdr.model import Model |
| 15 | +from functools import partial |
| 16 | +from .algorithm import Algorithm |
| 17 | +from .alk2011hscc import ALK2011HSCC |
| 18 | + |
| 19 | + |
| 20 | +class ASB2008CDCParallel: |
| 21 | + @dataclass |
| 22 | + class Options(Algorithm.Options): |
| 23 | + taylor_terms: int = 4 # for linearization |
| 24 | + tensor_order: int = 2 # for error approximation |
| 25 | + u_trans: np.ndarray = None |
| 26 | + factors: np.ndarray = None |
| 27 | + max_err: np.ndarray = None |
| 28 | + lin_err_x = None |
| 29 | + lin_err_u = None |
| 30 | + lin_err_f0 = None |
| 31 | + |
| 32 | + def _validate_misc(self, dim: int): |
| 33 | + assert self.tensor_order == 2 or self.tensor_order == 3 |
| 34 | + self.max_err = ( |
| 35 | + np.full(dim, np.inf) if self.max_err is None else self.max_err |
| 36 | + ) |
| 37 | + i = np.arange(1, self.taylor_terms + 2) |
| 38 | + self.factors = np.power(self.step, i) / factorial(i) |
| 39 | + return True |
| 40 | + |
| 41 | + def validation(self, dim: int): |
| 42 | + assert self._validate_time_related() |
| 43 | + assert self._validate_misc(dim) |
| 44 | + return True |
| 45 | + |
| 46 | + @staticmethod |
| 47 | + def linearize(dyn: Callable, dims, r: Geometry.Base, opt: Options): |
| 48 | + opt.lin_err_u = opt.u_trans if opt.u_trans is not None else opt.u.c |
| 49 | + sys = Model(dyn, dims) |
| 50 | + f0 = sys.evaluate((r.c, opt.lin_err_u), "numpy", 0, 0) |
| 51 | + opt.lin_err_x = r.c + f0 * 0.5 * opt.step |
| 52 | + opt.lin_err_f0 = sys.evaluate((opt.lin_err_x, opt.lin_err_u), "numpy", 0, 0) |
| 53 | + a = sys.evaluate((opt.lin_err_x, opt.lin_err_u), "numpy", 1, 0) |
| 54 | + b = sys.evaluate((opt.lin_err_x, opt.lin_err_u), "numpy", 1, 1) |
| 55 | + assert not (np.any(np.isnan(a))) or np.any(np.isnan(b)) |
| 56 | + lin_sys = LinSys(xa=a) |
| 57 | + lin_opt = ALK2011HSCC.Options() |
| 58 | + lin_opt.step = opt.step |
| 59 | + lin_opt.taylor_terms = opt.taylor_terms |
| 60 | + lin_opt.factors = opt.factors |
| 61 | + lin_opt.u = b @ (opt.u + opt.u_trans - opt.lin_err_u) |
| 62 | + lin_opt.u -= lin_opt.u.c |
| 63 | + lin_opt.u_trans = Zonotope( |
| 64 | + opt.lin_err_f0 + lin_opt.u.c, np.zeros((opt.lin_err_f0.shape[0], 1)) |
| 65 | + ) |
| 66 | + return lin_sys, lin_opt |
| 67 | + |
| 68 | + @staticmethod |
| 69 | + def abstract_err(dyn, dims, r: Geometry.Base, opt: Options): |
| 70 | + sys = Model(dyn, dims) |
| 71 | + ihx = cvt2(r, Geometry.TYPE.INTERVAL) |
| 72 | + total_int_x = ihx + opt.lin_err_x |
| 73 | + |
| 74 | + ihu = cvt2(opt.u, Geometry.TYPE.INTERVAL) |
| 75 | + total_int_u = ihu + opt.lin_err_u |
| 76 | + |
| 77 | + if opt.tensor_order == 2: |
| 78 | + dx = np.maximum(abs(ihx.inf), abs(ihx.sup)) |
| 79 | + du = np.maximum(abs(ihu.inf), abs(ihu.sup)) |
| 80 | + |
| 81 | + # evaluate the hessian matrix with the selected range-bounding technique |
| 82 | + hx = sys.evaluate((total_int_x, total_int_u), "interval", 2, 0) |
| 83 | + hu = sys.evaluate((total_int_x, total_int_u), "interval", 2, 1) |
| 84 | + xx = np.maximum(abs(hx.inf), abs(hx.sup)) |
| 85 | + uu = np.maximum(abs(hu.inf), abs(hu.sup)) |
| 86 | + |
| 87 | + err_lagrange = 0.5 * (dx @ xx @ dx + du @ uu @ du) |
| 88 | + |
| 89 | + verr_dyn = Zonotope(np.zeros(sys.dim), np.diag(err_lagrange)) |
| 90 | + return err_lagrange, verr_dyn |
| 91 | + elif opt.tensor_order == 3: |
| 92 | + r_red = r.reduce(Zonotope.REDUCE_METHOD, Zonotope.ERROR_ORDER) |
| 93 | + z = r_red.card_prod(opt.u) |
| 94 | + # evaluate hessian |
| 95 | + hx = sys.evaluate((opt.lin_err_x, opt.lin_err_u), "numpy", 2, 0) |
| 96 | + hu = sys.evaluate((opt.lin_err_x, opt.lin_err_u), "numpy", 2, 1) |
| 97 | + # evaluate third order |
| 98 | + tx = sys.evaluate((total_int_x, total_int_u), "interval", 3, 0) |
| 99 | + tu = sys.evaluate((total_int_x, total_int_u), "interval", 3, 1) |
| 100 | + |
| 101 | + # second order error |
| 102 | + err_sec = 0.5 * z.quad_map([hx, hu]) |
| 103 | + xx = Interval.sum((ihx @ tx @ ihx) * ihx, axis=1) |
| 104 | + uu = Interval.sum((ihu @ tu @ ihu) * ihu, axis=1) |
| 105 | + err_lagr = (xx + uu) / 6 |
| 106 | + err_lagr = cvt2(err_lagr, Geometry.TYPE.ZONOTOPE) |
| 107 | + |
| 108 | + # overall linearization error |
| 109 | + verr_dyn = err_sec + err_lagr |
| 110 | + verr_dyn = verr_dyn.reduce( |
| 111 | + Zonotope.REDUCE_METHOD, Zonotope.INTERMEDIATE_ORDER |
| 112 | + ) |
| 113 | + true_err = abs(cvt2(verr_dyn, Geometry.TYPE.INTERVAL)).sup |
| 114 | + return true_err, verr_dyn |
| 115 | + else: |
| 116 | + raise Exception("unsupported tensor order") |
| 117 | + |
| 118 | + @classmethod |
| 119 | + def linear_reach(cls, dyn: Callable, dims, r, err, opt: Options): |
| 120 | + lin_sys, lin_opt = cls.linearize(dyn, dims, r, opt) |
| 121 | + r_delta = r - opt.lin_err_x |
| 122 | + r_ti, r_tp = ALK2011HSCC.reach_one_step(lin_sys, r_delta, lin_opt) |
| 123 | + |
| 124 | + perf_ind_cur, perf_ind = np.inf, 0 |
| 125 | + applied_err, abstract_err, v_err_dyn = None, err, None |
| 126 | + |
| 127 | + while perf_ind_cur > 1 >= perf_ind: |
| 128 | + applied_err = 1.1 * abstract_err |
| 129 | + v_err = Zonotope(0 * applied_err, np.diag(applied_err)) |
| 130 | + r_all_err = ALK2011HSCC.error_solution(v_err, lin_opt) |
| 131 | + r_max = r_ti + r_all_err |
| 132 | + true_err, v_err_dyn = cls.abstract_err(dyn, dims, r_max, opt) |
| 133 | + |
| 134 | + # compare linearization error with the maximum allowed error |
| 135 | + temp = true_err / applied_err |
| 136 | + temp[np.isnan(temp)] = -np.inf |
| 137 | + perf_ind_cur = np.max(temp) |
| 138 | + perf_ind = np.max(true_err / opt.max_err) |
| 139 | + abstract_err = true_err |
| 140 | + |
| 141 | + # exception for set explosion |
| 142 | + if np.any(abstract_err > 1e100): |
| 143 | + raise Exception("Set Explosion") |
| 144 | + # translate reachable sets by linearization point |
| 145 | + r_ti += opt.lin_err_x |
| 146 | + r_tp += opt.lin_err_x |
| 147 | + |
| 148 | + # compute the reachable set due to the linearization error |
| 149 | + r_err = ALK2011HSCC.error_solution(v_err_dyn, lin_opt) |
| 150 | + |
| 151 | + # add the abstraction error to the reachable sets |
| 152 | + r_ti += r_err |
| 153 | + r_tp += r_err |
| 154 | + # determine the best dimension to split the set in order to reduce the |
| 155 | + # linearization error |
| 156 | + dim_for_split = [] |
| 157 | + if perf_ind > 1: |
| 158 | + raise NotImplementedError # TODO |
| 159 | + # store the linearization error |
| 160 | + r_ti = r_ti.reduce(Zonotope.REDUCE_METHOD, Zonotope.ORDER) |
| 161 | + r_tp = r_tp.reduce(Zonotope.REDUCE_METHOD, Zonotope.ORDER) |
| 162 | + return r_ti, r_tp, abstract_err, dim_for_split |
| 163 | + |
| 164 | + @classmethod |
| 165 | + def reach_one_step(cls, dyn: Callable, dims, x, err, opt: Options): |
| 166 | + |
| 167 | + r_ti, r_tp, abst_err, dims = cls.linear_reach(dyn, dims, x, err, opt) |
| 168 | + # check if the initial set has to be split |
| 169 | + if len(dims) <= 0: |
| 170 | + return r_ti, r_tp |
| 171 | + else: |
| 172 | + raise NotImplementedError # TODO |
| 173 | + |
| 174 | + @classmethod |
| 175 | + def reach(cls, dyn: Callable, dims, opts: Options, x: Zonotope): |
| 176 | + m = Model(dyn, dims) |
| 177 | + assert opts.validation(m.dim) |
| 178 | + |
| 179 | + ti_set, tp_set = [], [x] |
| 180 | + |
| 181 | + next_tp = x |
| 182 | + |
| 183 | + for step in range(opts.steps_num): |
| 184 | + next_ti, next_tp = cls.reach_one_step(dyn, dims, next_tp, np.zeros(x.shape), opts) |
| 185 | + ti_set.append(next_ti) |
| 186 | + tp_set.append(next_tp) |
| 187 | + |
| 188 | + return ti_set, tp_set |
| 189 | + |
| 190 | + @classmethod |
| 191 | + def reach_parallel(cls, dyn: Callable, dims, opts: Options, xs: [Zonotope]): |
| 192 | + # init containers for storing the results |
| 193 | + ri = [] |
| 194 | + |
| 195 | + partial_reach = partial(cls.reach, dyn, dims, opts) |
| 196 | + |
| 197 | + with ProcessPoolExecutor() as executor: |
| 198 | + futures = [executor.submit(partial_reach, x) for x in xs] |
| 199 | + |
| 200 | + for future in as_completed(futures): |
| 201 | + try: |
| 202 | + ri.append(future.result()) |
| 203 | + except Exception as e: |
| 204 | + raise e |
| 205 | + |
| 206 | + ri = [list(group) for group in zip(*ri)] |
| 207 | + |
| 208 | + return ri |
0 commit comments