Skip to content

Commit

Permalink
Support proxy's diag parameter where appropriate
Browse files Browse the repository at this point in the history
  • Loading branch information
asardaes committed Jul 1, 2024
1 parent 9b007ef commit 6d47400
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 46 deletions.
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ importFrom(rlang,"!!!")
importFrom(rlang,.data)
importFrom(rlang,as_environment)
importFrom(rlang,as_string)
importFrom(rlang,caller_env)
importFrom(rlang,enexpr)
importFrom(rlang,enexprs)
importFrom(rlang,env_bind)
Expand Down
7 changes: 6 additions & 1 deletion R/CLUSTERING-tsclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,9 @@ tsclust <- function(series = NULL, type = "partitional", k = 2L, ...,

if (!inherits(control, "PtCtrl") && !inherits(control, "FzCtrl"))
stop("Invalid control provided") # nocov

nrep <- if (is.null(control$nrep)) 1L else control$nrep

if (!is.character(centroid) || !(cent_char %in% c("pam", "fcmdd")))
control$distmat <- NULL

Expand All @@ -425,6 +427,9 @@ tsclust <- function(series = NULL, type = "partitional", k = 2L, ...,

# precompute distance matrix?
if (cent_char %in% c("pam", "fcmdd")) {
if (distance == "sdtw" && control$pam.precompute) {
args$dist$diag <- TRUE
}
dm <- pam_distmat(series, control, distance, cent_char, family, args, trace)
distmat <- dm$distmat
distmat_provided <- dm$distmat_provided
Expand Down Expand Up @@ -623,7 +628,7 @@ tsclust <- function(series = NULL, type = "partitional", k = 2L, ...,
# Which can do calculations in parallel if appropriate
distfun <- ddist2(distance = distance, control = control, control$symmetric)
dist_dots <- if ("sdtw" %in% proxy::pr_DB$get_entry(distance)$names) {
c(args$dist, list(diagonal = FALSE))
c(args$dist, list(diagonal = FALSE, .internal_ = TRUE))
} else {
args$dist
}
Expand Down
8 changes: 4 additions & 4 deletions R/DISTANCES-dtw-basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ dtw_basic <- function(x, y, window.size = NULL, norm = "L1",
dtw_basic_proxy <- function(x, y = NULL, window.size = NULL, norm = "L1",
step.pattern = dtw::symmetric2,
normalize = FALSE, sqrt.dist = TRUE, ...,
error.check = TRUE, pairwise = FALSE, lower_triangular_only = FALSE)
error.check = TRUE, pairwise = FALSE,
lower_triangular_only = FALSE, diagonal = TRUE)
{
x <- tslist(x)
if (error.check) check_consistency(x, "vltslist")
Expand All @@ -130,7 +131,6 @@ dtw_basic_proxy <- function(x, y = NULL, window.size = NULL, norm = "L1",
}

fill_type <- mat_type <- dim_names <- NULL # avoid warning about undefined globals
diagonal <- FALSE
eval(prepare_expr) # UTILS-expressions.R

# adjust parameters for this distance
Expand Down Expand Up @@ -179,9 +179,9 @@ dtw_basic_proxy <- function(x, y = NULL, window.size = NULL, norm = "L1",
}
else if (lower_triangular_only) {
dim(D) <- NULL
class(D) <- "dist"
class(D) <- c("distdiag", "dist")
attr(D, "Size") <- length(x)
attr(D, "Diag") <- FALSE
attr(D, "Diag") <- diagonal
attr(D, "Upper") <- FALSE
attr(D, "Labels") <- names(x)
}
Expand Down
17 changes: 10 additions & 7 deletions R/DISTANCES-gak.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ gak <- GAK

gak_proxy <- function(x, y = NULL, ..., sigma = NULL, window.size = NULL, normalize = TRUE,
error.check = TRUE, pairwise = FALSE, .internal_ = FALSE,
lower_triangular_only = FALSE)
lower_triangular_only = FALSE, diagonal = TRUE)
{
# normalization will be done manually to avoid multiple calculations of gak_x and gak_y
if (!.internal_ && !normalize) { # nocov start
Expand All @@ -186,7 +186,6 @@ gak_proxy <- function(x, y = NULL, ..., sigma = NULL, window.size = NULL, normal
stop("Parameter 'sigma' must be positive.")

fill_type <- mat_type <- dim_names <- NULL # avoid warning about undefined globals
diagonal <- FALSE
eval(prepare_expr) # UTILS-expressions.R

# adjust parameters for this distance
Expand Down Expand Up @@ -234,18 +233,22 @@ gak_proxy <- function(x, y = NULL, ..., sigma = NULL, window.size = NULL, normal
else if (lower_triangular_only) {
dim(D) <- NULL
if (normalize) {
j_upper <- if (diagonal) length(x) else length(x) - 1L
i_lower <- if (diagonal) 0L else 1L
k <- 1L
for (j in 1L:(length(x) - 1L)) {
for (i in (j+1L):length(x)) {
D[k] <- 1 - exp(D[k] - (gak_x[i] + gak_x[j]) / 2)
for (j in 1L:j_upper) {
for (i in (j + i_lower):length(x)) {
if (i != j) {
D[k] <- 1 - exp(D[k] - (gak_x[i] + gak_x[j]) / 2)
}
k <- k + 1L
}
}
}

class(D) <- "dist"
class(D) <- c("distdiag", "dist")
attr(D, "Size") <- length(x)
attr(D, "Diag") <- FALSE
attr(D, "Diag") <- diagonal
attr(D, "Upper") <- FALSE
attr(D, "Labels") <- names(x)
}
Expand Down
8 changes: 4 additions & 4 deletions R/DISTANCES-sbd.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ sbd <- SBD
#' @importFrom stats nextn
#'
sbd_proxy <- function(x, y = NULL, znorm = FALSE, ...,
error.check = TRUE, pairwise = FALSE, lower_triangular_only = FALSE)
error.check = TRUE, pairwise = FALSE,
lower_triangular_only = FALSE, diagonal = TRUE)
{
x <- tslist(x)

Expand Down Expand Up @@ -165,7 +166,6 @@ sbd_proxy <- function(x, y = NULL, znorm = FALSE, ...,

if (is_multivariate(c(x,y))) stop("SBD does not support multivariate series.") # nocov
fill_type <- mat_type <- dim_names <- NULL # avoid warning about undefined globals
diagonal <- FALSE
eval(prepare_expr) # UTILS-expressions.R

# calculate distance matrix
Expand All @@ -187,9 +187,9 @@ sbd_proxy <- function(x, y = NULL, znorm = FALSE, ...,
}
else if (lower_triangular_only) {
dim(D) <- NULL
class(D) <- "dist"
class(D) <- c("distdiag", "dist")
attr(D, "Size") <- length(x)
attr(D, "Diag") <- FALSE
attr(D, "Diag") <- diagonal
attr(D, "Upper") <- FALSE
attr(D, "Labels") <- names(x)
}
Expand Down
10 changes: 8 additions & 2 deletions R/DISTANCES-sdtw.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,16 @@ sdtw <- function(x, y, gamma = 0.01, ..., error.check = TRUE)

sdtw_proxy <- function(x, y = NULL, gamma = 0.01, ...,
error.check = TRUE, pairwise = FALSE, lower_triangular_only = FALSE,
diagonal = TRUE)
diagonal = TRUE, .internal_ = FALSE)
{
x <- tslist(x)
if (error.check) check_consistency(x, "vltslist")
if (error.check) {
check_consistency(x, "vltslist")
}
if (lower_triangular_only && !diagonal && !.internal_) {
warning("proxy calls using 'sdtw' should specify diag = TRUE")
}

if (is.null(y)) {
y <- x
symmetric <- TRUE
Expand Down
19 changes: 7 additions & 12 deletions R/S4-DistmatLowerTriangular.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
lower_triangular_index <- function(i, j, n, diagonal) {
stopifnot(i > 0L, i <= n, j > 0L, j <= i)
if (!diagonal) stopifnot(i != j)

i <- i - 1L
j <- j - 1L

adjustment <- if (diagonal) {
0L
}
else {
Reduce(x = 0L:j, init = 0L, f = function(a, b) { a + b + 1L })
}
diagonal <- as.integer(diagonal)
adjustment <- Reduce(x = 1L:j, init = 0L, f = function(a, b) {
a + b - diagonal
})

i + j * n - adjustment + 1L
i + (j - 1L) * n - adjustment
}

#' Distance matrix's lower triangular
Expand Down Expand Up @@ -97,7 +93,6 @@ setMethod(`[`, "DistmatLowerTriangular", function(x, i, j, ...) {
i <- combinations$i
j <- combinations$j
drop <- FALSE

}

n <- attr(x$distmat, "Size")
Expand Down Expand Up @@ -138,7 +133,7 @@ setAs("dist", "Distmat", function(from) { DistmatLowerTriangular$new(distmat = f
as.matrix.distdiag <- function(x) {
n <- attr(x, "Size")
m <- matrix(0, n, n)
m[lower.tri(m, diag = TRUE)] <- x
m[lower.tri(m, diag = attr(x, "Diag"))] <- x

lbls <- attr(x, "Labels")
if (!is.null(lbls)) {
Expand Down
6 changes: 1 addition & 5 deletions R/pkg.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,9 @@
"_PACKAGE"

# PREFUN for some of my proxy distances so that they support 'pairwise' directly
#' @importFrom rlang caller_env
#' @importFrom rlang env_bind
proxy_prefun <- function(x, y, pairwise, params, reg_entry) {
if (!is.null(reg_entry) && "sdtw" %in% reg_entry$names) {
rlang::env_bind(rlang::caller_env(), diag = TRUE)
}
params$pairwise <- pairwise
params$diagonal <- get_from_callers("diag", "logical")
list(x = x, y = y, pairwise = pairwise, p = params, reg_entry = reg_entry)
}

Expand Down
3 changes: 2 additions & 1 deletion src/centroids/dba.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class DtwBacktrackCalculator : public DistanceCalculator
public:
// constructor
DtwBacktrackCalculator(const Rcpp::List& dist_args, const Rcpp::List& x, const Rcpp::List& y)
: x_(x)
: DistanceCalculator("DTW_BACTRACK")
, x_(x)
, y_(y)
{
window_ = Rcpp::as<int>(dist_args["window.size"]);
Expand Down
3 changes: 2 additions & 1 deletion src/centroids/sdtw-cent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class SdtwCentCalculator : public DistanceCalculator
public:
// constructor
SdtwCentCalculator(const Rcpp::List& x, const Rcpp::List& y, const double gamma)
: gamma_(gamma)
: DistanceCalculator("SDTW_CENT")
, gamma_(gamma)
, x_(x)
, y_(y)
{
Expand Down
23 changes: 17 additions & 6 deletions src/distances/calculators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,20 @@ DistanceCalculatorFactory::create(const std::string& dist, const SEXP& DIST_ARGS
Rcpp::stop("Unknown distance measure"); // nocov
}

// base constuctor
DistanceCalculator::DistanceCalculator(std::string&& distance)
: distance(distance)
{ }

// =================================================================================================
/* DtwBasic */
// =================================================================================================

// -------------------------------------------------------------------------------------------------
/* constructor */
DtwBasicCalculator::DtwBasicCalculator(const SEXP& DIST_ARGS, const SEXP& X, const SEXP& Y)
: x_(X)
: DistanceCalculator("DTW_BASIC")
, x_(X)
, y_(Y)
{
Rcpp::List dist_args(DIST_ARGS);
Expand Down Expand Up @@ -95,7 +101,8 @@ double DtwBasicCalculator::calculate(const arma::mat& x, const arma::mat& y) {
// -------------------------------------------------------------------------------------------------
/* constructor */
GakCalculator::GakCalculator(const SEXP& DIST_ARGS, const SEXP& X, const SEXP& Y)
: x_(X)
: DistanceCalculator("GAK")
, x_(X)
, y_(Y)
{
Rcpp::List dist_args(DIST_ARGS);
Expand Down Expand Up @@ -137,7 +144,8 @@ double GakCalculator::calculate(const arma::mat& x, const arma::mat& y) {
// -------------------------------------------------------------------------------------------------
/* constructor */
LbiCalculator::LbiCalculator(const SEXP& DIST_ARGS, const SEXP& X, const SEXP& Y)
: x_(X)
: DistanceCalculator("LBI")
, x_(X)
, y_(Y)
{
Rcpp::List dist_args(DIST_ARGS);
Expand Down Expand Up @@ -189,7 +197,8 @@ double LbiCalculator::calculate(const arma::mat& x, const arma::mat& y,
// -------------------------------------------------------------------------------------------------
/* constructor */
LbkCalculator::LbkCalculator(const SEXP& DIST_ARGS, const SEXP& X, const SEXP& Y)
: x_(X)
: DistanceCalculator("LBK")
, x_(X)
{
Rcpp::List dist_args(DIST_ARGS);
p_ = Rcpp::as<int>(dist_args["p"]);
Expand Down Expand Up @@ -235,7 +244,8 @@ double LbkCalculator::calculate(const arma::mat& x,
// -------------------------------------------------------------------------------------------------
/* constructor */
SbdCalculator::SbdCalculator(const SEXP& DIST_ARGS, const SEXP& X, const SEXP& Y)
: x_(X)
: DistanceCalculator("SBD")
, x_(X)
, y_(Y)
{
// note cc_seq_truncated_ is not set here, it is allocated for each clone
Expand Down Expand Up @@ -296,7 +306,8 @@ double SbdCalculator::calculate(const arma::mat& x, const arma::mat& y,
// -------------------------------------------------------------------------------------------------
/* constructor */
SdtwCalculator::SdtwCalculator(const SEXP& DIST_ARGS, const SEXP& X, const SEXP& Y)
: x_(X)
: DistanceCalculator("SDTW")
, x_(X)
, y_(Y)
{
Rcpp::List dist_args(DIST_ARGS);
Expand Down
4 changes: 4 additions & 0 deletions src/distances/calculators.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ class DistanceCalculator
// a clone method to make life easier when copying objects in each thread
virtual DistanceCalculator* clone() const = 0;

std::string distance;

protected:
DistanceCalculator(std::string&& distance);

int maxLength(const TSTSList<arma::mat>& list) const {
unsigned int max_len = 0;
for (const arma::mat& x : list) {
Expand Down
6 changes: 4 additions & 2 deletions src/distmat/fillers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,10 @@ class LowerTriangularDiagonalFillWorker : public ParallelWorker {
for (id_t id = begin; id < end; id++) {
if (is_interrupted(id)) break; // nocov

double dist = dist_calculator->calculate(i,j);
(*distmat_)(id,0) = dist;
if (dist_calculator->distance == "SDTW" || i != j) {
double dist = dist_calculator->calculate(i,j);
(*distmat_)(id,0) = dist;
}

i++;
if (i >= nrows_) {
Expand Down

0 comments on commit 6d47400

Please sign in to comment.