|
130 | 130 | .. autoclass:: DirectionalSourceDerivative |
131 | 131 | :show-inheritance: |
132 | 132 | :members: mapper_method,directional_kind |
133 | | -.. autoclass:: DirectionalTargetDerivative |
134 | | - :show-inheritance: |
135 | | - :undoc-members: |
136 | | - :members: mapper_method,directional_kind,target_array_name |
137 | 133 |
|
138 | 134 | Transforming kernels |
139 | 135 | -------------------- |
@@ -1351,86 +1347,6 @@ def replace_inner_kernel(self, new_inner_kernel: Kernel) -> Kernel: |
1351 | 1347 | return type(self)(new_inner_kernel, dir_vec_name=self.dir_vec_name) |
1352 | 1348 |
|
1353 | 1349 |
|
1354 | | -class DirectionalTargetDerivative(DirectionalDerivative): |
1355 | | - mapper_method: ClassVar[str] = "map_directional_target_derivative" |
1356 | | - directional_kind: ClassVar[Literal["src", "tgt"]] = "tgt" |
1357 | | - target_array_name: ClassVar[str] = "targets" |
1358 | | - |
1359 | | - @override |
1360 | | - def get_code_transformer(self) -> Callable[[Expression], Expression]: |
1361 | | - from sumpy.codegen import VectorComponentRewriter |
1362 | | - vcr = VectorComponentRewriter(frozenset([self.dir_vec_name])) |
1363 | | - via = _VectorIndexAdder(self.dir_vec_name, (prim.Variable("itgt"),)) |
1364 | | - |
1365 | | - inner_transform = self.inner_kernel.get_code_transformer() |
1366 | | - |
1367 | | - def transform(expr: Expression) -> Expression: |
1368 | | - return via(vcr(inner_transform(expr))) |
1369 | | - |
1370 | | - return transform |
1371 | | - |
1372 | | - @overload |
1373 | | - def postprocess_at_target( |
1374 | | - self, expr: sym.Expr, bvec: sp.Matrix, |
1375 | | - ) -> sym.Expr: ... |
1376 | | - |
1377 | | - @overload |
1378 | | - def postprocess_at_target( |
1379 | | - self, expr: ExprDerivativeTaker, bvec: sp.Matrix, |
1380 | | - ) -> DifferentiatedExprDerivativeTaker: ... |
1381 | | - |
1382 | | - @override |
1383 | | - def postprocess_at_target( |
1384 | | - self, expr: sym.Expr | ExprDerivativeTaker, bvec: sp.Matrix, |
1385 | | - ) -> sym.Expr | DifferentiatedExprDerivativeTaker: |
1386 | | - dir_vec = sym.make_sym_vector(self.dir_vec_name, self.dim) |
1387 | | - target_vec = sym.make_sym_vector(self.target_array_name, self.dim) |
1388 | | - |
1389 | | - inner_expr = self.inner_kernel.postprocess_at_target(expr, bvec) |
1390 | | - |
1391 | | - # bvec = tgt - center |
1392 | | - if not isinstance(inner_expr, DifferentiatedExprDerivativeTaker): |
1393 | | - result = 0 |
1394 | | - for axis in range(self.dim): |
1395 | | - # Since `bvec` and `tgt` are two different symbolic variables |
1396 | | - # need to differentiate by both to get the correct answer |
1397 | | - result += ( |
1398 | | - (inner_expr.diff(bvec[axis]) + inner_expr.diff(target_vec[axis])) |
1399 | | - * dir_vec[axis]) |
1400 | | - |
1401 | | - assert isinstance(result, sym.Expr) |
1402 | | - return result |
1403 | | - |
1404 | | - new_transformation: DerivativeCoeffDict = defaultdict(lambda: 0) |
1405 | | - for axis in range(self.dim): |
1406 | | - axis_transformation = diff_derivative_coeff_dict( |
1407 | | - inner_expr.derivative_coeff_dict, axis, target_vec) |
1408 | | - for mi, coeff in axis_transformation.items(): |
1409 | | - new_transformation[mi] += coeff * dir_vec[axis] |
1410 | | - |
1411 | | - return DifferentiatedExprDerivativeTaker( |
1412 | | - inner_expr.taker, dict(new_transformation)) |
1413 | | - |
1414 | | - @override |
1415 | | - def get_args(self) -> Sequence[KernelArgument]: |
1416 | | - return [ |
1417 | | - KernelArgument( |
1418 | | - loopy_arg=lp.GlobalArg( |
1419 | | - self.dir_vec_name, |
1420 | | - None, |
1421 | | - shape=(self.dim, "ntargets"), |
1422 | | - offset=lp.auto |
1423 | | - ), |
1424 | | - ), |
1425 | | - *self.inner_kernel.get_args() |
1426 | | - ] |
1427 | | - |
1428 | | - @override |
1429 | | - def prepare_loopy_kernel(self, loopy_knl: lp.TranslationUnit) -> lp.TranslationUnit: |
1430 | | - loopy_knl = self.inner_kernel.prepare_loopy_kernel(loopy_knl) |
1431 | | - return lp.tag_array_axes(loopy_knl, self.dir_vec_name, "sep,C") |
1432 | | - |
1433 | | - |
1434 | 1350 | class DirectionalSourceDerivative(DirectionalDerivative): |
1435 | 1351 | mapper_method: ClassVar[str] = "map_directional_source_derivative" |
1436 | 1352 | directional_kind: ClassVar[Literal["src", "tgt"]] = "src" |
|
0 commit comments