Skip to content

Commit 2a67ced

Browse files
committed
prototype nodeFxn and model class have no more baked-in toy pieces
1 parent 07d0a79 commit 2a67ced

6 files changed

Lines changed: 171 additions & 84 deletions

File tree

nCompiler/R/all_utils.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ resetLabelFunctionCreators <- function() {
4343
}
4444

4545
ADtapeMgrLabelCreator <- labelFunctionCreator("ADtapeMgr")
46+
nodeFxnLabelCreator <- labelFunctionCreator("nodeFxn")
47+
modelLabelCreator <- labelFunctionCreator("model")
4648

4749
# no longer documented in Rd
4850
# Generates a valid C++ name from an R Name

nCompiler/R/nimbleModels.R

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ nodeFxnBase_nClass <- nClass(
3030
),
3131
# We haven't dealt with ensuring a virtual destructor when any method is virtual
3232
# For now I did it manually by editing the .h and .cpp
33-
predefined = quote(system.file(file.path("include","nCompiler", "predefined_nClasses"), package="nCompiler") |>
33+
predefined = quote(system.file(file.path("include","nCompiler", "predefined_nClasses"), package="nCompiler") |>
3434
file.path("nodeFxnBase_nClass")),
3535
compileInfo=list(interface="full",
3636
createFromR = FALSE)
@@ -76,10 +76,16 @@ modelBase_nClass <- nClass(
7676
## obj <- comp$test$new()
7777
## obj$calculate(NULL)
7878

79-
make_node_fun <- function(varInfo) {
79+
# Turn variables and methods into a nodeFxn nClass
80+
make_node_fun <- function(varInfo = list(),
81+
methods = list(),
82+
classname) {
8083
# varInfo will be a list (names not used) of name, nDim, sizes.
81-
varInfo_2_cppVar <- \(x) nCompiler:::symbolCppVar$new(baseType = nCompiler:::symbolBasic$new(type="double", nDim=x$nDim, name="")$genCppVar()$generate(),
82-
ref=TRUE, const=TRUE)
84+
varInfo_2_cppVar <- \(x) nCompiler:::symbolBasic$new(
85+
type="double", nDim=x$nDim, name="", isRef=TRUE, isConst=FALSE, interface=FALSE) # We could in future make some isConst=TRUE, but it might not matter much
86+
# varInfo_2_cppVar <- \(x) nCompiler:::symbolCppVar$new(
87+
# baseType = nCompiler:::symbolBasic$new(type="double", nDim=x$nDim, name="")$genCppVar()$generate(),
88+
# ref=TRUE, const=TRUE)
8389
typeList <- varInfo |> lapply(varInfo_2_cppVar)
8490
names(typeList) <- varInfo |> lapply(\(x) x$name) |> unlist()
8591

@@ -92,29 +98,59 @@ make_node_fun <- function(varInfo) {
9298
initFun <- function(){}
9399
formals(initFun) <- structure(as.pairlist(CpublicVars), names = ctorArgNames)
94100

101+
if(missing(classname))
102+
classname <- nodeFxnLabelCreator()
103+
104+
baseclass <- paste0("nodeFxnClass_<", classname, ">")
105+
95106
# This was a prototype
96-
node_dnorm <- substitute(
107+
node_nClass <- substitute(
97108
nClass(
98-
classname = "node_dnorm",
109+
classname = CLASSNAME,
99110
Cpublic = CPUBLIC,
100111
compileInfo = list(
101112
createFromR = FALSE, # Without a default constructor (which we've disabled here), createFromR is impossible
102-
nClass_inherit = list(base = "nodeFxnClass_<node_dnorm>")) # Ideally this line would be obtained from a base nClass, but we insert it directly for now
113+
nClass_inherit = list(base = BASECLASS)) # Ideally this line would be obtained from a base nClass, but we insert it directly for now
103114
),
104115
list(CPUBLIC = c(
105116
list(
106-
node_dnorm = nFunction(
107-
initFun,
108-
compileInfo = list(constructor=TRUE, initializers = initializersList)
109-
)
117+
nFunction(
118+
initFun,
119+
compileInfo = list(constructor=TRUE, initializers = initializersList)
120+
)
121+
) |> structure(names = classname),
122+
CpublicVars,
123+
methods
110124
),
111-
CpublicVars
112-
)))
113-
eval(node_dnorm)
125+
CLASSNAME = classname,
126+
BASECLASS = baseclass
127+
))
128+
eval(node_nClass)
114129
}
115130
#test <- nCompiler:::argType2symbol('CppVar(baseType = argType2Cpp("numericVector"), ref=TRUE, const=TRUE)')
116131

117-
makeModel_nClass <- function(varInfo) {
132+
# Make all the info needed to include a node in a model class.
133+
# The nodeFxn_nClass should be created first.
134+
# Currently it needs to have a name to include in nCompile(). Later we might be able to pass the object itself
135+
# At first drafting this is fairly trivial but could grow in complexity.
136+
137+
make_node_info <- function(membername,
138+
nodeFxnName,
139+
classname,
140+
varInfo = list()
141+
) {
142+
ctorArgs <- varInfo |> lapply(\(x) x$name) |> unlist()
143+
144+
list(nodeFxnName = nodeFxnName,
145+
membername = membername,
146+
classname = classname,
147+
ctorArgs = ctorArgs)
148+
}
149+
150+
makeModel_nClass <- function(varInfo,
151+
nodes = list(),
152+
classname
153+
) {
118154
# varInfo will be a list (names not used) of name, nDim, sizes.
119155
CpublicModelVars <- varInfo |> lapply(\(x) paste0("numericArray(nDim=",x$nDim,")"))
120156
names(CpublicModelVars) <- varInfo |> lapply(\(x) x$name) |> unlist()
@@ -127,6 +163,9 @@ makeModel_nClass <- function(varInfo) {
127163
opDefs$setup_node_mgmt$returnType <- nCompiler:::argType2symbol(quote(void()))
128164
opDefs$setup_node_mgmt$labelAbstractTypes$recurse <- FALSE
129165

166+
if(missing(classname))
167+
classname <- modelLabelCreator()
168+
130169
CpublicMethods <- list(
131170
do_setup_node_mgmt = nFunction(
132171
name = "call_setup_node_mgmt",
@@ -141,29 +180,44 @@ makeModel_nClass <- function(varInfo) {
141180
function(Rlist = 'RcppList') {cppLiteral('modelClass_::resize_from_list(Rlist);')}
142181
)
143182
)
144-
CpublicNodeFuns <- list(
145-
beta_node = 'node_dnorm()'
146-
)
183+
# nodes will be a list of membername, nodeFxnName, (node) classname, ctorArgs (list)
184+
node_pieces <- nodes |> lapply(\(x) {
185+
nClass_type <- paste0(x$nodeFxnName, "()")
186+
init_string <- paste0('nCpp("', x$membername, '( new ', x$classname, '(',
187+
paste0(x$ctorArgs, collapse=","), '))")')
188+
list(nClass_type = nClass_type,
189+
init_string = init_string,
190+
membername = x$membername)
191+
})
192+
membernames <- node_pieces |> lapply(\(x) x$membername) |> unlist()
193+
CpublicNodeFuns <- node_pieces |> lapply(\(x) x$nClass_type) |> setNames(membernames)
194+
# CpublicNodeFuns <- list(
195+
# beta_node = 'node_dnorm()'
196+
# )
147197
CpublicCtor <- list(
148-
mymodel = nFunction(
198+
nFunction(
149199
function(){},
150200
compileInfo = list(constructor=TRUE,
151-
initializers = c('nCpp("beta_node(new node_dnorm(mu, beta, 1))")'))
201+
#initializers = c('nCpp("beta_node(new node_dnorm(mu, beta, 1))")'))
202+
initializers = node_pieces |> lapply(\(x) x$init_string) |> unlist())
152203
)
153-
)
204+
) |> structure(names = classname)
205+
baseclass <- paste0("modelClass_<", classname, ">")
154206
ans <- substitute(
155207
nClass(
156-
classname = "mymodel",
208+
classname = CLASSNAME,
157209
inherit = modelBase_nClass,
158210
compileInfo = list(opDefs = OPDEFS,
159-
nClass_inherit = list(base="modelClass_<mymodel>")
211+
nClass_inherit = list(base=BASECLASS)
160212
#inherit = list(base = "public modelClass_<mymodel>"),
161213
#Hincludes = "<nCompiler/nClass_interface/post_Rcpp/nCompiler_model_base_devel.h>"
162214
),
163215
Cpublic = CPUBLIC
164216
),
165217
list(OPDEFS = opDefs,
166-
CPUBLIC = c(CpublicNodeFuns, CpublicModelVars, CpublicCtor, CpublicMethods))
218+
CPUBLIC = c(CpublicNodeFuns, CpublicModelVars, CpublicCtor, CpublicMethods),
219+
CLASSNAME = classname,
220+
BASECLASS = baseclass)
167221
)
168222
eval(ans, envir = parent.frame())
169223
}

nCompiler/R/symbolTable.R

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,17 @@ symbolBasic <-
4646
size = NULL,
4747
knownSize = NULL,
4848
isBlockRef = FALSE,
49+
isConst = FALSE,
4950
initialize = function(...,
5051
nDim = 0,
5152
size = if(nDim == 0) 1 else NA,
52-
isBlockRef = FALSE) {
53+
isBlockRef = FALSE,
54+
isConst = FALSE) {
5355
super$initialize(...)
5456
self$nDim <- nDim
5557
self$size <- size
5658
self$isBlockRef <- isBlockRef
59+
self$isConst <- isConst
5760
self
5861
},
5962
shortPrint = function() {
@@ -93,32 +96,40 @@ symbolBasic <-
9396
"type", type,"unrecognized\n"),
9497
FALSE)
9598
if(self$nDim == 0) {
96-
return(if(!(identical(self$name, "pi")))
97-
cppVarClass$new(baseType = cType,
98-
name = self$name,
99-
ptr = 0,
100-
ref = FALSE)
101-
else
102-
cppVarFullClass$new(baseType = cType,
99+
if(identical(self$name, "pi"))
100+
return(cppVarFullClass$new(baseType = cType,
103101
name = self$name,
104102
ptr = 0,
105103
ref = FALSE,
106-
constructor = "(M_PI)")
107-
)
104+
constructor = "(M_PI)"))
105+
if(isTRUE(self$isConst))
106+
return(cppVarFullClass$new(baseType = cType,
107+
name = self$name,
108+
ptr = FALSE,
109+
ref = self$isRef,
110+
const = self$isConst))
111+
return(cppVarClass$new(baseType = cType,
112+
name = self$name,
113+
ptr = 0,
114+
ref = self$isRef))
108115
}
109116
if(self$isBlockRef) {
110-
return(cppStridedTensorMapRef(name = self$name,
111-
nDim = self$nDim,
112-
scalarType = cType))
117+
ans <- cppStridedTensorMapRef(name = self$name,
118+
nDim = self$nDim,
119+
scalarType = cType)
113120
} else if(self$isRef) {
114-
return(cppEigenTensorRef(name = self$name,
121+
ans <- cppEigenTensorRef(name = self$name,
115122
nDim = self$nDim,
116-
scalarType = cType))
123+
scalarType = cType)
117124
} else {
118-
return(cppEigenTensor(name = self$name,
125+
ans <- cppEigenTensor(name = self$name,
119126
nDim = self$nDim,
120-
scalarType = cType))
127+
scalarType = cType)
128+
}
129+
if(self$isConst) {
130+
ans$const <- TRUE
121131
}
132+
ans
122133
}
123134
)
124135
)

nCompiler/inst/include/nCompiler/nClass_interface/post_Rcpp/generic_class_interface_Rcpp_steps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ class genericInterfaceC : virtual public genericInterfaceBaseC {
411411
name2access_type::iterator access = name2access.find(name);
412412
if(access == name2access.end()) {
413413
std::cout<<"Problem: \""<<name<<"\" is not a field in this nClass."<<std::endl;
414+
return nullptr;
414415
}
415416
return (access->second->ETaccess(this));
416417
}

nCompiler/inst/include/nCompiler/predefined_nClasses/modelBase_nClass/modelBase_nClass_hContent.h

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class modelClass_ : public modelBase_nClass {
6060
for(size_t i = 0; i < len; ++i) {
6161
// explicit cast is needed because even though Rnames[i] can cast to a string,
6262
// set_value takes a const string& so we need an object in place here.
63+
// set_value fails safely if a name is not found.
6364
static_cast<Derived*>(this)->set_value(std::string(Rnames[i]), Rlist[i]);
6465
}
6566
}
@@ -74,27 +75,30 @@ class modelClass_ : public modelBase_nClass {
7475
vs = Rlist[i];
7576
vec_len = vs.length();
7677
std::unique_ptr<ETaccessorBase> ETA = static_cast<Derived*>(this)->access(std::string(Rnames[i]));
77-
switch(vec_len) {
78-
case 0 :
79-
break;
80-
case 1 :
81-
ETA->template ref<1>().resize(vs[0]);
82-
break;
83-
case 2 :
84-
ETA->template ref<2>().resize(vs[0], vs[1]);
85-
break;
86-
case 3 :
87-
ETA->template ref<3>().resize(vs[0], vs[1], vs[2]);
88-
break;
89-
case 4 :
90-
ETA->template ref<4>().resize(vs[0], vs[1], vs[2], vs[3]);
91-
break;
92-
case 5 :
93-
ETA->template ref<5>().resize(vs[0], vs[1], vs[2], vs[3], vs[4]);
94-
break;
95-
case 6 :
96-
ETA->template ref<6>().resize(vs[0], vs[1], vs[2], vs[3], vs[4], vs[5]);
97-
break;
78+
// if the name was not found, a "Problem:" message was emitted, and we skip using it here.
79+
if(ETA) {
80+
switch(vec_len) {
81+
case 0 :
82+
break;
83+
case 1 :
84+
ETA->template ref<1>().resize(vs[0]);
85+
break;
86+
case 2 :
87+
ETA->template ref<2>().resize(vs[0], vs[1]);
88+
break;
89+
case 3 :
90+
ETA->template ref<3>().resize(vs[0], vs[1], vs[2]);
91+
break;
92+
case 4 :
93+
ETA->template ref<4>().resize(vs[0], vs[1], vs[2], vs[3]);
94+
break;
95+
case 5 :
96+
ETA->template ref<5>().resize(vs[0], vs[1], vs[2], vs[3], vs[4]);
97+
break;
98+
case 6 :
99+
ETA->template ref<6>().resize(vs[0], vs[1], vs[2], vs[3], vs[4], vs[5]);
100+
break;
101+
}
98102
}
99103
}
100104
}

nCompiler/tests/testthat/nimble_tests/test-nimbleModel.R

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,48 @@
44
library(nCompiler)
55
library(testthat)
66

7-
test_that("toy nimble model prototype works", {
8-
varInfoM <- list(list(name = "beta", nDim = 1), list(name = "mu", nDim = 0),
9-
list(name = "gamma", nDim = 2))
10-
11-
#debug(makeModel_nClass)
12-
ncm1 <- makeModel_nClass(varInfoM)
13-
14-
varInfo <- list(list(name = "x", nDim = 0), list(name = "mu", nDim = 1),
7+
test_that("nimble model prototype works", {
8+
nodeVarInfo <- list(list(name = "x", nDim = 1), list(name = "mu", nDim = 1),
159
list(name = "sd", nDim = 0))
16-
node_dnorm <- make_node_fun(varInfo)
17-
18-
Cncm1 <- nCompile(modelBase_nClass, nodeFxnBase_nClass, ncm1, node_dnorm)
19-
10+
calc_one <- nFunction(
11+
fun = function(inds = 'integerVector') {
12+
returnType('numericScalar')
13+
ans <- x[inds[1]]
14+
return(ans)
15+
}
16+
)
17+
my_nodeFxn <- make_node_fun(nodeVarInfo, list(calc_one=calc_one), "test_node")
18+
my_nodeInfo <- nCompiler:::make_node_info("beta_NF1", "my_nodeFxn", "test_node", nodeVarInfo)
19+
20+
modelVarInfo <- list(list(name="x", nDim = 1),
21+
list(name = "mu", nDim = 1),
22+
list(name = "sd", nDim = 0),
23+
list(name = "gamma", nDim = 2))
24+
#debug(makeModel_nClass)
25+
ncm1 <- makeModel_nClass(modelVarInfo, list(my_nodeInfo), classname = "my_model")
26+
#undebug(nCompiler:::addGenericInterface_impl)
27+
Cncm1 <- nCompile(modelBase_nClass, nodeFxnBase_nClass, ncm1, my_nodeFxn)
2028
obj <- Cncm1$ncm1$new()
29+
2130
obj$do_setup_node_mgmt()
22-
nodeObj <- obj$beta_node
23-
obj$beta <- 1:3
24-
expect_equal(obj$beta, 1:3)
31+
nodeObj <- obj$beta_NF1
32+
obj$x <- 1:3
33+
expect_equal(obj$x, 1:3)
2534

26-
obj$set_from_list(list(beta = 10:11))
27-
# expect Problem msg:
28-
obj$set_from_list(list(mu = 110, beta = 11:20, alpha = 101))
35+
obj$set_from_list(list(x = 10:11))
36+
# expect Problem msg: (alpha is not a field in the class)
37+
obj$set_from_list(list(mu = 110, x = 11:20, alpha = 101))
2938
obj$mu
3039

31-
obj$resize_from_list(list(beta = 7))
32-
expect_error(obj$resize_from_list(list(beta = 5, mu = 3, gamma = c(2, 4))))
33-
obj$resize_from_list(list(beta = 5, gamma = c(2, 4)))
34-
expect_equal(length(obj$beta), 5)
40+
obj$resize_from_list(list(x = 7))
41+
# expect Problem msg:
42+
obj$resize_from_list(list(alpha = 5, mu = 3, gamma = c(2, 4)))
43+
expect_equal(length(obj$mu), 3)
3544
expect_equal(dim(obj$gamma), c(2, 4))
45+
obj$resize_from_list(list(x = 5, gamma = c(3, 5)))
46+
expect_equal(length(obj$x), 5)
47+
expect_equal(dim(obj$gamma), c(3, 5))
48+
49+
obj$x <- 11:15
50+
expect_equal(nodeObj$calc_one(c(3)), 13)
3651
})

0 commit comments

Comments
 (0)