1313from loki .transform import Transformation
1414from loki import (
1515 FindVariables , DerivedType , SymbolAttributes ,
16- Array , single_variable_declaration , Transformer
16+ Array , single_variable_declaration , Transformer ,
17+ BasicType
1718)
19+ import pickle
20+ import os
1821
1922__all__ = ['ParallelRoutineDispatchTransformation' ]
2023
@@ -26,11 +29,15 @@ def __init__(self):
2629 "KLON" , "YDCPG_OPTS%KLON" , "YDGEOMETRY%YRDIM%NPROMA" ,
2730 "KPROMA" , "YDDIM%NPROMA" , "NPROMA"
2831 ]
32+ #TODO : do smthg for opening field_index.pkl
33+ with open (os .getcwd ()+ "/transformations/transformations/field_index.pkl" , 'rb' ) as fp :
34+ self .map_index = pickle .load (fp )
2935 # CALL FIELD_NEW (YL_ZA, UBOUNDS=[KLON, KFLEVG, KGPBLKS], LBOUNDS=[1, 0, 1], PERSISTENT=.TRUE.)
3036 self .new_calls = []
3137 # IF (ASSOCIATED (YL_ZA)) CALL FIELD_DELETE (YL_ZA)
3238 self .delete_calls = []
3339 self .routine_map_temp = {}
40+ self .routine_map_derived = {}
3441
3542 def transform_subroutine (self , routine , ** kwargs ):
3643 with pragma_regions_attached (routine ):
@@ -40,6 +47,7 @@ def transform_subroutine(self, routine, **kwargs):
4047 single_variable_declaration (routine )
4148 self .add_temp (routine )
4249 self .add_field (routine )
50+ self .add_derived (routine )
4351 #call add_arrays etc...
4452
4553 def process_parallel_region (self , routine , region ):
@@ -61,11 +69,15 @@ def process_parallel_region(self, routine, region):
6169 region .append (dr_hook_calls [1 ])
6270
6371 region_map_temp = self .decl_local_array (routine , region )
72+ region_map_derived = self .decl_derived_types (routine , region )
6473
6574 for var_name in region_map_temp :
6675 if var_name not in self .routine_map_temp :
6776 self .routine_map_temp [var_name ]= region_map_temp [var_name ]
6877
78+ for var_name in region_map_derived :
79+ if var_name not in self .routine_map_derived :
80+ self .routine_map_derived [var_name ]= region_map_derived [var_name ]
6981
7082
7183 @staticmethod
@@ -178,4 +190,51 @@ def add_field(self, routine):
178190 routine , cdname = 'DELETE_TEMPORARIES' ,
179191 handle = sym .Variable (name = 'ZHOOK_HANDLE_FIELD_API' , scope = routine )
180192 )
181- routine .body .insert (- 2 ,(dr_hook_calls [0 ], ir .Comment (text = '' ), * self .delete_calls , dr_hook_calls [1 ]))
193+ routine .body .insert (- 2 ,(dr_hook_calls [0 ], ir .Comment (text = '' ), * self .delete_calls , dr_hook_calls [1 ]))
194+
195+ def decl_derived_types (self , routine , region ):
196+ region_map_derived = {}
197+ derived = [var for var in FindVariables ().visit (region ) if var .name_parts [0 ] in routine .arguments ]
198+ for var in derived :
199+
200+ key = f"{ routine .variable_map [var .name_parts [0 ]].type .dtype .name } %{ '%' .join (var .name_parts [1 :])} "
201+ if key in self .map_index :
202+ value = self .map_index [key ]
203+ # Creating the pointer on the data : YL_A
204+ data_name = f"Z_{ var .name .replace ('%' , '_' )} "
205+ if "REAL" and "JPRB" in value [0 ]:
206+ data_type = SymbolAttributes (
207+ dtype = BasicType .REAL , kind = routine .symbol_map ['JPRB' ],
208+ pointer = True
209+ )
210+ data_dim = value [2 ] + 1
211+ data_shape = (sym .RangeIndex ((None , None )),) * data_dim
212+ ptr_var = sym .Variable (name = data_name , type = data_type , dimensions = data_shape , scope = routine )
213+
214+ else :
215+ raise NotImplementedError ("This type isn't implemented yet" )
216+
217+ # Creating the pointer on the field api object : YL%FA, YL%F_A...
218+ if routine .variable_map [var .name_parts [0 ]].type .dtype .name == "MF_PHYS_SURF_TYPE" :
219+ # YL%PA becomes YL%F_A
220+ field_name = f"{ '%' .join (var .name_parts [:- 1 ])} %F_{ var .name_parts [- 1 ][1 :]} "
221+ elif routine .variable_map [var .name_parts [0 ]].type .dtype .name == "FIELD_VARIABLES" :
222+ # YL%A becomes YL%FA
223+ field_name = f"{ '%' .join (var .name_parts [:- 1 ])} %F{ var .name_parts [- 1 ]} "
224+ if var .name_parts [- 1 ]== "P" : #YL%FP = YL%FT0
225+ field_name = f"{ field_name [- 1 ]} T0"
226+ else :
227+ # YL%A becomes YL%F_A
228+ field_name = f"{ '%' .join (var .name_parts [:- 1 ])} %F_{ var .name_parts [- 1 ]} "
229+ field_ptr_var = var .clone (name = field_name )
230+ region_map_derived [var .name ] = [field_ptr_var , ptr_var ]
231+ return (region_map_derived )
232+
233+ def add_derived (self , routine ):
234+ ptr_var = ()
235+ for value in self .routine_map_derived .values ():
236+ dcl = ir .VariableDeclaration (
237+ symbols = (value [1 ],)
238+ )
239+ ptr_var += (dcl ,)
240+ routine .spec .append (ptr_var )
0 commit comments