Skip to content

Commit

Permalink
Rework logic to return dist/crossdist structures
Browse files Browse the repository at this point in the history
  • Loading branch information
asardaes committed Jul 2, 2024
1 parent 0eaed23 commit f5da272
Show file tree
Hide file tree
Showing 39 changed files with 202 additions and 203 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Changelog

## Version 5.6.0
## Version 6.0.0
* Update Makevars for ARM version of Windows.
* Sanitize internal usage of `do.call` to avoid huge backtraces.
* Support lower triangular `distmat` objects for symmetric distances (#77).
Expand Down
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ S3method(as.data.frame,crossdist)
S3method(as.data.frame,pairdist)
S3method(as.matrix,crossdist)
S3method(as.matrix,pairdist)
S3method(base::as.matrix,distdiag)
S3method(base::dim,Distmat)
S3method(base::dim,DistmatLowerTriangular)
S3method(base::dim,SparseDistmat)
Expand Down
24 changes: 11 additions & 13 deletions R/CENTROIDS-pam.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#' Extract the medoid time series based on a distance measure.
#'
#' @export
#' @importFrom methods as
#' @importFrom rlang exprs
#' @importFrom Matrix rowSums
#'
Expand Down Expand Up @@ -43,24 +44,21 @@ pam_cent <- function(series, distance, ids = seq_along(series), distmat = NULL,
if (missing(distance))
distance <- attr(distmat, "method")

args <- rlang::exprs(
distmat = distmat,
series = series,
dist_args = dots,
distance = distance,
control = partitional_control(),
error.check = error.check
)

if (is.null(distmat)) {
if (is.null(distance))
stop("If 'distmat' is missing, 'distance' must be provided.")

args$distmat <- NULL
distmat <- Distmat$new(
series = series,
dist_args = dots,
distance = distance,
control = partitional_control(),
error.check = error.check
)
}
else {
distmat <- methods::as(distmat, "Distmat")
}

# S4-Distmat.R
distmat <- do.call(Distmat$new, args)
}

d <- distmat[ids, ids, drop = FALSE]
Expand Down
12 changes: 6 additions & 6 deletions R/CLUSTERING-ddist2.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ parallel_symmetric <- function(d_desc, ids, x, distance, dots) {
#' @importFrom proxy dist
#' @importFrom proxy pr_DB
#'
ddist2 <- function(distance, control, lower_triangular_only = FALSE) {
ddist2 <- function(distance, control) {
# I need to re-register any custom distances in each parallel worker
dist_entry <- proxy::pr_DB$get_entry(distance)
symmetric <- isTRUE(control$symmetric)
Expand Down Expand Up @@ -175,7 +175,7 @@ ddist2 <- function(distance, control, lower_triangular_only = FALSE) {
return(ret(use_distmat(control$distmat, x, centroids)))
}

dots <- get_dots(dist_entry, x, centroids, ..., lower_triangular_only = lower_triangular_only)
dots <- get_dots(dist_entry, x, centroids, ...)

if (!dist_entry$loop) {
# CUSTOM LOOP, LET THEM HANDLE OPTIMIZATIONS
Expand All @@ -187,8 +187,8 @@ ddist2 <- function(distance, control, lower_triangular_only = FALSE) {
dim(dm) <- NULL
return(ret(dm, class = "pairdist"))
}
else if (lower_triangular_only && inherits(dm, "dist")) {
return(ret(dm, class = "dist"))
else if (inherits(dm, "dist")) {
return(ret(dm))
}
else {
return(ret(base::as.matrix(dm), class = "crossdist"))
Expand Down Expand Up @@ -244,8 +244,8 @@ ddist2 <- function(distance, control, lower_triangular_only = FALSE) {
proxy::dist, x = x, y = NULL, method = distance, dots = dots
)

if (lower_triangular_only && inherits(dm, "dist")) {
return(ret(dm, class = "dist"))
if (inherits(dm, "dist")) {
return(ret(dm))
}
else {
return(ret(base::as.matrix(dm), class = "crossdist"))
Expand Down
42 changes: 13 additions & 29 deletions R/CLUSTERING-tsclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,25 @@ pam_distmat <- function(series, control, distance, cent_char, family, args, trac
distmat_provided <- FALSE

if (!is.null(distmat)) {
if (inherits(distmat, "dist")) {
if (inherits(distmat, "Distmat")) {
stop("Can this happen?")
}
else if (inherits(distmat, "dist")) {
n <- attr(distmat, "Size")
if (n != length(series))
if (is.null(n) || n != length(series))
stop("Dimensions of provided cross-distance matrix don't correspond ",
"to length of provided data")

# see S4-Distmat.R
if (!inherits(distmat, "Distmat"))
distmat <- DistmatLowerTriangular$new(distmat = distmat)
distmat <- DistmatLowerTriangular$new(distmat = distmat)
}
else {
if (nrow(distmat) != length(series) || ncol(distmat) != length(series))
stop("Dimensions of provided cross-distance matrix don't correspond ",
"to length of provided data")

# see S4-Distmat.R
if (!inherits(distmat, "Distmat")) distmat <- Distmat$new(distmat = distmat)
distmat <- Distmat$new(distmat = distmat)
}

distmat_provided <- TRUE
Expand All @@ -36,20 +38,10 @@ pam_distmat <- function(series, control, distance, cent_char, family, args, trac
warning("Using dtw_lb with control$pam.precompute = TRUE is not advised.") # nocov
if (trace) cat("\n\tPrecomputing distance matrix...\n\n")

if (control$symmetric) {
distfun <- ddist2(distance, control, lower_triangular_only = cent_char != "fcmdd")
distmat <- methods::as(quoted_call(distfun, x = series, centroids = NULL, dots = args$dist),
"Distmat")
}
else {
# see S4-Distmat.R
distmat <- Distmat$new(distmat = quoted_call(
family@dist,
x = series,
centroids = NULL,
dots = args$dist
))
}
distfun <- if (distance == "sdtw") sdtw_wrapper else ddist2(distance, control)
centroids <- if (cent_char == "fcmdd") series else NULL
distmat <- methods::as(quoted_call(distfun, x = series, centroids = centroids, dots = args$dist),
"Distmat")
}
else {
if (isTRUE(control$pam.sparse) && distance != "dtw_lb") {
Expand Down Expand Up @@ -427,9 +419,6 @@ tsclust <- function(series = NULL, type = "partitional", k = 2L, ...,

# precompute distance matrix?
if (cent_char %in% c("pam", "fcmdd")) {
if (distance == "sdtw" && control$pam.precompute) {
args$dist$diag <- TRUE
}
dm <- pam_distmat(series, control, distance, cent_char, family, args, trace)
distmat <- dm$distmat
distmat_provided <- dm$distmat_provided
Expand Down Expand Up @@ -626,13 +615,8 @@ tsclust <- function(series = NULL, type = "partitional", k = 2L, ...,
if (trace) cat("\nCalculating distance matrix...\n")
# Take advantage of the function I defined for the partitional methods
# Which can do calculations in parallel if appropriate
distfun <- ddist2(distance = distance, control = control, control$symmetric)
dist_dots <- if ("sdtw" %in% proxy::pr_DB$get_entry(distance)$names) {
c(args$dist, list(diagonal = FALSE, .internal_ = TRUE))
} else {
args$dist
}
distmat <- quoted_call(distfun, x = series, centroids = NULL, dots = dist_dots)
distfun <- ddist2(distance = distance, control = control)
distmat <- quoted_call(distfun, x = series, centroids = NULL, dots = args$dist)
}

# --------------------------------------------------------------------------------------
Expand Down
11 changes: 4 additions & 7 deletions R/DISTANCES-dtw-basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ dtw_basic <- function(x, y, window.size = NULL, norm = "L1",
dtw_basic_proxy <- function(x, y = NULL, window.size = NULL, norm = "L1",
step.pattern = dtw::symmetric2,
normalize = FALSE, sqrt.dist = TRUE, ...,
error.check = TRUE, pairwise = FALSE,
lower_triangular_only = FALSE, diagonal = TRUE)
error.check = TRUE, pairwise = FALSE)
{
x <- tslist(x)
if (error.check) check_consistency(x, "vltslist")
Expand Down Expand Up @@ -177,19 +176,17 @@ dtw_basic_proxy <- function(x, y = NULL, window.size = NULL, norm = "L1",
dim(D) <- NULL
class(D) <- "pairdist"
}
else if (lower_triangular_only) {
else if (symmetric) {
dim(D) <- NULL
class(D) <- c("distdiag", "dist")
class(D) <- "dist"
attr(D, "Size") <- length(x)
attr(D, "Diag") <- diagonal
attr(D, "Upper") <- FALSE
attr(D, "Labels") <- names(x)
}
else {
dimnames(D) <- dim_names
class(D) <- "crossdist"
}

attr(D, "method") <- "DTW_BASIC"
# return
D
}
32 changes: 15 additions & 17 deletions R/DISTANCES-gak.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ gak <- GAK
# ==================================================================================================

gak_proxy <- function(x, y = NULL, ..., sigma = NULL, window.size = NULL, normalize = TRUE,
error.check = TRUE, pairwise = FALSE, .internal_ = FALSE,
lower_triangular_only = FALSE, diagonal = TRUE)
error.check = TRUE, pairwise = FALSE, .internal_ = FALSE)
{
# normalization will be done manually to avoid multiple calculations of gak_x and gak_y
if (!.internal_ && !normalize) { # nocov start
Expand All @@ -171,6 +170,7 @@ gak_proxy <- function(x, y = NULL, ..., sigma = NULL, window.size = NULL, normal

x <- tslist(x)
if (error.check) check_consistency(x, "vltslist")

if (is.null(y)) {
symmetric <- normalize
y <- x
Expand All @@ -180,6 +180,7 @@ gak_proxy <- function(x, y = NULL, ..., sigma = NULL, window.size = NULL, normal
y <- tslist(y)
if (error.check) check_consistency(y, "vltslist")
}

if (is.null(sigma))
sigma <- estimate_sigma(x, y, TRUE)
else if (sigma <= 0)
Expand Down Expand Up @@ -230,26 +231,24 @@ gak_proxy <- function(x, y = NULL, ..., sigma = NULL, window.size = NULL, normal
if (normalize) D <- 1 - exp(D - 0.5 * (gak_x + gak_y))
class(D) <- "pairdist"
}
else if (lower_triangular_only) {
else if (symmetric) {
dim(D) <- NULL
if (normalize) {
j_upper <- if (diagonal) length(x) else length(x) - 1L
i_lower <- if (diagonal) 0L else 1L
k <- 1L
for (j in 1L:j_upper) {
for (i in (j + i_lower):length(x)) {
if (i != j) {
D[k] <- 1 - exp(D[k] - (gak_x[i] + gak_x[j]) / 2)
}
k <- k + 1L

# normalize
j_upper <- length(x) - 1L
i_lower <- 1L
k <- 1L
for (j in 1L:j_upper) {
for (i in (j + i_lower):length(x)) {
if (i != j) {
D[k] <- 1 - exp(D[k] - (gak_x[i] + gak_x[j]) / 2)
}
k <- k + 1L
}
}

class(D) <- c("distdiag", "dist")
class(D) <- "dist"
attr(D, "Size") <- length(x)
attr(D, "Diag") <- diagonal
attr(D, "Upper") <- FALSE
attr(D, "Labels") <- names(x)
}
else {
Expand All @@ -258,7 +257,6 @@ gak_proxy <- function(x, y = NULL, ..., sigma = NULL, window.size = NULL, normal
class(D) <- "crossdist"
}

if (!pairwise && symmetric && !lower_triangular_only) diag(D) <- 0

attr(D, "method") <- "GAK"
attr(D, "sigma") <- sigma
Expand Down
4 changes: 2 additions & 2 deletions R/DISTANCES-lb-improved.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ lb_improved_proxy <- function(x, y = NULL, window.size = NULL, norm = "L1", ...,
if (error.check) check_consistency(c(x,y), "tslist")
if (is_multivariate(c(x,y))) stop("lb_improved does not support multivariate series.") # nocov end

lower_triangular_only <- symmetric <- FALSE
fill_type <- mat_type <- dim_names <- NULL # avoid warning about undefined globals
symmetric <- FALSE
eval(prepare_expr) # UTILS-expressions.R

# adjust parameters for this distance
Expand Down Expand Up @@ -169,7 +169,7 @@ lb_improved_proxy <- function(x, y = NULL, window.size = NULL, norm = "L1", ...,
else
.Call(C_force_lb_symmetry, D, PACKAGE = "dtwclust")
}

attr(D, "method") <- "LB_Improved"
# return
D
}
4 changes: 2 additions & 2 deletions R/DISTANCES-lb-keogh.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ lb_keogh_proxy <- function(x, y = NULL, window.size = NULL, norm = "L1", ...,
if (error.check) check_consistency(c(x,y), "tslist")
if (is_multivariate(c(x,y))) stop("lb_keogh does not support multivariate series.") # nocov end

lower_triangular_only <- symmetric <- FALSE
fill_type <- mat_type <- dim_names <- NULL # avoid warning about undefined globals
symmetric <- FALSE
eval(prepare_expr) # UTILS-expressions.R

# adjust parameters for this distance
Expand Down Expand Up @@ -142,7 +142,7 @@ lb_keogh_proxy <- function(x, y = NULL, window.size = NULL, norm = "L1", ...,
else
.Call(C_force_lb_symmetry, D, PACKAGE = "dtwclust")
}

attr(D, "method") <- "LB_Keogh"
# return
D
}
14 changes: 5 additions & 9 deletions R/DISTANCES-sbd.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ sbd <- SBD
#' @importFrom stats fft
#' @importFrom stats nextn
#'
sbd_proxy <- function(x, y = NULL, znorm = FALSE, ...,
error.check = TRUE, pairwise = FALSE,
lower_triangular_only = FALSE, diagonal = TRUE)
{
sbd_proxy <- function(x, y = NULL, znorm = FALSE, ..., error.check = TRUE, pairwise = FALSE) {
x <- tslist(x)

if (error.check) check_consistency(x, "vltslist")
Expand Down Expand Up @@ -165,6 +162,7 @@ sbd_proxy <- function(x, y = NULL, znorm = FALSE, ...,
}

if (is_multivariate(c(x,y))) stop("SBD does not support multivariate series.") # nocov

fill_type <- mat_type <- dim_names <- NULL # avoid warning about undefined globals
eval(prepare_expr) # UTILS-expressions.R

Expand All @@ -185,19 +183,17 @@ sbd_proxy <- function(x, y = NULL, znorm = FALSE, ...,
dim(D) <- NULL
class(D) <- "pairdist"
}
else if (lower_triangular_only) {
else if (symmetric) {
dim(D) <- NULL
class(D) <- c("distdiag", "dist")
class(D) <- "dist"
attr(D, "Size") <- length(x)
attr(D, "Diag") <- diagonal
attr(D, "Upper") <- FALSE
attr(D, "Labels") <- names(x)
}
else {
dimnames(D) <- dim_names
class(D) <- "crossdist"
}

attr(D, "method") <- "SBD"
# return
D
}
Loading

0 comments on commit f5da272

Please sign in to comment.