diff --git a/.github/workflows/test-all.yaml b/.github/workflows/test-all.yaml index edd5c2ac..596b3554 100644 --- a/.github/workflows/test-all.yaml +++ b/.github/workflows/test-all.yaml @@ -42,7 +42,7 @@ jobs: - name: Package Dependencies run: R -q -e 'remotes::install_deps("nCompiler", dependencies=TRUE)' - name: Install inline - run: R -q -e 'remotes::install_cran("inline")' + run: R -q -e 'remotes::install_cran(c("inline", "nimble"))' - name: Build Package run: | R CMD build nCompiler @@ -53,7 +53,6 @@ jobs: testthat::test_dir("nCompiler/tests/testthat/uncompiled_tests", reporter = "summary") testthat::test_dir("nCompiler/tests/testthat/nCompile_tests", reporter = "summary") testthat::test_dir("nCompiler/tests/testthat/cpp_tests", reporter = "summary") - testthat::test_dir("nCompiler/tests/testthat/specificOp_tests", reporter = "summary") shell: Rscript {0} test-nCompile-features: @@ -77,6 +76,7 @@ jobs: run: | library(nCompiler) testthat::test_dir("nCompiler/tests/testthat/predefined_tests", reporter = "summary") + testthat::test_dir("nCompiler/tests/testthat/specificOp_tests", reporter = "summary") shell: Rscript {0} test-nClass: diff --git a/nCompiler/R/NC.R b/nCompiler/R/NC.R index ccf0c75a..bfcf5670 100644 --- a/nCompiler/R/NC.R +++ b/nCompiler/R/NC.R @@ -117,7 +117,8 @@ nClass <- function(classname, list(exportName = NULL, interface = "full", interfaceMembers = NULL, depends = list(), - inherit = list()), + inherit = list(), + nClass_inherit = list()), compileInfo ) if(missing(classname)) @@ -163,8 +164,6 @@ nClass <- function(classname, # so if provided in the nClass call, we stick it in new_env. # (That is not the only reason for new_env.) # Also note that the inherit arg is for nClass inheritance. The compileInfo$inherit element is for hard-coded C++ inheritance statements. - inheritQ <- substitute(inherit) - inherit_provided <- !is.null(inheritQ) #if(!is.null(inherit)) new_env$.inherit_obj <- inherit new_env$.NCinternals <- internals # Uncompiled behavior for Cpublic fields needs to be handled. diff --git a/nCompiler/R/NC_CompilerClass.R b/nCompiler/R/NC_CompilerClass.R index 88a1f7ed..b491f6c0 100644 --- a/nCompiler/R/NC_CompilerClass.R +++ b/nCompiler/R/NC_CompilerClass.R @@ -51,10 +51,15 @@ NC_CompilerClass <- R6::R6Class( methodNames <- myNCinternals$methodNames for(m in methodNames) { thisMethod <- NCgenerator$public_methods[[m]] + thisName <- NULL if(isConstructor(thisMethod)) { + #NFinternals(thisMethod)$cpp_code_name <- self$name NFinternals(thisMethod)$cpp_code_name <- self$name + } else { + thisName <- myNCinternals$all_methodName_to_cpp_code_name[[m]] } - NFcompilers[[m]] <<- NF_CompilerClass$new(f = thisMethod) + NFcompilers[[m]] <<- NF_CompilerClass$new(f = thisMethod, + name = thisName) } }, setupMethodSymbolTables = function() { diff --git a/nCompiler/R/NC_InternalsClass.R b/nCompiler/R/NC_InternalsClass.R index 9820faa0..885d22c8 100644 --- a/nCompiler/R/NC_InternalsClass.R +++ b/nCompiler/R/NC_InternalsClass.R @@ -13,6 +13,8 @@ NC_InternalsClass <- R6::R6Class( allFieldNames_self = character(), # not including inherited methods classname = character(), cpp_classname = character(), + all_methodName_to_cpp_code_name = list(), + orig_methodName_to_cpp_code_name = list(), compileInfo = list(), inherit_base_provided = FALSE, # compileInfo will include interface ("full", "generic", or "none"), @@ -27,6 +29,9 @@ NC_InternalsClass <- R6::R6Class( env = NULL, inheritQ = NULL, process_inherit_done = FALSE, + virtualMethodNames_self = character(), # will be used when checking inherited method validity, only for locally implemented methods + virtualMethodNames = character(), + check_inherit_done = FALSE, initialize = function(classname, Cpublic, isOnlyC = FALSE, @@ -45,10 +50,12 @@ NC_InternalsClass <- R6::R6Class( numEntries <- length(Cpublic) if(numEntries) { isMethod <- rep(FALSE, numEntries) + isVirtual <- rep(FALSE, numEntries) for(i in seq_along(Cpublic)) { if(isNF(Cpublic[[i]])) { isMethod[i] <- TRUE - NFinternals(Cpublic[[i]])$isMethod <- TRUE + isVirtual[i] <- isTRUE(NFinternals(Cpublic[[i]])$compileInfo$virtual) + # NFinternals(Cpublic[[i]])$isMethod <- TRUE next; } if(is.function(Cpublic[[i]])) { @@ -57,17 +64,25 @@ NC_InternalsClass <- R6::R6Class( call. = FALSE) } } + self$virtualMethodNames <- names(Cpublic)[isVirtual] self$symbolTable <- argTypeList2symbolTable(Cpublic[!isMethod], evalEnv = env) self$cppSymbolNames <- Rname2CppName(symbolTable$getSymbolNames()) self$methodNames <- names(Cpublic)[isMethod] self$allMethodNames_self <- methodNames + self$virtualMethodNames_self <- names(Cpublic)[isVirtual] self$allMethodNames <- methodNames self$fieldNames <- names(Cpublic)[!isMethod] self$allFieldNames_self <- fieldNames self$allFieldNames <- fieldNames - if(!is.null(self$compileInfo$inherit$base)) - self$inherit_base_provided <- TRUE + self$orig_methodName_to_cpp_code_name <- structure(vector("list", length=length(methodNames)), + names = methodNames) + for(mN in methodNames) { + self$orig_methodName_to_cpp_code_name[[mN]] <- NFinternals(Cpublic[[mN]])$cpp_code_name + } } + # An over-riding base class can be provided either through inherit or nClass_inherit. + if(!is.null(self$compileInfo$inherit$base) || !is.null(self$compileInfo$nClass_inherit$base)) + self$inherit_base_provided <- TRUE if(!is.null(enableDerivs)) { if(!is.list(enableDerivs)) enableDerivs <- as.list(enableDerivs) @@ -84,29 +99,38 @@ NC_InternalsClass <- R6::R6Class( # These are steps that need to be done after all classes are defined # and do not require recursion up the inheritance tree. if(!is.null(self$inheritQ)) { - inherit_obj <- eval(self$inheritQ, envir = self$env) + inherit_obj <- eval(self$inheritQ, envir = self$env) #inheritQ can be an expression but it must always return the same generator object if(!isNCgenerator(inherit_obj)) stop("An inherit argument that was provided to nClass is not nClass generator.") self$inheritNCinternals <- NCinternals(inherit_obj) message("add check that base class has interface 'none'") - if(!self$inherit_base_provided) - self$compileInfo$inherit$base <- paste("public", - self$inheritNCinternals$cpp_classname) - process_inherit_done <- FALSE - } else { - process_inherit_done <- TRUE + if(!self$inherit_base_provided) { + self$compileInfo$nClass_inherit$base <- self$inheritNCinternals$cpp_classname # don't paste "public" because it will go in interface_resolver< + } } + self$process_inherit_done <- FALSE + self$check_inherit_done <- FALSE }, process_inherit = function() { # These are steps that need to be done after connect_inherit # and require recursion up the inheritance tree, using flags. + # TO-DO: Error trap in methods of same name but different argument signatures. if(self$process_inherit_done) return() if(!is.null(self$inheritQ)) { self$inheritNCinternals$process_inherit() self$symbolTable$setParentST(self$inheritNCinternals$symbolTable) - self$allMethodNames <- c(self$allMethodNames_self, self$inheritNCinternals$allMethodNames) + newMethodNames <- setdiff(self$allMethodNames_self, + self$inheritNCinternals$allMethodNames) + self$allMethodNames <- c(newMethodNames, self$inheritNCinternals$allMethodNames) + self$all_methodName_to_cpp_code_name <- c(self$orig_methodName_to_cpp_code_name[newMethodNames], + self$inheritNCinternals$all_methodName_to_cpp_code_name) self$allFieldNames <- c(self$allFieldNames_self, self$inheritNCinternals$allFieldNames) - } + } else { + self$allMethodNames <- self$allMethodNames_self + self$all_methodName_to_cpp_code_name <- self$orig_methodName_to_cpp_code_name + self$allFieldNames <- self$allFieldNames_self + self$symbolTable$setParentST(NULL) + } self$process_inherit_done <- TRUE } ) diff --git a/nCompiler/R/NC_Utils.R b/nCompiler/R/NC_Utils.R index 2431786f..cd45d975 100644 --- a/nCompiler/R/NC_Utils.R +++ b/nCompiler/R/NC_Utils.R @@ -103,4 +103,94 @@ NC_find_method <- function(NCgenerator, name, inherits=TRUE) { } } method +} + +# This function will be called from nCompile after going through the +# NCinternals for all units and calling connect_inherit and then process_inherit +# (with all connect_inherits called before all process_inherits) +# At that point we are ready to check for disallowed method overloading +# (we don't allow the same method name in different levels of the hierarchy unless it is virtual +# and all signatures match, i.e. we don't allow C-style overloading because it wouldn't work in +# uncompiled (R) execution. This can be changed by an option, indicating one wants only the +# compiled behavior and doesn't care about uncompiled inconsistency.) +# and disallowed duplicate member variable names (for a similar reason: In C++ +# different levels of a hierarchy could each have their own "x", but that is not +# the case in an R6 class hierarchy, so we disallow it unless a user allows it by option). +# +# The previous calls will have initialized NCint$check_inherit_done to FALSE +NC_check_inheritance <- function(NCgenerator) { + allow_method_overloading <- isTRUE(get_nOption('allow_method_overloading')) + allow_inherited_field_duplicates <- isTRUE(get_nOption('allow_inherited_field_duplicates')) + if(allow_method_overloading && allow_inherited_field_duplicates) return(invisible(NULL)) + + if(!isNCgenerator(NCgenerator)) + stop("Input to NC_check_inheritance must be a nClass generator.") + NCint <- NCinternals(NCgenerator) + + if(is.null(NCint$inheritQ)) { + NCint$check_inherit_done <- TRUE + NCint$virtualMethodNames <- NCint$virtualMethodNames_self + return(NCint$virtualMethodNames_self) + } + if(NCint$check_inherit_done) return(NCint$virtualMethodNames) + # At this point, we have inheritance and have checked this NCgenerator yet. + inheritNCinternals <- NCint$inheritNCinternals + inheritNCgenerator <- eval(NCint$inheritQ, envir = NCint$env) + # Recurse up the inheritance ladder + # A design dilemma here was that the virtual marker is in the NFinternals, + # which can be accessed from the NCgenerator but not the NCinternals. + # That is why this function is not a method of NCinternals. + inherit_virtualMethodNames <- NC_check_inheritance(inheritNCgenerator) + new_virtualMethodNames <- character() + + if(!allow_method_overloading) { + local_virtualMethodNames <- NCint$virtualMethodNames_self + # default: check for disallowed method overloading + allMethodNames <- NCint$allMethodNames + for(mN in allMethodNames) { + # if a method is not in the self method names, it was inherited, so there is nothing to check + if(!(mN %in% NCint$allMethodNames_self)) next + if(!(mN %in% inheritNCinternals$allMethodNames)) { + # current level is the first one with this method name, so here we tag its virtual status + new_virtualMethodNames <- c(new_virtualMethodNames, mN) + next + } + # At this point the current level has the method and it is inherited + localMethod <- NCgenerator$public_methods[[mN]] + inheritMethod <- NC_find_method(inheritNCgenerator, mN) + if(is.null(inheritMethod)) + stop("Problem finding inherited method ", mN, " in NC_check_inheritance.", call. = FALSE) + if(!NF_types_match(localMethod, inheritMethod)) + stop(paste0("Method ", mN, " does not have the same arguments names,", + " and/or argument types, and/or returnType as a base class method of the same name.", + " Methods of the same name in an nClass hierarchy must have all of these the same", + " and the top-level one must be marked with compileInfo(virtual=TRUE).", + " (If you want to allow method overloading in C++ by turning off these requirements,", + " set nOptions(allow_method_overloading=TRUE)"), + call. = FALSE) + if(!(mN %in% inherit_virtualMethodNames)) + stop(paste0("Method ", mN, " is inherited, so", + " it must be marked with compileInfo(virtual=TRUE) in the top-level nClass that includes it.", + " That does not appear to be the case.", + " (If you want to allow method over-loading in C++ by turning off this requirement,", + " set nOptions(allow_method_overloading=TRUE)"), + call. = FALSE) + } + } + if(!allow_inherited_field_duplicates) { + # This would be slightly more efficient to do in NC_InternalsClass::process_inherit + # but we keep it here so all the checking is together here. + # + # If any of my own field names already existed from my inherited classes, + # that's not allowed + badFields <- NCint$allFieldNames_self %in% inheritNCinternals$allFieldNames + if(any(badFields)) + stop(paste0("Problem with field(s): ", paste(NCint$allFieldNames_self[badFields], collapse = ", "), + ". Fields with the same name are not allowed in base and inherited classes.", + " (If you want to allow local fields of the same name in C++ by turning off this requirement,", + " set nOptions(allow_inherited_field_duplicates=TRUE)"), + call. = FALSE ) + } + NCint$check_inherit_done <- TRUE + c(new_virtualMethodNames, inherit_virtualMethodNames) } \ No newline at end of file diff --git a/nCompiler/R/NF_CompilerClass.R b/nCompiler/R/NF_CompilerClass.R index fc8b3818..9cd89dc7 100644 --- a/nCompiler/R/NF_CompilerClass.R +++ b/nCompiler/R/NF_CompilerClass.R @@ -27,8 +27,7 @@ NF_CompilerClass <- R6::R6Class( derivsContent = list(), initialTypeInferenceDone = FALSE, initialize = function(f = NULL, - ## funName, - # const = FALSE, + name = NULL, # Allow an nClass to set the name of its method. useUniqueNameInCpp = FALSE, compileInfo = NULL) { self$auxEnv <- new.env() @@ -45,9 +44,14 @@ NF_CompilerClass <- R6::R6Class( } else { self$NFinternals <- NFinternals(f) } - self$origName <- NFinternals$uniqueName - if (useUniqueNameInCpp) self$name <- NFinternals$uniqueName - else self$name <- NFinternals$cpp_code_name + self$origName <- NFinternals$uniqueName2 + if(!is.null(name)) { + self$name <- name + } else { + if (useUniqueNameInCpp) self$name <- NFinternals$uniqueName2 + # NB If this is a method of a nClass, its cpp_code_name may be intercepted later but will not be changed here. + else self$name <- NFinternals$cpp_code_name + } self$origRcode <- NFinternals$code self$newRcode <- NFinternals$code self$isAD <- NFinternals$isAD diff --git a/nCompiler/R/NF_InternalsClass.R b/nCompiler/R/NF_InternalsClass.R index 5871f4d1..b6f8ef5a 100644 --- a/nCompiler/R/NF_InternalsClass.R +++ b/nCompiler/R/NF_InternalsClass.R @@ -12,8 +12,10 @@ NF_InternalsClass <- R6::R6Class( returnSym = NULL, control = list(), where = NULL, - isMethod = FALSE, + #isMethod = FALSE, uniqueName = character(), + uniqueName2 = character(), + #cpp_code_name = character(), cpp_code_name = character(), ## template = NULL, replaced with default_matchDef default_matchDef = NULL, @@ -46,9 +48,17 @@ NF_InternalsClass <- R6::R6Class( ## setupVarNames = NULL, ## Ditto where = parent.frame() ) { + ## name is required and is generated by NF() if not provided. ## uniqueName is only needed if this is not a method of a nClass. - if(!missing(name)) - self$uniqueName <- name + if(!missing(name)) { + self$uniqueName <- name + self$uniqueName2 <- paste(name, + nFunctionIDMaker(), + sep = "_") + } else { + stop("NF_InternalsClass needs a name argument.", call. = FALSE) + } + ## uniqueName2 is needed even for methods, to serve as unique keys. if(is.null(compileInfo$C_fun)) { fun_to_use <- fun } else { @@ -85,8 +95,6 @@ NF_InternalsClass <- R6::R6Class( ## e.g. 'print' to 'nPrint'; see 'nKeyWords' list in ## changeKeywords.R self$code <- body(fun_to_use) - if(isTRUE(control$changeKeywords)) - self$code <- nf_changeKeywords(self$code) if(code[[1]] != '{') self$code <- substitute({CODE}, list(CODE=code)) ## check all code except.nCompiler package nFunctions @@ -115,17 +123,33 @@ NF_InternalsClass <- R6::R6Class( self$returnSym <- argType2symbol(returnTypeDecl, origName = "returnType", evalEnv = where) + + # It is important to do this after getting the returnType info + # because this will change integer to nInteger, even in returnType + if(isTRUE(control$changeKeywords)) + self$code <- nf_changeKeywords(self$code) + ## We set the cpp_code_name here so that other nFunctions ## that call this one can determine, during compilation, - ## what this one's cpp function name will be: - if(!is.null(compileInfo$cpp_code_name)) + ## what this one's cpp function name will be. + ## However, if this nFunction is used as a method of a nClass, + ## and if that nClass inherits from another nClass and this + ## nFunction is actually virtual, then the base class's + ## cpp_code_name will be used instead. + ## However, we do not modify it in this NFinternals because + ## it is permitted to use this elsewhere, e.g. on its own + ## or to provide a method to a different nClass. + if(!is.null(compileInfo$cpp_code_name)) { + #self$cpp_code_name <- compileInfo$cpp_code_name self$cpp_code_name <- compileInfo$cpp_code_name - else { + } else { + #self$cpp_code_name <- Rname2CppName(name) self$cpp_code_name <- Rname2CppName(name) - if(isFALSE(predefined)) - self$cpp_code_name <- paste(self$cpp_code_name, - nFunctionIDMaker(), - sep = "_") + # do not uniquify cpp_code_name + # if(isFALSE(predefined)) + # self$cpp_code_name <- paste(self$cpp_code_name, + # nFunctionIDMaker(), + # sep = "_") } ## Unpack enableDerivs into AD self$isAD <- FALSE diff --git a/nCompiler/R/NF_Utils.R b/nCompiler/R/NF_Utils.R index e01b2e32..dc50dca8 100644 --- a/nCompiler/R/NF_Utils.R +++ b/nCompiler/R/NF_Utils.R @@ -56,3 +56,27 @@ nGet <- function(name, where) { # else # NULL } + +NF_types_match <- function(f1, f2) { + if(!isNF(f1) || !isNF(f2)) stop("Arguments to NF_types_match must be nFunctions") + match <- TRUE + NFint1 <- NFinternals(f1) + NFint2 <- NFinternals(f2) + + if(match) { + match <- isTRUE(all.equal(NFint1$returnSym, NFint2$returnSym)) + } + if(match) { + NF1args <- NFint1$argSymTab$getSymbolNames() + NF2args <- NFint2$argSymTab$getSymbolNames() + match <- isTRUE(all.equal(NF1args, NF2args)) + } + if(match) { + for(symName in NF1args) { + if(match) + match <- isTRUE(all.equal(NFint1$argSymTab$getSymbol(symName), + NFint2$argSymTab$getSymbol(symName))) + } + } + match +} diff --git a/nCompiler/R/compile_labelAbstractTypes.R b/nCompiler/R/compile_labelAbstractTypes.R index d8cf8eea..dd7e7ddb 100644 --- a/nCompiler/R/compile_labelAbstractTypes.R +++ b/nCompiler/R/compile_labelAbstractTypes.R @@ -297,7 +297,9 @@ inLabelAbstractTypesEnv( code$name <- '->member' code$args[[2]]$aux$obj_internals <- obj_internals code$args[[2]]$aux$nFunctionName <- innerName - code$args[[2]]$name <- NFinternals(method)$cpp_code_name + #code$args[[2]]$name <- NFinternals(method)$cpp_code_name + code$args[[2]]$name <- NCinternals(code$args[[1]]$type$NCgenerator)$all_methodName_to_cpp_code_name[[innerName]] + obj_internals <- NULL } else { ## Is RHS a field? symbol <- NCinternals(code$args[[1]]$type$NCgenerator)$symbolTable$getSymbol(innerName, inherits=TRUE) @@ -380,7 +382,8 @@ inLabelAbstractTypesEnv( wrapExprClassOperator(code = code, funName = 'nFunctionRef') # name substitution - cpp_code_name <- tgt$cpp_code_name + # cpp_code_name <- tgt$cpp_code_name + cpp_code_name <- NCinternals(obj)$all_methodName_to_cpp_code_name[[code$name]] code$name <- cpp_code_name # class in which function is defined diff --git a/nCompiler/R/compile_normalizeCalls.R b/nCompiler/R/compile_normalizeCalls.R index 020b7f43..ddb99426 100644 --- a/nCompiler/R/compile_normalizeCalls.R +++ b/nCompiler/R/compile_normalizeCalls.R @@ -52,7 +52,7 @@ compile_normalizeCalls <- function(code, # We defer: uniqueName, cpp_code_name cachedOpInfo <- update_cachedOpInfo(code, auxEnv$where) if(cachedOpInfo$case == "nFunction") { - uniqueName <- cachedOpInfo$obj_internals$uniqueName + uniqueName <- cachedOpInfo$obj_internals$uniqueName2 if(length(uniqueName)==0) stop( exprClassProcessingErrorMsg( diff --git a/nCompiler/R/cppDefs_TBB.R b/nCompiler/R/cppDefs_TBB.R index 65929220..7ec9d27c 100644 --- a/nCompiler/R/cppDefs_TBB.R +++ b/nCompiler/R/cppDefs_TBB.R @@ -28,7 +28,7 @@ cppParallelBodyClass <- R6::R6Class( else list() - output <- c(generateClassHeader(name, inheritance), + output <- c(generateClassHeader(name, inheritance, nClass_inheritance), list('public:'), ## In the future we can separate public and private lapply(generateObjectDefs(symbolsToUse), function(x) @@ -176,7 +176,7 @@ cppParallelReduceBodyClass <- R6::R6Class( else list() - output <- c(generateClassHeader(name, inheritance), + output <- c(generateClassHeader(name, inheritance, nClass_inheritance), list('public:'), ## In the future we can separate public and private lapply(generateObjectDefs(symbolsToUse), function(x) diff --git a/nCompiler/R/cppDefs_core.R b/nCompiler/R/cppDefs_core.R index 0c508328..44197c07 100644 --- a/nCompiler/R/cppDefs_core.R +++ b/nCompiler/R/cppDefs_core.R @@ -311,14 +311,14 @@ add_obj_hooks_impl <- function(self) { addGenericInterface_impl <- function(self) { name <- self$name - self$addInheritance(paste0("public genericInterfaceC<", + self$add_nClass_inheritance(paste0("genericInterfaceC<", name, - ">")) + ">"), first=TRUE) # It is ok to have multiple virtual inheritance from genericInterfaceBaseC, # but we clean it up here for slightly simpler code. - if("virtual public genericInterfaceBaseC" %in% self$inheritance) { - self$inheritance <- self$inheritance[-which(self$inheritance == "virtual public genericInterfaceBaseC")] - } + # if("virtual public genericInterfaceBaseC" %in% self$inheritance) { + # self$inheritance <- self$inheritance[-which(self$inheritance == "virtual public genericInterfaceBaseC")] + # } # self$Hincludes <- c(self$Hincludes, # nCompilerIncludeFile("nCompiler_class_interface.h")) self$Hpreamble <- c(self$Hpreamble, @@ -338,6 +338,7 @@ addGenericInterface_impl <- function(self) { cpp_fieldNames <- character() done <- FALSE current_NCgen <- self$Compiler$NCgenerator + my_NCgen <- current_NCgen while(!done) { NCint <- NCinternals(current_NCgen) NCcompInfo <- NCint$compileInfo @@ -345,7 +346,7 @@ addGenericInterface_impl <- function(self) { useIM <- !is.null(interfaceMembers) methodNames <- NCint$methodNames for(mName in methodNames) { - if(mName %in% names(cppArgInfos)) next + if(mName %in% outputMethodNames) next if(useIM && !(mName %in% interfaceMembers)) next NFint <- NFinternals(current_NCgen$public_methods[[mName]]) NFcompInfo <- NFint$compileInfo @@ -369,7 +370,9 @@ addGenericInterface_impl <- function(self) { step4 <- paste0('args({', step3, '})') cppArgInfos[iOut] <- step4 outputMethodNames[iOut] <- mName - outputCppMethodNames[iOut] <- NFint$cpp_code_name + # This line should give the same result as the next line. + # outputCppMethodNames[iOut] <- NFint$cpp_code_name + outputCppMethodNames[iOut] <- NCint$all_methodName_to_cpp_code_name[[mName]] outputMethodClassNames[iOut] <- NCint$cpp_classname iOut <- iOut + 1 } @@ -508,7 +511,13 @@ cppClassClass <- R6::R6Class( name = character(), symbolTable = NULL, ## list or symbolTable memberCppDefs = list(), # formerly cppFunctionDefs - inheritance = list(), ## classes to be declared as public + inheritance = list(), ## direct inheritance code, e.g. "public baseClass" + nClass_inheritance = list(), ## classes to be inherited via interface_resolver<>, which resolves + ## final implementations of "diamond" issues by setting the first as the implementer, + ## also adds the "virtual public genericInterfaceBaseC" inheritance if needed. + ## It is harmless to include an arbitrary class in here, if it is not first + ## and should be inherited as "public". + ## Entries here should be "genericInterfaceBaseC" or just "B", but omit the "public" etc. ## ancestors = 'list', ## classes inherited by inherited classes, needed to make all cast pointers ##extPtrTypes = 'ANY', ##private = 'list', # 'list'. This field is a placeholder for future functionality. Currently everything is generated as public @@ -582,6 +591,12 @@ cppClassClass <- R6::R6Class( addInheritance = function(newI) { inheritance <<- c(inheritance, newI) }, + add_nClass_inheritance = function(newI, first = FALSE) { + if(first) + nClass_inheritance <<- c(newI, nClass_inheritance) + else + nClass_inheritance <<- c(nClass_inheritance, newI) + }, ##addAncestors = function(newI) ancestors <<- c(ancestors, newI), ##setPrivate = function(name) private[[name]] <<- TRUE, generate = function(declaration = FALSE, ...) { @@ -591,7 +606,7 @@ cppClassClass <- R6::R6Class( else { list() } - output <- c(generateClassHeader(name, inheritance), + output <- c(generateClassHeader(name, inheritance, nClass_inheritance), list('public:'), ## In the future we can separate public and private generateAll(memberCppDefs, declaration = TRUE), # it is important to declare methods before variables @@ -632,9 +647,10 @@ cppClassClass <- R6::R6Class( } else { # Ensure inheritance from genericInterfaceBaseC so our custom Exporter in C++ # can always dynamic_pointer_cast to shared_ptr. - if(!("virtual public genericInterfaceBaseC" %in% self$inheritance)) { - self$addInheritance("virtual public genericInterfaceBaseC") - } + # if(!("virtual public genericInterfaceBaseC" %in% self$inheritance)) { + # if(length(self$genericInterfaceInheritance) > 0) + # self$addInheritance("virtual public genericInterfaceBaseC") + # } # These will always end up included and possibly multiple times, # so it's a bit sloppy but not worth cleaning up for now. self$Hpreamble <- c(self$Hpreamble, @@ -643,7 +659,6 @@ cppClassClass <- R6::R6Class( self$CPPpreamble <- c(self$CPPpreamble, "#define NCOMPILER_USES_NCLASS_INTERFACE", "#define USES_NCOMPILER") - } # The only case that would omit interface calls is generated predefined code. if(interfaceCalls) @@ -1059,7 +1074,14 @@ generateFunctionHeader <- function(self, ## ) ## } -generateClassHeader <- function(ns, inheritance) { +generateClassHeader <- function(ns, inheritance, nClass_inheritance=character()) { + # We do want an empty public interface_resolver<> if there is no nClass_inheritance. + # It will ensure virtual public genericInterfaceBaseC inheritance. + resolver_inheritance <- paste('public interface_resolver<', + paste(nClass_inheritance, + collapse = ', '), + '>') + inheritance <- c(resolver_inheritance, inheritance) inheritancePart <- if(length(inheritance) > 0) { paste(':', diff --git a/nCompiler/R/cppDefs_nClass.R b/nCompiler/R/cppDefs_nClass.R index 1e967cee..84929f01 100644 --- a/nCompiler/R/cppDefs_nClass.R +++ b/nCompiler/R/cppDefs_nClass.R @@ -204,6 +204,10 @@ cpp_nClassClass <- R6::R6Class( for(oneInheritance in Compiler$compileInfo$inherit) { self$addInheritance(oneInheritance) } + # This may not get used much or at all but here it is if needed. + for(oneInheritance in Compiler$compileInfo$nClass_inherit) { + self$add_nClass_inheritance(oneInheritance) + } inheritNCinternals <- NCinternals(self$Compiler$NCgenerator)$inheritNCinternals if(!is.null(inheritNCinternals)) { include_filebase <- make_cpp_filebase(inheritNCinternals$cpp_classname) diff --git a/nCompiler/R/nCompile.R b/nCompiler/R/nCompile.R index df8ae729..be071fa7 100644 --- a/nCompiler/R/nCompile.R +++ b/nCompiler/R/nCompile.R @@ -13,6 +13,8 @@ cppFileLabelFunction <- labelFunctionCreator('nCompiler_units') # - The SEXPgenerator C++ function in the cppDef is named paste0("new_", name) # - This then is used to create an RcppPacket. # - In nCompile, the cpp_name for that unitResult is the class name for the nClass generator +# - Note that cpp_code_name of methods may be intercepted later to match base class +# cpp_code_name if the method is virtual and inherited. # # compileCpp_nCompiler calls sourceCpp_nCompiler, which calls Rcpp::sourceCpp # compileCpp_nCompiler arranges the results into a named list of the [[Rcpp::export]] functions @@ -317,7 +319,10 @@ nCompile <- function(..., if(unitTypes[i] == "nCgen") NCinternals(units[[i]])$process_inherit() } - + for(i in seq_along(units)) { + if(unitTypes[i] == "nCgen") + NC_check_inheritance(units[[i]]) + } # set up exportNames and returnNames # exportNames will be from names(units) if named in the call or there is no exportName in the NF or NC compileInfo # Otherwise (i.e. no name provided in call and there is an exportName in the object def), use the exportName in the object def (compileInfo) diff --git a/nCompiler/R/nimbleModels.R b/nCompiler/R/nimbleModels.R index 8c7ee38a..f207d62f 100644 --- a/nCompiler/R/nimbleModels.R +++ b/nCompiler/R/nimbleModels.R @@ -3,6 +3,17 @@ # # see test-nimbleModel too. +## modelBase_nClass will be a base class with methods that +## have separate Rfun and Cfun contents and are predefined. +## +## model_nClass will inherit from modelBase_nClass and in C++ will +## use CRTP for a derived model. +## It will also split Rfun and Cfun and provide a custom inheritance statement +## It may provide different sets of calculate modes. +## It will also be predefined (will it get an interface?) + +## a model will inherit from model_nClass + modelBase_nClass <- nClass( classname = "modelBase_nClass", Cpublic = list( @@ -17,7 +28,7 @@ modelBase_nClass <- nClass( compileInfo = list(virtual=TRUE) ) ), - compileInfo=list(interface="none", + compileInfo=list(interface="full", createFromR = FALSE) ) @@ -48,7 +59,7 @@ makeModel_nClass <- function(varInfo) { function() {base_hw()} ), call_setup_node_mgmt = nFunction( - name = "setup_node_mgmt", + name = "call_setup_node_mgmt", function() {setup_node_mgmt()} ), set_from_list = nFunction( diff --git a/nCompiler/R/options.R b/nCompiler/R/options.R index 46e7a38a..c5cbb15d 100644 --- a/nCompiler/R/options.R +++ b/nCompiler/R/options.R @@ -15,6 +15,8 @@ updateDefaults <- function(defaults, control) { debugSizeProcessing = FALSE, serialize = FALSE, # if TRUE, include serialization code in generated C++ enableDerivs = FALSE, + allow_method_overloading = FALSE, + allow_inherited_field_duplicates = FALSE, compilerOptions = list( use_nCompiler_error_handling = TRUE, rebuild = FALSE, diff --git a/nCompiler/inst/include/nCompiler/nClass_interface/generic_class_interface.h b/nCompiler/inst/include/nCompiler/nClass_interface/generic_class_interface.h index f8fcde5e..ef91ce94 100644 --- a/nCompiler/inst/include/nCompiler/nClass_interface/generic_class_interface.h +++ b/nCompiler/inst/include/nCompiler/nClass_interface/generic_class_interface.h @@ -107,10 +107,78 @@ class genericInterfaceBaseC { }; }; +// FirstDerived and interface_resolver<> designed with help from Google Gemini +// Helper template to find the first type that inherits from Base +template +struct FirstGenericDerived { + using type = std::conditional_t< + std::is_base_of_v, + T, + typename FirstGenericDerived::type + >; +}; + +// Base case for the recursive helper template +template +struct FirstGenericDerived { + using type = std::conditional_t< + std::is_base_of_v, + T, + genericInterfaceBaseC + >; +}; + +template +class interface_resolver : public Bases..., virtual public genericInterfaceBaseC +{ +private: + using FirstFound = typename FirstGenericDerived::type; + +public: + const name2access_type& get_name2access() const override { + return FirstFound::get_name2access(); + } + SEXP get_value(const std::string &name) const override { + return FirstFound::get_value(name); + } + void set_value(const std::string &name, SEXP Svalue) override { + FirstFound::set_value(name, Svalue); + } + SEXP call_method(const std::string &name, SEXP Sargs) override { + return FirstFound::call_method(name, Sargs); + } + SEXP make_deserialized_return_SEXP() override { + return FirstFound::make_deserialized_return_SEXP(); + } +}; + +template<> +class interface_resolver<> : virtual public genericInterfaceBaseC +{ +private: + using FirstFound = genericInterfaceBaseC; + +public: + const name2access_type& get_name2access() const override { + return FirstFound::get_name2access(); + } + SEXP get_value(const std::string &name) const override { + return FirstFound::get_value(name); + } + void set_value(const std::string &name, SEXP Svalue) override { + FirstFound::set_value(name, Svalue); + } + SEXP call_method(const std::string &name, SEXP Sargs) override { + return FirstFound::call_method(name, Sargs); + } + SEXP make_deserialized_return_SEXP() override { + return FirstFound::make_deserialized_return_SEXP(); + } +}; + // A forward declaration. (This is being disabled and a new approach is being used.) //SEXP process_call_args(const genericInterfaceBaseC::args::argVectorT &argVector, // SEXP Sargs); - // Base class for accessing a single member from a nimble class, // converted to SEXP. // diff --git a/nCompiler/inst/include/nCompiler/nClass_interface_Rcpp_extensions/post_Rcpp/shared_ptr_as_wrap.h b/nCompiler/inst/include/nCompiler/nClass_interface_Rcpp_extensions/post_Rcpp/shared_ptr_as_wrap.h index bc2b07d0..75f4cc2e 100644 --- a/nCompiler/inst/include/nCompiler/nClass_interface_Rcpp_extensions/post_Rcpp/shared_ptr_as_wrap.h +++ b/nCompiler/inst/include/nCompiler/nClass_interface_Rcpp_extensions/post_Rcpp/shared_ptr_as_wrap.h @@ -35,7 +35,7 @@ namespace Rcpp { public: static constexpr bool T_is_polymorphic = std::is_polymorphic_v; - std::shared_ptr sp_; + std::shared_ptr sp_, spnew_; Exporter(SEXP Sx) { Rcpp::Environment Sx_env(Sx); // Sx is an environment, so initialize an Rcpp:Environment from it. SEXP Xptr = PROTECT(Sx_env["extptr"]); // Get the extptr element of it. @@ -55,7 +55,12 @@ namespace Rcpp { if(!ok) {stop("An argument that should be an nClass object is not valid.");} std::shared_ptr spbase = static_cast(R_ExternalPtrAddr(Xptr))->get_interfaceBase_shared_ptr(); if constexpr (T_is_polymorphic) { - sp_ = std::dynamic_pointer_cast(spbase); + spnew_ = std::dynamic_pointer_cast(spbase); + if(!spnew_) { + UNPROTECT(1); + stop("Invalid nClass assignment: check that the assigned object is of the expected class (or derived from it)."); + } + sp_ = spnew_; } else { sp_ = std::static_pointer_cast(spbase); } @@ -93,7 +98,7 @@ template struct wrap_shared_ptr_to_R< T, typename std::enable_if, T >::value>::type > { static SEXP go(std::shared_ptr< T > obj) { - SEXP Sans = PROTECT(T::setup_R_return_object_full( PROTECT(return_nCompiler_object< T >(obj) ) ) ); + SEXP Sans = PROTECT(loadedObjectHookC::setup_R_return_object_full( PROTECT(return_nCompiler_object< T >(obj) ) ) ); UNPROTECT(2); return Sans; } diff --git a/nCompiler/tests/testthat/nClass_tests/test-nClass_inherit.R b/nCompiler/tests/testthat/nClass_tests/test-nClass_inherit.R index 82826cc0..d8c83782 100644 --- a/nCompiler/tests/testthat/nClass_tests/test-nClass_inherit.R +++ b/nCompiler/tests/testthat/nClass_tests/test-nClass_inherit.R @@ -6,61 +6,189 @@ message("See comments in test-nClass_inherit.R for more notes.") ## See also test-nClass_nested -# With inheritcance, we DO NOT support interfacing to both base class and derived class. -# Only the most derived class should have interface = "generic" or "base". -# Any class to be used as a base class should have interface = "none". -# If one wants a pure object of that class, make an inherited class solely -# for the purpose of having an interface. -# This limitation would appear to be quite tricky to work around in C++, -# so there are no immediate plans to do so. - -# Making R6 and C++ inheritance behavior match comes reasonably close but is -# not perfect. - -# We support a compileInfo element for nFunction methods of nClasses that is +# The `inherit` argument to nClass can take a single argument, similar to R6 +# It is captured as an expression that returns a single nClass generator. +# (This must always be the same object, so the expression can't generate a new one each time it is evaluated.) + +# We use the inheritance semantics of R6 classes to set the default rules for +# nClasses. +# +# For fields: If two R6 classes have fields of the same name, they seem to +#. become one field. Therefore we disallow this in nClasses in order to +# avoid generating C++ classes that actually have two distinct members +#. of the same name and then getting different compiled vs. uncompiled behavior. +# This is checked in NC_check_inheritance. +#. nOptions(allow_inherited_field_duplicates=TRUE) disables this rule and +#. checking, and allows nCompile to happily generate C++ classes with +# distinct members of the same name. This is fine if a user doesn't care +#. about uncompiled behavior or discrepancies. +# +# For methods: In two R6 classes have methods of the same name, that works +# fine and a base class method can be accessed by super$foo(). +#. However, R6 has no notion of virtual vs. non-virtual inheritance, no +#. notion of signatures (argument and return types) being required to match +#. for virtual inheritance, and no notion of base class pointers. In effect, +# R6 objects are just passed as objects and a method call will always use +# the most derived version. To match that, we require nClass inherited methods +#. of the same name to have exactly matching argument names, types, and return type. +# And we require that the first base class with a method must mark it with +#. compileInfo=list(virtual=TRUE) (in the nFunction call). The last requirement +# is a bit like the use of "override" in C++ in that it shouldn't be strictly +# necessary but can allow us during compilation to catch potentially nasty bugs +# by giving the programmer a way to declare their intention. We require that +#. (whereas C++ override is optional). Error-trapping happens in NC_check_inheritance. +#. nOptions(allow_method_overloading=TRUE) removes these rules and allows +#. the compiler to generate C++ classes with overloaded versions of the same +# name and to have even the same name and signature be not virtual. This +# only makes sense if the user doesn't care about uncompiled behavior matching. + +# As just noted, we support a compileInfo element for nFunction methods of nClasses that is # `virtual` set to TRUE or FALSE. This is what is sounds like: whether to make # the C++ method virtual. -# R6 semantics are natively like "virtual": There is no notion of having a -# pointer to either base or derived. You just have an object, so if there -# is a method of the same name as in a base class, the derived method will be called. -# One can access base class methods as super$method(). We currently do not -# support that syntax but could potentially consider it. +# Finding inheritance in R6 was tricky because +# the generator retains an unevaluated expression for inherits. +# We now keep it that way as `inheritQ` (for "quoted") +# This allows an nClass call to inherit from a method that isn't defined yet. -# In our C++, only the most derived class should have an interface, so -# in effect we have the same system: the most derived version will be called. +# We do not currently support "super$" in compilation, so there is no +# way to call a base class method (yet). -# It appears that in R6, if a base and derived class have a member variable (not method) -# of the same name, there is only ever one copy of it, not one for each level -# of the class hierarchy. +test_that("nClass hierarchy traps lack of virtual declaration", { + ncA <- nClass( + Cpublic = list( + foo = nFunction( + function(x=double(1)) {returnType(integer(1))}, + compileInfo=list(virtual=FALSE) + ) + ) + ) + ncB <- nClass( + inherit = ncA, + Cpublic = list( + foo = nFunction( + function(x=double(1)) {returnType(integer(1))}, + compileInfo=list(virtual=FALSE) + ) + ) + ) + expect_error( + comp <- nCompile(ncA, ncB) + ) +}) -# Finding inheritance in R6 was tricky because -# the generator retains an unevaluated expression for inherits. -# We resolve that once an put it in NCinternals. -# Here is a summary of cases where compiled and uncompiled behavior will differ: -# -# 1. Base class and derived class both have member variable ("x"): -# - In uncompiled, there is only ever one "x". -# - In compiled, only the derived "x" is accessed by the interface. -# - If one provided get/set methods for changing "x" in the base class -# and if there are base class methods that use "x", then uncompiled and -# compiled could use different values of "x". -# -# 2. No "super" in compiled code. Currently there is no compilation support for using -# "self$super$method" to access base class methods. -# -# 3. cppLiteral coding of base class methods or use of inheritance: -# Well, anything in cppLiteral is not supported for uncompiled execution. -# Here in particular it stands out that harnessing virtual method dispatch is -# not something that can be mimicked in uncompiled R6. -# -# Hence the following recommendations if one wants uncompiled and compiled to have -# the same behavior: -# - Do not use the same variable name in base and derived classes. -# - Do not use "super". +test_that("nClass hierarchy traps mismatched argument names", { + ncA <- nClass( + Cpublic = list( + foo = nFunction( + function(z=double(1)) {returnType(integer(1))}, + compileInfo=list(virtual=TRUE) + ) + ) + ) + ncB <- nClass( + inherit = ncA, + Cpublic = list( + foo = nFunction( + function(x=double(1)) {returnType(integer(1))}, + compileInfo=list(virtual=FALSE) + ) + ) + ) + expect_error( + comp <- nCompile(ncA, ncB) + ) +}) + +test_that("nClass hierarchy traps mismatched argument types", { + ncA <- nClass( + Cpublic = list( + foo = nFunction( + function(x=double(0)) {returnType(integer(1))}, + compileInfo=list(virtual=TRUE) + ) + ) + ) + ncB <- nClass( + inherit = ncA, + Cpublic = list( + foo = nFunction( + function(x=double(1)) {returnType(integer(1))}, + compileInfo=list(virtual=FALSE) + ) + ) + ) + expect_error( + comp <- nCompile(ncA, ncB) + ) +}) -test_that("nClass hierarchies work as expected (including uncompiled vs compiled discrepancies)", { +test_that("nClass hierarchy traps mismatched return types", { + ncA <- nClass( + Cpublic = list( + foo = nFunction( + function(x=double(1)) {returnType(integer(1))}, + compileInfo=list(virtual=TRUE) + ) + ) + ) + ncB <- nClass( + inherit = ncA, + Cpublic = list( + foo = nFunction( + function(x=double(1)) {returnType(integer(0))}, + compileInfo=list(virtual=FALSE) + ) + ) + ) + expect_error( + comp <- nCompile(ncA, ncB) + ) +}) + +test_that("nClass hierarchy traps inherited field duplicate names", { + ncA <- nClass( + Cpublic = list( + x = 'numericVector', + y = 'numericVector', + foo = nFunction( + function(x=double(1)) {returnType(integer(1))}, + compileInfo=list(virtual=TRUE) + ) + ) + ) + ncB <- nClass( + inherit = ncA, + Cpublic = list( + x = 'numericVector', + z = 'numericVector', + foo = nFunction( + function(x=double(1)) {returnType(integer(1))}, + compileInfo=list(virtual=FALSE) + ) + ) + ) + expect_error( + comp <- nCompile(ncA, ncB) + ) +}) + + +test_that("nClass hierarchies work as expected (including uncompiled vs compiled discrepancies)", +{ + # This was written before all the error-trapping above. + # I am going to disable the error-trapping. I think this is good + # because now we also test the more general compilation, but + # I may not be thinking about cases we're missing. + oldOpt1 <- nOptions("allow_method_overloading") + oldOpt2 <- nOptions("allow_inherited_field_duplicates") + nOptions(allow_method_overloading = TRUE) + nOptions(allow_inherited_field_duplicates = TRUE) + on.exit({ + nOptions(allow_method_overloading = oldOpt1) + nOptions(allow_inherited_field_duplicates = oldOpt2) + }) ncA <- nClass( Rpublic = list( fooRA = function() v.A @@ -227,14 +355,25 @@ test_that("nClass hierarchies work as expected (including uncompiled vs compiled rm(objB, CobjB) gc() } - }) ############## -cat("With inheritance, we may now be able to interface at multiple levels, but it is untested.\n") +# cat("With inheritance, we may now be able to interface at multiple levels, but it is untested.\n") test_that("inheriting-only classes in 3-level hierarchy works", { + # This was written before all the error-trapping above. + # I am going to disable the error-trapping. I think this is good + # because now we also test the more general compilation, but + # I may not be thinking about cases we're missing. + oldOpt1 <- nOptions("allow_method_overloading") + oldOpt2 <- nOptions("allow_inherited_field_duplicates") + nOptions(allow_method_overloading = TRUE) + nOptions(allow_inherited_field_duplicates = TRUE) + on.exit({ + nOptions(allow_method_overloading = oldOpt1) + nOptions(allow_inherited_field_duplicates = oldOpt2) + }) ncBase <- nClass( classname = "ncBase", Cpublic = list( @@ -288,3 +427,280 @@ test_that("inheriting-only classes in 3-level hierarchy works", { rm(Cobj, Cobj2); gc() }) + +cat("Add inline checking of validity of shared_ptr's in generated code.\n") + +test_that("inheritance with interfaces at multiple levels", { + # This was written before all the error-trapping above. + # I am going to disable the error-trapping. I think this is good + # because now we also test the more general compilation, but + # I may not be thinking about cases we're missing. + oldOpt1 <- nOptions("allow_method_overloading") + oldOpt2 <- nOptions("allow_inherited_field_duplicates") + nOptions(allow_method_overloading = TRUE) + nOptions(allow_inherited_field_duplicates = TRUE) + on.exit({ + nOptions(allow_method_overloading = oldOpt1) + nOptions(allow_inherited_field_duplicates = oldOpt2) + }) + ncBase <- nClass( + classname = "ncBase", + Cpublic = list( + base_x = 'numericScalar', + # get_base_x will be non-virtual and uniquely named + get_base_x = nFunction( + function() { + return(base_x); returnType('numericScalar') + }, + name = "get_base_x"), + # get_x will be non-virtual and non uniquely named + get_x = nFunction( + function() { + return(base_x); returnType('numericScalar') + }, + name = "get_x"), + # get_x_virt will be virtual + get_x_virt = nFunction( + function() { + return(base_x); returnType('numericScalar') + }, + name = "get_x_virt", + compileInfo=list(virtual=TRUE)) + ), + compileInfo = list(interface = "full",createFromR=TRUE) + ) + + ncMid <- nClass( + inherit = ncBase, + classname = "ncMid", + Cpublic = list( + mid_x = 'numericScalar', + # get_base_x will be non-virtual and uniquely named + get_mid_x = nFunction( + function() { + return(mid_x); returnType('numericScalar') + }, + name = "get_mid_x"), + # get_x will be non-virtual and non uniquely named + get_x = nFunction( + function() { + return(mid_x); returnType('numericScalar') + }, + name = "get_x"), + # get_base_x_from_mid will be non-virtual and access base class member + get_base_x_from_mid = nFunction( + function() { + return(base_x); returnType('numericScalar') + }, + name = "get_base_x_from_mid"), + + # get_x_virt will be virtual + get_x_virt = nFunction( + function() { + return(mid_x); returnType('numericScalar') + }, + name = "get_x_virt", + compileInfo=list(virtual=TRUE)) + ), + compileInfo = list(interface = "full",createFromR=FALSE) + ) + + ncDer <- nClass( + inherit = ncMid, + Cpublic = list( + make_mid = nFunction( + function() {return(ncMid$new()); returnType('ncMid')} + ), + der_x = 'numericScalar', + # get_base_x will be non-virtual and uniquely named + get_der_x = nFunction( + function() { + return(der_x); returnType('numericScalar') + }, + name = "get_der_x"), + # get_x will be non-virtual and non uniquely named + get_x = nFunction( + function() { + return(der_x); returnType('numericScalar') + }, + name = "get_x"), + # get_base_x_from_mid will be non-virtual and access base class member + get_base_x_from_der = nFunction( + function() { + return(base_x); returnType('numericScalar') + }, + name = "get_base_x_from_der"), + get_mid_x_from_der = nFunction( + function() { + return(mid_x); returnType('numericScalar') + }, + name = "get_mid_x_from_der"), + + # get_x_virt will be virtual + get_x_virt = nFunction( + function() { + return(der_x); returnType('numericScalar') + }, + name = "get_x_virt", + compileInfo=list(virtual=TRUE)) + ), + compileInfo = list(interface = "full",createFromR=TRUE) + ) + + useClasses <- nClass( + classname = "useClasses", + Cpublic = list( + myBase = 'ncBase', + myMid = 'ncMid', + myDer = 'ncDer', + useBase = nFunction( + function(i = integer()) { + returnType(double()) + if(i == 1) return(myBase$get_x_virt()) + if(i == 2) return(myBase$get_base_x()) + if(i == 3) return(myBase$get_x()) + if(i == 4) return(myBase$base_x) + return(-1) + }, + name = "useBase"), + useMid = nFunction( + function(i = integer()) { + returnType(double()) + if(i == 1) return(myMid$get_x_virt()) + if(i == 2) return(myMid$get_base_x()) + if(i == 3) return(myMid$get_x()) + if(i == 4) return(myMid$base_x) + + if(i == 5) return(myMid$get_base_x_from_mid()) + if(i == 6) return(myMid$get_mid_x()) + if(i == 7) return(myMid$mid_x) + return(-1) + } + ), + useDer = nFunction( + function(i = integer()) { + returnType(double()) + if(i == 1) return(myDer$get_x_virt()) + if(i == 2) return(myDer$get_base_x()) + if(i == 3) return(myDer$get_x()) + if(i == 4) return(myDer$base_x) + + if(i == 5) return(myDer$get_base_x_from_mid()) + if(i == 6) return(myDer$get_mid_x()) + if(i == 7) return(myDer$mid_x) + + if(i == 8) return(myDer$get_base_x_from_der()) + if(i == 9) return(myDer$get_mid_x_from_der()) + if(i == 10) return(myDer$get_der_x()) + if(i == 11) return(myDer$der_x) + return(-1) + } + ) + ) + ) + + comp <- nCompile(ncBase, ncMid, ncDer, useClasses) + + # der obj works on its own + Cder <- comp$ncDer$new() + Cder$base_x <- 1 + expect_equal(Cder$base_x, 1) + expect_equal(Cder$get_base_x(), 1) + expect_equal( Cder$get_base_x_from_mid(),1) + expect_equal( Cder$get_base_x_from_der(),1) + + Cder$mid_x <- 2 + expect_equal( Cder$mid_x, 2) + expect_equal( Cder$get_mid_x(), 2) + expect_equal( Cder$get_mid_x_from_der(), 2) + + Cder$der_x <- 3 + expect_equal (Cder$der_x, 3) + expect_equal (Cder$get_der_x(), 3) + + expect_equal (Cder$get_x(), 3) + expect_equal (Cder$get_x_virt(), 3) + + expect_error(Cmid <- comp$ncMid$new()) + + # mid object works on its own (even though can't be created from R) + Cmid <- Cder$make_mid() + Cmid$base_x <- 111 + Cmid$mid_x <- 222 + expect_equal(c( + Cmid$base_x + , Cmid$get_base_x() + , Cmid$get_base_x_from_mid()), rep(111, 3)) + + expect_equal(c( + Cmid$mid_x + , Cmid$get_mid_x() + , Cmid$get_x()), rep(222, 3)) + + # base object works on its own + Cbase <- comp$ncBase$new() + Cbase$base_x <- 11 + expect_equal(Cbase$get_x_virt(), 11) + expect_error(Cbase$get_der_x()) + + obj <- comp$useClasses$new() + obj$myBase <- Cbase + obj$myDer <- Cder + + # base accessing an actual base + expect_equal(c( + obj$useBase(1) + ,obj$useBase(2) + ,obj$useBase(3) + ,obj$useBase(4)), rep(11, 4)) + + # der accessing an actual der + expect_equal(c( + obj$useDer(1) + ,obj$useDer(2) + ,obj$useDer(3) + ,obj$useDer(4)), c(3, 1, 3, 1)) + + expect_equal(c( + obj$useDer(5) + ,obj$useDer(6) + ,obj$useDer(7)), c(1, 2, 2)) + + expect_equal(c( + obj$useDer(8) + ,obj$useDer(9) + ,obj$useDer(10) + ,obj$useDer(11)), c(1, 2, 3, 3)) + + # base pointing to a der + obj$myBase <- Cder + expect_equal(c( + obj$useBase(1) + ,obj$useBase(2) + ,obj$useBase(3) + ,obj$useBase(4)), c(3,1,1,1)) + + + # base pointing to a mid + obj$myBase <- Cmid + expect_equal(c( + obj$useBase(1) + ,obj$useBase(2) + ,obj$useBase(3) + ,obj$useBase(4)), c(222,111,111,111)) + + # mid pointing to a der + obj$myMid <- Cder + expect_equal(c( + obj$useMid(1) + ,obj$useMid(2) + ,obj$useMid(3) + ,obj$useMid(4)), c(3,1,2,1)) + + expect_equal(c( + obj$useMid(5) + ,obj$useMid(6) + ,obj$useMid(7)), c(1, 2, 2)) + + rm(Cder, Cmid, Cbase); gc() +}) diff --git a/nCompiler/tests/testthat/nCompile_tests/test-nCompile.R b/nCompiler/tests/testthat/nCompile_tests/test-nCompile.R index f9723098..38f2945b 100644 --- a/nCompiler/tests/testthat/nCompile_tests/test-nCompile.R +++ b/nCompiler/tests/testthat/nCompile_tests/test-nCompile.R @@ -67,7 +67,6 @@ test_that("nCompile direct, package, and writePackage work with Eigen::Tensors", pkgload::unload("testpackage") }) - test_that("nCompile direct, package, and writePackage work with nClass interfaces", { nc <- nClass( Cpublic = list(