|
5 | 5 | # granted to it by virtue of its status as an intergovernmental organisation |
6 | 6 | # nor does it submit to any jurisdiction. |
7 | 7 |
|
| 8 | +# pylint: disable=too-many-lines |
8 | 9 | import pytest |
9 | 10 |
|
10 | 11 | from loki.expression.parser import parse_expr |
@@ -65,9 +66,15 @@ def check_stack_created_in_driver( |
65 | 66 |
|
66 | 67 | # # Check the stack size |
67 | 68 | assignments = FindNodes(Assignment).visit(driver.body) |
68 | | - for assignment in assignments: |
69 | | - if assignment.lhs == 'istsz': |
70 | | - assert simplify(assignment.rhs) == simplify(stack_size) |
| 69 | + istsz_assigns = [a for a in assignments if a.lhs == 'istsz'] |
| 70 | + assert len(istsz_assigns) >= 1 |
| 71 | + # First assignment is the size computation |
| 72 | + assert simplify(istsz_assigns[0].rhs) == simplify(stack_size) |
| 73 | + # When the transformation generates the stack (not pre-existing), |
| 74 | + # a MAX(..., 1) clamp statement follows to prevent zero-sized allocations |
| 75 | + if len(istsz_assigns) == 2: |
| 76 | + assert isinstance(istsz_assigns[1].rhs, InlineCall) |
| 77 | + assert istsz_assigns[1].rhs.function == 'MAX' |
71 | 78 |
|
72 | 79 | # # Check for stack assignment inside loop |
73 | 80 | loops = FindNodes(Loop).visit(driver.body) |
@@ -1429,3 +1436,103 @@ def test_pool_allocator_args_vs_kwargs(tmp_path, frontend, block_dim_alt, cray_p |
1429 | 1436 |
|
1430 | 1437 | # check that array size was imported to the driver |
1431 | 1438 | assert 'n' in driver.imported_symbols |
| 1439 | + |
| 1440 | + |
| 1441 | +@pytest.mark.parametrize('frontend', available_frontends()) |
| 1442 | +def test_pool_allocator_zero_size_allocation(tmp_path, frontend, block_dim, horizontal): |
| 1443 | + """ |
| 1444 | + When the kernel has no stack-allocatable temporaries (e.g. all arrays |
| 1445 | + have been demoted to scalars by a prior k-caching transformation), the |
| 1446 | + computed stack size is 0. A zero-sized ALLOCATE causes a Fortran |
| 1447 | + runtime error when the stack is accessed at index 1 |
| 1448 | + (``LOC(ZSTACK(1, IBL))``). |
| 1449 | +
|
| 1450 | + Verify that the pool allocator inserts a clamp statement |
| 1451 | + ``ISTSZ = MAX(ISTSZ, 1)`` after the size assignment to guarantee a |
| 1452 | + minimum allocation of 1. |
| 1453 | + """ |
| 1454 | + fcode_driver = """ |
| 1455 | +subroutine driver(NLON, NZ, NB, FIELD1) |
| 1456 | + use kernel_mod, only: kernel |
| 1457 | + implicit none |
| 1458 | + INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300) |
| 1459 | + INTEGER, INTENT(IN) :: NLON, NZ, NB |
| 1460 | + real(kind=jprb), intent(inout) :: field1(nlon, nb) |
| 1461 | + integer :: b |
| 1462 | + do b=1,nb |
| 1463 | + call KERNEL(1, nlon, nlon, nz, field1(:,b)) |
| 1464 | + end do |
| 1465 | +end subroutine driver |
| 1466 | + """.strip() |
| 1467 | + fcode_kernel = """ |
| 1468 | +module kernel_mod |
| 1469 | + implicit none |
| 1470 | +contains |
| 1471 | + subroutine kernel(start, end, klon, klev, field1) |
| 1472 | + implicit none |
| 1473 | + integer, parameter :: jprb = selected_real_kind(13,300) |
| 1474 | + integer, intent(in) :: start, end, klon, klev |
| 1475 | + real(kind=jprb), intent(inout) :: field1(klon) |
| 1476 | + real(kind=jprb) :: scalar_tmp |
| 1477 | + integer :: jl |
| 1478 | +
|
| 1479 | + do jl=start,end |
| 1480 | + scalar_tmp = field1(jl) * 2.0_jprb |
| 1481 | + field1(jl) = scalar_tmp |
| 1482 | + end do |
| 1483 | + end subroutine kernel |
| 1484 | +end module kernel_mod |
| 1485 | + """.strip() |
| 1486 | + |
| 1487 | + config = { |
| 1488 | + 'default': { |
| 1489 | + 'mode': 'idem', |
| 1490 | + 'role': 'kernel', |
| 1491 | + 'expand': True, |
| 1492 | + 'strict': False, |
| 1493 | + 'enable_imports': True, |
| 1494 | + }, |
| 1495 | + 'routines': { |
| 1496 | + 'driver': {'role': 'driver'} |
| 1497 | + } |
| 1498 | + } |
| 1499 | + |
| 1500 | + (tmp_path/'driver.F90').write_text(fcode_driver) |
| 1501 | + (tmp_path/'kernel_mod.F90').write_text(fcode_kernel) |
| 1502 | + scheduler = Scheduler( |
| 1503 | + paths=[tmp_path], config=SchedulerConfig.from_dict(config), |
| 1504 | + frontend=frontend, xmods=[tmp_path] |
| 1505 | + ) |
| 1506 | + |
| 1507 | + transformation = TemporariesPoolAllocatorTransformation( |
| 1508 | + block_dim=block_dim, horizontal=horizontal, check_bounds=False, |
| 1509 | + cray_ptr_loc_rhs=False |
| 1510 | + ) |
| 1511 | + scheduler.process(transformation=transformation) |
| 1512 | + driver = scheduler['#driver'].ir |
| 1513 | + |
| 1514 | + # Stack infrastructure should still be created |
| 1515 | + assert 'istsz' in driver.variables |
| 1516 | + assert 'zstack(:,:)' in driver.variables |
| 1517 | + |
| 1518 | + # Find all assignments to ISTSZ |
| 1519 | + assigns = [a for a in FindNodes(Assignment).visit(driver.body) |
| 1520 | + if a.lhs == 'istsz'] |
| 1521 | + assert len(assigns) == 2, ( |
| 1522 | + f'Expected 2 ISTSZ assignments (size + clamp), got {len(assigns)}' |
| 1523 | + ) |
| 1524 | + |
| 1525 | + # First assignment: ISTSZ = 0 (no temporaries) |
| 1526 | + assert assigns[0].rhs == 0 |
| 1527 | + |
| 1528 | + # Second assignment: ISTSZ = MAX(ISTSZ, 1) |
| 1529 | + clamp_rhs = assigns[1].rhs |
| 1530 | + assert isinstance(clamp_rhs, InlineCall) |
| 1531 | + assert clamp_rhs.function == 'MAX' |
| 1532 | + assert clamp_rhs.parameters[0] == 'istsz' |
| 1533 | + assert clamp_rhs.parameters[1] == 1 |
| 1534 | + |
| 1535 | + # ALLOCATE should reference ISTSZ |
| 1536 | + allocations = FindNodes(Allocation).visit(driver.body) |
| 1537 | + assert len(allocations) == 1 |
| 1538 | + assert 'zstack(istsz,nb)' in allocations[0].variables |
0 commit comments