@@ -2695,12 +2695,12 @@ void FunctionCallRule::runRule(const MatchFinder::MatchResult &Result) {
26952695EventAPICallRule *EventAPICallRule::CurrentRule = nullptr ;
26962696void EventAPICallRule::registerMatcher (MatchFinder &MF ) {
26972697 auto eventAPIName = [&]() {
2698- return hasAnyName (
2699- " cudaEventCreate " , " cudaEventCreateWithFlags " , " cudaEventDestroy " ,
2700- " cudaEventRecord " , " cudaEventElapsedTime " , " cudaEventSynchronize " ,
2701- " cudaEventQuery " , " cuEventCreate " , " cuEventRecord " ,
2702- " cuEventSynchronize " , " cuEventQuery " , " cuEventElapsedTime " ,
2703- " cuEventDestroy_v2" );
2698+ return hasAnyName (" cudaEventCreate " , " cudaEventCreateWithFlags " ,
2699+ " cudaEventDestroy " , " cudaEventRecord " ,
2700+ " cudaEventRecordWithFlags " , " cudaEventElapsedTime " ,
2701+ " cudaEventSynchronize " , " cudaEventQuery " , " cuEventCreate " ,
2702+ " cuEventRecord " , " cuEventSynchronize " , " cuEventQuery " ,
2703+ " cuEventElapsedTime " , " cuEventDestroy_v2" );
27042704 };
27052705
27062706 MF .addMatcher (
@@ -3095,7 +3095,8 @@ void EventAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
30953095 }
30963096 std::string ReplStr = MapNames::getDpctNamespace () + " sycl_event_query" ;
30973097 emplaceTransformation (new ReplaceCalleeName (CE , std::move (ReplStr)));
3098- } else if (FuncName == " cudaEventRecord" || FuncName == " cuEventRecord" ) {
3098+ } else if (FuncName == " cudaEventRecord" || FuncName == " cuEventRecord" ||
3099+ FuncName == " cudaEventRecordWithFlags" ) {
30993100 handleEventRecord (CE , Result, IsAssigned);
31003101 } else if (FuncName == " cudaEventElapsedTime" ||
31013102 FuncName == " cuEventElapsedTime" ) {
@@ -3260,7 +3261,22 @@ void EventAPICallRule::findEventAPI(const Stmt *Node, const CallExpr *&Call,
32603261void EventAPICallRule::handleEventRecordWithProfilingEnabled (
32613262 const CallExpr *CE , const MatchFinder::MatchResult &Result,
32623263 bool IsAssigned) {
3263- auto StreamArg = CE ->getArg (CE ->getNumArgs () - 1 );
3264+ int NumArgs = CE ->getNumArgs ();
3265+ const Expr *StreamArg = CE ->getArg (NumArgs - 1 );
3266+ if (NumArgs == 3 ) { // Special process for cudaEventRecordWithFlags().
3267+ StreamArg = CE ->getArg (1 );
3268+ auto APIName = CE ->getDirectCallee ()->getNameInfo ().getName ().getAsString ();
3269+ const Expr *SecArg = CE ->getArg (2 );
3270+ ExprAnalysis Arg2EA (SecArg);
3271+ auto Arg2Name = Arg2EA.getReplacedString ();
3272+ if (Arg2Name != " cudaEventRecordDefault" ) {
3273+ report (CE ->getBeginLoc (), Diagnostics::NOT_SUPPORTED_PARAMETER , false ,
3274+ APIName, " parameter " + Arg2Name + " is unsupported" );
3275+ return ;
3276+ }
3277+ emplaceTransformation (removeArg (CE , 2 , *Result.SourceManager ));
3278+ }
3279+
32643280 auto EventArg = CE ->getArg (0 );
32653281 ExprAnalysis StreamEA (StreamArg);
32663282 ExprAnalysis Arg0EA (EventArg);
0 commit comments