Skip to content

Commit efa9ba1

Browse files
bucknerjclaude
andcommitted
Use meaningful group names in dbind()
When a named list of ISMs is passed to dbind(), use the list names as group levels in the resulting BISM. When BISMs are in the mix, preserve their existing @groups levels. Unnamed entries fall back to position-based numeric indices, skipping any that collide with existing labels. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1972f02 commit efa9ba1

2 files changed

Lines changed: 80 additions & 13 deletions

File tree

R/InfinitySparseMatrix.R

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -973,34 +973,61 @@ dbind <- function(..., force_unique_names = FALSE) {
973973

974974

975975
# Convert all matrices to ISMs if they aren't already.
976-
mats <- lapply(mats, function(x) {
976+
# Also build a parallel vector of group labels for each entry.
977+
input_names <- names(mats)
978+
converted <- lapply(seq_along(mats), function(i) {
979+
x <- mats[[i]]
980+
nm <- if (!is.null(input_names) && nzchar(input_names[i])) input_names[i] else NA_character_
981+
977982
if (is(x, "BlockedInfinitySparseMatrix")) {
978-
# Replace BISM with list of ISMs
979-
findSubproblems(x)
983+
# Replace BISM with list of ISMs; use its existing group level names
984+
sp <- findSubproblems(x)
985+
list(mats = sp, labels = names(sp))
980986
} else if (inherits(x, "list")) {
981987
# If any entry in ... is a list,
988+
inner_names <- names(x)
982989
# 1) Convert all entries in that list to ISM while keeping BISM as BISM
983990
x <- lapply(x, .as.ism_or_bism)
984-
# 2) If we have any BISMs, split into list of ISMS
985-
x <- lapply(x, function(y) {
991+
# 2) If we have any BISMs, split into list of ISMs, preserving labels
992+
inner_converted <- lapply(seq_along(x), function(j) {
993+
y <- x[[j]]
994+
inner_nm <- if (!is.null(inner_names) && nzchar(inner_names[j])) inner_names[j] else NA_character_
986995
if (is(y, "BlockedInfinitySparseMatrix")) {
987-
findSubproblems(y)
996+
sp <- findSubproblems(y)
997+
list(mats = sp, labels = names(sp))
988998
} else {
989-
y
999+
list(mats = y, labels = inner_nm)
9901000
}
9911001
})
9921002
# 3) pull list of lists into list
993-
flatten_list(x)
1003+
list(mats = flatten_list(lapply(inner_converted, `[[`, "mats")),
1004+
labels = unlist(lapply(inner_converted, `[[`, "labels")))
9941005
} else {
9951006
# This will error appropriately if some element in `mats` cannot be
9961007
# converted to an ISM.
997-
.as.ism_or_bism(x)
1008+
list(mats = .as.ism_or_bism(x), labels = nm)
9981009
}
9991010
})
10001011

10011012
# If we were passed any BISMs, we have a list of lists of ISM, so flatten to a
10021013
# single list.
1003-
mats <- flatten_list(mats)
1014+
mats <- flatten_list(lapply(converted, `[[`, "mats"))
1015+
group_labels <- unlist(lapply(converted, `[[`, "labels"))
1016+
1017+
# Replace NA labels (from unnamed entries) with numeric indices based on
1018+
# their position, incrementing to avoid collisions with existing labels.
1019+
na_idx <- which(is.na(group_labels))
1020+
if (length(na_idx) > 0) {
1021+
existing <- group_labels[!is.na(group_labels)]
1022+
for (i in na_idx) {
1023+
candidate <- i
1024+
while (as.character(candidate) %in% existing) {
1025+
candidate <- candidate + 1L
1026+
}
1027+
group_labels[i] <- as.character(candidate)
1028+
existing <- c(existing, as.character(candidate))
1029+
}
1030+
}
10041031

10051032
# new row and column positions are based on current, incrementing by number of
10061033
# rows/columns in all previous matrices.
@@ -1052,10 +1079,10 @@ dbind <- function(..., force_unique_names = FALSE) {
10521079
newdim <- as.integer(c(sum(vapply(lapply(mats, methods::slot, "dimension"), "[", 1, 1)),
10531080
sum(vapply(lapply(mats, methods::slot, "dimension"), "[", 1, 2))))
10541081

1055-
# This needs to be much smarter, especially if any element is already a BISM
1056-
groups <- as.factor(rep(seq_along(mats), times =
1082+
groups <- factor(rep(group_labels, times =
10571083
vapply(lapply(mats, slot, "colnames"), length, 1) +
1058-
vapply(lapply(mats, slot, "rownames"), length, 1)))
1084+
vapply(lapply(mats, slot, "rownames"), length, 1)),
1085+
levels = unique(group_labels))
10591086
names(groups) <- do.call(c, Map(c, cnameslist, rnameslist))
10601087

10611088
newdata <- do.call(c, mats)

tests/testthat/test.InfinitySparseMatrix.R

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,46 @@ test_that("dbind", {
955955
expect_identical(bmix1, bmix3)
956956
})
957957

958+
test_that("dbind uses meaningful group names", {
959+
data(nuclearplants)
960+
np <- nuclearplants
961+
np$group <- as.numeric(cut(np$cap, breaks = c(0, 600, 825, 1000, 2000)))
962+
963+
m1 <- match_on(pr ~ cost, data = np[np$group == 1,])
964+
m2 <- match_on(pr ~ cost, data = np[np$group == 2,])
965+
m3 <- match_on(pr ~ cost, data = np[np$group == 3,])
966+
m4 <- match_on(pr ~ cost, data = np[np$group == 4,])
967+
968+
# Named list of ISMs -> group levels match the list names
969+
bm_named <- dbind(first = m1, second = m2)
970+
expect_identical(levels(bm_named@groups), c("first", "second"))
971+
972+
# Unnamed list -> levels are "1", "2", ... (backward compat)
973+
bm_unnamed <- dbind(m1, m2)
974+
expect_identical(levels(bm_unnamed@groups), c("1", "2"))
975+
976+
# Mixed named list with a BISM entry -> ISM entries get list names,
977+
# BISM entries get their original @groups levels
978+
b1 <- match_on(pr ~ cost + strata(group), data = np[np$group < 3,])
979+
bm_mixed <- dbind(b1, extra = m3)
980+
bism_levels <- levels(b1@groups)
981+
expect_identical(levels(bm_mixed@groups), c(bism_levels, "extra"))
982+
983+
# Partially named list -> named entries use names, unnamed fall back to index
984+
bm_partial <- dbind(a = m1, m2, m3)
985+
expect_identical(levels(bm_partial@groups), c("a", "2", "3"))
986+
987+
# Unnamed fallback indices skip names that collide with existing labels
988+
bm_collision <- dbind(m1, b1)
989+
# m1 at position 1 collides with b1's groups "1"/"2", so gets "3"
990+
# Levels preserve input order: m1's group first, then b1's groups
991+
expect_identical(levels(bm_collision@groups), c("3", bism_levels))
992+
993+
# Named list passed as single argument
994+
bm_named_list <- dbind(list(x = m1, y = m2))
995+
expect_identical(levels(bm_named_list@groups), c("x", "y"))
996+
})
997+
958998
test_that("dbind'ing a very large number of matrices", {
959999
data(nuclearplants)
9601000

0 commit comments

Comments
 (0)