diff --git a/.github/workflows/test-all.yaml b/.github/workflows/test-all.yaml index 5b299222..c0428946 100644 --- a/.github/workflows/test-all.yaml +++ b/.github/workflows/test-all.yaml @@ -7,18 +7,30 @@ on: run_tests: description: 'Run all tests' required: true + default: 'no' + run_nCompile: + description: 'Run nCompile tests' + required: false + default: 'yes' + run_nClass: + description: 'Run nClass tests' + required: false default: 'yes' + run_math: + description: 'Run math tests' + required: false + default: 'no' run_tensorOps: description: 'Run tensorOps tests' required: false - default: 'yes' + default: 'no' jobs: test-nCompile: runs-on: ubuntu-latest container: image: rocker/r2u:latest - if: github.event.inputs.run_tests == 'yes' + if: github.event.inputs.run_tests == 'yes' || github.event.inputs.run_nCompile == 'yes' steps: - uses: actions/checkout@v3 - name: SessionInfo @@ -26,7 +38,7 @@ jobs: - name: Package Dependencies run: R -q -e 'remotes::install_deps("nCompiler", dependencies=TRUE)' - name: Install inline - run: R -q -e 'remotes::install_cran(c("inline", "nimble"))' + run: R -q -e 'remotes::install_cran("inline")' - name: Build Package run: | R CMD build nCompiler @@ -44,7 +56,7 @@ jobs: runs-on: ubuntu-latest container: image: rocker/r2u:latest - if: github.event.inputs.run_tests == 'yes' + if: github.event.inputs.run_tests == 'yes' || github.event.inputs.run_nClass == 'yes' steps: - uses: actions/checkout@v3 - name: SessionInfo @@ -55,7 +67,7 @@ jobs: - name: Package Dependencies run: R -q -e 'remotes::install_deps("nCompiler", dependencies=TRUE)' - name: Install inline - run: R -q -e 'remotes::install_cran("inline")' + run: R -q -e 'remotes::install_cran(c("inline", "nimble"))' - name: Build Package run: | R CMD build nCompiler @@ -63,6 +75,7 @@ jobs: - name: Run nCompile and other tests run: | library(nCompiler) + testthat::test_dir("nCompiler/tests/testthat/nimble_tests", reporter = "summary") testthat::test_dir("nCompiler/tests/testthat/nClass_tests", reporter = "summary") testthat::test_dir("nCompiler/tests/testthat/types_tests", reporter = "summary") testthat::test_dir("nCompiler/tests/testthat/serialization_tests", reporter = "summary") @@ -72,7 +85,7 @@ jobs: runs-on: ubuntu-latest container: image: rocker/r2u:latest - if: github.event.inputs.run_tests == 'yes' + if: github.event.inputs.run_tests == 'yes' || github.event.inputs.run_math == 'yes' steps: - uses: actions/checkout@v3 - name: SessionInfo diff --git a/.github/workflows/test-suites.yml b/.github/workflows/test-suites.yml deleted file mode 100644 index bac3a2a2..00000000 --- a/.github/workflows/test-suites.yml +++ /dev/null @@ -1,269 +0,0 @@ -# .github/workflows/test-suites.yml -name: Test Suites - -on: - workflow_dispatch: - inputs: - install_and_cache: - description: 'Run and cache installation steps only' - required: true - default: 'no' - run_tests: - description: 'Run all tests (assuming [or after] cache is up-to-date)' - required: true - default: 'yes' - -# env: -# RSPM: https://packagemanager.posit.co/cran/latest # Enables Linux binaries from Posit -# R_KEEP_PKG_SOURCE: yes # Keeps sources for debugging if needed -# R_CALLR_ENV: full - -jobs: - install-cache: - runs-on: ubuntu-latest - if: github.event.inputs.install_and_cache == 'yes' - - steps: - - uses: actions/checkout@v4 - - - uses: r-lib/actions/setup-pandoc@v2 - - - uses: r-lib/actions/setup-r@v2 - with: - r-version: 'release' - use-public-rspm: true - - # - name: Cache R packages - # uses: actions/cache@v3 - # with: - # path: ${{ env.R_LIBS_USER }} - # key: ${{ runner.os }}-r-${{ hashFiles('nCompiler/DESCRIPTION', 'ci-extra-packages.txt') }} - # restore-keys: ${{ runner.os }}-r- - -# setup-r-dependencies gave problems so we are doing the recommended packages manually - # - name: Install system dependencies - # run: | - # sudo apt-get update - # sudo apt-get install -y \ - # build-essential \ - # libcurl4-openssl-dev \ - # libssl-dev \ - # libgit2-dev \ - # libxml2-dev \ - # libfontconfig1-dev \ - # libfreetype6-dev \ - # libharfbuzz-dev \ - # libfribidi-dev \ - # libpng-dev \ - # libjpeg-dev \ - # libtiff5-dev \ - # libglib2.0-dev \ - # libpango1.0-dev \ - # libicu-dev - - - name: Install R dependencies - run: | - Rscript -e 'install.packages("remotes")' - Rscript -e 'remotes::install_deps("nCompiler", dependencies = TRUE)' - - - name: Install extra packages - run: | - Rscript -e ' - pkgs <- readLines("ci-extra-packages.txt") - new_pkgs <- pkgs[!pkgs %in% installed.packages()[,"Package"]] - if(length(new_pkgs)) { - install.packages(new_pkgs) - stopifnot(all(pkgs %in% rownames(installed.packages()))) - } - ' - - # Job 1: nCompile and specific operator tests - test-nCompile: - runs-on: ubuntu-latest - if: github.event.inputs.run_tests == 'yes' - name: nCompile and specific operator tests - - steps: - - uses: actions/checkout@v4 - - uses: r-lib/actions/setup-r@v2 - with: - r-version: 'release' - use-public-rspm: true - - - name: Restore cache - uses: actions/cache@v3 - with: - path: ${{ env.R_LIBS_USER }} - key: ${{ runner.os }}-r-${{ hashFiles('nCompiler/DESCRIPTION', 'ci-extra-packages.txt') }} - restore-keys: ${{ runner.os }}-r- - - - name: Ensure R binary path is first in PATH - run: echo "$(Rscript -e 'cat(R.home())')" >> $GITHUB_PATH - - # - name: Add R to PATH - # run: | - # echo "/usr/local/bin" >> $GITHUB_PATH - - - name: Check cache contents - run: | - Rscript -e ' - cat(Sys.which("R"), "\n") - print(.libPaths()); - installed <- installed.packages()[,1]; - cat("Installed packages:\n", paste(installed, collapse=", "), "\n")' - - - name: Debug R - run: | - Rscript -e 'cat("R.home():", R.home(), "\nR.home(bin):", R.home("bin"), "\n")' - echo "PATH=$PATH" - which R - R --version - - - name: Install package - run: R CMD INSTALL --install-tests nCompiler - - - name: Run nCompile and other tests - run: | - library(nCompiler) - library(devtools) - library(withr) - library(callr) - testthat::test_dir("nCompiler/tests/testthat/uncompiled_tests", reporter = "summary") - testthat::test_dir("nCompiler/tests/testthat/nCompile_tests", reporter = "summary") - testthat::test_dir("nCompiler/tests/testthat/cpp_tests", reporter = "summary") - testthat::test_dir("nCompiler/tests/testthat/specificOp_tests", reporter = "summary") - shell: Rscript {0} - - # Job 2: nClass, types and serialization tests - test-nClass: - runs-on: ubuntu-latest - if: github.event.inputs.run_tests == 'yes' - name: nClass, types and serialization tests - - steps: - - uses: actions/checkout@v4 - - uses: r-lib/actions/setup-r@v2 - with: - r-version: 'release' - use-public-rspm: true - - - name: Restore cache - uses: actions/cache@v3 - with: - path: ${{ env.R_LIBS_USER }} - key: ${{ runner.os }}-r-${{ hashFiles('nCompiler/DESCRIPTION', 'ci-extra-packages.txt') }} - restore-keys: ${{ runner.os }}-r- - - - name: Add R to PATH - run: | - echo "/usr/local/bin" >> $GITHUB_PATH - - - name: Check cache contents - run: | - Rscript -e ' - cat(Sys.which("R"), "\n") - print(.libPaths()); - installed <- installed.packages()[,1]; - cat("Installed packages:\n", paste(installed, collapse=", "), "\n")' - - - name: Install package - run: R CMD INSTALL --install-tests nCompiler - - - name: Run nClass, types and serialization tests - run: | - library(nCompiler) - library(devtools) - library(withr) - library(callr) - testthat::test_dir("nCompiler/tests/testthat/nClass_tests", reporter = "summary") - testthat::test_dir("nCompiler/tests/testthat/types_tests", reporter = "summary") - testthat::test_dir("nCompiler/tests/testthat/serialization_tests", reporter = "summary") - shell: Rscript {0} - - # Job 3: Math tests - test-math: - runs-on: ubuntu-latest - if: github.event.inputs.run_tests == 'yes' - name: Math Tests - - steps: - - uses: actions/checkout@v4 - - uses: r-lib/actions/setup-r@v2 - with: - r-version: 'release' - use-public-rspm: true - - - name: Restore cache - uses: actions/cache@v3 - with: - path: ${{ env.R_LIBS_USER }} - key: ${{ runner.os }}-r-${{ hashFiles('nCompiler/DESCRIPTION', 'ci-extra-packages.txt') }} - restore-keys: ${{ runner.os }}-r- - - - name: Add R to PATH - run: | - echo "/usr/local/bin" >> $GITHUB_PATH - - - name: Check cache contents - run: | - Rscript -e ' - cat(Sys.which("R"), "\n") - print(.libPaths()); - installed <- installed.packages()[,1]; - cat("Installed packages:\n", paste(installed, collapse=", "), "\n")' - - - name: Install package - run: R CMD INSTALL --install-tests nCompiler - - - name: Run Math tests - run: | - library(nCompiler) - library(devtools) - library(withr) - library(callr) - testthat::test_dir("nCompiler/tests/testthat/math_tests", reporter = "summary") - shell: Rscript {0} - - # Job 4: tensorOps - test-tensorOps: - runs-on: ubuntu-latest - if: github.event.inputs.run_tests == 'yes' - name: TensorOps Tests - - steps: - - uses: actions/checkout@v4 - - uses: r-lib/actions/setup-r@v2 - with: - r-version: 'release' - use-public-rspm: true - - - name: Restore cache - uses: actions/cache@v3 - with: - path: ${{ env.R_LIBS_USER }} - key: ${{ runner.os }}-r-${{ hashFiles('nCompiler/DESCRIPTION', 'ci-extra-packages.txt') }} - restore-keys: ${{ runner.os }}-r- - - - name: Add R to PATH - run: | - echo "/usr/local/bin" >> $GITHUB_PATH - - - name: Check cache contents - run: | - Rscript -e ' - cat(Sys.which("R"), "\n") - print(.libPaths()); - installed <- installed.packages()[,1]; - cat("Installed packages:\n", paste(installed, collapse=", "), "\n")' - - - name: Install package - run: R CMD INSTALL --install-tests nCompiler - - - name: Run TensorOps tests - run: | - library(nCompiler) - library(devtools) - library(withr) - library(callr) - testthat::test_dir("nCompiler/tests/testthat/tensorOps_tests", reporter = "summary") - shell: Rscript {0} diff --git a/nCompiler/DESCRIPTION b/nCompiler/DESCRIPTION index 712e3eb4..41aed328 100644 --- a/nCompiler/DESCRIPTION +++ b/nCompiler/DESCRIPTION @@ -6,7 +6,7 @@ Authors@R: c(person("Perry", "de Valpine", email = "pdevalpine@berkeley.edu", ro person("Christopher", "Paciorek", role = "ctb"), person("James", "Duncan", role = "ctr")) Description: Provides nFunction and nClass for function and class definitions for which C++ can be automatically generated, or used in combination with other C++. Supports linear algebra by code-generating C++ that uses the Eigen library. Supports automatic differentiation by code-generating C++ that uses the CppAD library. Support for parallelization is planned by code-generating C++ that uses Intel Threading Building Blocks (TBB). -Depends: R (>= 3.3.0) +Depends: R (>= 4.1.0) Imports: methods,R6,Rcpp,pkgKitten,roxygen2 Suggests: @@ -49,6 +49,10 @@ Collate: cppDefs_TBB.R cppDefs_variables.R developerTools.R + symbolTable.R + symbolTableClass.R + symbolTable_utils.R + typeDeclarations.R NC.R NC_Compile.R NC_CompilerClass.R @@ -70,6 +74,7 @@ Collate: NF_Utils.R nCompile.R nConstructor.R + nimbleModels.R nList.R nTry.R packaging.R @@ -78,11 +83,7 @@ Collate: Rcpp_nCompiler.R Rexecution.R Rhooks.R - symbolTable.R - symbolTableClass.R - symbolTable_utils.R testingTools.R - typeDeclarations.R compile_zzz_operatorLists.R zzz_NC_Predefined.R RoxygenNote: 7.3.2 diff --git a/nCompiler/NAMESPACE b/nCompiler/NAMESPACE index 3a322f8f..664cb3ec 100644 --- a/nCompiler/NAMESPACE +++ b/nCompiler/NAMESPACE @@ -55,7 +55,10 @@ export(isUserDefined) export(logfact) export(loggam) export(logit) +export(makeModel_nClass) +export(make_node_fun) export(method) +export(modelBase_nClass) export(new.loadedObjectEnv) ## needed for Rcpp::Function access in loadedObjectEnv.h export(new.loadedObjectEnv_full) ## ditto export(nBacksolve) diff --git a/nCompiler/R/NC.R b/nCompiler/R/NC.R index 72c8153c..ccf0c75a 100644 --- a/nCompiler/R/NC.R +++ b/nCompiler/R/NC.R @@ -142,12 +142,15 @@ nClass <- function(classname, stop("In nFunction 'initialize', use 'compileInfo = list(constructor=TRUE)'.") } + inheritQ <- substitute(inherit) + inherit_provided <- !is.null(inheritQ) + internals = NC_InternalsClass$new(classname = classname, Cpublic = Cpublic, isOnlyC = length(Rpublic) == 0, enableDerivs = enableDerivs, enableSaving = enableSaving, - inherit = inherit, + inheritQ = inheritQ, # control = control, compileInfo = compileInfo, predefined = predefined, @@ -159,7 +162,10 @@ nClass <- function(classname, # "captured as an unevaluated expression which is evaluated in parent_env each time an object is instantiated." # so if provided in the nClass call, we stick it in new_env. # (That is not the only reason for new_env.) - if(!is.null(inherit)) new_env$.inherit_obj <- inherit + # Also note that the inherit arg is for nClass inheritance. The compileInfo$inherit element is for hard-coded C++ inheritance statements. + inheritQ <- substitute(inherit) + inherit_provided <- !is.null(inheritQ) + #if(!is.null(inherit)) new_env$.inherit_obj <- inherit new_env$.NCinternals <- internals # Uncompiled behavior for Cpublic fields needs to be handled. # Right now a type string like 'numericScalar' just becomes a @@ -173,7 +179,7 @@ nClass <- function(classname, parent_env = new_env ), list(INHERIT = - if(!is.null(inherit)) quote(.inherit_obj) + if(inherit_provided) inheritQ else quote(nClassClass)) )) ## 2. in the generator diff --git a/nCompiler/R/NC_FullCompiledInterface.R b/nCompiler/R/NC_FullCompiledInterface.R index e7ced115..7ee78c1a 100644 --- a/nCompiler/R/NC_FullCompiledInterface.R +++ b/nCompiler/R/NC_FullCompiledInterface.R @@ -77,9 +77,10 @@ build_compiled_nClass <- function(NCgenerator, interfaceMethods <- mapply(buildMethod_for_compiled_nClass, NCgenerator$public_methods[CmethodNames], CmethodNames) - if(!is.null(NCgenerator$parent_env$.inherit_obj)) { + inherit_obj <- NCgenerator$get_inherit() + if(isNCgenerator(inherit_obj)) { derivedNames <- c(derivedNames, CmethodNames) - baseNCgen <- NCgenerator$parent_env$.inherit_obj + baseNCgen <- inherit_obj baseCmethodNames <- NCinternals(baseNCgen)$methodNames baseCmethodNames <- setdiff(baseCmethodNames, derivedNames) # Note: baseCmethodNames could be empty but we still need to @@ -98,9 +99,10 @@ build_compiled_nClass <- function(NCgenerator, recurse_make_Rmethods <- function(NCgenerator, RmethodNames, derivedNames = character()) { interfaceMethods <- NCgenerator$public_methods[RmethodNames] - if(!is.null(NCgenerator$parent_env$.inherit_obj)) { + inherit_obj <- NCgenerator$get_inherit() + if(isNCgenerator(inherit_obj)) { derivedNames <- c(derivedNames, RmethodNames) - baseNCgen <- NCgenerator$parent_env$.inherit_obj + baseNCgen <- inherit_obj baseCmethodNames <- NCinternals(baseNCgen)$methodNames baseRmethodNames <- setdiff(names(baseNCgen$public_methods), c(baseCmethodNames, 'clone')) @@ -142,9 +144,10 @@ build_compiled_nClass <- function(NCgenerator, NCint <- NCinternals(NCgenerator) activeBindingResults <- buildActiveBinding_for_compiled_nClass(NCint, CfieldNames) - if(!is.null(NCgenerator$parent_env$.inherit_obj)) { + inherit_obj <- NCgenerator$get_inherit() + if(isNCgenerator(inherit_obj)) { derivedNames <- c(derivedNames, CfieldNames) - baseNCgen <- NCgenerator$parent_env$.inherit_obj + baseNCgen <- inherit_obj baseCfieldNames <- NCinternals(baseNCgen)$fieldNames baseCfieldNames <- setdiff(baseCfieldNames, derivedNames) baseActiveBindingResults <- @@ -166,9 +169,10 @@ build_compiled_nClass <- function(NCgenerator, recurse_make_Rfields <- function(NCgenerator, RfieldNames, derivedNames = character()) { interfaceFields <- NCgenerator$public_fields[RfieldNames] - if(!is.null(NCgenerator$parent_env$.inherit_obj)) { + inherit_obj <- NCgenerator$get_inherit() + if(isNCgenerator(inherit_obj)) { derivedNames <- c(derivedNames, RfieldNames) - baseNCgen <- NCgenerator$parent_env$.inherit_obj + baseNCgen <- inherit_obj baseCfieldNames <- NCinternals(baseNCgen)$fieldNames baseRfieldNames <- setdiff(names(baseNCgen$public_fields), c(baseCfieldNames, 'clone')) @@ -388,9 +392,10 @@ build_generic_fns_for_compiled_nClass <- function(NCgenerator) { interfaceFns <- mapply(build_generic_fn_for_compiled_nClass_method, NCgenerator$public_methods[CmethodNames], CmethodNames) - if(!is.null(NCgenerator$parent_env$.inherit_obj)) { + inherit_obj <- NCgenerator$get_inherit() + if(isNCgenerator(inherit_obj)) { derivedNames <- c(derivedNames, CmethodNames) - baseNCgen <- NCgenerator$parent_env$.inherit_obj + baseNCgen <- inherit_obj baseCmethodNames <- NCinternals(baseNCgen)$methodNames baseCmethodNames <- setdiff(baseCmethodNames, derivedNames) baseInterfaceFns <- recurse_make_Cmethods(baseNCgen, diff --git a/nCompiler/R/NC_InternalsClass.R b/nCompiler/R/NC_InternalsClass.R index 2a8016af..9820faa0 100644 --- a/nCompiler/R/NC_InternalsClass.R +++ b/nCompiler/R/NC_InternalsClass.R @@ -7,11 +7,14 @@ NC_InternalsClass <- R6::R6Class( cppSymbolNames = NULL, methodNames = character(), allMethodNames = character(), # including inherited methods + allMethodNames_self = character(), # not including inherited methods fieldNames = character(), allFieldNames = character(), # including inherited methods + allFieldNames_self = character(), # not including inherited methods classname = character(), cpp_classname = character(), compileInfo = list(), + inherit_base_provided = FALSE, # compileInfo will include interface ("full", "generic", or "none"), # interfaceMembers, exportName, and depends depends = list(), @@ -21,22 +24,20 @@ NC_InternalsClass <- R6::R6Class( enableSaving = NULL, predefined = FALSE, inheritNCinternals = NULL, + env = NULL, + inheritQ = NULL, + process_inherit_done = FALSE, initialize = function(classname, Cpublic, isOnlyC = FALSE, enableDerivs = NULL, enableSaving = get_nOption("enableSaving"), - inherit = NULL, + inheritQ = NULL, compileInfo = list(), predefined = FALSE, env = parent.frame()) { - if(!is.null(inherit)) { - self$inheritNCinternals <- NCinternals(inherit) - message("add check that base class has interface 'none'") - if(is.null(compileInfo$inherit$base)) - compileInfo$inherit$base <- paste("public", - inheritNCinternals$cpp_classname) - } + self$env <- env + self$inheritQ <- inheritQ self$compileInfo <- compileInfo self$classname <- classname self$cpp_classname <- Rname2CppName(classname) @@ -59,14 +60,13 @@ NC_InternalsClass <- R6::R6Class( self$symbolTable <- argTypeList2symbolTable(Cpublic[!isMethod], evalEnv = env) self$cppSymbolNames <- Rname2CppName(symbolTable$getSymbolNames()) self$methodNames <- names(Cpublic)[isMethod] + self$allMethodNames_self <- methodNames self$allMethodNames <- methodNames self$fieldNames <- names(Cpublic)[!isMethod] + self$allFieldNames_self <- fieldNames self$allFieldNames <- fieldNames - if(!is.null(inherit)) { - self$symbolTable$setParentST(inheritNCinternals$symbolTable) - self$allMethodNames <- c(self$allMethodNames, inheritNCinternals$allMethodNames) - self$allFieldNames <- c(self$allFieldNames, inheritNCinternals$allFieldNames) - } + if(!is.null(self$compileInfo$inherit$base)) + self$inherit_base_provided <- TRUE } if(!is.null(enableDerivs)) { if(!is.list(enableDerivs)) @@ -79,6 +79,35 @@ NC_InternalsClass <- R6::R6Class( } self$predefined <- predefined self$enableSaving <- enableSaving + }, + connect_inherit = function() { + # These are steps that need to be done after all classes are defined + # and do not require recursion up the inheritance tree. + if(!is.null(self$inheritQ)) { + inherit_obj <- eval(self$inheritQ, envir = self$env) + if(!isNCgenerator(inherit_obj)) + stop("An inherit argument that was provided to nClass is not nClass generator.") + self$inheritNCinternals <- NCinternals(inherit_obj) + message("add check that base class has interface 'none'") + if(!self$inherit_base_provided) + self$compileInfo$inherit$base <- paste("public", + self$inheritNCinternals$cpp_classname) + process_inherit_done <- FALSE + } else { + process_inherit_done <- TRUE + } + }, + process_inherit = function() { + # These are steps that need to be done after connect_inherit + # and require recursion up the inheritance tree, using flags. + if(self$process_inherit_done) return() + if(!is.null(self$inheritQ)) { + self$inheritNCinternals$process_inherit() + self$symbolTable$setParentST(self$inheritNCinternals$symbolTable) + self$allMethodNames <- c(self$allMethodNames_self, self$inheritNCinternals$allMethodNames) + self$allFieldNames <- c(self$allFieldNames_self, self$inheritNCinternals$allFieldNames) + } + self$process_inherit_done <- TRUE } ) ) diff --git a/nCompiler/R/NC_Utils.R b/nCompiler/R/NC_Utils.R index 793b1396..2431786f 100644 --- a/nCompiler/R/NC_Utils.R +++ b/nCompiler/R/NC_Utils.R @@ -96,8 +96,8 @@ NC_find_method <- function(NCgenerator, name, inherits=TRUE) { done <- TRUE } else { if(inherits) { - current_NCgen <- current_NCgen$parent_env$.inherit_obj # same as current_NCgen$get_inherit() if there is inheritance, but get_inherit returns the base class at the top - done <- is.null(current_NCgen) + current_NCgen <- current_NCgen$get_inherit() #parent_env$.inherit_obj # same as current_NCgen$get_inherit() if there is inheritance, but get_inherit returns the base class at the top + done <- !isNCgenerator(current_NCgen) } else done <- TRUE } diff --git a/nCompiler/R/NF.R b/nCompiler/R/NF.R index 045db877..c9d699d1 100644 --- a/nCompiler/R/NF.R +++ b/nCompiler/R/NF.R @@ -119,7 +119,7 @@ nFunction <- function(fun, # in the R function. control <- updateDefaults( list(returnInternals=FALSE, check=get_nOption('check_nFunction'), - changeKeywords = is.null(compileInfo$C_fun), + changeKeywords = is.null(compileInfo$C_fun), # by default, if C_fun is provided, don't change it, but user can override. updateArgPassing = TRUE), control ) diff --git a/nCompiler/R/NF_CompilerClass.R b/nCompiler/R/NF_CompilerClass.R index 77fb9fa3..fc8b3818 100644 --- a/nCompiler/R/NF_CompilerClass.R +++ b/nCompiler/R/NF_CompilerClass.R @@ -54,7 +54,7 @@ NF_CompilerClass <- R6::R6Class( if(is.null(compileInfo)) self$compileInfo <- NFinternals$compileInfo } - if(length(compileInfo$exportName) == 0) + if(length(compileInfo$exportName) == 0) # should this be self$compileInfo$exportName? self$compileInfo$exportName <- name # possibly swap const and compileInfo$isConst, keeping the latter only ##self$const <- const || isTRUE(compileInfo$isConst) @@ -347,9 +347,9 @@ processNFstages <- function(NFcompiler, NFtry({ compilerStage_labelAbstractTypes(NFcompiler, debug) - + # This will only collect nClasses from classGenerator$new() - # Other nClasses will end up in the symbolTable and be + # Other nClasses will end up in the symbolTable and be # collected later. ## NFcompiler$needed_nClasses <- ## c(NFcompiler$needed_nClasses, @@ -392,7 +392,7 @@ processNFstages <- function(NFcompiler, logAfterStage(stageName) } } - + ## insert new lines created by size processing stageName <- 'addInsertions' if (logging) logBeforeStage(stageName) diff --git a/nCompiler/R/NF_InternalsClass.R b/nCompiler/R/NF_InternalsClass.R index 3667b91e..5871f4d1 100644 --- a/nCompiler/R/NF_InternalsClass.R +++ b/nCompiler/R/NF_InternalsClass.R @@ -15,7 +15,8 @@ NF_InternalsClass <- R6::R6Class( isMethod = FALSE, uniqueName = character(), cpp_code_name = character(), - ## template = NULL, replaced with compileInfo$matchDef + ## template = NULL, replaced with default_matchDef + default_matchDef = NULL, code = NULL, RcppPacket = NULL, Rwrapper = NULL, @@ -96,13 +97,7 @@ NF_InternalsClass <- R6::R6Class( ## in the "decoration" system. They could be put in "value" argument. ## Either a named "value" or a ... is in all types. - ## not used until much later - if(is.null(self$compileInfo$opDef)) - self$compileInfo$opDef <- list() - if(is.null(self$compileInfo$opDef$matchDef)) { - self$compileInfo$opDef$matchDef <- Rarguments_2_function(arguments, body = quote({})) - } - # self$template <- Rarguments_2_function(arguments, body = quote({})) ## generateTemplate() + self$default_matchDef <- Rarguments_2_function(arguments, body = quote({})) ## generateTemplate() returnTypeInfo <- nf_extractReturnType(code) returnTypeDecl <- returnTypeInfo$returnType if(is.null(returnTypeDecl)) { diff --git a/nCompiler/R/NF_Utils.R b/nCompiler/R/NF_Utils.R index e3a5a886..e01b2e32 100644 --- a/nCompiler/R/NF_Utils.R +++ b/nCompiler/R/NF_Utils.R @@ -50,8 +50,9 @@ nGet <- function(name, where) { ## switch where to be the generator's parent_env where <- where$parent_env } - if(exists(name, envir = where, inherits = TRUE)) - get(name, envir = where, inherits = TRUE) - else - NULL + get0(name, envir = where) + # if(exists(name, envir = where, inherits = TRUE)) + # get(name, envir = where, inherits = TRUE) + # else + # NULL } diff --git a/nCompiler/R/compile_aaa_operatorLists.R b/nCompiler/R/compile_aaa_operatorLists.R index 610d2a8e..6c2db3ed 100644 --- a/nCompiler/R/compile_aaa_operatorLists.R +++ b/nCompiler/R/compile_aaa_operatorLists.R @@ -94,12 +94,19 @@ updateOperatorDef <- function(ops, field, subfield = NULL, val) { } } +getOperatorField <- function(opDef, field = NULL, subfield = NULL) { + if (is.null(opDef)) return(NULL) + if (is.null(field)) return(opDef) + if (is.null(opDef[[field]])) return(NULL) + if (is.null(subfield)) return(opDef[[field]]) + opDef[[field]][[subfield]] +} + getOperatorDef <- function(op, field = NULL, subfield = NULL) { - opInfo <- get0(op, envir=operatorDefUserEnv) -# opInfo <- operatorDefEnv[[op]] - if (is.null(opInfo) || is.null(field)) return(opInfo) - if (is.null(opInfo[[field]]) || is.null(subfield)) return(opInfo[[field]]) - return(opInfo[[field]][[subfield]]) + opDef <- get0(op, envir=operatorDefUserEnv) +# opDef <- operatorDefEnv[[op]] + if(is.null(opDef)) return(NULL) + getOperatorField(opDef, field, subfield) } assignOperatorDef( @@ -122,7 +129,7 @@ assignOperatorDef( ) assignOperatorDef( - 'NFCALL_', # This is used for non-method nFunctions in normalizeCalls and for any (including method) nFunctions after normalizeCalls + 'nFunction_default', # This is used for non-method nFunctions in normalizeCalls and for any (including method) nFunctions after normalizeCalls list( labelAbstractTypes = list( handler = 'nFunction_or_method_call'), @@ -134,17 +141,29 @@ assignOperatorDef( ) assignOperatorDef( - 'NCMETHOD_', # This is a transient label that only exists within normalizeCalls + 'custom_default', list( - ## labelAbstractTypes = list( - ## handler = 'nFunction_or_method_call'), - normalizeCalls = list( - handler = 'nFunction_or_method_call')#, # becomes NFCALL_ - ## cppOutput = list( - ## handler = 'Generic_nFunction') + labelAbstractTypes = list( + handler = 'custom_call' + # May use nFunction field. + ), + cppOutput = list( + handler = 'AsIs') ) ) +# assignOperatorDef( +# 'NCMETHOD_', # This is a transient label that only exists within normalizeCalls +# list( +# ## labelAbstractTypes = list( +# ## handler = 'nFunction_or_method_call'), +# normalizeCalls = list( +# handler = 'nFunction_or_method_call')#, # becomes NFCALL_ +# ## cppOutput = list( +# ## handler = 'Generic_nFunction') +# ) +# ) + assignOperatorDef( c('dim'), list( diff --git a/nCompiler/R/compile_eigenization.R b/nCompiler/R/compile_eigenization.R index bdc3b19f..adf981f7 100644 --- a/nCompiler/R/compile_eigenization.R +++ b/nCompiler/R/compile_eigenization.R @@ -73,21 +73,27 @@ compile_eigenize <- function(code, return(invisible(NULL)) } - handlingInfo <- getOperatorDef(code$name, "eigenImpl") + opInfo <- check_cachedOpInfo(code, where=auxEnv$where, update=TRUE) + handlingInfo <- getOperatorField(opInfo$opDef, "eigenImpl") +# handlingInfo <- getOperatorDef(code$name, "eigenImpl") # operatorDefEnv[[code$name]] # if(!is.null(opInfo)) { # handlingInfo <- opInfo[["eigenImpl"]] if(!is.null(handlingInfo)) { beforeHandler <- handlingInfo[['beforeHandler']] if(!is.null(beforeHandler)) { - setupExprs <- c(setupExprs, - eval(call(beforeHandler, - code, - symTab, - auxEnv, - workEnv, - handlingInfo), - envir = eigenizeEnv)) + if(is.function(beforeHandler)) + setupExprs <- c(setupExprs, + beforeHandler(code, symTab, auxEnv, workEnv, handlingInfo)) + else + setupExprs <- c(setupExprs, + eval(call(beforeHandler, + code, + symTab, + auxEnv, + workEnv, + handlingInfo), + envir = eigenizeEnv)) # return(if(length(setupExprs) == 0) NULL else setupExprs) } } @@ -110,14 +116,18 @@ compile_eigenize <- function(code, if(!is.null(handlingInfo)) { handler <- handlingInfo[['handler']] if(!is.null(handler)) { - setupExprs <- c(setupExprs, - eval(call(handler, - code, - symTab, - auxEnv, - workEnv, - handlingInfo), - envir = eigenizeEnv)) + if(is.function(handler)) + setupExprs <- c(setupExprs, + handler(code, symTab, auxEnv, workEnv, handlingInfo)) + else + setupExprs <- c(setupExprs, + eval(call(handler, + code, + symTab, + auxEnv, + workEnv, + handlingInfo), + envir = eigenizeEnv)) } } # } diff --git a/nCompiler/R/compile_generateCpp.R b/nCompiler/R/compile_generateCpp.R index 69b22655..ce0a932e 100644 --- a/nCompiler/R/compile_generateCpp.R +++ b/nCompiler/R/compile_generateCpp.R @@ -85,7 +85,18 @@ compile_generateCpp <- function(code, ans[[length(code$args) + 2]] <- paste0(indent, '}') return(ans) } - handler <- getOperatorDef(code$name, "cppOutput", "handler") + # All calls must have valid opInfo or be a core DSL operator + # This compiler stage is called from cppDefs' generate() methods, + # so where$auxEnv is not available. + # This means that any penultimate compiler stages that changed the + # call name to something non-core (and hence potentially the handler) + # must update the cachedOpInfo. + # or we must add a final pass to do so. + # An example is changes made in eigenization, such as inserting `index[`. + # This is a core operator so it will be found in the check_cachedOpInfo with update=TRUE. + opInfo <- check_cachedOpInfo(code, where=baseenv(), update=TRUE, allowFail = TRUE) + handler <- getOperatorField(opInfo$opDef, "cppOutput", "handler") + # handler <- getOperatorDef(code$name, "cppOutput", "handler") # opInfo <- operatorDefEnv[[code$name]] # if(!is.null(opInfo)) { # handlingInfo <- opInfo[["cppOutput"]] @@ -94,10 +105,13 @@ compile_generateCpp <- function(code, if(!is.null(handler)) { if (logging) appendToLog(paste('Calling handler', handler, 'for', code$name)) - res <- eval(call(handler, - code, - symTab), - envir = genCppEnv) + if(is.function(handler)) + res <- handler(code, symTab) + else + res <- eval(call(handler, + code, + symTab), + envir = genCppEnv) if (logging) { appendToLog(paste('Finished handling', handler, 'for', code$name, 'with result:')) @@ -165,10 +179,9 @@ inGenCppEnv( inGenCppEnv( Generic_nFunction <- function(code, symTab) { - innerCode <- code$args[['call']] - cpp_code_name <- code$aux$cpp_code_name + cpp_code_name <- code$aux$cachedOpInfo$obj_internals$cpp_code_name paste0(cpp_code_name, - '(', paste0(unlist(lapply(innerCode$args, + '(', paste0(unlist(lapply(code$args, compile_generateCpp, symTab, asArg = TRUE) ), diff --git a/nCompiler/R/compile_labelAbstractTypes.R b/nCompiler/R/compile_labelAbstractTypes.R index 647dd88a..d8cf8eea 100644 --- a/nCompiler/R/compile_labelAbstractTypes.R +++ b/nCompiler/R/compile_labelAbstractTypes.R @@ -100,7 +100,10 @@ compile_labelAbstractTypes <- function(code, return(invisible(NULL)) } - handlingInfo <- getOperatorDef(code$name, "labelAbstractTypes") + opInfo <- check_cachedOpInfo(code, where=auxEnv$where, update=TRUE) + handlingInfo <- getOperatorField(opInfo$opDef, "labelAbstractTypes") + +# handlingInfo <- getOperatorDef(code$name, "labelAbstractTypes") # opInfo <- operatorDefEnv[[code$name]] # if(!is.null(opInfo)) { # handlingInfo <- opInfo[["labelAbstractTypes"]] @@ -109,8 +112,11 @@ compile_labelAbstractTypes <- function(code, if(!is.null(handler)) { if (logging) appendToLog(paste('Calling handler', handler, 'for', code$name)) - ans <- eval(call(handler, code, symTab, auxEnv, handlingInfo), - envir = labelAbstractTypesEnv) + if(is.function(handler)) + ans <- handler(code, symTab, auxEnv, handlingInfo) + else + ans <- eval(call(handler, code, symTab, auxEnv, handlingInfo), + envir = labelAbstractTypesEnv) nErrorEnv$stateInfo <- character() if (logging) { appendToLog(paste('Finished handling', handler, 'for', code$name)) @@ -403,27 +409,28 @@ inLabelAbstractTypesEnv( } ) +nCompiler:::inLabelAbstractTypesEnv( + custom_call <- + function(code, symTab, auxEnv, handlingInfo) { + recurse <- isTRUE(handlingInfo[['recurse']]) + if(recurse) { + inserts <- recurse_labelAbstractTypes(code, symTab, auxEnv, + handlingInfo) + } + returnType <- handlingInfo[['returnType']] + code$type <- returnType$clone(deep = TRUE) + if(length(inserts) == 0) NULL else inserts + } +) + inLabelAbstractTypesEnv( nFunction_or_method_call <- function(code, symTab, auxEnv, handlingInfo) { - # We have code = NFCALL_(foo(x, y)) - # innerCall if foo(x,y) - # We'll set innerCall$type to symbolNF - # and we'll set code$type to the returnType of foo(x, y) - innerCall <- code$args[['call']] - if(is.null(innerCall)) - stop( - exprClassProcessingErrorMsg( - code, paste('In nFunction_or_method_call: the nFunction (or method) ', - code$name, - ' has NULL content.') - ), call. = FALSE - ) - inserts <- recurse_labelAbstractTypes(innerCall, symTab, auxEnv, + inserts <- recurse_labelAbstractTypes(code, symTab, auxEnv, handlingInfo) - obj_internals <- code$aux$obj_internals - nFunctionName <- code$aux$nFunctionName - innerCall$type <- symbolNF$new(name = nFunctionName) + obj_internals <- code$aux$cachedOpInfo$obj_internals +# nFunctionName <- obj_internals$nFunctionName +# code$aux$symbolNF <- symbolNF$new(name = nFunctionName) returnSym <- obj_internals$returnSym if(is.null(returnSym)) stop( @@ -435,24 +442,6 @@ inLabelAbstractTypesEnv( ) code$type <- returnSym$clone() ## Not sure if a clone is needed, but it seems safer to make one. inserts - - # useArgs <- c(FALSE, rep(TRUE, length(code$args)-1)) - # inserts <- recurse_labelAbstractTypes(code, symTab, auxEnv, - # handlingInfo, useArgs) - # obj_internals <- code$args[[1]]$aux$obj_internals - # nFunctionName <- code$args[[1]]$aux$nFunctionName - # code$args[[1]]$type <- symbolNF$new(name = nFunctionName) - # returnSym <- obj_internals$returnSym - # if(is.null(returnSym)) - # stop( - # exprClassProcessingErrorMsg( - # code, paste('In nFunction_or_method_call: the nFunction (or method) ', - # code$name, - # ' does not have a valid returnType.') - # ), call. = FALSE - # ) - # code$type <- returnSym$clone() ## Not sure if a clone is needed, but it seems safer to make one. - # invisible(NULL) } ) diff --git a/nCompiler/R/compile_normalizeCalls.R b/nCompiler/R/compile_normalizeCalls.R index 14576642..020b7f43 100644 --- a/nCompiler/R/compile_normalizeCalls.R +++ b/nCompiler/R/compile_normalizeCalls.R @@ -6,6 +6,14 @@ inNormalizeCallsEnv <- function(expr) { eval(expr, envir = normalizeCallsEnv) } +# This stage does two things: +# 1. It normalizes the arguments of calls according to a matchDef, if found via an opDef. +#. This includes separating the compileTime arguments into the aux list of the call. +# 2. It caches information about the nFunction or nClass method being called, +#. if relevant. Future stages should check if the call name still matches what was cached, +#. they should look up from scratch. This is because handlers can change a name and then +#. then opDef should be found in future steps for the new name. +#. The caching also allows future stages to be more efficient in determining an nFunction or nClass method case. compile_normalizeCalls <- function(code, symTab, auxEnv) { @@ -26,89 +34,51 @@ compile_normalizeCalls <- function(code, } return(NULL) } - opInfo <- getOperatorDef(code$name) - ## Check for nFunctions or nClass methods (also nFunctions) - if(is.null(opInfo)) { - obj <- NULL - if(isNCgenerator(auxEnv$where)) {## We are in a class method (by direct call within another class method, no `$` involved) - obj <- NC_find_method(auxEnv$where, code$name, inherits=TRUE) - if(!is.null(obj)) { - code$aux$obj_internals <- NFinternals(obj) - if(isNF(obj)) { - opInfo <- operatorDefEnv[['NCMETHOD_']] - } else { - stop(exprClassProcessingErrorMsg(code, - paste0('method ', code$name, 'is being called, but it is not a nFunction.')), - call. = FALSE) - } - } - } - ## Next we check if code$name exists in the where. - ## Note that if we are in a method, auxEnv$where will be the - ## generator, which is an environment. However, we need - ## to use nGet instead of exists and get. - if(is.null(obj)) { - obj <- nGet(code$name, where = auxEnv$where) - ## An nFunction will be transformed to - ## have code$name 'NFCALL_' during simpleTransformations, - ## but that hasn't happened yet, so we manually use it here. - ## To-do: This could be cleaned up by either making that change here - ## when first detected or making a separate compiler stage just for that. - if(!is.null(obj)) { - if(isNF(obj)) { - code$aux$obj_internals <- NFinternals(obj) - opInfo <- operatorDefEnv[['NFCALL_']] - uniqueName <- NFinternals(obj)$uniqueName - if(length(uniqueName)==0) - stop( - exprClassProcessingErrorMsg( - code, - paste0('nFunction ', code$name, 'is being called, ', - 'but it is malformed because it has no internal name.')), - call. = FALSE) - if(is.null(auxEnv$needed_nFunctions[[uniqueName]])) { - ## We could put the nFunction itself in the needed_nFunctions list, - ## but we do not as a way to avoid having many references to R6 objects - ## in a blind attempt to facilitate garbage collection based on past experience. - ## Instead, we provide what is needed to look up the nFunction again later. - auxEnv$needed_nFunctions[[uniqueName]] <- list(code$name, auxEnv$where) - } - } else - stop(exprClassProcessingErrorMsg( + # If we are in an nClass: + # look for opDef in the nFunction method; + # next look for an opDef in the nClass, which might not even correspond to a method. + # If we are in an nFunction (not a method): + # look for an opDef in the nFunction. + # If we are in an nFunction or nClass method and no opDef has been found, + #. assign nFunction_default to be the opDef. + # Next look for a global opDef, which will look first for user-defined opDefs. + # If no opDef is defined anywhere, it is an error. + # If an opDef is found, check for a handler. + # If a handler is found, call it. + # If no handler is found, look for a matchDef and normalize the arguments by default. + # + # What gets cached in the aux of the exprClass for the call: + # cachedOpInfo = list(opDef, name, obj_internals, case) + # We defer: uniqueName, cpp_code_name + cachedOpInfo <- update_cachedOpInfo(code, auxEnv$where) + if(cachedOpInfo$case == "nFunction") { + uniqueName <- cachedOpInfo$obj_internals$uniqueName + if(length(uniqueName)==0) + stop( + exprClassProcessingErrorMsg( code, - paste0(code$name, ' is being used as a function, but it is not a nFunction.')), - call. = FALSE) - } + paste0('nFunction ', code$name, 'is being called, ', + 'but it is malformed because it has no internal name.')), + call. = FALSE) + if(is.null(auxEnv$needed_nFunctions[[uniqueName]])) { + ## We could put the nFunction itself in the needed_nFunctions list, + ## but we do not as a way to avoid having many references to R6 objects + ## in a blind attempt to facilitate garbage collection based on past experience. + ## Instead, we provide what is needed to look up the nFunction again later. + auxEnv$needed_nFunctions[[uniqueName]] <- list(code$name, auxEnv$where) } } - # There is also the nFunctionRef to think about. That is a bit more of a type however, not a call. - if(!is.null(opInfo)) { - matchDef <- opInfo[["matchDef"]] - if(!is.null(matchDef)) { - matched_code <- exprClass_put_args_in_order(matchDef, code, opInfo$compileArgs) - code <- replaceArgInCaller(code, matched_code) - } - handlingInfo <- opInfo[["normalizeCalls"]] - if(!is.null(handlingInfo)) { - handler <- handlingInfo[['handler']] - if(!is.null(handler)) { - if (logging) - appendToLog(paste('Calling handler', handler, 'for', code$name)) - ans <- eval(call(handler, code, symTab, auxEnv, handlingInfo), - envir = normalizeCallsEnv) - nErrorEnv$stateInfo <- character() - if (logging) { - appendToLog(paste('Finished handling', handler, 'for', code$name)) - logAST(code, paste('Resulting AST for', code$name), showImpl = FALSE) - } - return(ans) - } - } + + opDef <- cachedOpInfo$opDef + matchDef <- opDef[["matchDef"]] + if(is.null(matchDef)) + matchDef <- cachedOpInfo$obj_internals$default_matchDef + if(!is.null(matchDef)) { + exprClass_put_args_in_order(matchDef, code, opDef$compileArgs) + # code <- replaceArgInCaller(code, matched_code) } - # default behavior if there is no handler (which will be for many or most calls) normalizeCallsEnv$recurse_normalizeCalls(code, symTab, auxEnv, handlingInfo) } - # Where to put a generic recursion call? nErrorEnv$stateInfo <- character() invisible(NULL) } @@ -131,48 +101,145 @@ inNormalizeCallsEnv( } ) -## inNormalizeCallsEnv( -## skip <- -## function(code, symTab, auxEnv, handlingInfo, -## useArgs = rep(TRUE, length(code$args))) { -## NULL -## } -## ) +check_cachedOpInfo <- function(code, where=NULL, update=TRUE, allowFail=FALSE) { + cachedOpInfo <- code$aux$cachedOpInfo + up_to_date <- !is.null(cachedOpInfo) && cachedOpInfo$name == code$name + if(up_to_date) return(cachedOpInfo) + if(update) { + if(!is.environment(where)) + stop( + exprClassProcessingErrorMsg( + code, + "Internal error: check_cachedOpInfo called with update=TRUE but non-environment 'where' argument.") + , call. = FALSE) + update_cachedOpInfo(code, where, allowFail=allowFail) + } else { + NULL + } +} -inNormalizeCallsEnv( - convert_nFunction_or_method_AST <- - function(code) { - nFunctionName <- code$name - obj_internals <- code$aux$obj_internals - code$aux$obj_internals <- NULL - opDef <- obj_internals$compileInfo$opDef - matched_code <- exprClass_put_args_in_order(def=opDef$matchDef, expr=code, compileArgs = opDef$compileArgs) - code <- replaceArgInCaller(code, matched_code) - ## Note that the string `NFCALL_` matches the operatorDef entry. - ## Therefore the change-of-name here will automatically trigger use of - ## the 'NFCALL_' operatorDef in later stages. - newExpr <- wrapInExprClass(code, 'NFCALL_', "call") - # code$name <- 'NFCALL_' - cpp_code_name <- obj_internals$cpp_code_name - # fxnNameExpr <- exprClass$new(name = cpp_code_name, isName = TRUE, - # isCall = FALSE, isLiteral = FALSE, isAssign = FALSE) - newExpr$aux$obj_internals <- obj_internals - # newExpr$aux$nFunctionName <- nFunctionName - newExpr$aux$cpp_code_name <- cpp_code_name - ## We may need to add content to this symbol if - ## necessary for later processing steps. - ## insertArg(code, 1, fxnNameExpr, "FUN_") - obj_internals <- NULL - invisible(NULL) +update_cachedOpInfo <- function(code, where, allowFail=FALSE) { + opDef <- NULL + cachedOpInfo <- list(name = code$name) + ## Note that if we are in a method, auxEnv$where will be the + ## generator, which is an environment. + is_NCgenerator <- isNCgenerator(where) + if(is.null(opDef)) { # for good form, in case there a prior case is added later + # First look for an nClass method. + obj <- NULL + if(is.null(obj)) { + if(is_NCgenerator) {## We are in a class method (by direct call within another class method, no `$` involved) + # NC_find_method checks the class and its parents + obj <- NC_find_method(where, code$name, inherits=TRUE) + if(!is.null(obj)) { + if(isNF(obj)) { + cachedOpInfo$case <- "nClass method" # possibly disambiguate method from keyword + } else { + stop(exprClassProcessingErrorMsg(code, + paste0('method ', code$name, 'is being called, but it is not a nFunction.')), + call. = FALSE) + } + } + } } -) - -inNormalizeCallsEnv( - nFunction_or_method_call <- - function(code, symTab, auxEnv, handlingInfo) { - recurse_normalizeCalls(code, symTab, auxEnv, - handlingInfo) - convert_nFunction_or_method_AST(code) - NULL + if(is.null(obj)) { + # Next look for an nFunction (that is not a method) + # + # N.B. nGet follows lexical scoping from either an nFunction or nClass + # whatever "where" is. (In the nClass case, + # this is different than finding a method as above. + # Instead it is following scoping to find nFunctions.) + obj <- nGet(code$name, where = where) + if(!is.null(obj)) { + if(isNF(obj)) { + # There is no error trapping if obj is not an nFunction, because + # it could be simply an R function, since nGet (via get0) may traverse up to R_GlobalEnv. + cachedOpInfo$case <- "nFunction" + } else { + obj <- NULL # reset to NULL if not an nFunction + } + } } -) + if(!is.null(obj)) { + # We found an nFunction object that is either a method or not. + cachedOpInfo$obj_internals <- NFinternals(obj) + opDef <- cachedOpInfo$obj_internals$compileInfo$opDef # might be NULL + } + } + if(is.null(opDef)) { + ## At this point, we have not found an nFunction or nClass method. + if(is_NCgenerator) { + opDef <- NCinternals(where)$compileInfo$opDefs[[code$name]] + if(!is.null(opDef)) { + cachedOpInfo$case <- "nClass method" # this could be a pure keyword or an nFunction with opDef provided at the nClass level + # a pure keyword will have obj_internals == NULL, providing a way to + # tell these cases apart later if necessary. + } + } + } + if(is.null(opDef)) { + if(!is.null(cachedOpInfo$case)) { + # We found an nFunction (method or not) but it did not have an opDef + # so here we insert the default opDef + if(cachedOpInfo$case == "nFunction" || cachedOpInfo$case == "nClass method") { + opDef <- getOperatorDef("nFunction_default") + } + } + } + if(is.null(opDef)) { + # We haven't found any form of nFunction or custom nClass keyword, + # so look for globally defined operators (e.g. "+" and all the basic DSL keywords) + opDef <- getOperatorDef(code$name) + if(!is.null(opDef)) { + cachedOpInfo$case <- "global" + } + } + if(is.null(opDef)) { + if(!allowFail) + stop(exprClassProcessingErrorMsg( + code, + paste0('No operator definition found for ', code$name, '.')), + call. = FALSE) + } + cachedOpInfo$opDef <- opDef + code$aux$cachedOpInfo <- cachedOpInfo + cachedOpInfo +} + +# inNormalizeCallsEnv( +# convert_nFunction_or_method_AST <- +# function(code) { +# nFunctionName <- code$name +# obj_internals <- code$aux$obj_internals +# code$aux$obj_internals <- NULL +# opDef <- obj_internals$compileInfo$opDef +# matched_code <- exprClass_put_args_in_order(def=opDef$matchDef, expr=code, compileArgs = opDef$compileArgs) +# code <- replaceArgInCaller(code, matched_code) +# ## Note that the string `NFCALL_` matches the operatorDef entry. +# ## Therefore the change-of-name here will automatically trigger use of +# ## the 'NFCALL_' operatorDef in later stages. +# newExpr <- wrapInExprClass(code, 'NFCALL_', "call") +# # code$name <- 'NFCALL_' +# cpp_code_name <- obj_internals$cpp_code_name +# # fxnNameExpr <- exprClass$new(name = cpp_code_name, isName = TRUE, +# # isCall = FALSE, isLiteral = FALSE, isAssign = FALSE) +# newExpr$aux$obj_internals <- obj_internals +# # newExpr$aux$nFunctionName <- nFunctionName +# newExpr$aux$cpp_code_name <- cpp_code_name +# ## We may need to add content to this symbol if +# ## necessary for later processing steps. +# ## insertArg(code, 1, fxnNameExpr, "FUN_") +# obj_internals <- NULL +# invisible(NULL) +# } +# ) + +# inNormalizeCallsEnv( +# nFunction_or_method_call <- +# function(code, symTab, auxEnv, handlingInfo) { +# recurse_normalizeCalls(code, symTab, auxEnv, +# handlingInfo) +# convert_nFunction_or_method_AST(code) +# NULL +# } +# ) diff --git a/nCompiler/R/compile_processAD.R b/nCompiler/R/compile_processAD.R index dbc78753..c85c43cb 100644 --- a/nCompiler/R/compile_processAD.R +++ b/nCompiler/R/compile_processAD.R @@ -52,7 +52,9 @@ compile_processAD <- function(code, return(invisible(NULL)) } - handlingInfo <- getOperatorDef(code$name, "processAD") + opInfo <- check_cachedOpInfo(code, where=auxEnv$where, update=TRUE) + handlingInfo <- getOperatorField(opInfo$opDef, "processAD") +# handlingInfo <- getOperatorDef(code$name, "processAD") # opInfo <- operatorDefEnv[[code$name]] # if(!is.null(opInfo)) { # handlingInfo <- opInfo[["processAD"]] diff --git a/nCompiler/R/compile_simpleTransformations.R b/nCompiler/R/compile_simpleTransformations.R index 67178014..2598bfa6 100644 --- a/nCompiler/R/compile_simpleTransformations.R +++ b/nCompiler/R/compile_simpleTransformations.R @@ -24,7 +24,9 @@ compile_simpleTransformations <- function(code, } } - handlingInfo <- getOperatorDef(code$name, opInfoName) + opInfo <- check_cachedOpInfo(code, where=auxEnv$where, update=TRUE, allowFail = TRUE) + handlingInfo <- getOperatorField(opInfo$opDef, opInfoName) + # handlingInfo <- getOperatorDef(code$name, opInfoName) # opInfo <- getOperatorDef(code$name) #operatorDefEnv[[code$name]] # if(!is.null(opInfo)) { # handlingInfo <- opInfo[[opInfoName]] diff --git a/nCompiler/R/cppDefs_core.R b/nCompiler/R/cppDefs_core.R index dd703c31..83917b6f 100644 --- a/nCompiler/R/cppDefs_core.R +++ b/nCompiler/R/cppDefs_core.R @@ -14,7 +14,7 @@ # duplication. internalCppDefs are assumed to be unique.) # cppDefinitionClass has getter methods but no generate method. cppDefinitionClass <- R6::R6Class( - classname = 'cppDefinitionClass', + classname = 'cppDefinitionClass', portable = FALSE, public = list( filename = character(), @@ -132,10 +132,10 @@ cppGlobalObjectClass <- R6::R6Class( ## For globals for class static members, we want no declaration ## Otherwise we want a declaration and put extern in front if(staticMembers & declaration) return(character()) - output <- paste0(generateAll(symbolTable$getSymbols(), + output <- paste0(generateAll(symbolTable$getSymbols(), declaration = declaration), ';') - if(declaration) + if(declaration) output <- paste("extern ", output) output } @@ -186,7 +186,7 @@ cppNamespaceClass <- R6::R6Class( ## C++ class object. ## A class is like a namespace with inheritance -## At the moment everything is public. +## At the moment everything is public. ## This class can build cppFunction objects for a generator function and a finalizer function ## The generator can be called via .Call to return an external pointer to a new object of the class ## The finalizer is the finalizer assigned to the object when the external pointer is made @@ -206,7 +206,7 @@ buildSEXPgenerator_impl <- function(self) { substitute(cppLiteral(RETURNLINE), list(RETURNLINE = as.character(returnLine))) ) - + allCode <- putCodeLinesInBrackets(allCodeList) allCode <- nParse(allCode) self$internalCppDefs[["SEXPgenerator"]] <- @@ -236,7 +236,7 @@ build_set_nClass_env_impl <- function(self) { substitute(cppLiteral(SETTERLINE), list(SETTERLINE = as.character(setterLine))) ) - + allCode <- putCodeLinesInBrackets(allCodeList) allCode <- nParse(allCode) args <- symbolTableClass$new() @@ -314,8 +314,13 @@ addGenericInterface_impl <- function(self) { self$addInheritance(paste0("public genericInterfaceC<", name, ">")) -# self$Hincludes <- c(self$Hincludes, -# nCompilerIncludeFile("nCompiler_class_interface.h")) + # It is ok to have multiple virtual inheritance from genericInterfaceBaseC, + # but we clean it up here for slightly simpler code. + if("virtual public genericInterfaceBaseC" %in% self$inherit) { + self$inherit <- self$inherit[-which(self$inherit == "virtual public genericInterfaceBaseC")] + } + # self$Hincludes <- c(self$Hincludes, + # nCompilerIncludeFile("nCompiler_class_interface.h")) self$Hpreamble <- c(self$Hpreamble, "#define NCOMPILER_USES_NCLASS_INTERFACE", "#define USES_NCOMPILER") @@ -383,8 +388,8 @@ addGenericInterface_impl <- function(self) { fieldClassNames <- c(fieldClassNames, rep(NCint$cpp_classname, length(new_cpp_fieldNames))) # - current_NCgen <- current_NCgen$parent_env$.inherit_obj # same as current_NCgen$get_inherit() if there is inheritance, but get_inherit returns the base class at the top - done <- is.null(current_NCgen) + current_NCgen <- current_NCgen$get_inherit() #$parent_env$.inherit_obj # same as current_NCgen$get_inherit() if there is inheritance, but get_inherit returns the base class at the top + done <- !isNCgenerator(current_NCgen) } if(iOut > 1) { methodsContent <- paste0("method(\"", @@ -577,7 +582,7 @@ cppClassClass <- R6::R6Class( addInheritance = function(newI) { inheritance <<- c(inheritance, newI) }, - ##addAncestors = function(newI) ancestors <<- c(ancestors, newI), + ##addAncestors = function(newI) ancestors <<- c(ancestors, newI), ##setPrivate = function(name) private[[name]] <<- TRUE, generate = function(declaration = FALSE, ...) { if(declaration) { @@ -605,7 +610,7 @@ cppClassClass <- R6::R6Class( if(length(memberCppDefs) > 0) { output <- generateAll(memberCppDefs, scopes = name) } else { - output <- "" + output <- "" } } unlist(output) @@ -624,6 +629,21 @@ cppClassClass <- R6::R6Class( if(interface) { addGenericInterface_impl(self) add_obj_hooks_impl(self) + } else { + # Ensure inheritance from genericInterfaceBaseC so our custom Exporter in C++ + # can always dynamic_pointer_cast to shared_ptr. + if(!("virtual public genericInterfaceBaseC" %in% self$inherit)) { + self$addInheritance("virtual public genericInterfaceBaseC") + } + # These will always end up included and possibly multiple times, + # so it's a bit sloppy but not worth cleaning up for now. + self$Hpreamble <- c(self$Hpreamble, + "#define NCOMPILER_USES_NCLASS_INTERFACE", + "#define USES_NCOMPILER") + self$CPPpreamble <- c(self$CPPpreamble, + "#define NCOMPILER_USES_NCLASS_INTERFACE", + "#define USES_NCOMPILER") + } # The only case that would omit interface calls is generated predefined code. if(interfaceCalls) @@ -709,7 +729,7 @@ cppCodeBlockClass <- R6::R6Class( } else useSymTab <- self$symbolTable if(isTRUE(self$cppADCode)) - recurseSetCppADExprs(self$code, TRUE) + recurseSetCppADExprs(self$code, TRUE) outputCppCode <- c(outputCppCode, compile_generateCpp(self$code, useSymTab, @@ -717,7 +737,7 @@ cppCodeBlockClass <- R6::R6Class( showBracket = FALSE) ) if(isTRUE(self$cppADCode)) - recurseSetCppADExprs(self$code, FALSE) + recurseSetCppADExprs(self$code, FALSE) } else { stop('code in generate() is not of the right type.', call. = FALSE) @@ -943,8 +963,13 @@ generateFunctionHeader <- function(self, abstract_text <- character() externC_text <- character() if(declaration) { - virtual_text <- compileInfo$virtual - if(is.null(virtual_text)) virtual_text <- if(isTRUE(self$virtual)) 'virtual ' else character() + virtual_text <- character() + if(is.character(compileInfo$virtual)) + virtual_text <- compileInfo$virtual + else if(isTRUE(self$virtual)) + virtual_text <- 'virtual ' + # virtual_text <- compileInfo$virtual + # if(is.null(virtual_text)) virtual_text <- if(isTRUE(self$virtual)) 'virtual ' else character() isAbstract <- compileInfo$abstract if(is.null(isAbstract)) isAbstract <- self$abstract @@ -997,6 +1022,7 @@ generateFunctionHeader <- function(self, externC_text, template_text, static_text, + virtual_text, returnType_text, scopes_name_text, args_text, diff --git a/nCompiler/R/cppDefs_nClass.R b/nCompiler/R/cppDefs_nClass.R index 71851561..1e967cee 100644 --- a/nCompiler/R/cppDefs_nClass.R +++ b/nCompiler/R/cppDefs_nClass.R @@ -65,7 +65,7 @@ cpp_nClassBaseClass <- R6::R6Class( ## SEXPmemberInterfaceFuns = 'ANY', ## List of SEXP interface functions, one for each member function Compiler = NULL, ##nimCompProc = 'ANY', ## nfProcessing or nlProcessing object, needed to get the member data symbol table post-compilation - + ##Rgenerator = 'ANY' , ## function to generate and wrap a new object from an R object ##CmultiInterface = 'ANY', ## object for interfacing multiple C instances when a top-level interface is not needed built = NULL, @@ -140,7 +140,7 @@ cpp_nClassClass <- R6::R6Class( public = list( # ctor_names = character(), # dtor_names = character(), - ##NC_Compiler = NULL, + ##NC_Compiler = NULL, ##parentsSizeAndDims = 'ANY', getInternalDefs = function() { super$getInternalDefs() @@ -274,7 +274,7 @@ cpp_nClassClass <- R6::R6Class( #self$compileInfo$interface controls whether to inherit from base classes for interfacing #It would be wierd to do the former without the latter, # so unless/until we get a case where that behavior is needed - # we will prevent it. + # we will prevent it. # The option interface=FALSE will be called in the case of a predefined, # when building the predefined, when first we do buildAll(interfaceCalls=FALSE). # This might build the interface but will not build the calls. diff --git a/nCompiler/R/nCompile.R b/nCompiler/R/nCompile.R index ba153075..df8ae729 100644 --- a/nCompiler/R/nCompile.R +++ b/nCompiler/R/nCompile.R @@ -306,6 +306,18 @@ nCompile <- function(..., interfaces <- interfaces[names(units)] unitTypes <- get_nCompile_types(units) + + # We defer processing of nClass inheritance until compile time to allow nClass + # to be called with inherit = some_nClass before some_nClass is defined. + for(i in seq_along(units)) { + if(unitTypes[i] == "nCgen") + NCinternals(units[[i]])$connect_inherit() + } + for(i in seq_along(units)) { + if(unitTypes[i] == "nCgen") + NCinternals(units[[i]])$process_inherit() + } + # set up exportNames and returnNames # exportNames will be from names(units) if named in the call or there is no exportName in the NF or NC compileInfo # Otherwise (i.e. no name provided in call and there is an exportName in the object def), use the exportName in the object def (compileInfo) diff --git a/nCompiler/R/nimbleModels.R b/nCompiler/R/nimbleModels.R new file mode 100644 index 00000000..8c7ee38a --- /dev/null +++ b/nCompiler/R/nimbleModels.R @@ -0,0 +1,132 @@ +# Here we are drafting support for new implementation of nimbleModels. +# This should eventually live in a separate package, but for now it is easier to draft here. +# +# see test-nimbleModel too. + +modelBase_nClass <- nClass( + classname = "modelBase_nClass", + Cpublic = list( + hw = nFunction( + name = "hw", + function() {cppLiteral('Rprintf("modelBase_nClass hw (should not see this)\\n");')}, + compileInfo = list(virtual=TRUE) + ), + bye = nFunction( + name = "bye", + function() {cppLiteral('Rprintf("modelBase_nClass hw (should not see this)\\n");')}, + compileInfo = list(virtual=TRUE) + ) + ), + compileInfo=list(interface="none", + createFromR = FALSE) +) + +makeModel_nClass <- function(varInfo) { + # varInfo will be a list (names not used) of name, nDim, sizes. + CpublicModelVars <- varInfo |> lapply(\(x) paste0("numericArray(nDim=",x$nDim,")")) + names(CpublicModelVars) <- varInfo |> lapply(\(x) x$name) |> unlist() + opDefs <- list( + base_hw = getOperatorDef("custom_call"), + setup_node_mgmt = getOperatorDef("custom_call") + ) + opDefs$base_hw$returnType <- nCompiler:::argType2symbol(quote(void())) + opDefs$base_hw$labelAbstractTypes$recurse <- FALSE + opDefs$setup_node_mgmt$returnType <- nCompiler:::argType2symbol(quote(void())) + opDefs$setup_node_mgmt$labelAbstractTypes$recurse <- FALSE + + CpublicMethods <- list( + hw = nFunction( + name = "hw", + function() {cppLiteral('Rprintf("hw\\n");')} + ), + # base_hw = nFunction( + # name = "base_hw", + # function() {cppLiteral('modelBaseClass::base_hw();')} + # ), + call_base_hw = nFunction( + name = "call_base_hw", + function() {base_hw()} + ), + call_setup_node_mgmt = nFunction( + name = "setup_node_mgmt", + function() {setup_node_mgmt()} + ), + set_from_list = nFunction( + name = "set_from_list", + function(Rlist = 'RcppList') {cppLiteral('modelClass_::set_from_list(Rlist);')} + ), + resize_from_list = nFunction( + name = "resize_from_list", + function(Rlist = 'RcppList') {cppLiteral('modelClass_::resize_from_list(Rlist);')} + ) + ) + CpublicNodeFuns <- list( + beta_node = 'node_dnorm()' + ) + CpublicCtor <- list( + mymodel = nFunction( + function(){}, + compileInfo = list(constructor=TRUE, + initializers = c('nCpp("beta_node(new node_dnorm(mu, beta, 1))")')) + ) + ) + ans <- substitute( + nClass( + classname = "mymodel", + # inherit = modelBase_nClass, + compileInfo = list(opDefs = OPDEFS, + inherit = list(base = "public modelClass_"), + Hincludes = ""), + Cpublic = CPUBLIC + ), + list(OPDEFS = opDefs, + CPUBLIC = c(CpublicNodeFuns, CpublicModelVars, CpublicCtor, CpublicMethods)) + ) + eval(ans, envir = parent.frame()) +} + +make_node_fun <- function(varInfo) { + # varInfo will be a list (names not used) of name, nDim, sizes. + foo <- \(x) nCompiler:::symbolCppVar$new(baseType = nCompiler:::symbolBasic$new(type="double", nDim=x$nDim, name="")$genCppVar()$generate(), + ref=TRUE, const=TRUE) + typeList <- varInfo |> lapply(foo) + names(typeList) <- varInfo |> lapply(\(x) x$name) |> unlist() + +# baseTypeStrings <- varInfo |> lapply(\(x) paste0("numericArray(nDim=",x$nDim,")")) |> unlist() +# typeStrings <- paste0('CppVar(baseType=argType2Cpp("',baseTypeStrings,'"),ref=TRUE,const=TRUE)') +# typeList <- typeStrings |> lapply(nMakeType) +# names(typeList) <- names(varInfo) + #names(Cpublic) <- varInfo |> lapply(\(x) x$name) |> unlist() + CpublicVars <- names(typeList) |> lapply(\(x) eval(substitute(quote(T(typeList$NAME)), + list(NAME=as.name(x))))) + names(CpublicVars) <- names(typeList) + + ctorArgNames <- paste0(names(typeList), '_') + initializersList <- paste0(names(typeList), '(', ctorArgNames ,')') + initFun <- function(){} + formals(initFun) <- structure(as.pairlist(CpublicVars), names = ctorArgNames) + +# This was a prototype +# mu_type <- list(a = nMakeType('CppVar(baseType = argType2Cpp("numericVector"), ref=TRUE, const=TRUE)')) + node_dnorm <- substitute( + nClass( + classname = "node_dnorm", + Cpublic = CPUBLIC, + compileInfo = list(createFromR = FALSE, + inherit = list(base = "public nodeFunctionClass_"), + Hincludes = "") + ), + list(CPUBLIC = c( + list( +# mu2 = quote(T(mu_type$a)), +# mean = quote(ref('numericScalar')), + node_dnorm = nFunction( + initFun, #function(mu_ = T(mu_type$a)) {}, + compileInfo = list(constructor=TRUE, initializers = initializersList) #list('mu2(mu_)')) + ) + ), + CpublicVars + ))) + eval(node_dnorm) +} +#test <- nCompiler:::argType2symbol('CppVar(baseType = argType2Cpp("numericVector"), ref=TRUE, const=TRUE)') diff --git a/nCompiler/inst/include/nCompiler/nClass_interface/generic_class_interface.h b/nCompiler/inst/include/nCompiler/nClass_interface/generic_class_interface.h index 220f0d37..f8fcde5e 100644 --- a/nCompiler/inst/include/nCompiler/nClass_interface/generic_class_interface.h +++ b/nCompiler/inst/include/nCompiler/nClass_interface/generic_class_interface.h @@ -44,9 +44,22 @@ methods\ ; +class accessor_base; + // Base class for interfaces to nimble classes class genericInterfaceBaseC { public: + typedef std::map name2index_type; + typedef std::map > name2access_type; + typedef std::pair > name_access_pair; + + virtual const name2access_type& get_name2access() const{ + std::cout<<"Error: you should be in a derived genericInterfaceC class for get_name2access"< ETaccess(genericInterfaceBaseC *) { std::cout<<"Error: you should be in access for a derived accessor class"< getInterfacePtr(genericInterfaceBaseC *intBasePtr) { + return nullptr; + } virtual ~accessor_base(){} }; diff --git a/nCompiler/inst/include/nCompiler/nClass_interface/nClass_factory.h b/nCompiler/inst/include/nCompiler/nClass_interface/nClass_factory.h index 5d449e99..fc134204 100644 --- a/nCompiler/inst/include/nCompiler/nClass_interface/nClass_factory.h +++ b/nCompiler/inst/include/nCompiler/nClass_interface/nClass_factory.h @@ -14,6 +14,10 @@ class shared_ptr_holder_base { std::cout<<"Error: you should be in a derived shared_ptr_holder class get_ptr(). This is the base class."< get_interfaceBase_shared_ptr() const { + std::cout<<"Error: you should be in a derived shared_ptr_holder class get_interfaceBase_shared_ptr(). This is the base class."<()); + }; virtual shared_ptr_holder_base* make_shared_ptr_holder()=0; virtual ~shared_ptr_holder_base() { // std::cout<<"destructing shared_ptr_holder_base"<(dynamic_cast(sp_.get())); } + std::shared_ptr get_interfaceBase_shared_ptr() const { + return std::static_pointer_cast(sp_); + } shared_ptr_holder_base* make_shared_ptr_holder() { std::cout<<"making new shared_ptr_holder_base"<(x); // } +// maybe put these inside the class or namespace. +template +struct is_shared_ptr : std::false_type {}; +template +struct is_shared_ptr> : std::true_type {}; // // end ETaccess +template +struct shared_ptr_element_type {using type = void;}; + +template +struct shared_ptr_element_type> {using type = U;}; + // Interface to class T. template -class genericInterfaceC : public genericInterfaceBaseC { +class genericInterfaceC : virtual public genericInterfaceBaseC { public: ~genericInterfaceC() { #ifdef SHOW_DESTRUCTORS @@ -299,10 +310,13 @@ class genericInterfaceC : public genericInterfaceBaseC { template class accessor_class : public accessor_base { public: - typedef P T2::*ptrtype; + typedef P T2::*ptrtype; // T2 will only be T or a base class of T. ptrtype ptr; - accessor_class(ptrtype ptr) : ptr(ptr) {}; + static constexpr bool P_is_shared_ptr = is_shared_ptr

::value; + using shared_ptr_element = typename shared_ptr_element_type

::type; + static constexpr bool shared_ptr_element_is_polymorphic = std::is_polymorphic_v; + accessor_class(ptrtype ptr) : ptr(ptr) {}; SEXP get(const genericInterfaceBaseC *intBasePtr) const { #ifdef SHOW_FIELDS std::cout<<"in derived get"< ans( new ETaccessor

( dynamic_cast(intBasePtr)->*ptr ) ); return ans; } + std::shared_ptr getInterfacePtr(genericInterfaceBaseC *intBasePtr) { + if constexpr(P_is_shared_ptr) { + if constexpr (shared_ptr_element_is_polymorphic) { + return std::dynamic_pointer_cast(dynamic_cast(intBasePtr)->*ptr); + } else { + return std::static_pointer_cast(dynamic_cast(intBasePtr)->*ptr); + } + } + return nullptr; + } }; // static maps from character names static int name_count; - typedef std::map name2index_type; +// typedef std::map name2index_type; static name2index_type name2index; - typedef std::map > name2access_type; - typedef std::pair > name_access_pair; + // typedef std::map > name2access_type; + // typedef std::pair > name_access_pair; static name2access_type name2access; + const name2access_type& get_name2access() const{ + return name2access; + } + // Enter a new (name, member ptr) pair to static maps. template static name_access_pair field(std::string name, P T2::*ptr) { diff --git a/nCompiler/inst/include/nCompiler/nClass_interface/post_Rcpp/nCompiler_model_base_devel.h b/nCompiler/inst/include/nCompiler/nClass_interface/post_Rcpp/nCompiler_model_base_devel.h index bc3365b4..6c8f8938 100644 --- a/nCompiler/inst/include/nCompiler/nClass_interface/post_Rcpp/nCompiler_model_base_devel.h +++ b/nCompiler/inst/include/nCompiler/nClass_interface/post_Rcpp/nCompiler_model_base_devel.h @@ -5,13 +5,50 @@ // but the quickest way to get into the guts of development is to // put it here. +class modelBaseClass_ { + public: + double v; + virtual ~modelBaseClass_() {}; +}; + +class nodeFunctionClassBase_ { + public: + double v; + virtual ~nodeFunctionClassBase_() {}; +}; + template -class modelBaseClass { +class modelClass_ : public modelBaseClass_ { public: double v; void base_hw() { Rprintf("base hw\n"); } + std::vector< std::shared_ptr > nodeFunctionPtrs; + // NEXT STEPS: record the shared_ptrs and indices for future use. + // build up calculate at level of node and then model. + void setup_node_mgmt() { + Derived *self = static_cast(this); + const auto& name2access = self->get_name2access(); + size_t n = name2access.size(); + Rprintf("There are %d member variables indexed:\n", (int)n); + auto i_n2a = name2access.begin(); + auto end_n2a = name2access.end(); + for(; i_n2a != end_n2a; ++i_n2a) { + // This compiles and runs but does not successfully identify any genericInterfaceBaseC members. + std::shared_ptr ptr = i_n2a->second->getInterfacePtr(dynamic_cast(self)); + bool got_one = (ptr != nullptr); + if(got_one) + Rprintf("HOORAY: field %s is genericInterfaceBaseC\n", i_n2a->first.c_str()); + else + Rprintf("field %s is NOT a genericInterfaceBaseC\n", i_n2a->first.c_str()); + } + } + /* + mv name2access typedefs to the base class. + create virtual accessor function for name2access. + check on what "Derived" is here. + */ void set_from_list(Rcpp::List Rlist) { Rcpp::CharacterVector Rnames = Rlist.names(); size_t len = Rnames.length(); @@ -59,10 +96,10 @@ class modelBaseClass { }; template -class nodeFunctionBase { +class nodeFunctionClass_ : public nodeFunctionClassBase_ { public: double v; - + virtual ~nodeFunctionClass_() {}; }; #endif // NCOMPILER_MODEL_BASE_DEVEL_H_ diff --git a/nCompiler/inst/include/nCompiler/nClass_interface_Rcpp_extensions/post_Rcpp/shared_ptr_as_wrap.h b/nCompiler/inst/include/nCompiler/nClass_interface_Rcpp_extensions/post_Rcpp/shared_ptr_as_wrap.h index b471fc8f..bc2b07d0 100644 --- a/nCompiler/inst/include/nCompiler/nClass_interface_Rcpp_extensions/post_Rcpp/shared_ptr_as_wrap.h +++ b/nCompiler/inst/include/nCompiler/nClass_interface_Rcpp_extensions/post_Rcpp/shared_ptr_as_wrap.h @@ -33,6 +33,8 @@ namespace Rcpp { template class Exporter< std::shared_ptr< T > > { public: + static constexpr bool T_is_polymorphic = std::is_polymorphic_v; + std::shared_ptr sp_; Exporter(SEXP Sx) { Rcpp::Environment Sx_env(Sx); // Sx is an environment, so initialize an Rcpp:Environment from it. @@ -51,7 +53,13 @@ namespace Rcpp { ok=true;}}} } if(!ok) {stop("An argument that should be an nClass object is not valid.");} - sp_ = dynamic_cast* >(static_cast(R_ExternalPtrAddr(Xptr)))->sp(); + std::shared_ptr spbase = static_cast(R_ExternalPtrAddr(Xptr))->get_interfaceBase_shared_ptr(); + if constexpr (T_is_polymorphic) { + sp_ = std::dynamic_pointer_cast(spbase); + } else { + sp_ = std::static_pointer_cast(spbase); + } + // sp_ = dynamic_cast* >(static_cast(R_ExternalPtrAddr(Xptr)))->sp(); UNPROTECT(1); } inline std::shared_ptr< T > get(){ diff --git a/nCompiler/tests/testthat/nClass_tests/test-nClass_inherit.R b/nCompiler/tests/testthat/nClass_tests/test-nClass_inherit.R index 7c93f795..82826cc0 100644 --- a/nCompiler/tests/testthat/nClass_tests/test-nClass_inherit.R +++ b/nCompiler/tests/testthat/nClass_tests/test-nClass_inherit.R @@ -232,6 +232,8 @@ test_that("nClass hierarchies work as expected (including uncompiled vs compiled ############## +cat("With inheritance, we may now be able to interface at multiple levels, but it is untested.\n") + test_that("inheriting-only classes in 3-level hierarchy works", { ncBase <- nClass( classname = "ncBase", @@ -260,11 +262,29 @@ test_that("inheriting-only classes in 3-level hierarchy works", { Cpublic = list(x3 = 'numericScalar') ) - comp <- nCompile(ncBase, ncMid, ncDer) # check if order still matters + ncUseBase <- nClass( + classname = "ncUseBase", + Cpublic = list( + myBase = 'ncBase', + call_add_x = nFunction( + fun = function(v = 'numericScalar') { + return(myBase$add_x(v)); returnType('numericScalar') + } + ) + ) + ) + comp <- nCompile(ncUseBase, ncBase, ncMid, ncDer) # check if order still matters Cobj <- comp$ncDer$new() Cobj$x <- 10 expect_equal(Cobj$add_x(15), 25) expect_equal(method(Cobj$private$CppObj, "add_x")(15), 25) expect_equal(Cobj$add_2x_virt(15), 35) + + Cobj2 <- comp$ncUseBase$new() + expect_true(is.null(Cobj2$myBase)) + Cobj2$myBase <- Cobj + expect_equal(Cobj2$call_add_x(15), 25) + + rm(Cobj, Cobj2); gc() }) diff --git a/nCompiler/tests/testthat/nCompile_tests/test-compileNimble.R b/nCompiler/tests/testthat/nCompile_tests/test-compileNimble.R deleted file mode 100644 index 04a94ee8..00000000 --- a/nCompiler/tests/testthat/nCompile_tests/test-compileNimble.R +++ /dev/null @@ -1,93 +0,0 @@ -## Support for bridging nimble's compileNimble to nCompile -## Only basic tests will be here. -## The real tests will be running nimble's test suite. - -library(nimble) -#library(nCompiler) -library(testthat) - -test_that("compileNimble bridge works for simple nimbleFunction (RC function)",{ - RCF1 <- nimbleFunction( - run = function(x = double(1)) { - ans <- sum(x) - return(ans) - returnType(double()) - } - ) - CRCF1 <- `:::`("nCompiler", "compileNimble")(RCF1) - expect_equal(CRCF1(1:3), 6) -}) - -test_that("compileNimble bridge works for one nimbleFunction object", { - nf <- nimbleFunction( - setup = function() {x <- 1:2}, - run = function() {return(x[1]); returnType(double())} - ) - nf1 <- nf() - Cnf1 <- compileNimble(nf1) - expect_identical(Cnf1$x, 1:2) -}) -## NEXT STEPS: -## get a custom handler working -## try a test file from nimble using nClass -## add nClass to nCompiler:::compileNimble -## -## document, document, document - -test_that("registering a user-defined operator definition (opDef) works", { - ## first version: provide a function - nimArrayHandler <- function(code,...) { - code$name <- 'nArray' - NULL - } - # This test works by: - # providing a handler to relpace "nimArray" with "nArray" - # and a handler to replace "nimArray2" with "nArray" to - # check on handling multiple cases. - registerOpDef( - list(nimArray = - list( - matchDef = function(value=0, dim=c(1,1), init=TRUE, - fillZeros=TRUE, recycle=TRUE, nDim, - type="double") {}, - # normalizeCalls=list(handler='skip'), - simpleTransformations=list(handler = nimArrayHandler)))) - expect_equal(ls(`:::`("nCompiler", "operatorDefUserEnv")), "nimArray") - - registerOpDef( - list(nimArray2 = - list( - matchDef = function(value=0, dim=c(1,1), init=TRUE, - fillZeros=TRUE, recycle=TRUE, nDim, - type="double") {}, - simpleTransformations=list(handler = 'replace', - replacement = 'nArray')))) - expect_equal(ls(`:::`("nCompiler", "operatorDefUserEnv")), c("nimArray", "nimArray2")) - - nc <- nClass( - Cpublic = list( - foo = nFunction( - function() { - ans <- nimArray( 6, dim = 2) - ans2 <- nArray(value = 5, dim = 2) - return(ans) - returnType('double(1)') - } - ), - foo2 = nFunction( - function() { - ans <- nimArray2(3,dim = 2) - return(ans) - returnType('double(1)') - }) - )) - Cnc <- nCompile(nc) - obj <- Cnc$new() - expect_identical(obj$foo(), c(6, 6)) - expect_identical(obj$foo2(), c(3, 3)) - rm(obj); gc() - # - deregisterOpDef("nimArray") - deregisterOpDef("nimArray2") - expect_equal(length(ls(`:::`("nCompiler", "operatorDefUserEnv"))), 0) -}) diff --git a/nCompiler/tests/testthat/nCompile_tests/test-nCompile_C_fun.R b/nCompiler/tests/testthat/nCompile_tests/test-nCompile_C_fun.R new file mode 100644 index 00000000..53b2f757 --- /dev/null +++ b/nCompiler/tests/testthat/nCompile_tests/test-nCompile_C_fun.R @@ -0,0 +1,94 @@ +# Test providing separate R_fun and C_fun in an nFunction +# to fully separate compiled and uncompiled behavior. + +test_that("C_fun works in an nFunction", { + # Key points + # 1. argTypes and returnType refer to the C_fun is provided. + # 2. The args of fun are untyped and may even differ from C_fun, if C_fun is provided. + # 3. By default, changeKeywords does not happen, so "nStep" instead of "step" + foo <- nFunction( + fun = function(x) { + x+1 + }, + compileInfo = list( + C_fun = function(x, y, z) { + ans <- x + y + nStep(z) + return(ans) + } + ), + argTypes = list(x='numericVector',y='numericScalar', z='numericScalar'), + returnType = 'numericVector' + ) + + expect_equal(foo(3), 4) + cfoo <- nCompile(foo) + expect_equal(cfoo(3:4, 100, 1.2), (103:104) + 1) +}) + +test_that("C_fun errors out in an nFunction from not changing keywords ", { + # Now expect and error if "step" was used + foo <- nFunction( + fun = function(x) { + x+1 + }, + compileInfo = list( + C_fun = function(x, y, z) { + ans <- x + y + step(z) + return(ans) + } + ), + argTypes = list(x='numericVector',y='numericScalar', z='numericScalar'), + returnType = 'numericVector' + ) + + cat("Expected error message about no op def for step.\n") + expect_error(cfoo <- nCompile(foo)) +}) + +test_that("C_fun works in an nFunction with changeKeywords=TRUE", { + # Now step is ok b/c control$changeKeywords is TRUE + foo <- nFunction( + fun = function(x) { + x+1 + }, + compileInfo = list( + C_fun = function(x, y, z) { + ans <- x + y + step(z) + return(ans) + }), + control = list( + changeKeywords=TRUE + ), + argTypes = list(x='numericVector',y='numericScalar', z='numericScalar'), + returnType = 'numericVector' + ) + expect_equal(foo(3), 4) + cfoo <- nCompile(foo) + expect_equal(cfoo(3:4, 100, 1.2), (103:104) + 1) +}) + +test_that("C_fun works in an nClass method", { + foo <- nFunction( + fun = function(x) { + x+1 + }, + compileInfo = list( + C_fun = function(x, y, z) { + ans <- x + y + step(z) + return(ans) + }), + control = list( + changeKeywords=TRUE + ), + argTypes = list(x='numericVector',y='numericScalar', z='numericScalar'), + returnType = 'numericVector' + ) + fooClass <- nClass( + Cpublic=list(foo=foo) + ) + obj <- fooClass$new() + expect_equal(obj$foo(3), 4) + cfooClass <- nCompile(fooClass) + cobj <- cfooClass$new() + expect_equal(cobj$foo(3:4, 100, 1.2), (103:104) + 1) +}) diff --git a/nCompiler/tests/testthat/nCompile_tests/test-userOps.R b/nCompiler/tests/testthat/nCompile_tests/test-userOps.R new file mode 100644 index 00000000..e25c417e --- /dev/null +++ b/nCompiler/tests/testthat/nCompile_tests/test-userOps.R @@ -0,0 +1,317 @@ +# test user-defined operator definitions +# These can be: +# 1. provided globally to define (or take over) a new keyword via registerOpDef +# - precedence goes to a user-defined opDef over a built-in one +# 2. provided as an opDef list to an nFunction +# 3. provided as an opDef list (within compileInfo) to an nFunction that is an nClass method +# 4. provided as a list of opDef lists (within compileInfo) to an nClass + +# The definition "closest" to the nFunction takes precedence. +# e.g. if an nFunction has an opDef list, that takes precedence over +# an opDef list provided to the nClass that contains the nFunction. + +cat("add userOps test for custom_default opDef\n") + +test_that("registering a global user-defined operator definition (opDef) works", { + ## first version: provide a function + nimArrayHandler <- function(code,...) { + code$name <- 'nArray' + NULL + } + # This test works by: + # providing a handler to relpace "nimArray" with "nArray" + # and a handler to replace "nimArray2" with "nArray" to + # check on handling multiple cases. + registerOpDef( + list(nimArray = + list( + matchDef = function(value=0, dim=c(1,1), init=TRUE, + fillZeros=TRUE, recycle=TRUE, nDim, + type="double") {}, + simpleTransformations=list(handler = nimArrayHandler)))) + expect_equal(ls(`:::`("nCompiler", "operatorDefUserEnv")), "nimArray") + + registerOpDef( + list(nimArray2 = + list( + matchDef = function(value=0, dim=c(1,1), init=TRUE, + fillZeros=TRUE, recycle=TRUE, nDim, + type="double") {}, + simpleTransformations=list(handler = 'replace', + replacement = 'nArray')))) + expect_equal(ls(`:::`("nCompiler", "operatorDefUserEnv")), c("nimArray", "nimArray2")) + + nc <- nClass( + Cpublic = list( + foo = nFunction( + function() { + ans <- nimArray( 6, dim = 2) + ans2 <- nArray(value = 5, dim = 2) + return(ans) + returnType('double(1)') + } + ), + foo2 = nFunction( + function() { + ans <- nimArray2(3,dim = 2) + return(ans) + returnType('double(1)') + }) + )) + Cnc <- nCompile(nc) + obj <- Cnc$new() + expect_identical(obj$foo(), c(6, 6)) + expect_identical(obj$foo2(), c(3, 3)) + rm(obj); gc() + # + deregisterOpDef("nimArray") + deregisterOpDef("nimArray2") + expect_equal(length(ls(`:::`("nCompiler", "operatorDefUserEnv"))), 0) +}) + +cat("User opDef could be dangerous prior to genCpp because it won't update cachedOpDef\n") + +test_that("nFunction custom opDef works through a sequence of changes and handlers", { + # We set up a series of messages and function renamings to track + # the custom opDef through each one and also + # show that it is being updated at each step. + # Note the manual updating during eigenImpl. + # + # foo is an nFunction with compileArgs and a simpleTrans handler that + # renames it foo2 and emits a msg. + check_V <- function(code, ...) { + cat("MSG1: check_V was called. ") + if(code$aux$compileArgs$V == "W") + cat("MSG2: compile arg V was found. ") + code$name <- "foo2" + } + custom_opDef <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef$matchDef <- function(V) {} + custom_opDef$compileArgs <- "V" + custom_opDef$simpleTransformations <- list( + handler = check_V + ) + foo <- nFunction( + fun = function() { + return(1.2) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef) + ) + + # foo2 is an nFunction with custom labelAbstractTypes handler + # that emits a msg and renames it to foo3. + check_LAT <- function(code, symTab, auxEnv, handlingInfo) { + cat("MSG3: LAT for foo2 was used. ") + handler <- nCompiler:::getOperatorDef("nFunction_default")$labelAbstractTypes$handler + ans <- eval(call(handler, code,symTab,auxEnv,handlingInfo), + envir=nCompiler:::labelAbstractTypesEnv) + code$name <- "foo3" + ans + } + custom_opDef2 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef2$labelAbstractTypes <- list( + handler = check_LAT + ) + foo2 <- nFunction( + fun = function() { + return(2.3) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef2) + ) + + # foo3 is an nFunction with a custom eigenImpl handler + # that emits, a msg, renames it, and updated the cachedOpInfo + check_EIG <- function(code, symTab, auxEnv, workEnv, handlingInfo) { + cat("MSG4: LAT for foo3 was used. ") + code$name <- "foo4" + nCompiler:::update_cachedOpInfo(code, auxEnv$where) + NULL + } + custom_opDef3 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef3$eigenImpl <- list( + handler = check_EIG + ) + foo3 <- nFunction( + fun = function() { + return(3.4) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef3) + ) + + ## foo4 is an nFunction with a custom cppOutput handler + # that emits a msg and pastes "10.0*3.4 + " in front of the + # default output. + check_GenCpp <- function(code, symTab) { +# cat("MSG5: GenCpp for foo4 was used.") + handler <- nCompiler:::getOperatorDef("nFunction_default")$cppOutput$handler + base_out <- eval(call(handler, code, symTab), + envir=nCompiler:::genCppEnv) + paste0("10.0*3.4 + ", base_out, collapse = "") + } + custom_opDef4 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef4$cppOutput <- list( + handler = check_GenCpp + ) + foo4 <- nFunction( + fun = function() { + return(4.5) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef4) + ) + + call_foo <- nFunction( + fun = function() { + ans <- foo(V = "W"); return(ans) + }, argTypes=list(), returnType=quote(double()) + ) + out <- capture_output( + cppDefs <- nCompile(call_foo, control=list(return_cppDefs=TRUE)) + ) + check <- grepl("^MSG1", out) + expect_true(check) + check <- grepl("MSG2", out) + expect_true(check) + check <- grepl("MSG3", out) + expect_true(check) + check <- grepl("MSG4", out) + expect_true(check) + + out_code <- cppDefs[[1]]$generate() |> unlist() + check <- grepl("10\\.0\\*3\\.4", out_code) |> sum() + expect_true(check==1) + + cat("#include needs are not cleaned up if an nFunction is renamed by a handler to another nFunction. All must be included in nCompile.") + comp <- nCompile(foo, foo2, foo3, foo4, call_foo) + expect_equal(comp$call_foo(), 10*3.4 + 4.5) + +}) + +test_that("nClass custom opDefs of 3 kinds works through a sequence of changes and handlers", { + + # foo is an nFunction with compileArgs and a simpleTrans handler that + # renames it foo2 and emits a msg. + check_V <- function(code, ...) { + cat("MSG1: check_V was called. ") + if(code$aux$compileArgs$V == "W") + cat("MSG2: compile arg V was found. ") + code$name <- "foo2" + } + custom_opDef <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef$matchDef <- function(V) {} + custom_opDef$compileArgs <- "V" + custom_opDef$simpleTransformations <- list( + handler = check_V + ) + foo <- nFunction( + fun = function() { + return(1.2) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef) + ) + + # foo2 is an nFunction with custom labelAbstractTypes handler + # that emits a msg and renames it to foo3. + # It's handler will be put at the nClass level + check_LAT <- function(code, symTab, auxEnv, handlingInfo) { + cat("MSG3: LAT for foo2 was used. ") + handler <- nCompiler:::getOperatorDef("nFunction_default")$labelAbstractTypes$handler + ans <- eval(call(handler, code,symTab,auxEnv,handlingInfo), + envir=nCompiler:::labelAbstractTypesEnv) + code$name <- "foo3" + ans + } + custom_opDef2 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef2$labelAbstractTypes <- list( + handler = check_LAT + ) + foo2 <- nFunction( + fun = function() { + return(2.3) + }, argTypes=list(), returnType=quote(double()), +# compileInfo=list(opDef=custom_opDef2) + ) + + # foo3 is an nFunction with a custom eigenImpl handler + # that emits, a msg, renames it, and updated the cachedOpInfo + # This will be treated as a keyword opDef, only, without an + # actual nFunction. + check_EIG <- function(code, symTab, auxEnv, workEnv, handlingInfo) { + cat("MSG4: LAT for foo3 was used. ") + code$name <- "foo4" + nCompiler:::update_cachedOpInfo(code, auxEnv$where) + NULL + } + custom_opDef3 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef3$eigenImpl <- list( + handler = check_EIG + ) +# foo3 <- nFunction( +# fun = function() { +# return(3.4) +# }, argTypes=list(), returnType=quote(double()), +# compileInfo=list(opDef=custom_opDef3) +# ) + + ## foo4 is an nFunction with a custom cppOutput handler + # that emits a msg and pastes "10.0*3.4 + " in front of the + # default output. + # This will remain an nFunction outside the nClass. + check_GenCpp <- function(code, symTab) { +# cat("MSG5: GenCpp for foo4 was used.") + handler <- nCompiler:::getOperatorDef("nFunction_default")$cppOutput$handler + base_out <- eval(call(handler, code, symTab), + envir=nCompiler:::genCppEnv) + paste0("10.0*3.4 + ", base_out, collapse = "") + } + custom_opDef4 <- nCompiler:::getOperatorDef("nFunction_default") + custom_opDef4$cppOutput <- list( + handler = check_GenCpp + ) + foo4 <- nFunction( + fun = function() { + return(4.5) + }, argTypes=list(), returnType=quote(double()), + compileInfo=list(opDef=custom_opDef4) + ) + + call_foo <- nFunction( + fun = function() { + ans <- foo(V = "W"); return(ans) + }, argTypes=list(), returnType=quote(double()) + ) + + foo_class <- nClass( + Cpublic = list( + foo = foo, + foo2 = foo2, + call_foo = call_foo + ), + compileInfo = list( + opDefs = list( + foo2 = custom_opDef2, + foo3 = custom_opDef3 + ) + ) + ) + + out <- capture_output( + cppDefs <- nCompile(foo_class, control=list(return_cppDefs=TRUE)) + ) + check <- grepl("^MSG1", out) + expect_true(check) + check <- grepl("MSG2", out) + expect_true(check) + check <- grepl("MSG3", out) + expect_true(check) + check <- grepl("MSG4", out) + expect_true(check) + + out_code <- cppDefs[[1]]$generate() |> unlist() + check <- grepl("10\\.0\\*3\\.4", out_code) |> sum() + expect_true(check==1) + + comp <- nCompile(foo4, foo_class) + obj <- comp$foo_class$new() + expect_equal(obj$call_foo(), 10*3.4 + 4.5) + rm(obj); gc() +}) diff --git a/nCompiler/tests/testthat/nimble_tests/test-compileNimble.R b/nCompiler/tests/testthat/nimble_tests/test-compileNimble.R new file mode 100644 index 00000000..b90c4cd1 --- /dev/null +++ b/nCompiler/tests/testthat/nimble_tests/test-compileNimble.R @@ -0,0 +1,35 @@ +## Support for bridging nimble's compileNimble to nCompile +## Only basic tests will be here. +## The real tests will be running nimble's test suite. + +library(nimble) +#library(nCompiler) +library(testthat) + +test_that("compileNimble bridge works for simple nimbleFunction (RC function)",{ + RCF1 <- nimbleFunction( + run = function(x = double(1)) { + ans <- sum(x) + return(ans) + returnType(double()) + } + ) + CRCF1 <- `:::`("nCompiler", "compileNimble")(RCF1) + expect_equal(CRCF1(1:3), 6) +}) + +test_that("compileNimble bridge works for one nimbleFunction object", { + nf <- nimbleFunction( + setup = function() {x <- 1:2}, + run = function() {return(x[1]); returnType(double())} + ) + nf1 <- nf() + Cnf1 <- compileNimble(nf1) + expect_identical(Cnf1$x, 1:2) +}) +## NEXT STEPS: +## get a custom handler working +## try a test file from nimble using nClass +## add nClass to nCompiler:::compileNimble +## +## document, document, document diff --git a/nCompiler/tests/testthat/nimble_tests/test-nimbleModel.R b/nCompiler/tests/testthat/nimble_tests/test-nimbleModel.R new file mode 100644 index 00000000..273f3451 --- /dev/null +++ b/nCompiler/tests/testthat/nimble_tests/test-nimbleModel.R @@ -0,0 +1,35 @@ +# Test code needed for new nimbleModel system. +# Some or all of this should eventually go in a separate package. + +library(nCompiler) +library(testthat) + +test_that("toy nimble model prototype works", { + varInfoM <- list(list(name = "beta", nDim = 1), list(name = "mu", nDim = 0), + list(name = "gamma", nDim = 2)) + + #debug(makeModel_nClass) + ncm1 <- makeModel_nClass(varInfoM) + + varInfo <- list(list(name = "x", nDim = 0), list(name = "mu", nDim = 1), + list(name = "sd", nDim = 0)) + node_dnorm <- make_node_fun(varInfo) + + Cncm1 <- nCompile(modelBase_nClass, ncm1, node_dnorm) + + obj <- Cncm1$ncm1$new() + obj$call_setup_node_mgmt() + nodeObj <- obj$beta_node + obj$beta <- 1:3 + expect_equal(obj$beta, 1:3) + + obj$set_from_list(list(beta = 10:11)) + obj$set_from_list(list(mu = 110, beta = 11:20, alpha = 101)) + obj$mu + + obj$resize_from_list(list(beta = 7)) + expect_error(obj$resize_from_list(list(beta = 5, mu = 3, gamma = c(2, 4)))) + obj$resize_from_list(list(beta = 5, gamma = c(2, 4))) + expect_equal(length(obj$beta), 5) + expect_equal(dim(obj$gamma), c(2, 4)) +})