1212using namespace clang ::dpct;
1313using namespace clang ::ast_matchers;
1414
15+ #define STRINGIZE (x ) #x
16+ #define EXPAND_AND_STRINGIZE (x ) STRINGIZE (x)
17+
18+ #define FOR_ALL_STANDARD_RMA_TYPES (PREFIX, POSTFIX ) \
19+ EXPAND_AND_STRINGIZE (nvshmem##PREFIX ##_##float ##_##POSTFIX ), \
20+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##double ##_##POSTFIX ), \
21+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##char ##_##POSTFIX ), \
22+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##schar##_##POSTFIX ), \
23+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##short ##_##POSTFIX ), \
24+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##int ##_##POSTFIX ), \
25+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##long ##_##POSTFIX ), \
26+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##longlong##_##POSTFIX ), \
27+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##uchar##_##POSTFIX ), \
28+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##ushort##_##POSTFIX ), \
29+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##uint##_##POSTFIX ), \
30+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##ulong##_##POSTFIX ), \
31+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##ulonglong##_##POSTFIX ), \
32+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##int8##_##POSTFIX ), \
33+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##int16##_##POSTFIX ), \
34+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##int32##_##POSTFIX ), \
35+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##int64##_##POSTFIX ), \
36+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##uint8##_##POSTFIX ), \
37+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##uint16##_##POSTFIX ), \
38+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##uint32##_##POSTFIX ), \
39+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##uint64##_##POSTFIX ), \
40+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##size##_##POSTFIX ), \
41+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##ptrdiff##_##POSTFIX )
42+
43+ #define FOR_ALL_SIZES (PREFIX, OP, POSTFIX ) \
44+ EXPAND_AND_STRINGIZE (nvshmem##PREFIX ##_##OP ##8 ##POSTFIX ), \
45+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##OP ##16 ##POSTFIX ), \
46+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##OP ##32 ##POSTFIX ), \
47+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##OP ##64 ##POSTFIX ), \
48+ EXPAND_AND_STRINGIZE(nvshmem##PREFIX ##_##OP ##128 ##POSTFIX )
49+
1550void clang::dpct::NVSHMEMRule::registerMatcher(ast_matchers::MatchFinder &MF ) {
1651 auto NvshmemAPI = [&]() {
1752 return hasAnyName (
@@ -25,8 +60,88 @@ void clang::dpct::NVSHMEMRule::registerMatcher(ast_matchers::MatchFinder &MF) {
2560 " nvshmem_team_my_pe" , " nvshmem_team_n_pes" , " nvshmem_team_get_config" ,
2661 " nvshmem_team_translate_pe" , " nvshmem_team_split_strided" ,
2762 " nvshmem_team_split_2d" , " nvshmem_team_destroy" ,
63+ // RMA
64+ FOR_ALL_STANDARD_RMA_TYPES (, put) /* nvshmem_TYPENAME_put*/ ,
65+ FOR_ALL_STANDARD_RMA_TYPES (
66+ x, put_on_stream) /* nvshmemx_TYPENAME_put_on_stream*/ ,
67+ FOR_ALL_STANDARD_RMA_TYPES (x,
68+ put_block) /* nvshmemx_TYPENAME_put_block*/ ,
69+ FOR_ALL_STANDARD_RMA_TYPES (x, put_warp) /* nvshmemx_TYPENAME_put_warp*/ ,
70+ FOR_ALL_SIZES (, put, ) /* nvshmem_putSIZE*/ ,
71+ FOR_ALL_SIZES (x, put, _on_stream) /* nvshmemx_putSIZE_on_stream*/ ,
72+ FOR_ALL_SIZES (x, put, _block) /* nvshmem_putSIZE_block*/ ,
73+ FOR_ALL_SIZES (x, put, _warp) /* nvshmem_putSIZE_warp*/ ,
74+ FOR_ALL_STANDARD_RMA_TYPES (, iput) /* nvshmem_TYPENAME_iput*/ ,
75+ FOR_ALL_STANDARD_RMA_TYPES (
76+ x, iput_on_stream) /* nvshmemx_TYPENAME_iput_on_stream*/ ,
77+ FOR_ALL_STANDARD_RMA_TYPES (x,
78+ iput_block) /* nvshmemx_TYPENAME_iput_block*/ ,
79+ FOR_ALL_STANDARD_RMA_TYPES (x,
80+ iput_warp) /* nvshmemx_TYPENAME_iput_warp*/ ,
81+ FOR_ALL_SIZES (, iput, ) /* nvshmem_iputSIZE*/ ,
82+ FOR_ALL_SIZES (x, iput, _on_stream) /* nvshmem_iputSIZE_on_stream*/ ,
83+ FOR_ALL_SIZES (x, iput, _block) /* nvshmem_iputSIZE_block*/ ,
84+ FOR_ALL_SIZES (x, iput, _warp) /* nvshmem_iputSIZE_warp*/ ,
85+ " nvshmem_putmem" , " nvshmemx_putmem_on_stream" , " nvshmemx_putmem_block" ,
86+ " nvshmemx_putmem_warp" ,
87+ FOR_ALL_STANDARD_RMA_TYPES (, p) /* nvshmem_TYPENAME_p*/ ,
88+ FOR_ALL_STANDARD_RMA_TYPES (, get) /* nvshmem_TYPENAME_get*/ ,
89+ FOR_ALL_STANDARD_RMA_TYPES (
90+ x, get_on_stream) /* nvshmemx_TYPENAME_get_on_stream*/ ,
91+ FOR_ALL_STANDARD_RMA_TYPES (x,
92+ get_block) /* nvshmemx_TYPENAME_get_block*/ ,
93+ FOR_ALL_STANDARD_RMA_TYPES (x, get_warp) /* nvshmemx_TYPENAME_get_warp*/ ,
94+ FOR_ALL_SIZES (, get, ) /* nvshmem_getSIZE*/ ,
95+ FOR_ALL_SIZES (x, get, _on_stream) /* nvshmem_getSIZE_on_stream*/ ,
96+ FOR_ALL_SIZES (x, get, _block) /* nvshmem_getSIZE_block*/ ,
97+ FOR_ALL_SIZES (x, get, _warp) /* nvshmem_getSIZE_warp*/ ,
98+ FOR_ALL_STANDARD_RMA_TYPES (, iget) /* nvshmem_TYPENAME_iget*/ ,
99+ FOR_ALL_STANDARD_RMA_TYPES (
100+ x, iget_on_stream) /* nvshmemx_TYPENAME_iget_on_stream*/ ,
101+ FOR_ALL_STANDARD_RMA_TYPES (x,
102+ iget_block) /* nvshmemx_TYPENAME_iget_block*/ ,
103+ FOR_ALL_STANDARD_RMA_TYPES (x,
104+ iget_warp) /* nvshmemx_TYPENAME_iget_warp*/ ,
105+ FOR_ALL_SIZES (, iget, ) /* nvshmem_igetSIZE*/ ,
106+ FOR_ALL_SIZES (x, iget, _on_stream) /* nvshmem_igetSIZE_on_stream*/ ,
107+ FOR_ALL_SIZES (x, iget, _block) /* nvshmem_igetSIZE_block*/ ,
108+ FOR_ALL_SIZES (x, iget, _warp) /* nvshmem_igetSIZE_warp*/ ,
109+ " nvshmem_getmem" , " nvshmemx_getmem_on_stream" , " nvshmemx_getmem_block" ,
110+ " nvshmemx_getmem_warp" ,
111+ FOR_ALL_STANDARD_RMA_TYPES (, g) /* nvshmem_TYPENAME_g*/ ,
28112 // Nonblocking RMA
29- " nvshmem_putmem_nbi" ,
113+ FOR_ALL_STANDARD_RMA_TYPES (, put_nbi) /* nvshmem_TYPENAME_put_nbi*/ ,
114+ FOR_ALL_STANDARD_RMA_TYPES (
115+ x, put_nbi_on_stream) /* nvshmemx_TYPENAME_put_nbi_on_stream*/ ,
116+ FOR_ALL_STANDARD_RMA_TYPES (
117+ x, put_nbi_block) /* nvshmemx_TYPENAME_put_nbi_block*/ ,
118+ FOR_ALL_STANDARD_RMA_TYPES (
119+ x, put_nbi_warp) /* nvshmemx_TYPENAME_put_nbi_warp*/ ,
120+ FOR_ALL_SIZES (, put, _nbi) /* nvshmem_putSIZE_nbi*/ ,
121+ FOR_ALL_SIZES (x, put,
122+ _nbi_on_stream) /* nvshmemx_putSIZE_nbi_on_stream*/ ,
123+ FOR_ALL_SIZES (x, put, _nbi_block) /* nvshmem_putSIZE_nbi_block*/ ,
124+ FOR_ALL_SIZES (x, put, _nbi_warp) /* nvshmem_putSIZE_nbi_warp*/ ,
125+ " nvshmem_putmem_nbi" , " nvshmemx_putmem_nbi_on_stream" ,
126+ " nvshmemx_putmem_nbi_block" , " nvshmemx_putmem_nbi_warp" ,
127+ FOR_ALL_STANDARD_RMA_TYPES (, get_nbi) /* nvshmem_TYPENAME_get_nbi*/ ,
128+ FOR_ALL_STANDARD_RMA_TYPES (
129+ x, get_nbi_on_stream) /* nvshmemx_TYPENAME_get_nbi_on_stream*/ ,
130+ FOR_ALL_STANDARD_RMA_TYPES (
131+ x, get_nbi_block) /* nvshmemx_TYPENAME_get_nbi_block*/ ,
132+ FOR_ALL_STANDARD_RMA_TYPES (
133+ x, get_nbi_warp) /* nvshmemx_TYPENAME_get_nbi_warp*/ ,
134+ FOR_ALL_SIZES (, get, _nbi) /* nvshmem_getSIZE_nbi*/ ,
135+ FOR_ALL_SIZES (x, get,
136+ _nbi_on_stream) /* nvshmemx_getSIZE_nbi_on_stream*/ ,
137+ FOR_ALL_SIZES (x, get, _nbi_block) /* nvshmem_getSIZE_nbi_block*/ ,
138+ FOR_ALL_SIZES (x, get, _nbi_warp) /* nvshmem_getSIZE_nbi_warp*/ ,
139+ " nvshmem_getmem_nbi" , " nvshmemx_getmem_nbi_on_stream" ,
140+ " nvshmemx_getmem_nbi_block" , " nvshmemx_getmem_nbi_warp" ,
141+ // Memory Ordering
142+ " nvshmem_fence" , " nvshmem_quiet" , " nvshmemx_quiet_on_stream" ,
143+ // Collective Operations
144+ " nvshmemx_barrier_all_on_stream" ,
30145 // Signalling Operations
31146 " nvshmemx_signal_op" , " nvshmem_signal_wait_until" ,
32147 " nvshmem_putmem_signal_nbi" );
@@ -112,3 +227,8 @@ void clang::dpct::NVSHMEMRule::runRule(
112227 emplaceTransformation (EA .getReplacement ());
113228 EA .applyAllSubExprRepl ();
114229}
230+
231+ #undef STRINGIZE
232+ #undef EXPAND_AND_STRINGIZE
233+ #undef FOR_ALL_STANDARD_RMA_TYPES
234+ #undef FOR_ALL_SIZES
0 commit comments