@@ -5793,6 +5793,10 @@ void SemaSYCL::MarkDevices() {
57935793
57945794void SemaSYCL::ProcessFreeFunction (FunctionDecl *FD ) {
57955795 if (isFreeFunction (FD )) {
5796+ if (FD ->isVariadic ()) {
5797+ Diag (FD ->getLocation (), diag::err_free_function_variadic_args);
5798+ return ;
5799+ }
57965800 SyclKernelDecompMarker DecompMarker (*this );
57975801 SyclKernelFieldChecker FieldChecker (*this );
57985802 SyclKernelUnionChecker UnionChecker (*this );
@@ -6491,16 +6495,36 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
64916495
64926496class FreeFunctionPrinter {
64936497 raw_ostream &O;
6498+ PrintingPolicy &Policy;
64946499 bool NSInserted = false ;
64956500
64966501public:
6497- FreeFunctionPrinter (raw_ostream &O) : O(O) {}
6502+ FreeFunctionPrinter (raw_ostream &O, PrintingPolicy &PrintPolicy)
6503+ : O(O), Policy(PrintPolicy) {}
6504+
6505+ // / Emits the function declaration of template free function.
6506+ // / \param FTD The function declaration to print.
6507+ // / \param S Sema object.
6508+ void printFreeFunctionDeclaration (FunctionTemplateDecl *FTD ,
6509+ clang::SemaSYCL &S) {
6510+ const FunctionDecl *TemplatedDecl = FTD ->getTemplatedDecl ();
6511+ if (!TemplatedDecl)
6512+ return ;
6513+ const std::string TemplatedDeclParams =
6514+ getTemplatedParamList (TemplatedDecl->parameters (), Policy);
6515+ const std::string TemplateParams =
6516+ getTemplateParameters (FTD ->getTemplateParameters (), S);
6517+ printFreeFunctionDeclaration (TemplatedDecl, TemplatedDeclParams,
6518+ TemplateParams);
6519+ }
64986520
64996521 // / Emits the function declaration of a free function.
65006522 // / \param FD The function declaration to print.
65016523 // / \param Args The arguments of the function.
6524+ // / \param TemplateParameters The template parameters of the function.
65026525 void printFreeFunctionDeclaration (const FunctionDecl *FD ,
6503- const std::string &Args) {
6526+ const std::string &Args,
6527+ std::string_view TemplateParameters = " " ) {
65046528 const DeclContext *DC = FD ->getDeclContext ();
65056529 if (DC ) {
65066530 // if function in namespace, print namespace
@@ -6510,6 +6534,7 @@ class FreeFunctionPrinter {
65106534 // function
65116535 NSInserted = true ;
65126536 }
6537+ O << TemplateParameters;
65136538 O << FD ->getReturnType ().getAsString () << " " ;
65146539 O << FD ->getNameAsString () << " (" << Args << " );" ;
65156540 if (NSInserted) {
@@ -6533,6 +6558,95 @@ class FreeFunctionPrinter {
65336558 if (NSInserted)
65346559 PrintNamespaces (O, FD , /* isPrintNamesOnly=*/ true );
65356560 O << FD ->getIdentifier ()->getName ().data ();
6561+ if (FD ->getPrimaryTemplate ()) {
6562+ std::string Buffer;
6563+ llvm::raw_string_ostream StringStream (Buffer);
6564+ const TemplateArgumentList *TAL = FD ->getTemplateSpecializationArgs ();
6565+ ArrayRef<TemplateArgument> A = TAL ->asArray ();
6566+ bool FirstParam = true ;
6567+ for (const auto &X : A) {
6568+ if (FirstParam)
6569+ FirstParam = false ;
6570+ else if (X.getKind () == TemplateArgument::Pack) {
6571+ for (const auto &PackArg : X.pack_elements ()) {
6572+ StringStream << " , " ;
6573+ PackArg.print (Policy, StringStream, true );
6574+ }
6575+ continue ;
6576+ } else {
6577+ StringStream << " , " ;
6578+ }
6579+
6580+ X.print (Policy, StringStream, true );
6581+ }
6582+ StringStream.flush ();
6583+ if (Buffer.front () != ' <' )
6584+ Buffer = " <" + Buffer + " >" ;
6585+ O << Buffer;
6586+ }
6587+ }
6588+
6589+ private:
6590+ // / Helper method to get arguments of templated function as a string
6591+ // / \param Parameters Array of parameters of the function.
6592+ // / \param Policy Printing policy.
6593+ // / returned string Example:
6594+ // / \code
6595+ // / template <typename T1, typename T2>
6596+ // / void foo(T1 a, T2 b);
6597+ // / \endcode
6598+ // / returns string "T1 a, T2 b"
6599+ std::string
6600+ getTemplatedParamList (const llvm::ArrayRef<clang::ParmVarDecl *> Parameters,
6601+ PrintingPolicy Policy) {
6602+ bool FirstParam = true ;
6603+ llvm::SmallString<128 > ParamList;
6604+ llvm::raw_svector_ostream ParmListOstream{ParamList};
6605+ Policy.SuppressTagKeyword = true ;
6606+ for (ParmVarDecl *Param : Parameters) {
6607+ if (FirstParam)
6608+ FirstParam = false ;
6609+ else
6610+ ParmListOstream << " , " ;
6611+ ParmListOstream << Param->getType ().getAsString (Policy);
6612+ ParmListOstream << " " << Param->getNameAsString ();
6613+ }
6614+ return ParamList.str ().str ();
6615+ }
6616+
6617+ // / Helper method to get text representation of the template parameters.
6618+ // / Throws an error if the last parameter is a pack.
6619+ // / \param TPL The template parameter list.
6620+ // / \param S The SemaSYCL object.
6621+ // / Example:
6622+ // / \code
6623+ // / template <typename T1, class T2>
6624+ // / void foo(T1 a, T2 b);
6625+ // / \endcode
6626+ // / returns string "template <typename T1, class T2> "
6627+ std::string getTemplateParameters (const clang::TemplateParameterList *TPL ,
6628+ SemaSYCL &S) {
6629+ std::string TemplateParams{" template <" };
6630+ bool FirstParam{true };
6631+ for (NamedDecl *Param : *TPL ) {
6632+ if (!FirstParam)
6633+ TemplateParams += " , " ;
6634+ FirstParam = false ;
6635+ if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) {
6636+ TemplateParams +=
6637+ TemplateParam->wasDeclaredWithTypename () ? " typename " : " class " ;
6638+ if (TemplateParam->isParameterPack ())
6639+ TemplateParams += " ... " ;
6640+ TemplateParams += TemplateParam->getNameAsString ();
6641+ } else if (const auto *NonTypeParam =
6642+ dyn_cast<NonTypeTemplateParmDecl>(Param)) {
6643+ TemplateParams += NonTypeParam->getType ().getAsString ();
6644+ TemplateParams += " " ;
6645+ TemplateParams += NonTypeParam->getNameAsString ();
6646+ }
6647+ }
6648+ TemplateParams += " > " ;
6649+ return TemplateParams;
65366650 }
65376651};
65386652
@@ -6836,11 +6950,16 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
68366950 ParmList += " , " ;
68376951 ParmListWithNamesOstream << " , " ;
68386952 }
6839- Policy.SuppressTagKeyword = true ;
6840- Param->getType ().print (ParmListWithNamesOstream, Policy);
6841- Policy.SuppressTagKeyword = false ;
6842- ParmListWithNamesOstream << " " << Param->getNameAsString ();
6843- ParmList += Param->getType ().getCanonicalType ().getAsString (Policy);
6953+ if (Param->isParameterPack ()) {
6954+ ParmListWithNamesOstream << " Args... args" ;
6955+ ParmList += " Args ..." ;
6956+ } else {
6957+ Policy.SuppressTagKeyword = true ;
6958+ Param->getType ().print (ParmListWithNamesOstream, Policy);
6959+ Policy.SuppressTagKeyword = false ;
6960+ ParmListWithNamesOstream << " " << Param->getNameAsString ();
6961+ ParmList += Param->getType ().getCanonicalType ().getAsString (Policy);
6962+ }
68446963 }
68456964 ParmListWithNamesOstream.flush ();
68466965 FunctionTemplateDecl *FTD = K.SyclKernel ->getPrimaryTemplate ();
@@ -6876,30 +6995,14 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
68766995 // template arguments that match default template arguments while printing
68776996 // template-ids, even if the source code doesn't reference them.
68786997 Policy.EnforceDefaultTemplateArgs = true ;
6879- FreeFunctionPrinter FFPrinter (O);
6998+ FreeFunctionPrinter FFPrinter (O, Policy );
68806999 if (FTD ) {
6881- FTD ->print (O, Policy);
6882- O << " ;\n " ;
7000+ FFPrinter.printFreeFunctionDeclaration (FTD , S);
68837001 } else {
68847002 FFPrinter.printFreeFunctionDeclaration (K.SyclKernel , ParmListWithNames);
68857003 }
68867004
68877005 FFPrinter.printFreeFunctionShim (K.SyclKernel , ShimCounter, ParmList);
6888- if (FTD ) {
6889- const TemplateArgumentList *TAL =
6890- K.SyclKernel ->getTemplateSpecializationArgs ();
6891- ArrayRef<TemplateArgument> A = TAL ->asArray ();
6892- bool FirstParam = true ;
6893- O << " <" ;
6894- for (const auto &X : A) {
6895- if (FirstParam)
6896- FirstParam = false ;
6897- else
6898- O << " , " ;
6899- X.print (Policy, O, true );
6900- }
6901- O << " >" ;
6902- }
69037006 O << " ;\n " ;
69047007 O << " }\n " ;
69057008 Policy.SuppressDefaultTemplateArgs = true ;
0 commit comments