Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions devito/ir/cgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Utilities to turn SymPy objects into C strings.
"""
from contextlib import suppress
from ctypes import _Pointer

import numpy as np
import sympy
Expand Down Expand Up @@ -110,11 +111,17 @@ def parenthesize(self, item, level, strict=False):
return super().parenthesize(item, level, strict=strict)

def _print_PyCPointerType(self, expr):
ctype = f'{self._print_type(expr._type_)}'
base_type, nstart = expr, 0
while issubclass(base_type, _Pointer):
base_type = base_type._type_
nstart += 1

ctype = f'{self._print_type(base_type)}'
stars = '*' * nstart
if ctype.endswith('*'):
return f'{ctype}*'
return f'{ctype}{stars}'
else:
return f'{ctype} *'
return f'{ctype} {stars}'

def _print_type(self, expr):
with suppress(TypeError):
Expand Down
6 changes: 3 additions & 3 deletions examples/seismic/self_adjoint/sa_03_iso_correctness.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,9 @@
"output_type": "stream",
"text": [
"Operator `IsoFwdOperator` ran in 0.04 s\n",
"No source type defined, returning uninitiallized (zero) source\n",
"No source type defined, returning uninitialized (zero) source\n",
"Operator `IsoAdjOperator` ran in 0.03 s\n",
"No source type defined, returning uninitiallized (zero) source\n",
"No source type defined, returning uninitialized (zero) source\n",
"Operator `IsoAdjOperator` ran in 0.03 s\n"
]
},
Expand Down Expand Up @@ -639,7 +639,7 @@
"output_type": "stream",
"text": [
"Operator `IsoFwdOperator` ran in 0.03 s\n",
"No source type defined, returning uninitiallized (zero) source\n",
"No source type defined, returning uninitialized (zero) source\n",
"Operator `IsoAdjOperator` ran in 0.03 s\n"
]
},
Expand Down
2 changes: 1 addition & 1 deletion examples/seismic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def src(self):
def new_src(self, name='src', src_type='self', coordinates=None):
coords = coordinates or self.src_positions
if self.src_type is None or src_type is None:
warning("No source type defined, returning uninitiallized (zero) source")
warning("No source type defined, returning uninitialized (zero) source")
src = PointSource(name=name, grid=self.grid,
time_range=self.time_axis, npoint=self.nsrc,
coordinates=coords,
Expand Down
13 changes: 12 additions & 1 deletion tests/test_iet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ctypes import c_void_p
from ctypes import POINTER, c_void_p

import cgen
import numpy as np
Expand All @@ -20,6 +20,7 @@
FLOAT, Byref, Class, FieldFromComposite, InlineIf, ListInitializer, Macro, SizeOf,
String
)
from devito.symbolics.extended_dtypes import c_complex
from devito.tools import CustomDtype, as_tuple, dtype_to_ctype
from devito.types import Array, CustomDimension, LocalObject, Pointer, Symbol
from devito.types.misc import FunctionMap
Expand Down Expand Up @@ -528,6 +529,16 @@ def test_codegen_quality0():
assert foo1.parameters[0] is a


def test_complex_array():
grid = Grid(shape=(4, 4, 4))
_, y, z = grid.dimensions

a = Array(name='a', dimensions=grid.dimensions, dtype=POINTER(c_complex))

assert str(Definition(a)) == \
"float _Complex **restrict a_vec __attribute__ ((aligned (64)));"


def test_special_array_definition():

class MyArray(Array):
Expand Down
Loading