|
26 | 26 |
|
27 | 27 |
|
28 | 28 | # Public interface |
29 | | -__all__ = ['Jacobi', |
| 29 | +__all__ = ['CardinalBasis', |
| 30 | + 'Jacobi', |
30 | 31 | 'Legendre', |
31 | 32 | 'Ultraspherical', |
32 | 33 | 'Chebyshev', |
@@ -334,10 +335,166 @@ def enum_indices(tensorsig): |
334 | 335 | # self._ncc_matrices = [self._ncc_matrix_recursion(ncc.data[ind], ncc.domain.full_bases, operand.domain.full_bases, separability, **kw) for ind in np.ndindex(*tshape)] |
335 | 336 |
|
336 | 337 |
|
337 | | -class IntervalBasis(Basis): |
| 338 | +class CardinalBasis(Basis): |
| 339 | + """Cardinal basis.""" |
338 | 340 |
|
339 | 341 | dim = 1 |
| 342 | + group_shape = (1,) |
| 343 | + subaxis_dependence = [False] |
| 344 | + |
| 345 | + def __init__(self, coord, size): |
| 346 | + self.coord = coord |
| 347 | + self.coordsys = coord |
| 348 | + self.size = size |
| 349 | + self.shape = (size,) |
| 350 | + self.dealias = (1,) |
| 351 | + super().__init__(coord) |
| 352 | + |
| 353 | + def __add__(self, other): |
| 354 | + if other is None or other is self: |
| 355 | + return self |
| 356 | + return NotImplemented |
| 357 | + |
| 358 | + def __mul__(self, other): |
| 359 | + if other is None or other is self: |
| 360 | + return self |
| 361 | + return NotImplemented |
| 362 | + |
| 363 | + def __rmatmul__(self, other): |
| 364 | + # NCC (other) * operand (self) |
| 365 | + if other is None or other is self: |
| 366 | + return self |
| 367 | + return NotImplemented |
| 368 | + |
| 369 | + def elements_to_groups(self, grid_space, elements): |
| 370 | + # No permutations |
| 371 | + return elements |
| 372 | + |
| 373 | + def valid_elements(self, tensorsig, grid_space, elements): |
| 374 | + # No invalid modes |
| 375 | + vshape = tuple(cs.dim for cs in tensorsig) + elements[0].shape |
| 376 | + return np.ones(shape=vshape, dtype=bool) |
| 377 | + |
| 378 | + def matrix_dependence(self, matrix_coupling): |
| 379 | + return matrix_coupling |
| 380 | + |
| 381 | + def global_grids(self, dist, scales): |
| 382 | + """Global grids.""" |
| 383 | + return (self.global_grid(dist, scales[0]),) |
| 384 | + |
| 385 | + def global_grid(self, dist, scale): |
| 386 | + """Global grid.""" |
| 387 | + if scale != 1: |
| 388 | + raise NotImplementedError("Cardinal basis only supports scale=1.") |
| 389 | + return np.arange(self.size) |
| 390 | + |
| 391 | + def local_grids(self, dist, scales): |
| 392 | + """Local grids.""" |
| 393 | + return (self.local_grid(dist, scales[0]),) |
| 394 | + |
| 395 | + def local_grid(self, dist, scale): |
| 396 | + """Local grid.""" |
| 397 | + if scale != 1: |
| 398 | + raise NotImplementedError("Cardinal basis only supports scale=1.") |
| 399 | + local_elements = dist.grid_layout.local_elements(self.domain(dist), scales=scale) |
| 400 | + return np.arange(self.size)[local_elements[dist.get_basis_axis(self)]] |
| 401 | + |
| 402 | + def local_modes(self, dist): |
| 403 | + """Local grid.""" |
| 404 | + local_elements = dist.coeff_layout.local_elements(self.domain(dist), scales=1) |
| 405 | + return reshape_vector(local_elements[dist.get_basis_axis(self)], dim=dist.dim, axis=dist.get_basis_axis(self)) |
| 406 | + |
| 407 | + def global_shape(self, grid_space, scales): |
| 408 | + return self.shape |
| 409 | + |
| 410 | + def chunk_shape(self, grid_space): |
| 411 | + return (1,) |
| 412 | + |
| 413 | + def forward_transform(self, field, axis, gdata, cdata): |
| 414 | + """Forward transform field data.""" |
| 415 | + np.copyto(cdata, gdata) |
| 416 | + |
| 417 | + def backward_transform(self, field, axis, cdata, gdata): |
| 418 | + """Backward transform field data.""" |
| 419 | + np.copyto(gdata, cdata) |
| 420 | + |
| 421 | + |
| 422 | +class ConvertConstantCardinal(operators.ConvertConstant, operators.SpectralOperator1D): |
| 423 | + """Convert constant to Cardinal basis.""" |
| 424 | + |
| 425 | + output_basis_type = CardinalBasis |
| 426 | + subaxis_dependence = [True] |
| 427 | + subaxis_coupling = [True] |
| 428 | + |
| 429 | + @staticmethod |
| 430 | + def _full_matrix(input_basis, output_basis): |
| 431 | + return np.ones((output_basis.size, 1)) |
| 432 | + |
| 433 | + |
| 434 | +class InterpolateCardinal(operators.Interpolate, operators.SpectralOperator1D): |
| 435 | + """Interpolate Cardinal basis.""" |
| 436 | + |
| 437 | + input_basis_type = CardinalBasis |
| 438 | + basis_subaxis = 0 |
| 439 | + subaxis_dependence = [True] |
| 440 | + subaxis_coupling = [True] |
| 441 | + |
| 442 | + def __init__(self, operand, coord, position, out=None): |
| 443 | + if not isinstance(position, (int, np.integer)): |
| 444 | + raise TypeError("Cardinal interpolation position must be an integer") |
| 445 | + super().__init__(operand, coord, position, out=out) |
| 446 | + |
| 447 | + @staticmethod |
| 448 | + def _output_basis(input_basis, position): |
| 449 | + return None |
| 450 | + |
| 451 | + @staticmethod |
| 452 | + def _full_matrix(input_basis, output_basis, position): |
| 453 | + interp_vector = np.zeros(input_basis.size) |
| 454 | + interp_vector[position] = 1 |
| 455 | + return interp_vector[None, :] |
| 456 | + |
| 457 | + |
| 458 | +class IntegrateCardinal(operators.Integrate, operators.SpectralOperator1D): |
| 459 | + """Cardinal basis integration.""" |
| 460 | + |
| 461 | + input_coord_type = Coordinate |
| 462 | + input_basis_type = CardinalBasis |
| 463 | + subaxis_dependence = [True] |
| 464 | + subaxis_coupling = [True] |
| 465 | + |
| 466 | + @staticmethod |
| 467 | + def _output_basis(input_basis): |
| 468 | + return None |
| 469 | + |
| 470 | + @staticmethod |
| 471 | + def _full_matrix(input_basis, output_basis): |
| 472 | + integ_vector = np.ones(input_basis.size) |
| 473 | + return integ_vector[None, :] |
| 474 | + |
| 475 | + |
| 476 | +class AverageCardinal(operators.Average, operators.SpectralOperator1D): |
| 477 | + """Cardinal basis averaging.""" |
| 478 | + |
| 479 | + input_coord_type = Coordinate |
| 480 | + input_basis_type = CardinalBasis |
340 | 481 | subaxis_dependence = [True] |
| 482 | + subaxis_coupling = [True] |
| 483 | + |
| 484 | + @staticmethod |
| 485 | + def _output_basis(input_basis): |
| 486 | + return None |
| 487 | + |
| 488 | + @staticmethod |
| 489 | + def _full_matrix(input_basis, output_basis): |
| 490 | + ave_vector = np.ones(input_basis.size) / input_basis.size |
| 491 | + return ave_vector[None, :] |
| 492 | + |
| 493 | + |
| 494 | +class IntervalBasis(Basis): |
| 495 | + |
| 496 | + dim = 1 |
| 497 | + subaxis_dependence = [False] |
341 | 498 |
|
342 | 499 | def __init__(self, coord, size, bounds, dealias): |
343 | 500 | self.coord = coord |
@@ -6084,15 +6241,16 @@ def cfl_spacing(self): |
6084 | 6241 | velocity = self.operand |
6085 | 6242 | coordsys = velocity.tensorsig[0] |
6086 | 6243 | spacing = [] |
6087 | | - for i, c in enumerate(coordsys.coords): |
| 6244 | + for c in coordsys.coords: |
6088 | 6245 | basis = velocity.domain.get_basis(c) |
6089 | 6246 | if basis: |
6090 | 6247 | dealias = basis.dealias[0] |
6091 | 6248 | axis_spacing = basis.local_grid_spacing(self.dist, dealias) * dealias |
6092 | 6249 | N = basis.grid_shape((dealias,))[0] |
6093 | 6250 | if isinstance(basis, Jacobi) and basis.a == -1/2 and basis.b == -1/2: |
6094 | 6251 | #Special case for ChebyshevT (a=b=-1/2) |
6095 | | - local_elements = self.dist.grid_layout.local_elements(basis.domain(self.dist), scales=dealias)[i] |
| 6252 | + axis = self.dist.get_basis_axis(basis) |
| 6253 | + local_elements = self.dist.grid_layout.local_elements(basis.domain(self.dist), scales=dealias)[axis] |
6096 | 6254 | i = np.arange(N)[local_elements].reshape(axis_spacing.shape) |
6097 | 6255 | theta = np.pi * (i + 1/2) / N |
6098 | 6256 | axis_spacing[:] = dealias * basis.COV.stretch * np.sin(theta) * np.pi / N |
|
0 commit comments