|
7 | 7 | from conftest import assert_structure |
8 | 8 | from devito import ( |
9 | 9 | DefaultDimension, Dimension, Eq, Function, Grid, MatrixSparseTimeFunction, Operator, |
10 | | - PrecomputedSparseFunction, PrecomputedSparseTimeFunction, SparseFunction, |
| 10 | + PrecomputedSparseFunction, PrecomputedSparseTimeFunction, Real, SparseFunction, |
11 | 11 | SparseTimeFunction, SubDomain, TimeFunction, switchconfig |
12 | 12 | ) |
13 | 13 | from devito.operations.interpolators import LinearInterpolator, SincInterpolator |
|
17 | 17 | from examples.seismic.acoustic import AcousticWaveSolver, acoustic_setup |
18 | 18 |
|
19 | 19 |
|
| 20 | +class SparseFirst(SparseFunction): |
| 21 | + """ Custom sparse class with the sparse dimension as the first one""" |
| 22 | + _sparse_position = 0 |
| 23 | + |
| 24 | + |
20 | 25 | def unit_box(name='a', shape=(11, 11), grid=None, space_order=1): |
21 | 26 | """Create a field with value 0. to 1. in each dimension""" |
22 | 27 | grid = grid or Grid(shape=shape) |
@@ -698,11 +703,6 @@ def test_sparse_first(): |
698 | 703 | """ |
699 | 704 | Tests custom sprase function with sparse dimension as first index. |
700 | 705 | """ |
701 | | - |
702 | | - class SparseFirst(SparseFunction): |
703 | | - """ Custom sparse class with the sparse dimension as the first one""" |
704 | | - _sparse_position = 0 |
705 | | - |
706 | 706 | dr = Dimension("cd") |
707 | 707 | ds = DefaultDimension("ps", default_value=3) |
708 | 708 | grid = Grid((11, 11)) |
@@ -870,6 +870,27 @@ def test_interp_complex(dtype): |
870 | 870 | assert np.isclose(sc.data[0], fc.data[5, 5, 5]) |
871 | 871 |
|
872 | 872 |
|
| 873 | +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) |
| 874 | +def test_interp_complex_and_real(dtype): |
| 875 | + grid = Grid((11, 11, 11)) |
| 876 | + |
| 877 | + sc = SparseFunction(name="sc", grid=grid, npoint=1, dtype=dtype) |
| 878 | + sc.coordinates.data[:] = [.5, .5, .5] |
| 879 | + scre = SparseFunction(name="sce", grid=grid, npoint=1, coordinates=sc.coordinates) |
| 880 | + |
| 881 | + fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype) |
| 882 | + fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape) |
| 883 | + exprs = sc.interpolate(expr=fc) + scre.interpolate(expr=Real(fc)) |
| 884 | + opC = Operator(exprs, name="OpC") |
| 885 | + opC() |
| 886 | + |
| 887 | + assert np.isclose(sc.data[0], fc.data[5, 5, 5]) |
| 888 | + assert np.isclose(scre.data[0], fc.data[5, 5, 5].real) |
| 889 | + |
| 890 | + assert_structure(opC, ['p_sc', 'p_sc,rscx,rscy,rscz', 'p_sc,rscex,rscey,rscez'], |
| 891 | + 'p_sc,rscx,rscy,rscz,rscex,rscey,rscez') |
| 892 | + |
| 893 | + |
873 | 894 | class SD0(SubDomain): |
874 | 895 | name = 'sd0' |
875 | 896 |
|
|
0 commit comments