Skip to content

Commit 036f04b

Browse files
authored
Merge pull request #257 from markmfredrickson/issue-128-rank-mahal-index
Issue 128 rank mahalanobis
2 parents 64d290b + 4dea0d5 commit 036f04b

9 files changed

Lines changed: 173 additions & 40 deletions

File tree

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ Suggests:
3939
pander,
4040
xtable,
4141
rrelaxiv,
42-
magrittr
42+
magrittr,
43+
MASS
4344
Enhances:
4445
CBPS,
4546
haven

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ importFrom(stats,terms)
164164
importFrom(stats,terms.formula)
165165
importFrom(stats,update)
166166
importFrom(stats,update.formula)
167+
importFrom(stats,var)
167168
importFrom(tibble,as_tibble)
168169
importFrom(tibble,enframe)
169170
importFrom(tibble,tibble)

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
- `optmatch:::scoreCaliper()` gains an optional
44
`within=` argument (#245)
5+
- `match_on(z~x, method="rank_mahalanobis", within=foo)` makes
6+
better use of restrictions in foo to contain time/memory costs.
7+
- For rank Mahalanobis with underlying covariances pooled across
8+
treatment and control groups, you can now use
9+
`match_on(z~x, method="pooled_cov_rank_mahalanobis")`.
510
- Updates to internal C++ code
611

712
## Changes in **optmatch** Version 0.10.7

R/match_on.R

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,14 @@ match_on.bigglm <- function(x,
274274
#' redundancies among the variables by scaling down variables contributions in
275275
#' proportion to their correlations with other included variables.)
276276
#'
277-
#' Euclidean distance is also available, via \code{method="euclidean"}, and
278-
#' ranked, Mahalanobis distance, via \code{method="rank_mahalanobis"}.
277+
#' Euclidean distance is also available, via \code{method="euclidean"}, as
278+
#' are two flavors of ranked-based Mahalanobis distance, via
279+
#' \code{method="rank_mahalanobis"} or \code{method="pooled_cov_rank_mahalanobis"}.
280+
#' Either rank-transforms the covariates first; they differ in whether
281+
#' subsequent covariance of thus-transformed covariates is calculated
282+
#' on all subjects or by pooling of with-group covariances across
283+
#' treatment and control. The \code{method=} argument can be abbreviated
284+
#' in the usual way (via [base::pmatch()]).
279285
#'
280286
#' The treatment indicator \code{Z} as noted above must either be numeric
281287
#' (1 representing treated units and 0 control units) or logical
@@ -404,15 +410,19 @@ match_on.formula <- function(x,
404410
methodname <- as.character(class(method))
405411
}
406412

407-
which.method <- pmatch(methodname, c("mahalanobis", "euclidean", "rank_mahalanobis", "function"), 4)
413+
which.method <- pmatch(methodname,
414+
c("mahalanobis", "euclidean",
415+
"rank_mahalanobis", "pooled_cov_rank_mahalanobis",
416+
"function"), 5)
408417
tmp <- switch(which.method,
409-
makedist(z, data, compute_mahalanobis, within),
410-
makedist(z, data, compute_euclidean, within),
411-
makedist(z, data, compute_rank_mahalanobis, within),
412-
{
413-
warning("Passing a user-defined `method` to `match_on.formula` is not supported and results are not guaranteed. User-defined distances should use `match_on.function` instead.")
414-
makedist(z, data, match.fun(method), within)
415-
}
418+
makedist(z, data, compute_mahalanobis, within),
419+
makedist(z, data, compute_euclidean, within),
420+
makedist(z, data, compute_rank_mahalanobis, within),
421+
makedist(z, data, compute_pooled_cov_rank_mahalanobis, within),
422+
{
423+
warning("Passing a user-defined `method` to `match_on.formula` is not supported and results are not guaranteed. User-defined distances should use `match_on.function` instead.")
424+
makedist(z, data, match.fun(method), within)
425+
}
416426
)
417427
rm(mf)
418428

@@ -509,19 +519,7 @@ compute_mahalanobis <- function(index, data, z) {
509519
cv <- mt + mc
510520
rm(mt, mc)
511521

512-
inv.scale.matrix <- try(solve(cv), silent = TRUE)
513-
514-
if (inherits(inv.scale.matrix,"try-error")) {
515-
dnx <- dimnames(cv)
516-
s <- svd(cv)
517-
nz <- (s$d > sqrt(.Machine$double.eps) * s$d[1])
518-
if (!any(nz)) stop("covariance has rank zero")
519-
520-
inv.scale.matrix <- s$v[, nz] %*% (t(s$u[, nz])/s$d[nz])
521-
dimnames(inv.scale.matrix) <- dnx[2:1]
522-
rm(dnx, s, nz)
523-
}
524-
522+
inv.scale.matrix <- safe_invert(cv)
525523
rm(cv)
526524

527525
return(mahalanobisHelper(data, index, inv.scale.matrix))
@@ -548,19 +546,50 @@ compute_rank_mahalanobis <- function(index, data, z) {
548546
if (is.null(rownames(data)) | !all(index %in% rownames(data)))
549547
stop("data must have row names matching index")
550548

551-
# begin workaround solution to #128
552-
all_treated <- rownames(data)[as.logical(z)]
553-
all_control <- rownames(data)[!z]
554-
all_indices <- expand.grid(all_treated, all_control,
555-
KEEP.OUT.ATTRS = FALSE, stringsAsFactors = FALSE)
556-
all_indices <- paste(all_indices[[1]], all_indices[[2]], sep="%@%")
557-
short_indices <- paste(index[,1], index[,2], sep="%@%")
558-
indices <- match(short_indices, all_indices)
559-
if (any(is.na(indices))) stop("Unanticipated problem. (Make sure row names of data don't use the string '%@%'.)")
560-
# Now, since `r_smahal` is ignoring its `index` argument anyway:
561-
rankdists <- sqrt(r_smahal(NULL, data, z))
562-
rankdists <- rankdists[indices]
563-
return(rankdists)
549+
data <- apply(data, 2, rank)
550+
n <- nrow(data)
551+
m <- cov(data)
552+
cv <- scale_addressing_ties(nrow(data), cov(data))
553+
inv.scale.matrix <- safe_invert(cv)
554+
rm(cv)
555+
556+
return(mahalanobisHelper(data, index, inv.scale.matrix))
557+
}
558+
559+
compute_pooled_cov_rank_mahalanobis <- function(index, data, z) {
560+
if (!all(is.finite(data)))
561+
stop("Infinite or NA values detected in data for Mahalanobis computations.")
562+
563+
if (is.null(rownames(data)) | !all(index %in% rownames(data)))
564+
stop("data must have row names matching index")
565+
566+
data <- apply(data, 2, rank)
567+
568+
if (sum(z) == 1) {
569+
mt <- 0 # Addressing #168
570+
} else {
571+
treated <- data[z, ,drop = FALSE]
572+
nt <- nrow(treated)
573+
mt <- scale_addressing_ties(nt, cov(treated))
574+
mt <- mt * (sum(z) - 1) / (length(z) - 2)
575+
}
576+
577+
if (sum(!z) == 1) {
578+
mc <- 0 # Addressing #168
579+
} else {
580+
control <- data[!z, ,drop = FALSE]
581+
nc <- nrow(control)
582+
mc <- scale_addressing_ties(nc, cov(control))
583+
mc <- mc * (sum(!z) - 1) / (length(!z) - 2)
584+
}
585+
586+
cv <- mt + mc
587+
rm(mt, mc)
588+
589+
inv.scale.matrix <- safe_invert(cv)
590+
rm(cv)
591+
592+
return(mahalanobisHelper(data, index, inv.scale.matrix))
564593
}
565594

566595
#' @details \bold{First argument (\code{x}): \code{function}.} The passed function

R/utilities.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,30 @@ missing_x_msg <- function(x_str, data_str, ...) {
138138
paste(data_str, "$", x_str, sep=""),
139139
msg_tail)
140140
}
141+
142+
#' @importFrom stats var
143+
scale_addressing_ties <- function(n, cv) {
144+
vuntied <- var(1:n)
145+
rat <- sqrt(vuntied/diag(cv))
146+
if (length(rat) > 1) {
147+
diag_rat <- diag(rat)
148+
} else {
149+
diag_rat <- as.matrix(rat)
150+
}
151+
return(diag_rat %*% cv %*% diag_rat)
152+
}
153+
154+
safe_invert <- function(x) {
155+
inv.scale.matrix <- try(solve(x), silent = TRUE)
156+
157+
if (inherits(inv.scale.matrix,"try-error")) {
158+
dnx <- dimnames(x)
159+
s <- svd(x)
160+
nz <- (s$d > sqrt(.Machine$double.eps) * s$d[1])
161+
if (!any(nz)) stop("covariance has rank zero")
162+
163+
inv.scale.matrix <- s$v[, nz] %*% (t(s$u[, nz])/s$d[nz])
164+
dimnames(inv.scale.matrix) <- dnx[2:1]
165+
}
166+
return(inv.scale.matrix)
167+
}

man/match_on-methods.Rd

Lines changed: 8 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test.match_on.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ test_that("Issue 87: NA's in data => unmatchable, but retained, units in distanc
174174
expect_equivalent(f("mahalanobis"), expectedM)
175175
expect_equivalent(f("euclid"), expectedM)
176176
expect_equivalent(f("rank_mahal"), expectedM)
177+
expect_equivalent(f("pooled_cov"), expectedM)
177178

178179
cal1 <- caliper(match_on(z~x1, data=d), width=1e3)
179180
expect_equivalent(g(as.matrix(match_on(z ~ x1 + x2, data = d,

tests/testthat/test.rank.mahal.R

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ test_that("Fix for #128 (`compute_rank_mahalanobis` ignores index argument) hold
111111

112112
reference_rankmahal <- compute_smahal(z, X)
113113

114-
indices <- expand.grid(rownames(reference_rankmahal), colnames(reference_rankmahal))
114+
indices <- expand.grid(rownames(reference_rankmahal), colnames(reference_rankmahal))
115115
indices <- as.matrix(indices)
116116
expect_equivalent(optmatch:::compute_rank_mahalanobis(indices, X, as.logical(z)),
117117
reference_rankmahal[1L:numdists])
@@ -125,3 +125,42 @@ test_that("Fix for #128 (`compute_rank_mahalanobis` ignores index argument) hold
125125

126126

127127
})
128+
129+
is_scalar_multiple <- function(A, B) {
130+
# Ensure dimensions match
131+
if (!all(dim(A) == dim(B))) {
132+
return(FALSE)
133+
}
134+
135+
# Find positions where B is non-zero to avoid division by zero
136+
non_zero_positions <- B != 0
137+
138+
# Compute element-wise ratio where B != 0
139+
ratios <- A[non_zero_positions] / B[non_zero_positions]
140+
141+
# Check if all ratios are (approximately) equal
142+
return(all(abs(ratios - ratios[1]) < 1e-8))
143+
}
144+
145+
test_that("compute_pooled_cov_rank_mahalanobis results match ordinary Mahalanobis's", {
146+
## nr number of samples
147+
nr <- 10L
148+
z <- integer(nr)
149+
## two outcomes: 0 (from initialization), and 1 (assigned below randomly)
150+
z[sample(1:nr, nr / 2L)] <- 1L
151+
152+
## Goal: two groups with the same within-group variance and no rank ties
153+
df <- data.frame(z = z, X = integer(nr))
154+
df[df$z == 0, 'X'] <- seq(1, by=2, len = nr / 2) # odds
155+
df[df$z == 1, 'X'] <- seq(2, by=2, len = nr / 2) # evens
156+
157+
A <- match_on(z~., data=df, method="pooled_cov")
158+
B <- match_on(z~., data=df, method="mahalanobis")
159+
# Check if all ratios are (approximately) equal
160+
expect_true(is_scalar_multiple(A, B))
161+
162+
ez <- exactMatch(z~., data=df)
163+
A <- match_on(z~., data=df, method="pooled_cov", within=ez)
164+
B <- match_on(z~., data=df, method="mahalanobis", within=ez)
165+
expect_true(is_scalar_multiple(A, B))
166+
})

tests/testthat/test.utilities.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Tests for utility functions
33
################################################################################
44

5+
library(MASS)
6+
57
context("Utility Functions")
68

79
test_that("toZ", {
@@ -60,3 +62,25 @@ test_that("#159 - toZ for labelled", {
6062
expect_true(TRUE) # avoiding empty test warning
6163
}
6264
})
65+
66+
test_that("scale_addressing_ties", {
67+
x <- cbind(sample(1:5, 10, replace=TRUE),
68+
sample(1:5, 10, replace=TRUE))
69+
y <- scale_addressing_ties(nrow(x), cov(x))
70+
dy <- diag(y)
71+
expect_equal(dy, rep(var(1:nrow(x)), length(dy)))
72+
})
73+
74+
test_that("safe_invert", {
75+
## full rank symmetric square matrix
76+
A <- matrix(runif(25), 5, 5)
77+
symmetric_matrix <- A %*% t(A)
78+
inv_A <- safe_invert(A)
79+
expect_equal(inv_A, solve(A))
80+
81+
## rank deficient symmetric square matrix
82+
B <- matrix(runif(15), 5, 3)
83+
symmetric_matrix <- B %*% t(B)
84+
inv_B <- safe_invert(B)
85+
expect_equal(inv_B, ginv(B))
86+
})

0 commit comments

Comments
 (0)