Skip to content

Commit

Permalink
Fix distfun for hierarchical TS clusters
Browse files Browse the repository at this point in the history
  • Loading branch information
asardaes committed Jul 1, 2024
1 parent a971b91 commit f2e8f87
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 19 deletions.
19 changes: 12 additions & 7 deletions R/CLUSTERING-tsclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -606,10 +606,6 @@ tsclust <- function(series = NULL, type = "partitional", k = 2L, ...,
# Calculate distance matrix
# --------------------------------------------------------------------------------------

# Take advantage of the function I defined for the partitional methods
# Which can do calculations in parallel if appropriate
distfun <- ddist2(distance = distance, control = control, control$symmetric)

if (!is.null(distmat)) {
if (inherits(distmat, "matrix") && nrow(distmat) != length(series) || ncol(distmat) != length(series))
stop("Dimensions of provided cross-distance matrix don't correspond to ",
Expand All @@ -623,15 +619,23 @@ tsclust <- function(series = NULL, type = "partitional", k = 2L, ...,
}
else {
if (trace) cat("\nCalculating distance matrix...\n")
distmat <- quoted_call(distfun, x = series, centroids = NULL, dots = args$dist)
# Take advantage of the function I defined for the partitional methods
# Which can do calculations in parallel if appropriate
distfun <- ddist2(distance = distance, control = control, control$symmetric)
dist_dots <- if ("sdtw" %in% proxy::pr_DB$get_entry(distance)$names) {
c(args$dist, list(diagonal = FALSE))
} else {
args$dist
}
distmat <- quoted_call(distfun, x = series, centroids = NULL, dots = dist_dots)
}

# --------------------------------------------------------------------------------------
# Cluster
# --------------------------------------------------------------------------------------

if (trace) cat("Performing hierarchical clustering...\n")
if (inherits(distmat, "matrix") && !base::isSymmetric(base::as.matrix(distmat)))
if (!inherits(distmat, "dist") && !base::isSymmetric(base::as.matrix(distmat)))
warning("Distance matrix is not symmetric, ",
"and hierarchical clustering assumes it is ",
"(it ignores the upper triangular).")
Expand Down Expand Up @@ -686,7 +690,8 @@ tsclust <- function(series = NULL, type = "partitional", k = 2L, ...,
stats::as.hclust(hc),
call = MYCALL,
family = methods::new("tsclustFamily",
dist = distfun,
dist = ddist2(distance = distance,
control = control),
allcent = allcent,
preproc = preproc),
control = control,
Expand Down
6 changes: 3 additions & 3 deletions R/DISTANCES-sdtw.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ sdtw <- function(x, y, gamma = 0.01, ..., error.check = TRUE)
# ==================================================================================================

sdtw_proxy <- function(x, y = NULL, gamma = 0.01, ...,
error.check = TRUE, pairwise = FALSE, lower_triangular_only = FALSE)
error.check = TRUE, pairwise = FALSE, lower_triangular_only = FALSE,
diagonal = TRUE)
{
x <- tslist(x)
if (error.check) check_consistency(x, "vltslist")
Expand All @@ -58,7 +59,6 @@ sdtw_proxy <- function(x, y = NULL, gamma = 0.01, ...,
}

fill_type <- mat_type <- dim_names <- NULL # avoid warning about undefined globals
diagonal <- TRUE
eval(prepare_expr) # UTILS-expressions.R

# adjust parameters for this distance
Expand All @@ -85,7 +85,7 @@ sdtw_proxy <- function(x, y = NULL, gamma = 0.01, ...,
dim(D) <- NULL
class(D) <- c("distdiag", "dist")
attr(D, "Size") <- length(x)
attr(D, "Diag") <- TRUE
attr(D, "Diag") <- diagonal
attr(D, "Upper") <- FALSE
attr(D, "Labels") <- names(x)
}
Expand Down
8 changes: 2 additions & 6 deletions R/S4-tsclustFamily.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ f_cluster <- function(distmat, m) {
# Custom initialize
# ==================================================================================================

#' @importFrom methods as
#' @importFrom methods callNextMethod
#' @importFrom methods initialize
#' @importFrom methods setMethod
Expand All @@ -144,12 +145,7 @@ setMethod("initialize", "tsclustFamily",
if (is.character(allcent)) {
if (allcent %in% c("pam", "fcmdd")) {
if (!is.null(control$distmat) && !inherits(control$distmat, "Distmat")) {
control$distmat <- if (inherits(control$distmat, "dist")) {
DistmatLowerTriangular$new(distmat = control$distmat)
}
else {
Distmat$new(distmat = control$distmat)
}
control$distmat <- methods::as(control$distmat, "Distmat")
}
}
allcent <- all_cent2(allcent, control)
Expand Down
6 changes: 4 additions & 2 deletions src/distmat/fillers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ class SymmetricFillWorker : public ParallelWorker {
const int grain,
const int nrows)
: ParallelWorker(grain, 10, 1000)
, distmat_(distmat)
, dist_calculator_(dist_calculator)
, distmat_(distmat)
, nrows_(nrows)
{ }

Expand Down Expand Up @@ -233,6 +233,9 @@ class SymmetricFillWorker : public ParallelWorker {
mutex_.unlock();
}

private:
const std::shared_ptr<DistanceCalculator> dist_calculator_;

protected:
std::shared_ptr<Distmat> distmat_;

Expand All @@ -242,7 +245,6 @@ class SymmetricFillWorker : public ParallelWorker {
}

private:
const std::shared_ptr<DistanceCalculator> dist_calculator_;
const id_t nrows_;
};

Expand Down
Binary file modified tests/testthat/rds/hc_all.rds
Binary file not shown.
Binary file modified tests/testthat/rds/hc_cent.rds
Binary file not shown.
Binary file modified tests/testthat/rds/hc_cent2.rds
Binary file not shown.
Binary file modified tests/testthat/rds/hc_diana.rds
Binary file not shown.
3 changes: 2 additions & 1 deletion tests/testthat/system/hierarchical.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ test_that("Hierarchical clustering works as expected.", {

## ---------------------------------------------------------- with provided distmat
id_avg <- which(sapply(hc_all, slot, "method") == "average")
distmat <- hc_all[[1L]]@distmat
distmat <- as.matrix(hc_all[[1L]]@distmat)
attr(distmat, "method") <- attr(hc_all[[1L]]@distmat, "method")
expect_output(
hc_avg <- tsclust(data, type = "hierarchical", k = 20L,
distance = "sbd", trace = TRUE,
Expand Down

0 comments on commit f2e8f87

Please sign in to comment.