@@ -216,9 +216,17 @@ void MemVarRefMigrationRule::runRule(const MatchFinder::MatchResult &Result) {
216216 }
217217 }
218218 }
219- if (!HasTypeCasted && Decl->hasAttr <CUDAConstantAttr>() &&
220- (MemVarRef->getType ()->getTypeClass () ==
221- Type::TypeClass::ConstantArray)) {
219+ auto FD = dpct::DpctGlobalInfo::findAncestor<FunctionDecl>(MemVarRef);
220+ auto CE = dpct::DpctGlobalInfo::findAncestor<CallExpr>(MemVarRef);
221+ if (auto VD =dyn_cast<VarDecl>(MemVarRef->getDecl ()); FD && VD &&
222+ !VD ->isLocalVarDeclOrParm () &&
223+ !isGlobalOrDeviceFuncDecl (FD )) {
224+ if (CE &&
225+ !DpctGlobalInfo::isInCudaPath (CE ->getCalleeDecl ()->getBeginLoc ()))
226+ emplaceTransformation (new InsertAfterStmt (MemVarRef, " .get_ptr()" ));
227+ } else if (!HasTypeCasted && Decl->hasAttr <CUDAConstantAttr>() &&
228+ (MemVarRef->getType ()->getTypeClass () ==
229+ Type::TypeClass::ConstantArray)) {
222230 const Expr *RHS = getRHSOfTheNonConstAssignedVar (MemVarRef);
223231 if (RHS ) {
224232 auto Range = GetReplRange (RHS );
@@ -235,7 +243,7 @@ void MemVarRefMigrationRule::runRule(const MatchFinder::MatchResult &Result) {
235243 if (VD == nullptr )
236244 return ;
237245 auto Var = Global.findMemVarInfo (VD );
238- if (Func-> hasAttr <CUDAGlobalAttr>() || Func-> hasAttr <CUDADeviceAttr>( )) {
246+ if (isGlobalOrDeviceFuncDecl ( Func)) {
239247 if (DpctGlobalInfo::useGroupLocalMemory () &&
240248 VD ->hasAttr <CUDASharedAttr>() && VD ->getStorageClass () != SC_Extern) {
241249 if (!Var)
@@ -829,7 +837,7 @@ void MemVarAnalysisRule::runRule(const MatchFinder::MatchResult &Result) {
829837 return ;
830838 }
831839 auto Var = MemVarInfo::buildMemVarInfo (VD );
832- if (Func-> hasAttr <CUDAGlobalAttr>() || Func-> hasAttr <CUDADeviceAttr>( )) {
840+ if (isGlobalOrDeviceFuncDecl ( Func)) {
833841 if (!(DpctGlobalInfo::useGroupLocalMemory () &&
834842 VD ->hasAttr <CUDASharedAttr>() &&
835843 VD ->getStorageClass () != SC_Extern)) {
@@ -1025,7 +1033,7 @@ void ZeroLengthArrayRule::runRule(const MatchFinder::MatchResult &Result) {
10251033 const clang::FunctionDecl *FD = DpctGlobalInfo::getParentFunction (TL );
10261034 if (FD ) {
10271035 // Check if the array is in device code
1028- if (!(FD -> getAttr <CUDADeviceAttr>()) && !( FD -> getAttr <CUDAGlobalAttr>() ))
1036+ if (!isGlobalOrDeviceFuncDecl (FD ))
10291037 return ;
10301038 }
10311039 }
0 commit comments