From 42f026ed95c47780a6c8bcae52e2b8a95bd952a4 Mon Sep 17 00:00:00 2001 From: perrydv Date: Wed, 10 Sep 2025 11:56:45 +0200 Subject: [PATCH 1/3] Establish cachedOpInfo in code$auxEnv so that user-defined opDefs can occur at various levels. --- nCompiler/R/NF_InternalsClass.R | 11 +- nCompiler/R/NF_Utils.R | 9 +- nCompiler/R/compile_aaa_operatorLists.R | 41 ++- nCompiler/R/compile_eigenization.R | 4 +- nCompiler/R/compile_generateCpp.R | 18 +- nCompiler/R/compile_labelAbstractTypes.R | 44 +-- nCompiler/R/compile_normalizeCalls.R | 308 +++++++++++------- nCompiler/R/compile_processAD.R | 4 +- nCompiler/R/compile_simpleTransformations.R | 4 +- .../testthat/nCompile_tests/test-userOps.R | 12 + 10 files changed, 263 insertions(+), 192 deletions(-) create mode 100644 nCompiler/tests/testthat/nCompile_tests/test-userOps.R diff --git a/nCompiler/R/NF_InternalsClass.R b/nCompiler/R/NF_InternalsClass.R index 3667b91e..5871f4d1 100644 --- a/nCompiler/R/NF_InternalsClass.R +++ b/nCompiler/R/NF_InternalsClass.R @@ -15,7 +15,8 @@ NF_InternalsClass <- R6::R6Class( isMethod = FALSE, uniqueName = character(), cpp_code_name = character(), - ## template = NULL, replaced with compileInfo$matchDef + ## template = NULL, replaced with default_matchDef + default_matchDef = NULL, code = NULL, RcppPacket = NULL, Rwrapper = NULL, @@ -96,13 +97,7 @@ NF_InternalsClass <- R6::R6Class( ## in the "decoration" system. They could be put in "value" argument. ## Either a named "value" or a ... is in all types. - ## not used until much later - if(is.null(self$compileInfo$opDef)) - self$compileInfo$opDef <- list() - if(is.null(self$compileInfo$opDef$matchDef)) { - self$compileInfo$opDef$matchDef <- Rarguments_2_function(arguments, body = quote({})) - } - # self$template <- Rarguments_2_function(arguments, body = quote({})) ## generateTemplate() + self$default_matchDef <- Rarguments_2_function(arguments, body = quote({})) ## generateTemplate() returnTypeInfo <- nf_extractReturnType(code) returnTypeDecl <- returnTypeInfo$returnType if(is.null(returnTypeDecl)) { diff --git a/nCompiler/R/NF_Utils.R b/nCompiler/R/NF_Utils.R index e3a5a886..e01b2e32 100644 --- a/nCompiler/R/NF_Utils.R +++ b/nCompiler/R/NF_Utils.R @@ -50,8 +50,9 @@ nGet <- function(name, where) { ## switch where to be the generator's parent_env where <- where$parent_env } - if(exists(name, envir = where, inherits = TRUE)) - get(name, envir = where, inherits = TRUE) - else - NULL + get0(name, envir = where) + # if(exists(name, envir = where, inherits = TRUE)) + # get(name, envir = where, inherits = TRUE) + # else + # NULL } diff --git a/nCompiler/R/compile_aaa_operatorLists.R b/nCompiler/R/compile_aaa_operatorLists.R index 610d2a8e..49cb1315 100644 --- a/nCompiler/R/compile_aaa_operatorLists.R +++ b/nCompiler/R/compile_aaa_operatorLists.R @@ -94,12 +94,19 @@ updateOperatorDef <- function(ops, field, subfield = NULL, val) { } } +getOperatorField <- function(opDef, field = NULL, subfield = NULL) { + if (is.null(opDef)) return(NULL) + if (is.null(field)) return(opDef) + if (is.null(opDef[[field]])) return(NULL) + if (is.null(subfield)) return(opDef[[field]]) + opDef[[field]][[subfield]] +} + getOperatorDef <- function(op, field = NULL, subfield = NULL) { - opInfo <- get0(op, envir=operatorDefUserEnv) -# opInfo <- operatorDefEnv[[op]] - if (is.null(opInfo) || is.null(field)) return(opInfo) - if (is.null(opInfo[[field]]) || is.null(subfield)) return(opInfo[[field]]) - return(opInfo[[field]][[subfield]]) + opDef <- get0(op, envir=operatorDefUserEnv) +# opDef <- operatorDefEnv[[op]] + if(is.null(opDef)) return(NULL) + getOperatorField(opDef, field, subfield) } assignOperatorDef( @@ -122,7 +129,7 @@ assignOperatorDef( ) assignOperatorDef( - 'NFCALL_', # This is used for non-method nFunctions in normalizeCalls and for any (including method) nFunctions after normalizeCalls + 'nFunction_default', # This is used for non-method nFunctions in normalizeCalls and for any (including method) nFunctions after normalizeCalls list( labelAbstractTypes = list( handler = 'nFunction_or_method_call'), @@ -133,17 +140,17 @@ assignOperatorDef( ) ) -assignOperatorDef( - 'NCMETHOD_', # This is a transient label that only exists within normalizeCalls - list( - ## labelAbstractTypes = list( - ## handler = 'nFunction_or_method_call'), - normalizeCalls = list( - handler = 'nFunction_or_method_call')#, # becomes NFCALL_ - ## cppOutput = list( - ## handler = 'Generic_nFunction') - ) -) +# assignOperatorDef( +# 'NCMETHOD_', # This is a transient label that only exists within normalizeCalls +# list( +# ## labelAbstractTypes = list( +# ## handler = 'nFunction_or_method_call'), +# normalizeCalls = list( +# handler = 'nFunction_or_method_call')#, # becomes NFCALL_ +# ## cppOutput = list( +# ## handler = 'Generic_nFunction') +# ) +# ) assignOperatorDef( c('dim'), diff --git a/nCompiler/R/compile_eigenization.R b/nCompiler/R/compile_eigenization.R index bdc3b19f..9d6bbb6e 100644 --- a/nCompiler/R/compile_eigenization.R +++ b/nCompiler/R/compile_eigenization.R @@ -73,7 +73,9 @@ compile_eigenize <- function(code, return(invisible(NULL)) } - handlingInfo <- getOperatorDef(code$name, "eigenImpl") + opInfo <- check_cachedOpInfo(code, where=auxEnv$where, update=TRUE) + handlingInfo <- getOperatorField(opInfo$opDef, "eigenImpl") +# handlingInfo <- getOperatorDef(code$name, "eigenImpl") # operatorDefEnv[[code$name]] # if(!is.null(opInfo)) { # handlingInfo <- opInfo[["eigenImpl"]] diff --git a/nCompiler/R/compile_generateCpp.R b/nCompiler/R/compile_generateCpp.R index 69b22655..e33d8287 100644 --- a/nCompiler/R/compile_generateCpp.R +++ b/nCompiler/R/compile_generateCpp.R @@ -85,7 +85,18 @@ compile_generateCpp <- function(code, ans[[length(code$args) + 2]] <- paste0(indent, '}') return(ans) } - handler <- getOperatorDef(code$name, "cppOutput", "handler") + # All calls must have valid opInfo or be a core DSL operator + # This compiler stage is called from cppDefs' generate() methods, + # so where$auxEnv is not available. + # This means that any penultimate compiler stages that changed the + # call name to something non-core (and hence potentially the handler) + # must update the cachedOpInfo. + # or we must add a final pass to do so. + # An example is changes made in eigenization, such as inserting `index[`. + # This is a core operator so it will be found in the check_cachedOpInfo with update=TRUE. + opInfo <- check_cachedOpInfo(code, where=baseenv(), update=TRUE, allowFail = TRUE) + handler <- getOperatorField(opInfo$opDef, "cppOutput", "handler") + # handler <- getOperatorDef(code$name, "cppOutput", "handler") # opInfo <- operatorDefEnv[[code$name]] # if(!is.null(opInfo)) { # handlingInfo <- opInfo[["cppOutput"]] @@ -165,10 +176,9 @@ inGenCppEnv( inGenCppEnv( Generic_nFunction <- function(code, symTab) { - innerCode <- code$args[['call']] - cpp_code_name <- code$aux$cpp_code_name + cpp_code_name <- code$aux$cachedOpInfo$obj_internals$cpp_code_name paste0(cpp_code_name, - '(', paste0(unlist(lapply(innerCode$args, + '(', paste0(unlist(lapply(code$args, compile_generateCpp, symTab, asArg = TRUE) ), diff --git a/nCompiler/R/compile_labelAbstractTypes.R b/nCompiler/R/compile_labelAbstractTypes.R index 647dd88a..b28ce0ad 100644 --- a/nCompiler/R/compile_labelAbstractTypes.R +++ b/nCompiler/R/compile_labelAbstractTypes.R @@ -100,7 +100,10 @@ compile_labelAbstractTypes <- function(code, return(invisible(NULL)) } - handlingInfo <- getOperatorDef(code$name, "labelAbstractTypes") + opInfo <- check_cachedOpInfo(code, where=auxEnv$where, update=TRUE) + handlingInfo <- getOperatorField(opInfo$opDef, "labelAbstractTypes") + +# handlingInfo <- getOperatorDef(code$name, "labelAbstractTypes") # opInfo <- operatorDefEnv[[code$name]] # if(!is.null(opInfo)) { # handlingInfo <- opInfo[["labelAbstractTypes"]] @@ -406,24 +409,11 @@ inLabelAbstractTypesEnv( inLabelAbstractTypesEnv( nFunction_or_method_call <- function(code, symTab, auxEnv, handlingInfo) { - # We have code = NFCALL_(foo(x, y)) - # innerCall if foo(x,y) - # We'll set innerCall$type to symbolNF - # and we'll set code$type to the returnType of foo(x, y) - innerCall <- code$args[['call']] - if(is.null(innerCall)) - stop( - exprClassProcessingErrorMsg( - code, paste('In nFunction_or_method_call: the nFunction (or method) ', - code$name, - ' has NULL content.') - ), call. = FALSE - ) - inserts <- recurse_labelAbstractTypes(innerCall, symTab, auxEnv, + inserts <- recurse_labelAbstractTypes(code, symTab, auxEnv, handlingInfo) - obj_internals <- code$aux$obj_internals - nFunctionName <- code$aux$nFunctionName - innerCall$type <- symbolNF$new(name = nFunctionName) + obj_internals <- code$aux$cachedOpInfo$obj_internals +# nFunctionName <- obj_internals$nFunctionName +# code$aux$symbolNF <- symbolNF$new(name = nFunctionName) returnSym <- obj_internals$returnSym if(is.null(returnSym)) stop( @@ -435,24 +425,6 @@ inLabelAbstractTypesEnv( ) code$type <- returnSym$clone() ## Not sure if a clone is needed, but it seems safer to make one. inserts - - # useArgs <- c(FALSE, rep(TRUE, length(code$args)-1)) - # inserts <- recurse_labelAbstractTypes(code, symTab, auxEnv, - # handlingInfo, useArgs) - # obj_internals <- code$args[[1]]$aux$obj_internals - # nFunctionName <- code$args[[1]]$aux$nFunctionName - # code$args[[1]]$type <- symbolNF$new(name = nFunctionName) - # returnSym <- obj_internals$returnSym - # if(is.null(returnSym)) - # stop( - # exprClassProcessingErrorMsg( - # code, paste('In nFunction_or_method_call: the nFunction (or method) ', - # code$name, - # ' does not have a valid returnType.') - # ), call. = FALSE - # ) - # code$type <- returnSym$clone() ## Not sure if a clone is needed, but it seems safer to make one. - # invisible(NULL) } ) diff --git a/nCompiler/R/compile_normalizeCalls.R b/nCompiler/R/compile_normalizeCalls.R index 14576642..25b367e4 100644 --- a/nCompiler/R/compile_normalizeCalls.R +++ b/nCompiler/R/compile_normalizeCalls.R @@ -6,6 +6,14 @@ inNormalizeCallsEnv <- function(expr) { eval(expr, envir = normalizeCallsEnv) } +# This stage does two things: +# 1. It normalizes the arguments of calls according to a matchDef, if found via an opDef. +#. This includes separating the compileTime arguments into the aux list of the call. +# 2. It caches information about the nFunction or nClass method being called, +#. if relevant. Future stages should check if the call name still matches what was cached, +#. they should look up from scratch. This is because handlers can change a name and then +#. then opDef should be found in future steps for the new name. +#. The caching also allows future stages to be more efficient in determining an nFunction or nClass method case. compile_normalizeCalls <- function(code, symTab, auxEnv) { @@ -26,89 +34,52 @@ compile_normalizeCalls <- function(code, } return(NULL) } - opInfo <- getOperatorDef(code$name) - ## Check for nFunctions or nClass methods (also nFunctions) - if(is.null(opInfo)) { - obj <- NULL - if(isNCgenerator(auxEnv$where)) {## We are in a class method (by direct call within another class method, no `$` involved) - obj <- NC_find_method(auxEnv$where, code$name, inherits=TRUE) - if(!is.null(obj)) { - code$aux$obj_internals <- NFinternals(obj) - if(isNF(obj)) { - opInfo <- operatorDefEnv[['NCMETHOD_']] - } else { - stop(exprClassProcessingErrorMsg(code, - paste0('method ', code$name, 'is being called, but it is not a nFunction.')), - call. = FALSE) - } - } - } - ## Next we check if code$name exists in the where. - ## Note that if we are in a method, auxEnv$where will be the - ## generator, which is an environment. However, we need - ## to use nGet instead of exists and get. - if(is.null(obj)) { - obj <- nGet(code$name, where = auxEnv$where) - ## An nFunction will be transformed to - ## have code$name 'NFCALL_' during simpleTransformations, - ## but that hasn't happened yet, so we manually use it here. - ## To-do: This could be cleaned up by either making that change here - ## when first detected or making a separate compiler stage just for that. - if(!is.null(obj)) { - if(isNF(obj)) { - code$aux$obj_internals <- NFinternals(obj) - opInfo <- operatorDefEnv[['NFCALL_']] - uniqueName <- NFinternals(obj)$uniqueName - if(length(uniqueName)==0) - stop( - exprClassProcessingErrorMsg( - code, - paste0('nFunction ', code$name, 'is being called, ', - 'but it is malformed because it has no internal name.')), - call. = FALSE) - if(is.null(auxEnv$needed_nFunctions[[uniqueName]])) { - ## We could put the nFunction itself in the needed_nFunctions list, - ## but we do not as a way to avoid having many references to R6 objects - ## in a blind attempt to facilitate garbage collection based on past experience. - ## Instead, we provide what is needed to look up the nFunction again later. - auxEnv$needed_nFunctions[[uniqueName]] <- list(code$name, auxEnv$where) - } - } else - stop(exprClassProcessingErrorMsg( + # If we are in an nClass: + # look for opDef in the nFunction method; + # next look for an opDef in the nClass, which might not even correspond to a method. + # If we are in an nFunction (not a method): + # look for an opDef in the nFunction. + # If we are in an nFunction or nClass method and no opDef has been found, + #. assign nFunction_default to be the opDef. + # Next look for a global opDef, which will look first for user-defined opDefs. + # If no opDef is defined anywhere, it is an error. + # If an opDef is found, check for a handler. + # If a handler is found, call it. + # If no handler is found, look for a matchDef and normalize the arguments by default. + # + # What gets cached in the aux of the exprClass for the call: + # cachedOpInfo = list(opDef, name, obj_internals, case) + # We defer: uniqueName, cpp_code_name + cachedOpInfo <- update_cachedOpInfo(code, auxEnv$where) + if(cachedOpInfo$case == "nFunction") { + uniqueName <- cachedOpInfo$obj_internals$uniqueName + if(length(uniqueName)==0) + stop( + exprClassProcessingErrorMsg( code, - paste0(code$name, ' is being used as a function, but it is not a nFunction.')), - call. = FALSE) - } + paste0('nFunction ', code$name, 'is being called, ', + 'but it is malformed because it has no internal name.')), + call. = FALSE) + if(is.null(auxEnv$needed_nFunctions[[uniqueName]])) { + ## We could put the nFunction itself in the needed_nFunctions list, + ## but we do not as a way to avoid having many references to R6 objects + ## in a blind attempt to facilitate garbage collection based on past experience. + ## Instead, we provide what is needed to look up the nFunction again later. + auxEnv$needed_nFunctions[[uniqueName]] <- list(code$name, auxEnv$where) } } - # There is also the nFunctionRef to think about. That is a bit more of a type however, not a call. - if(!is.null(opInfo)) { - matchDef <- opInfo[["matchDef"]] - if(!is.null(matchDef)) { - matched_code <- exprClass_put_args_in_order(matchDef, code, opInfo$compileArgs) - code <- replaceArgInCaller(code, matched_code) - } - handlingInfo <- opInfo[["normalizeCalls"]] - if(!is.null(handlingInfo)) { - handler <- handlingInfo[['handler']] - if(!is.null(handler)) { - if (logging) - appendToLog(paste('Calling handler', handler, 'for', code$name)) - ans <- eval(call(handler, code, symTab, auxEnv, handlingInfo), - envir = normalizeCallsEnv) - nErrorEnv$stateInfo <- character() - if (logging) { - appendToLog(paste('Finished handling', handler, 'for', code$name)) - logAST(code, paste('Resulting AST for', code$name), showImpl = FALSE) - } - return(ans) - } - } - } - # default behavior if there is no handler (which will be for many or most calls) + normalizeCallsEnv$recurse_normalizeCalls(code, symTab, auxEnv, handlingInfo) + + opDef <- cachedOpInfo$opDef + matchDef <- opDef[["matchDef"]] + if(is.null(matchDef)) + matchDef <- cachedOpInfo$obj_internals$default_matchDef + if(!is.null(matchDef)) { + exprClass_put_args_in_order(matchDef, code, opDef$compileArgs) + # code <- replaceArgInCaller(code, matched_code) + } } - # Where to put a generic recursion call? nErrorEnv$stateInfo <- character() invisible(NULL) } @@ -131,48 +102,145 @@ inNormalizeCallsEnv( } ) -## inNormalizeCallsEnv( -## skip <- -## function(code, symTab, auxEnv, handlingInfo, -## useArgs = rep(TRUE, length(code$args))) { -## NULL -## } -## ) +check_cachedOpInfo <- function(code, where=NULL, update=TRUE, allowFail=FALSE) { + cachedOpInfo <- code$aux$cachedOpInfo + up_to_date <- !is.null(cachedOpInfo) && cachedOpInfo$name == code$name + if(up_to_date) return(cachedOpInfo) + if(update) { + if(!is.environment(where)) + stop( + exprClassProcessingErrorMsg( + code, + "Internal error: check_cachedOpInfo called with update=TRUE but non-environment 'where' argument.") + , call. = FALSE) + update_cachedOpInfo(code, where, allowFail=allowFail) + } else { + NULL + } +} -inNormalizeCallsEnv( - convert_nFunction_or_method_AST <- - function(code) { - nFunctionName <- code$name - obj_internals <- code$aux$obj_internals - code$aux$obj_internals <- NULL - opDef <- obj_internals$compileInfo$opDef - matched_code <- exprClass_put_args_in_order(def=opDef$matchDef, expr=code, compileArgs = opDef$compileArgs) - code <- replaceArgInCaller(code, matched_code) - ## Note that the string `NFCALL_` matches the operatorDef entry. - ## Therefore the change-of-name here will automatically trigger use of - ## the 'NFCALL_' operatorDef in later stages. - newExpr <- wrapInExprClass(code, 'NFCALL_', "call") - # code$name <- 'NFCALL_' - cpp_code_name <- obj_internals$cpp_code_name - # fxnNameExpr <- exprClass$new(name = cpp_code_name, isName = TRUE, - # isCall = FALSE, isLiteral = FALSE, isAssign = FALSE) - newExpr$aux$obj_internals <- obj_internals - # newExpr$aux$nFunctionName <- nFunctionName - newExpr$aux$cpp_code_name <- cpp_code_name - ## We may need to add content to this symbol if - ## necessary for later processing steps. - ## insertArg(code, 1, fxnNameExpr, "FUN_") - obj_internals <- NULL - invisible(NULL) +update_cachedOpInfo <- function(code, where, allowFail=FALSE) { + opDef <- NULL + cachedOpInfo <- list(name = code$name) + ## Note that if we are in a method, auxEnv$where will be the + ## generator, which is an environment. + is_NCgenerator <- isNCgenerator(where) + if(is.null(opDef)) { # for good form, in case there a prior case is added later + # First look for an nClass method. + obj <- NULL + if(is.null(obj)) { + if(is_NCgenerator) {## We are in a class method (by direct call within another class method, no `$` involved) + # NC_find_method checks the class and its parents + obj <- NC_find_method(where, code$name, inherits=TRUE) + if(!is.null(obj)) { + if(isNF(obj)) { + cachedOpInfo$case <- "nClass method" # possibly disambiguate method from keyword + } else { + stop(exprClassProcessingErrorMsg(code, + paste0('method ', code$name, 'is being called, but it is not a nFunction.')), + call. = FALSE) + } + } + } } -) - -inNormalizeCallsEnv( - nFunction_or_method_call <- - function(code, symTab, auxEnv, handlingInfo) { - recurse_normalizeCalls(code, symTab, auxEnv, - handlingInfo) - convert_nFunction_or_method_AST(code) - NULL + if(is.null(obj)) { + # Next look for an nFunction (that is not a method) + # + # N.B. nGet follows lexical scoping from either an nFunction or nClass + # whatever "where" is. (In the nClass case, + # this is different than finding a method as above. + # Instead it is following scoping to find nFunctions.) + obj <- nGet(code$name, where = where) + if(!is.null(obj)) { + if(isNF(obj)) { + # There is no error trapping if obj is not an nFunction, because + # it could be simply an R function, since nGet (via get0) may traverse up to R_GlobalEnv. + cachedOpInfo$case <- "nFunction" + } else { + obj <- NULL # reset to NULL if not an nFunction + } + } } -) + if(!is.null(obj)) { + # We found an nFunction object that is either a method or not. + cachedOpInfo$obj_internals <- NFinternals(obj) + opDef <- cachedOpInfo$obj_internals$compileInfo$opDef # might be NULL + } + } + if(is.null(opDef)) { + ## At this point, we have not found an nFunction or nClass method. + if(is_NCgenerator) { + opDef <- where$compileInfo$opDefs[[code$name]] + if(!is.null(opDef)) { + cachedOpInfo$case <- "nClass method" # this could be a pure keyword or an nFunction with opDef provided at the nClass level + # a pure keyword will have obj_internals == NULL, providing a way to + # tell these cases apart later if necessary. + } + } + } + if(is.null(opDef)) { + if(!is.null(cachedOpInfo$case)) { + # We found an nFunction (method or not) but it did not have an opDef + # so here we insert the default opDef + if(cachedOpInfo$case == "nFunction" || cachedOpInfo$case == "nClass method") { + opDef <- getOperatorDef("nFunction_default") + } + } + } + if(is.null(opDef)) { + # We haven't found any form of nFunction or custom nClass keyword, + # so look for globally defined operators (e.g. "+" and all the basic DSL keywords) + opDef <- getOperatorDef(code$name) + if(!is.null(opDef)) { + cachedOpInfo$case <- "global" + } + } + if(is.null(opDef)) { + if(!allowFail) + stop(exprClassProcessingErrorMsg( + code, + paste0('No operator definition found for ', code$name, '.')), + call. = FALSE) + } + cachedOpInfo$opDef <- opDef + code$aux$cachedOpInfo <- cachedOpInfo + cachedOpInfo +} + +# inNormalizeCallsEnv( +# convert_nFunction_or_method_AST <- +# function(code) { +# nFunctionName <- code$name +# obj_internals <- code$aux$obj_internals +# code$aux$obj_internals <- NULL +# opDef <- obj_internals$compileInfo$opDef +# matched_code <- exprClass_put_args_in_order(def=opDef$matchDef, expr=code, compileArgs = opDef$compileArgs) +# code <- replaceArgInCaller(code, matched_code) +# ## Note that the string `NFCALL_` matches the operatorDef entry. +# ## Therefore the change-of-name here will automatically trigger use of +# ## the 'NFCALL_' operatorDef in later stages. +# newExpr <- wrapInExprClass(code, 'NFCALL_', "call") +# # code$name <- 'NFCALL_' +# cpp_code_name <- obj_internals$cpp_code_name +# # fxnNameExpr <- exprClass$new(name = cpp_code_name, isName = TRUE, +# # isCall = FALSE, isLiteral = FALSE, isAssign = FALSE) +# newExpr$aux$obj_internals <- obj_internals +# # newExpr$aux$nFunctionName <- nFunctionName +# newExpr$aux$cpp_code_name <- cpp_code_name +# ## We may need to add content to this symbol if +# ## necessary for later processing steps. +# ## insertArg(code, 1, fxnNameExpr, "FUN_") +# obj_internals <- NULL +# invisible(NULL) +# } +# ) + +# inNormalizeCallsEnv( +# nFunction_or_method_call <- +# function(code, symTab, auxEnv, handlingInfo) { +# recurse_normalizeCalls(code, symTab, auxEnv, +# handlingInfo) +# convert_nFunction_or_method_AST(code) +# NULL +# } +# ) diff --git a/nCompiler/R/compile_processAD.R b/nCompiler/R/compile_processAD.R index dbc78753..c85c43cb 100644 --- a/nCompiler/R/compile_processAD.R +++ b/nCompiler/R/compile_processAD.R @@ -52,7 +52,9 @@ compile_processAD <- function(code, return(invisible(NULL)) } - handlingInfo <- getOperatorDef(code$name, "processAD") + opInfo <- check_cachedOpInfo(code, where=auxEnv$where, update=TRUE) + handlingInfo <- getOperatorField(opInfo$opDef, "processAD") +# handlingInfo <- getOperatorDef(code$name, "processAD") # opInfo <- operatorDefEnv[[code$name]] # if(!is.null(opInfo)) { # handlingInfo <- opInfo[["processAD"]] diff --git a/nCompiler/R/compile_simpleTransformations.R b/nCompiler/R/compile_simpleTransformations.R index 67178014..2598bfa6 100644 --- a/nCompiler/R/compile_simpleTransformations.R +++ b/nCompiler/R/compile_simpleTransformations.R @@ -24,7 +24,9 @@ compile_simpleTransformations <- function(code, } } - handlingInfo <- getOperatorDef(code$name, opInfoName) + opInfo <- check_cachedOpInfo(code, where=auxEnv$where, update=TRUE, allowFail = TRUE) + handlingInfo <- getOperatorField(opInfo$opDef, opInfoName) + # handlingInfo <- getOperatorDef(code$name, opInfoName) # opInfo <- getOperatorDef(code$name) #operatorDefEnv[[code$name]] # if(!is.null(opInfo)) { # handlingInfo <- opInfo[[opInfoName]] diff --git a/nCompiler/tests/testthat/nCompile_tests/test-userOps.R b/nCompiler/tests/testthat/nCompile_tests/test-userOps.R new file mode 100644 index 00000000..55eac225 --- /dev/null +++ b/nCompiler/tests/testthat/nCompile_tests/test-userOps.R @@ -0,0 +1,12 @@ +# test user-defined operator definitions +# These can be: +# 1. provided globally to define (or take over) a new keyword via registerOpDef +# - precedence goes to a user-defined opDef over a built-in one +# 2. provided as an opDef list to an nFunction +# 3. provided as an opDef list (within compileInfo) to an nFunction that is an nClass method +# 4. provided as a list of opDef lists (within compileInfo) to an nClass + +# The definition "closest" to the nFunction takes precedence. +# e.g. if an nFunction has an opDef list, that takes precedence over +# an opDef list provided to the nClass that contains the nFunction. + From f4a47b9fc233c657bfd775137a83fed41da802a0 Mon Sep 17 00:00:00 2001 From: perrydv Date: Wed, 10 Sep 2025 12:32:46 +0200 Subject: [PATCH 2/3] move recurse_normalizeCalls after arg ordering and extracting compile-time args --- nCompiler/R/compile_normalizeCalls.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nCompiler/R/compile_normalizeCalls.R b/nCompiler/R/compile_normalizeCalls.R index 25b367e4..2725c64f 100644 --- a/nCompiler/R/compile_normalizeCalls.R +++ b/nCompiler/R/compile_normalizeCalls.R @@ -69,8 +69,6 @@ compile_normalizeCalls <- function(code, } } - normalizeCallsEnv$recurse_normalizeCalls(code, symTab, auxEnv, handlingInfo) - opDef <- cachedOpInfo$opDef matchDef <- opDef[["matchDef"]] if(is.null(matchDef)) @@ -79,6 +77,7 @@ compile_normalizeCalls <- function(code, exprClass_put_args_in_order(matchDef, code, opDef$compileArgs) # code <- replaceArgInCaller(code, matched_code) } + normalizeCallsEnv$recurse_normalizeCalls(code, symTab, auxEnv, handlingInfo) } nErrorEnv$stateInfo <- character() invisible(NULL) From f0a3b871519cb57051c539b58af4eb6dfac868b0 Mon Sep 17 00:00:00 2001 From: perrydv Date: Wed, 10 Sep 2025 14:52:34 +0200 Subject: [PATCH 3/3] tests and cleanup for user-defined opDefs --- nCompiler/R/compile_eigenization.R | 40 ++- nCompiler/R/compile_generateCpp.R | 11 +- nCompiler/R/compile_labelAbstractTypes.R | 7 +- nCompiler/R/compile_normalizeCalls.R | 2 +- .../nCompile_tests/test-compileNimble.R | 58 ---- .../testthat/nCompile_tests/test-userOps.R | 303 ++++++++++++++++++ 6 files changed, 340 insertions(+), 81 deletions(-) diff --git a/nCompiler/R/compile_eigenization.R b/nCompiler/R/compile_eigenization.R index 9d6bbb6e..adf981f7 100644 --- a/nCompiler/R/compile_eigenization.R +++ b/nCompiler/R/compile_eigenization.R @@ -82,14 +82,18 @@ compile_eigenize <- function(code, if(!is.null(handlingInfo)) { beforeHandler <- handlingInfo[['beforeHandler']] if(!is.null(beforeHandler)) { - setupExprs <- c(setupExprs, - eval(call(beforeHandler, - code, - symTab, - auxEnv, - workEnv, - handlingInfo), - envir = eigenizeEnv)) + if(is.function(beforeHandler)) + setupExprs <- c(setupExprs, + beforeHandler(code, symTab, auxEnv, workEnv, handlingInfo)) + else + setupExprs <- c(setupExprs, + eval(call(beforeHandler, + code, + symTab, + auxEnv, + workEnv, + handlingInfo), + envir = eigenizeEnv)) # return(if(length(setupExprs) == 0) NULL else setupExprs) } } @@ -112,14 +116,18 @@ compile_eigenize <- function(code, if(!is.null(handlingInfo)) { handler <- handlingInfo[['handler']] if(!is.null(handler)) { - setupExprs <- c(setupExprs, - eval(call(handler, - code, - symTab, - auxEnv, - workEnv, - handlingInfo), - envir = eigenizeEnv)) + if(is.function(handler)) + setupExprs <- c(setupExprs, + handler(code, symTab, auxEnv, workEnv, handlingInfo)) + else + setupExprs <- c(setupExprs, + eval(call(handler, + code, + symTab, + auxEnv, + workEnv, + handlingInfo), + envir = eigenizeEnv)) } } # } diff --git a/nCompiler/R/compile_generateCpp.R b/nCompiler/R/compile_generateCpp.R index e33d8287..ce0a932e 100644 --- a/nCompiler/R/compile_generateCpp.R +++ b/nCompiler/R/compile_generateCpp.R @@ -105,10 +105,13 @@ compile_generateCpp <- function(code, if(!is.null(handler)) { if (logging) appendToLog(paste('Calling handler', handler, 'for', code$name)) - res <- eval(call(handler, - code, - symTab), - envir = genCppEnv) + if(is.function(handler)) + res <- handler(code, symTab) + else + res <- eval(call(handler, + code, + symTab), + envir = genCppEnv) if (logging) { appendToLog(paste('Finished handling', handler, 'for', code$name, 'with result:')) diff --git a/nCompiler/R/compile_labelAbstractTypes.R b/nCompiler/R/compile_labelAbstractTypes.R index b28ce0ad..13ecced5 100644 --- a/nCompiler/R/compile_labelAbstractTypes.R +++ b/nCompiler/R/compile_labelAbstractTypes.R @@ -112,8 +112,11 @@ compile_labelAbstractTypes <- function(code, if(!is.null(handler)) { if (logging) appendToLog(paste('Calling handler', handler, 'for', code$name)) - ans <- eval(call(handler, code, symTab, auxEnv, handlingInfo), - envir = labelAbstractTypesEnv) + if(is.function(handler)) + ans <- handler(code, symTab, auxEnv, handlingInfo) + else + ans <- eval(call(handler, code, symTab, auxEnv, handlingInfo), + envir = labelAbstractTypesEnv) nErrorEnv$stateInfo <- character() if (logging) { appendToLog(paste('Finished handling', handler, 'for', code$name)) diff --git a/nCompiler/R/compile_normalizeCalls.R b/nCompiler/R/compile_normalizeCalls.R index 2725c64f..020b7f43 100644 --- a/nCompiler/R/compile_normalizeCalls.R +++ b/nCompiler/R/compile_normalizeCalls.R @@ -169,7 +169,7 @@ update_cachedOpInfo <- function(code, where, allowFail=FALSE) { if(is.null(opDef)) { ## At this point, we have not found an nFunction or nClass method. if(is_NCgenerator) { - opDef <- where$compileInfo$opDefs[[code$name]] + opDef <- NCinternals(where)$compileInfo$opDefs[[code$name]] if(!is.null(opDef)) { cachedOpInfo$case <- "nClass method" # this could be a pure keyword or an nFunction with opDef provided at the nClass level # a pure keyword will have obj_internals == NULL, providing a way to diff --git a/nCompiler/tests/testthat/nCompile_tests/test-compileNimble.R b/nCompiler/tests/testthat/nCompile_tests/test-compileNimble.R index 04a94ee8..b90c4cd1 100644 --- a/nCompiler/tests/testthat/nCompile_tests/test-compileNimble.R +++ b/nCompiler/tests/testthat/nCompile_tests/test-compileNimble.R @@ -33,61 +33,3 @@ test_that("compileNimble bridge works for one nimbleFunction object", { ## add nClass to nCompiler:::compileNimble ## ## document, document, document - -test_that("registering a user-defined operator definition (opDef) works", { - ## first version: provide a function - nimArrayHandler <- function(code,...) { - code$name <- 'nArray' - NULL - } - # This test works by: - # providing a handler to relpace "nimArray" with "nArray" - # and a handler to replace "nimArray2" with "nArray" to - # check on handling multiple cases. - registerOpDef( - list(nimArray = - list( - matchDef = function(value=0, dim=c(1,1), init=TRUE, - fillZeros=TRUE, recycle=TRUE, nDim, - type="double") {}, - # normalizeCalls=list(handler='skip'), - simpleTransformations=list(handler = nimArrayHandler)))) - expect_equal(ls(`:::`("nCompiler", "operatorDefUserEnv")), "nimArray") - - registerOpDef( - list(nimArray2 = - list( - matchDef = function(value=0, dim=c(1,1), init=TRUE, - fillZeros=TRUE, recycle=TRUE, nDim, - type="double") {}, - simpleTransformations=list(handler = 'replace', - replacement = 'nArray')))) - expect_equal(ls(`:::`("nCompiler", "operatorDefUserEnv")), c("nimArray", "nimArray2")) - - nc <- nClass( - Cpublic = list( - foo = nFunction( - function() { - ans <- nimArray( 6, dim = 2) - ans2 <- nArray(value = 5, dim = 2) - return(ans) - returnType('double(1)') - } - ), - foo2 = nFunction( - function() { - ans <- nimArray2(3,dim = 2) - return(ans) - returnType('double(1)') - }) - )) - Cnc <- nCompile(nc) - obj <- Cnc$new() - expect_identical(obj$foo(), c(6, 6)) - expect_identical(obj$foo2(), c(3, 3)) - rm(obj); gc() - # - deregisterOpDef("nimArray") - deregisterOpDef("nimArray2") - expect_equal(length(ls(`:::`("nCompiler", "operatorDefUserEnv"))), 0) -}) diff --git a/nCompiler/tests/testthat/nCompile_tests/test-userOps.R b/nCompiler/tests/testthat/nCompile_tests/test-userOps.R index 55eac225..3c3274d5 100644 --- a/nCompiler/tests/testthat/nCompile_tests/test-userOps.R +++ b/nCompiler/tests/testthat/nCompile_tests/test-userOps.R @@ -10,3 +10,306 @@ # e.g. if an nFunction has an opDef list, that takes precedence over # an opDef list provided to the nClass that contains the nFunction. +test_that("registering a global user-defined operator definition (opDef) works", { + ## first version: provide a function + nimArrayHandler <- function(code,...) { + code$name <- 'nArray' + NULL + } + # This test works by: + # providing a handler to relpace "nimArray" with "nArray" + # and a handler to replace "nimArray2" with "nArray" to + # check on handling multiple cases. + registerOpDef( + list(nimArray = + list( + matchDef = function(value=0, dim=c(1,1), init=TRUE, + fillZeros=TRUE, recycle=TRUE, nDim, + type="double") {}, + simpleTransformations=list(handler = nimArrayHandler)))) + expect_equal(ls(`:::`("nCompiler", "operatorDefUserEnv")), "nimArray") + + registerOpDef( + list(nimArray2 = + list( + matchDef = function(value=0, dim=c(1,1), init=TRUE, + fillZeros=TRUE, recycle=TRUE, nDim, + type="double") {}, + simpleTransformations=list(handler = 'replace', + replacement = 'nArray')))) + expect_equal(ls(`:::`("nCompiler", "operatorDefUserEnv")), c("nimArray", "nimArray2")) + + nc <- nClass( + Cpublic = list( + foo = nFunction( + function() { + ans <- nimArray( 6, dim = 2) + ans2 <- nArray(value = 5, dim = 2) + return(ans) + returnType('double(1)') + } + ), + foo2 = nFunction( + function() { + ans <- nimArray2(3,dim = 2) + return(ans) + returnType('double(1)') + }) + )) + Cnc <- nCompile(nc) + obj <- Cnc$new() + expect_identical(obj$foo(), c(6, 6)) + expect_identical(obj$foo2(), c(3, 3)) + rm(obj); gc() + # + deregisterOpDef("nimArray") + deregisterOpDef("nimArray2") + expect_equal(length(ls(`:::`("nCompiler", "operatorDefUserEnv"))), 0) +}) + +cat("User opDef could be dangerous prior to genCpp because it won't update cachedOpDef\n") + +test_that("nFunction custom opDef works through a sequence of changes and handlers", { + # We set up a series of messages and function renamings to track + # the custom opDef through each one and also + # show that it is being updated at each step. + # Note the manual updating during eigenImpl. + # + # foo is an nFunction with compileArgs and a simpleTrans handler that + # renames it foo2 and emits a msg. + check_V <- function(code, ...) { + cat("MSG1: check_V was called. ") + if(code$aux$compileArgs$V == "W") + cat("MSG2: compile arg V was found. ") + code$name <- "foo2" + } + custom_opDef <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef$matchDef <- function(V) {} + custom_opDef$compileArgs <- "V" + custom_opDef$simpleTransformations <- list( + handler = check_V + ) + foo <- nFunction( + fun = function() { + return(1.2) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef) + ) + + # foo2 is an nFunction with custom labelAbstractTypes handler + # that emits a msg and renames it to foo3. + check_LAT <- function(code, symTab, auxEnv, handlingInfo) { + cat("MSG3: LAT for foo2 was used. ") + handler <- nCompiler:::getOperatorDef("nFunction_default")$labelAbstractTypes$handler + ans <- eval(call(handler, code,symTab,auxEnv,handlingInfo), + envir=nCompiler:::labelAbstractTypesEnv) + code$name <- "foo3" + ans + } + custom_opDef2 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef2$labelAbstractTypes <- list( + handler = check_LAT + ) + foo2 <- nFunction( + fun = function() { + return(2.3) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef2) + ) + + # foo3 is an nFunction with a custom eigenImpl handler + # that emits, a msg, renames it, and updated the cachedOpInfo + check_EIG <- function(code, symTab, auxEnv, workEnv, handlingInfo) { + cat("MSG4: LAT for foo3 was used. ") + code$name <- "foo4" + nCompiler:::update_cachedOpInfo(code, auxEnv$where) + NULL + } + custom_opDef3 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef3$eigenImpl <- list( + handler = check_EIG + ) + foo3 <- nFunction( + fun = function() { + return(3.4) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef3) + ) + + ## foo4 is an nFunction with a custom cppOutput handler + # that emits a msg and pastes "10.0*3.4 + " in front of the + # default output. + check_GenCpp <- function(code, symTab) { +# cat("MSG5: GenCpp for foo4 was used.") + handler <- nCompiler:::getOperatorDef("nFunction_default")$cppOutput$handler + base_out <- eval(call(handler, code, symTab), + envir=nCompiler:::genCppEnv) + paste0("10.0*3.4 + ", base_out, collapse = "") + } + custom_opDef4 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef4$cppOutput <- list( + handler = check_GenCpp + ) + foo4 <- nFunction( + fun = function() { + return(4.5) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef4) + ) + + call_foo <- nFunction( + fun = function() { + ans <- foo(V = "W"); return(ans) + }, argTypes=list(), returnType=quote(double()) + ) + out <- capture_output( + cppDefs <- nCompile(call_foo, control=list(return_cppDefs=TRUE)) + ) + check <- grepl("^MSG1", out) + expect_true(check) + check <- grepl("MSG2", out) + expect_true(check) + check <- grepl("MSG3", out) + expect_true(check) + check <- grepl("MSG4", out) + expect_true(check) + + out_code <- cppDefs[[1]]$generate() |> unlist() + check <- grepl("10\\.0\\*3\\.4", out_code) |> sum() + expect_true(check==1) + + cat("#include needs are not cleaned up if an nFunction is renamed by a handler to another nFunction. All must be included in nCompile.") + comp <- nCompile(foo, foo2, foo3, foo4, call_foo) + expect_equal(comp$call_foo(), 10*3.4 + 4.5) + +}) + +test_that("nClass custom opDefs of 3 kinds works through a sequence of changes and handlers", { + + # foo is an nFunction with compileArgs and a simpleTrans handler that + # renames it foo2 and emits a msg. + check_V <- function(code, ...) { + cat("MSG1: check_V was called. ") + if(code$aux$compileArgs$V == "W") + cat("MSG2: compile arg V was found. ") + code$name <- "foo2" + } + custom_opDef <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef$matchDef <- function(V) {} + custom_opDef$compileArgs <- "V" + custom_opDef$simpleTransformations <- list( + handler = check_V + ) + foo <- nFunction( + fun = function() { + return(1.2) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef) + ) + + # foo2 is an nFunction with custom labelAbstractTypes handler + # that emits a msg and renames it to foo3. + # It's handler will be put at the nClass level + check_LAT <- function(code, symTab, auxEnv, handlingInfo) { + cat("MSG3: LAT for foo2 was used. ") + handler <- nCompiler:::getOperatorDef("nFunction_default")$labelAbstractTypes$handler + ans <- eval(call(handler, code,symTab,auxEnv,handlingInfo), + envir=nCompiler:::labelAbstractTypesEnv) + code$name <- "foo3" + ans + } + custom_opDef2 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef2$labelAbstractTypes <- list( + handler = check_LAT + ) + foo2 <- nFunction( + fun = function() { + return(2.3) + }, argTypes=list(), returnType=quote(double()), +# compileInfo=list(opDef=custom_opDef2) + ) + + # foo3 is an nFunction with a custom eigenImpl handler + # that emits, a msg, renames it, and updated the cachedOpInfo + # This will be treated as a keyword opDef, only, without an + # actual nFunction. + check_EIG <- function(code, symTab, auxEnv, workEnv, handlingInfo) { + cat("MSG4: LAT for foo3 was used. ") + code$name <- "foo4" + nCompiler:::update_cachedOpInfo(code, auxEnv$where) + NULL + } + custom_opDef3 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef3$eigenImpl <- list( + handler = check_EIG + ) +# foo3 <- nFunction( +# fun = function() { +# return(3.4) +# }, argTypes=list(), returnType=quote(double()), +# compileInfo=list(opDef=custom_opDef3) +# ) + + ## foo4 is an nFunction with a custom cppOutput handler + # that emits a msg and pastes "10.0*3.4 + " in front of the + # default output. + # This will remain an nFunction outside the nClass. + check_GenCpp <- function(code, symTab) { +# cat("MSG5: GenCpp for foo4 was used.") + handler <- nCompiler:::getOperatorDef("nFunction_default")$cppOutput$handler + base_out <- eval(call(handler, code, symTab), + envir=nCompiler:::genCppEnv) + paste0("10.0*3.4 + ", base_out, collapse = "") + } + custom_opDef4 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef4$cppOutput <- list( + handler = check_GenCpp + ) + foo4 <- nFunction( + fun = function() { + return(4.5) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef4) + ) + + call_foo <- nFunction( + fun = function() { + ans <- foo(V = "W"); return(ans) + }, argTypes=list(), returnType=quote(double()) + ) + + foo_class <- nClass( + Cpublic = list( + foo = foo, + foo2 = foo2, + call_foo = call_foo + ), + compileInfo = list( + opDefs = list( + foo2 = custom_opDef2, + foo3 = custom_opDef3 + ) + ) + ) + + out <- capture_output( + cppDefs <- nCompile(foo_class, control=list(return_cppDefs=TRUE)) + ) + check <- grepl("^MSG1", out) + expect_true(check) + check <- grepl("MSG2", out) + expect_true(check) + check <- grepl("MSG3", out) + expect_true(check) + check <- grepl("MSG4", out) + expect_true(check) + + out_code <- cppDefs[[1]]$generate() |> unlist() + check <- grepl("10\\.0\\*3\\.4", out_code) |> sum() + expect_true(check==1) + + comp <- nCompile(foo4, foo_class) + obj <- comp$foo_class$new() + expect_equal(obj$call_foo(), 10*3.4 + 4.5) + rm(obj); gc() +})