Skip to content

Commit

Permalink
Support distmat as lower triangular for distances with proxy loops
Browse files Browse the repository at this point in the history
  • Loading branch information
asardaes committed Jul 2, 2024
1 parent baeaefe commit 8c04727
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 24 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Collate:
'DISTANCES-sdtw.R'
'GENERICS-cvi.R'
'S4-Distmat.R'
'S4-DistmatLowerTriangular.R'
'S4-PairTracker.R'
'S4-SparseDistmat.R'
'S4-tsclustFamily.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ S3method(as.data.frame,pairdist)
S3method(as.matrix,crossdist)
S3method(as.matrix,pairdist)
S3method(base::dim,Distmat)
S3method(base::dim,DistmatLowerTriangular)
S3method(base::dim,SparseDistmat)
S3method(cl_class_ids,TSClusters)
S3method(cl_membership,TSClusters)
Expand Down Expand Up @@ -105,6 +106,7 @@ importFrom(ggplot2,theme_bw)
importFrom(ggrepel,geom_label_repel)
importFrom(graphics,plot)
importFrom(methods,S3Part)
importFrom(methods,as)
importFrom(methods,callNextMethod)
importFrom(methods,initialize)
importFrom(methods,is)
Expand Down
24 changes: 16 additions & 8 deletions R/CLUSTERING-ddist2.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ parallel_symmetric <- function(d_desc, ids, x, distance, dots) {
#' @importFrom proxy dist
#' @importFrom proxy pr_DB
#'
ddist2 <- function(distance, control) {
ddist2 <- function(distance, control, lower_triangular_only = FALSE) {
# I need to re-register any custom distances in each parallel worker
dist_entry <- proxy::pr_DB$get_entry(distance)
symmetric <- isTRUE(control$symmetric)
Expand Down Expand Up @@ -175,20 +175,23 @@ ddist2 <- function(distance, control) {
return(ret(use_distmat(control$distmat, x, centroids)))
}

dots <- get_dots(dist_entry, x, centroids, ...)
dots <- get_dots(dist_entry, x, centroids, ..., lower_triangular_only = lower_triangular_only)

if (!dist_entry$loop) {
# CUSTOM LOOP, LET THEM HANDLE OPTIMIZATIONS
dm <- base::as.matrix(quoted_call(
dm <- quoted_call(
proxy::dist, x = x, y = centroids, method = distance, dots = dots
))
)

if (isTRUE(dots$pairwise)) {
dim(dm) <- NULL
return(ret(dm, class = "pairdist"))
}
else if (lower_triangular_only && inherits(dm, "dist")) {
return(ret(dm, class = "dist", Size = length(x)))
}
else {
return(ret(dm, class = "crossdist"))
return(ret(base::as.matrix(dm), class = "crossdist"))
}
}

Expand Down Expand Up @@ -237,11 +240,16 @@ ddist2 <- function(distance, control) {
}
else if (!multiple_workers) {
# WHOLE SYMMETRIC DISTMAT WITHOUT CUSTOM LOOP OR USING SEQUENTIAL proxy LOOP
dm <- base::as.matrix(quoted_call(
dm <- quoted_call(
proxy::dist, x = x, y = NULL, method = distance, dots = dots
))
)

return(ret(dm, class = "crossdist"))
if (lower_triangular_only && inherits(dm, "dist")) {
return(ret(dm, class = "dist", Size = length(x)))
}
else {
return(ret(base::as.matrix(dm), class = "crossdist"))
}
}
}

Expand Down
47 changes: 35 additions & 12 deletions R/CLUSTERING-tsclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,53 @@
# ==================================================================================================

# Get an appropriate distance matrix object for internal use with PAM/FCMdd centroids
#' @importFrom methods as
pam_distmat <- function(series, control, distance, cent_char, family, args, trace) {
distmat <- control$distmat
distmat_provided <- FALSE

if (!is.null(distmat)) {
if (nrow(distmat) != length(series) || ncol(distmat) != length(series))
stop("Dimensions of provided cross-distance matrix don't correspond ",
"to length of provided data")
# see S4-Distmat.R
if (!inherits(distmat, "Distmat")) distmat <- Distmat$new(distmat = distmat)
if (inherits(distmat, "dist")) {
n <- attr(distmat, "Size")
if (n != length(series))
stop("Dimensions of provided cross-distance matrix don't correspond ",
"to length of provided data")

# see S4-Distmat.R
if (!inherits(distmat, "Distmat"))
distmat <- DistmatLowerTriangular$new(distmat = distmat)
}
else {
if (nrow(distmat) != length(series) || ncol(distmat) != length(series))
stop("Dimensions of provided cross-distance matrix don't correspond ",
"to length of provided data")

# see S4-Distmat.R
if (!inherits(distmat, "Distmat")) distmat <- Distmat$new(distmat = distmat)
}

distmat_provided <- TRUE
if (trace) cat("\n\tDistance matrix provided...\n\n") # nocov
}
else if (isTRUE(control$pam.precompute) || cent_char == "fcmdd") {
if (distance == "dtw_lb")
warning("Using dtw_lb with control$pam.precompute = TRUE is not advised.") # nocov
if (trace) cat("\n\tPrecomputing distance matrix...\n\n")
# see S4-Distmat.R
distmat <- Distmat$new(distmat = quoted_call(
family@dist,
x = series,
centroids = NULL,
dots = args$dist
))

if (control$symmetric) {
distfun <- ddist2(distance, control, lower_triangular_only = TRUE)
distmat <- methods::as(quoted_call(distfun, x = series, centroids = NULL, dots = args$dist),
"Distmat")
}
else {
# see S4-Distmat.R
distmat <- Distmat$new(distmat = quoted_call(
family@dist,
x = series,
centroids = NULL,
dots = args$dist
))
}
}
else {
if (isTRUE(control$pam.sparse) && distance != "dtw_lb") {
Expand Down
129 changes: 129 additions & 0 deletions R/S4-DistmatLowerTriangular.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
lower_triangular_index <- function(i, j, n, diagonal) {
stopifnot(i > 0L, i <= n, j > 0L, j <= i)

i <- i - 1L
j <- j - 1L

adjustment <- if (diagonal) {
0L
}
else {
Reduce(x = 0L:j, init = 0L, f = function(a, b) { a + b + 1L })
}

i + j * n - adjustment + 1L
}

#' Distance matrix's lower triangular
#'
#' Reference class that is used internally for PAM centroids when `pam.precompute = TRUE` and
#' `pam.sparse = FALSE`. It contains [Distmat-class].
#'
#' @include S4-Distmat.R
#' @importFrom methods setRefClass
#'
#' @field distmat The lower triangular.
#'
#' @keywords internal
#'
DistmatLowerTriangular <- methods::setRefClass(
"DistmatLowerTriangular",
contains = "Distmat",
fields = list(
distmat = "ANY"
),
methods = list(
initialize = function(..., distmat) {
"Initialization based on needed parameters"

if (missing(distmat)) {
stop("distmat must be provided for this class.")
}
else if (!inherits(distmat, "dist")) {
stop("distmat must be a 'dist' object.")
}

callSuper(..., distmat = distmat)
# return
invisible(NULL)
}
)
)

#' Generics for `DistmatLowerTriangular`
#'
#' Generics with methods for [DistmatLowerTriangular-class].
#'
#' @name DistmatLowerTriangular-generics
#' @rdname DistmatLowerTriangular-generics
#' @keywords internal
#' @importFrom methods setMethod
#'
NULL

#' @rdname DistmatLowerTriangular-generics
#' @aliases show,DistmatLowerTriangular
#' @importFrom methods show
#'
#' @param object A [DistmatLowerTriangular-class] object.
#'
setMethod("show", "DistmatLowerTriangular", function(object) { methods::show(object$distmat) }) # nocov

#' @rdname DistmatLowerTriangular-generics
#' @aliases [,DistmatLowerTriangular,ANY,ANY,ANY
#'
#' @param x A [DistmatLowerTriangular-class] object.
#' @param i Row indices.
#' @param j Column indices.
#' @param ... Ignored.
#'
setMethod(`[`, "DistmatLowerTriangular", function(x, i, j, ...) {
if (missing(j)) {
stopifnot(inherits(i, "matrix"), ncol(i) == 2L)
j <- i[, 2L]
i <- i[, 1L]
drop <- TRUE
}
else {
out_dim <- c(length(i), length(j))
out_dimnames <- list(i, j)
combinations <- expand.grid(i = i, j = j)
i <- combinations$i
j <- combinations$j
drop <- FALSE

}

n <- attr(x$distmat, "Size")
diagonal <- isTRUE(attr(x$distmat, "Diag"))
entries <- mapply(i, j, FUN = function(i, j) {
if (!diagonal && i == j) {
0
}
else if (j > i) {
x$distmat[lower_triangular_index(j, i, n, diagonal)]
}
else {
x$distmat[lower_triangular_index(i, j, n, diagonal)]
}
})

if (drop) {
entries
}
else {
dim(entries) <- out_dim
dimnames(entries) <- out_dimnames
entries
}
})

#' @exportS3Method base::dim
dim.DistmatLowerTriangular <- function(x) { rep(attr(x$distmat, "Size"), 2L) } # nocov

methods::setOldClass("dist")
methods::setOldClass("crossdist")

setAs("matrix", "Distmat", function(from) { Distmat$new(distmat = from) })
setAs("crossdist", "Distmat", function(from) { Distmat$new(distmat = from) })
setAs("dist", "Distmat", function(from) { DistmatLowerTriangular$new(distmat = from) })
12 changes: 8 additions & 4 deletions R/S4-tsclustFamily.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,14 @@ setMethod("initialize", "tsclustFamily",
if (!missing(allcent)) {
if (is.character(allcent)) {
if (allcent %in% c("pam", "fcmdd")) {
if (!is.null(control$distmat) && !inherits(control$distmat, "Distmat"))
control$distmat <- Distmat$new( # see S4-Distmat.R
distmat = base::as.matrix(control$distmat)
)
if (!is.null(control$distmat) && !inherits(control$distmat, "Distmat")) {
control$distmat <- if (inherits(control$distmat, "dist")) {
DistmatLowerTriangular$new(distmat = control$distmat)
}
else {
Distmat$new(distmat = control$distmat)
}
}
}
allcent <- all_cent2(allcent, control)
}
Expand Down
24 changes: 24 additions & 0 deletions man/DistmatLowerTriangular-class.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions man/DistmatLowerTriangular-generics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions tests/testthat/integration/custom-dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ test_that("Calling tsclust after registering a custom distance works as expected
## just for expect below
pc_ndtw@control$symmetric <- TRUE
pc_ndtw@call <- pc_ndtw_sym@call <- as.call(list("foo", bar = 1))
pc_ndtw@control$distmat <- pc_ndtw@distmat <- as.matrix(pc_ndtw@distmat)
pc_ndtw_sym@control$distmat <- pc_ndtw_sym@distmat <- as.matrix(pc_ndtw_sym@distmat)

expect_identical(pc_ndtw, pc_ndtw_sym)

Expand Down
Binary file modified tests/testthat/rds/pc_ndtw.rds
Binary file not shown.
Binary file modified tests/testthat/rds/pc_ndtw_sym.rds
Binary file not shown.

0 comments on commit 8c04727

Please sign in to comment.