|
7 | 7 | from conftest import assert_structure |
8 | 8 | from devito import ( |
9 | 9 | DefaultDimension, Dimension, Eq, Function, Grid, MatrixSparseTimeFunction, Operator, |
10 | | - PrecomputedSparseFunction, PrecomputedSparseTimeFunction, SparseFunction, |
| 10 | + PrecomputedSparseFunction, PrecomputedSparseTimeFunction, Real, SparseFunction, |
11 | 11 | SparseTimeFunction, SubDomain, TimeFunction, switchconfig |
12 | 12 | ) |
13 | 13 | from devito.operations.interpolators import LinearInterpolator, SincInterpolator |
|
17 | 17 | from examples.seismic.acoustic import AcousticWaveSolver, acoustic_setup |
18 | 18 |
|
19 | 19 |
|
| 20 | +class SparseFirst(SparseFunction): |
| 21 | + """ Custom sparse class with the sparse dimension as the first one""" |
| 22 | + _sparse_position = 0 |
| 23 | + |
| 24 | + |
20 | 25 | def unit_box(name='a', shape=(11, 11), grid=None, space_order=1): |
21 | 26 | """Create a field with value 0. to 1. in each dimension""" |
22 | 27 | grid = grid or Grid(shape=shape) |
@@ -698,11 +703,6 @@ def test_sparse_first(): |
698 | 703 | """ |
699 | 704 | Tests custom sprase function with sparse dimension as first index. |
700 | 705 | """ |
701 | | - |
702 | | - class SparseFirst(SparseFunction): |
703 | | - """ Custom sparse class with the sparse dimension as the first one""" |
704 | | - _sparse_position = 0 |
705 | | - |
706 | 706 | dr = Dimension("cd") |
707 | 707 | ds = DefaultDimension("ps", default_value=3) |
708 | 708 | grid = Grid((11, 11)) |
@@ -870,6 +870,28 @@ def test_interp_complex(dtype): |
870 | 870 | assert np.isclose(sc.data[0], fc.data[5, 5, 5]) |
871 | 871 |
|
872 | 872 |
|
| 873 | +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) |
| 874 | +def test_interp_complex_and_real(dtype): |
| 875 | + grid = Grid((11, 11, 11)) |
| 876 | + |
| 877 | + sc = SparseFunction(name="sc", grid=grid, npoint=1, dtype=dtype) |
| 878 | + sc.coordinates.data[:] = [.5, .5, .5] |
| 879 | + scre = SparseFunction(name="sce", grid=grid, npoint=1, coordinates=sc.coordinates) |
| 880 | + |
| 881 | + fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype) |
| 882 | + fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape) |
| 883 | + exprs = sc.interpolate(expr=fc) + scre.interpolate(expr=Real(fc)) |
| 884 | + opC = Operator(exprs, name="OpC") |
| 885 | + opC() |
| 886 | + |
| 887 | + assert np.isclose(sc.data[0], fc.data[5, 5, 5]) |
| 888 | + assert np.isclose(scre.data[0], fc.data[5, 5, 5].real) |
| 889 | + |
| 890 | + assert_structure(opC, ['p_sc', 'p_sc,rp_scx,rp_scy,rp_scz', |
| 891 | + 'p_sc,rp_scx,rp_scy,rp_scz'], |
| 892 | + 'p_sc,rp_scx,rp_scy,rp_scz,rp_scx,rp_scy,rp_scz') |
| 893 | + |
| 894 | + |
873 | 895 | class SD0(SubDomain): |
874 | 896 | name = 'sd0' |
875 | 897 |
|
@@ -989,9 +1011,9 @@ def test_interpolate_subdomain(self): |
989 | 1011 | assert np.all(np.isclose(sr1.data, check1)) |
990 | 1012 | assert np.all(np.isclose(sr2.data, check2)) |
991 | 1013 | assert_structure(op, |
992 | | - ['p_sr0', 'p_sr0rsr0xrsr0y', 'p_sr1', |
993 | | - 'p_sr1rsr1xrsr1y', 'p_sr2', 'p_sr2rsr2xrsr2y'], |
994 | | - 'p_sr0rsr0xrsr0yp_sr1rsr1xrsr1yp_sr2rsr2xrsr2y') |
| 1014 | + ['p_sr0', 'p_sr0rp_sr0xrp_sr0y', 'p_sr1', |
| 1015 | + 'p_sr1rp_sr1xrp_sr1y', 'p_sr2', 'p_sr2rp_sr2xrp_sr2y'], |
| 1016 | + 'p_sr0rp_sr0xrp_sr0yp_sr1rp_sr1xrp_sr1yp_sr2rp_sr2xrp_sr2y') |
995 | 1017 |
|
996 | 1018 | def test_interpolate_subdomain_sinc(self): |
997 | 1019 | """ |
@@ -1032,9 +1054,9 @@ def test_interpolate_subdomain_sinc(self): |
1032 | 1054 | assert np.all(np.isclose(sr0.data, sr2.data)) |
1033 | 1055 | assert np.all(np.isclose(sr1.data, sr2.data)) |
1034 | 1056 | assert_structure(op, |
1035 | | - ['p_sr0', 'p_sr0rsr0xrsr0y', 'p_sr1', |
1036 | | - 'p_sr1rsr1xrsr1y', 'p_sr2', 'p_sr2rsr2xrsr2y'], |
1037 | | - 'p_sr0rsr0xrsr0yp_sr1rsr1xrsr1yp_sr2rsr2xrsr2y') |
| 1057 | + ['p_sr0', 'p_sr0rp_sr0xrp_sr0y', 'p_sr1', |
| 1058 | + 'p_sr1rp_sr1xrp_sr1y', 'p_sr2', 'p_sr2rp_sr2xrp_sr2y'], |
| 1059 | + 'p_sr0rp_sr0xrp_sr0yp_sr1rp_sr1xrp_sr1yp_sr2rp_sr2xrp_sr2y') |
1038 | 1060 |
|
1039 | 1061 | def test_inject_subdomain(self): |
1040 | 1062 | """ |
@@ -1080,8 +1102,8 @@ def test_inject_subdomain(self): |
1080 | 1102 | assert np.all(np.isclose(f0.data, check0)) |
1081 | 1103 | assert np.all(np.isclose(f1.data, check1)) |
1082 | 1104 | assert_structure(op, |
1083 | | - ['p_sr0rsr0xrsr0y'], |
1084 | | - 'p_sr0rsr0xrsr0y') |
| 1105 | + ['p_sr0rp_sr0xrp_sr0y'], |
| 1106 | + 'p_sr0rp_sr0xrp_sr0y') |
1085 | 1107 |
|
1086 | 1108 | def test_inject_subdomain_sinc(self): |
1087 | 1109 | """ |
@@ -1112,8 +1134,8 @@ def test_inject_subdomain_sinc(self): |
1112 | 1134 | assert np.all(np.isclose(f0.data, f2.data[:9, -9:])) |
1113 | 1135 | assert np.all(np.isclose(f1.data, f2.data[1:-1, 1:-1])) |
1114 | 1136 | assert_structure(op, |
1115 | | - ['p_sr0rsr0xrsr0y'], |
1116 | | - 'p_sr0rsr0xrsr0y') |
| 1137 | + ['p_sr0rp_sr0xrp_sr0y'], |
| 1138 | + 'p_sr0rp_sr0xrp_sr0y') |
1117 | 1139 |
|
1118 | 1140 | @pytest.mark.xfail(reason="OOB issue") |
1119 | 1141 | @pytest.mark.parallel(mode=4) |
|
0 commit comments