From 30aae07921b40ddfe6c87a00ba73aaf6c6658ab5 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Wed, 3 Jun 2026 12:33:16 +0000 Subject: [PATCH] [DeepcopyDataflowAnalysis] make sure disabled calls do not cause issues --- .../data_offload/offload_deepcopy.py | 21 ++++- .../tests/test_offload_deepcopy.py | 93 ++++++++++++++++++- 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/loki/transformations/data_offload/offload_deepcopy.py b/loki/transformations/data_offload/offload_deepcopy.py index 809e9fc17..6212f4306 100644 --- a/loki/transformations/data_offload/offload_deepcopy.py +++ b/loki/transformations/data_offload/offload_deepcopy.py @@ -7,6 +7,7 @@ from collections import defaultdict from pathlib import Path +import fnmatch import re try: @@ -125,12 +126,23 @@ class DeepcopyDataflowAnalysis(DataflowAnalysis): the dataflow analysis of the child :any:`Subroutine` and ignoring the intents altogether. """ - def __init__(self, successor_map, include_literal_kinds=False): + def __init__(self, successor_map, disabled_calls=(), include_literal_kinds=False): super().__init__(include_literal_kinds=include_literal_kinds) self.successor_map = successor_map + self.disabled_calls = tuple(name.lower() for name in disabled_calls) + + def is_disabled_call(self, call): + """Return ``True`` if the call matches the configured disabled-call patterns.""" + + name = getattr(getattr(call, 'name', None), 'name', '').lower() + if not name: + return False + return any(fnmatch.fnmatch(name, pattern) for pattern in self.disabled_calls) def resolve_call_effects(self, call, *, attacher, **kwargs): if not call.routine: + if self.is_disabled_call(call): + return None msg = f'[Loki::DataOffloadDeepcopyAnalysis] Cannot apply transformation without enriching calls: {call}.' raise RuntimeError(msg) @@ -209,9 +221,13 @@ class DataOffloadDeepcopyAnalysis(Transformation): def __init__(self, output_analysis=False): self.output_analysis = output_analysis + self.disabled_calls = () def transform_subroutine(self, routine, **kwargs): + item = kwargs.get('item') + self.disabled_calls = as_tuple(item.config.get('disable')) if item and item.config else () + if not (item := kwargs.pop('item', None)): msg = f'[Loki::DataOffloadDeepcopyAnalysis] Cannot apply transformation without item: {routine}.' raise RuntimeError(msg) @@ -332,7 +348,8 @@ def process_body(self, routine_name, item, successors, successor_map, scope_node warning(f'[Loki::DataOffloadDeepcopyAnalysis] Pointer associations found in {routine_name}') dataflow_analysis = DeepcopyDataflowAnalysis( - successor_map=successor_map, include_literal_kinds=False + successor_map=successor_map, disabled_calls=self.disabled_calls, + include_literal_kinds=False ) with dfa_attached(scope_node, dfa=dataflow_analysis): diff --git a/loki/transformations/data_offload/tests/test_offload_deepcopy.py b/loki/transformations/data_offload/tests/test_offload_deepcopy.py index 9a80ff347..28f9a52bb 100644 --- a/loki/transformations/data_offload/tests/test_offload_deepcopy.py +++ b/loki/transformations/data_offload/tests/test_offload_deepcopy.py @@ -21,7 +21,8 @@ from loki.subroutine import Subroutine from loki.tools import gettempdir, flatten, as_tuple from loki.transformations import ( - DataOffloadDeepcopyAnalysis, DataOffloadDeepcopyTransformation, find_driver_loops + DataOffloadDeepcopyAnalysis, DataOffloadDeepcopyTransformation, RemoveCodeTransformation, + find_driver_loops ) from loki.types import BasicType, DerivedType, SymbolAttributes, Scope @@ -356,6 +357,62 @@ def fixture_expected_analysis(): } +@pytest.fixture(scope='module', name='deepcopy_abort_code') +def fixture_deepcopy_abort_code(): + fcode = { + 'kernel': ( + """ +module kernel_mod +contains +subroutine kernel(flag, a) + logical, intent(in) :: flag + real, intent(inout) :: a(:) + + if (flag) then + !$loki remove + a(:) = 0. + !$loki end remove + endif +end subroutine kernel +end module kernel_mod + """.strip() + ), + 'driver': ( + """ +module driver_mod +contains +subroutine driver(ngpblks, flag, a) + use kernel_mod, only : kernel + integer, intent(in) :: ngpblks + logical, intent(in) :: flag + real, intent(inout) :: a(:,:) + integer :: ibl + +!$loki data +!$loki driver-loop + do ibl = 1, ngpblks + call kernel(flag, a(:, ibl)) + enddo +!$loki end data + +end subroutine driver +end module driver_mod + """.strip() + ) + } + + workdir = gettempdir()/'test_offload_deepcopy_abort' + if workdir.exists(): + rmtree(workdir) + workdir.mkdir() + for name, code in fcode.items(): + (workdir/f'{name}.F90').write_text(code) + + yield workdir + + rmtree(workdir) + + @pytest.fixture(scope='function', name='config') def fixture_config(): """ @@ -954,3 +1011,37 @@ def test_offload_deepcopy_simple_driver(frontend, config, deepcopy_code, tmp_pat # filter out target calls, as we only need to check generated boilerplate calls = [call for call in calls if not call.name.name.lower() in driver_item.targets] check_other_variable_type('offload', conds, calls, pragmas, driver) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_offload_deepcopy_analysis_ignores_unresolved_abor1_replacement( + frontend, config, deepcopy_abort_code, tmp_path +): + """Ensure disabled ABOR1-style replacement calls are ignored by deepcopy analysis.""" + + config['default']['disable'] += ['ABOR1*'] + config['routines'] = { + 'driver': {'role': 'driver'}, + } + + scheduler = Scheduler( + paths=deepcopy_abort_code, config=config, frontend=frontend, xmods=[tmp_path], + output_dir=tmp_path, preprocess=True + ) + + scheduler.process(transformation=RemoveCodeTransformation( + remove_marked_regions=True, + replacement_call='ABOR1_ACC', + replacement_msg='Reached removed GPU-unsupported-path in {}' + )) + + kernel = scheduler['kernel_mod#kernel'].ir + kernel_calls = FindNodes(ir.CallStatement).visit(kernel.body) + assert any(call.name.name.lower() == 'abor1_acc' for call in kernel_calls) + + transformation = DataOffloadDeepcopyAnalysis() + scheduler.process(transformation=transformation) + + driver_item = scheduler['driver_mod#driver'] + analysis = driver_item.trafo_data[transformation._key]['analysis'] + assert analysis