Skip to content

Commit 2cb72a8

Browse files
committed
Pool allocator: add a clamp statement 'ISTSZ = MAX(ISTSZ, 1)' to ensure size is at least 1
1 parent ec59ed0 commit 2cb72a8

2 files changed

Lines changed: 122 additions & 4 deletions

File tree

loki/transformations/temporaries/pool_allocator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,18 @@ def _get_stack_storage_and_size_var(self, routine, stack_size):
443443
stack_type_bytes = InlineCall(Variable(name='C_SIZEOF'),
444444
parameters=as_tuple(stack_type_bytes))
445445
stack_size_assign = Assignment(lhs=stack_size_var, rhs=stack_size)
446-
body_prepend += [stack_size_assign]
446+
# Ensure the stack size is at least 1 to avoid zero-sized allocations,
447+
# which cause runtime errors when the stack storage is accessed at index 1
448+
# (e.g. LOC(ZSTACK(1, IBL))).
449+
stack_size_clamp = Assignment(
450+
lhs=stack_size_var,
451+
rhs=InlineCall(
452+
function=Variable(name='MAX'),
453+
parameters=(stack_size_var, IntLiteral(1)),
454+
kw_parameters=()
455+
)
456+
)
457+
body_prepend += [stack_size_assign, stack_size_clamp]
447458
variables_append += [stack_size_var]
448459

449460
if self.stack_storage_name in variable_map:

loki/transformations/temporaries/tests/test_pool_allocator.py

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# granted to it by virtue of its status as an intergovernmental organisation
66
# nor does it submit to any jurisdiction.
77

8+
# pylint: disable=too-many-lines
89
import pytest
910

1011
from loki.expression.parser import parse_expr
@@ -65,9 +66,15 @@ def check_stack_created_in_driver(
6566

6667
# # Check the stack size
6768
assignments = FindNodes(Assignment).visit(driver.body)
68-
for assignment in assignments:
69-
if assignment.lhs == 'istsz':
70-
assert simplify(assignment.rhs) == simplify(stack_size)
69+
istsz_assigns = [a for a in assignments if a.lhs == 'istsz']
70+
assert len(istsz_assigns) >= 1
71+
# First assignment is the size computation
72+
assert simplify(istsz_assigns[0].rhs) == simplify(stack_size)
73+
# When the transformation generates the stack (not pre-existing),
74+
# a MAX(..., 1) clamp statement follows to prevent zero-sized allocations
75+
if len(istsz_assigns) == 2:
76+
assert isinstance(istsz_assigns[1].rhs, InlineCall)
77+
assert istsz_assigns[1].rhs.function == 'MAX'
7178

7279
# # Check for stack assignment inside loop
7380
loops = FindNodes(Loop).visit(driver.body)
@@ -1429,3 +1436,103 @@ def test_pool_allocator_args_vs_kwargs(tmp_path, frontend, block_dim_alt, cray_p
14291436

14301437
# check that array size was imported to the driver
14311438
assert 'n' in driver.imported_symbols
1439+
1440+
1441+
@pytest.mark.parametrize('frontend', available_frontends())
1442+
def test_pool_allocator_zero_size_allocation(tmp_path, frontend, block_dim, horizontal):
1443+
"""
1444+
When the kernel has no stack-allocatable temporaries (e.g. all arrays
1445+
have been demoted to scalars by a prior k-caching transformation), the
1446+
computed stack size is 0. A zero-sized ALLOCATE causes a Fortran
1447+
runtime error when the stack is accessed at index 1
1448+
(``LOC(ZSTACK(1, IBL))``).
1449+
1450+
Verify that the pool allocator inserts a clamp statement
1451+
``ISTSZ = MAX(ISTSZ, 1)`` after the size assignment to guarantee a
1452+
minimum allocation of 1.
1453+
"""
1454+
fcode_driver = """
1455+
subroutine driver(NLON, NZ, NB, FIELD1)
1456+
use kernel_mod, only: kernel
1457+
implicit none
1458+
INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
1459+
INTEGER, INTENT(IN) :: NLON, NZ, NB
1460+
real(kind=jprb), intent(inout) :: field1(nlon, nb)
1461+
integer :: b
1462+
do b=1,nb
1463+
call KERNEL(1, nlon, nlon, nz, field1(:,b))
1464+
end do
1465+
end subroutine driver
1466+
""".strip()
1467+
fcode_kernel = """
1468+
module kernel_mod
1469+
implicit none
1470+
contains
1471+
subroutine kernel(start, end, klon, klev, field1)
1472+
implicit none
1473+
integer, parameter :: jprb = selected_real_kind(13,300)
1474+
integer, intent(in) :: start, end, klon, klev
1475+
real(kind=jprb), intent(inout) :: field1(klon)
1476+
real(kind=jprb) :: scalar_tmp
1477+
integer :: jl
1478+
1479+
do jl=start,end
1480+
scalar_tmp = field1(jl) * 2.0_jprb
1481+
field1(jl) = scalar_tmp
1482+
end do
1483+
end subroutine kernel
1484+
end module kernel_mod
1485+
""".strip()
1486+
1487+
config = {
1488+
'default': {
1489+
'mode': 'idem',
1490+
'role': 'kernel',
1491+
'expand': True,
1492+
'strict': False,
1493+
'enable_imports': True,
1494+
},
1495+
'routines': {
1496+
'driver': {'role': 'driver'}
1497+
}
1498+
}
1499+
1500+
(tmp_path/'driver.F90').write_text(fcode_driver)
1501+
(tmp_path/'kernel_mod.F90').write_text(fcode_kernel)
1502+
scheduler = Scheduler(
1503+
paths=[tmp_path], config=SchedulerConfig.from_dict(config),
1504+
frontend=frontend, xmods=[tmp_path]
1505+
)
1506+
1507+
transformation = TemporariesPoolAllocatorTransformation(
1508+
block_dim=block_dim, horizontal=horizontal, check_bounds=False,
1509+
cray_ptr_loc_rhs=False
1510+
)
1511+
scheduler.process(transformation=transformation)
1512+
driver = scheduler['#driver'].ir
1513+
1514+
# Stack infrastructure should still be created
1515+
assert 'istsz' in driver.variables
1516+
assert 'zstack(:,:)' in driver.variables
1517+
1518+
# Find all assignments to ISTSZ
1519+
assigns = [a for a in FindNodes(Assignment).visit(driver.body)
1520+
if a.lhs == 'istsz']
1521+
assert len(assigns) == 2, (
1522+
f'Expected 2 ISTSZ assignments (size + clamp), got {len(assigns)}'
1523+
)
1524+
1525+
# First assignment: ISTSZ = 0 (no temporaries)
1526+
assert assigns[0].rhs == 0
1527+
1528+
# Second assignment: ISTSZ = MAX(ISTSZ, 1)
1529+
clamp_rhs = assigns[1].rhs
1530+
assert isinstance(clamp_rhs, InlineCall)
1531+
assert clamp_rhs.function == 'MAX'
1532+
assert clamp_rhs.parameters[0] == 'istsz'
1533+
assert clamp_rhs.parameters[1] == 1
1534+
1535+
# ALLOCATE should reference ISTSZ
1536+
allocations = FindNodes(Allocation).visit(driver.body)
1537+
assert len(allocations) == 1
1538+
assert 'zstack(istsz,nb)' in allocations[0].variables

0 commit comments

Comments
 (0)