Skip to content

Commit

Permalink
Merge pull request #284 from alan-turing-institute/fix_single_logical…
Browse files Browse the repository at this point in the history
…_bug

Fix_single_logical_bug
  • Loading branch information
RaphaelS1 authored Feb 23, 2023
2 parents c727639 + 866aa95 commit bb410f4
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 15 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: distr6
Title: The Complete R6 Probability Distributions Interface
Version: 1.6.13
Version: 1.6.14
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# distr6 1.6.14

* Fix bug when extracting a single distribution with a logical vector from `MatDist`

# distr6 1.6.13

* Fix reordering bug when extracting vector distributions
Expand Down
33 changes: 20 additions & 13 deletions R/SDistribution_Matdist.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Matdist <- R6Class("Matdist",
support = Set$new(1, class = "numeric")^"n",
type = Reals$new()^"n"
)
private$.ndists <- nrow(gprm(self, "pdf"))
invisible(self)
},

Expand All @@ -80,7 +81,7 @@ Matdist <- R6Class("Matdist",
#' @param n `(integer(1))` \cr
#' Ignored.
strprint = function(n = 2) {
"Matdist()"
sprintf("Matdist(%s)", private$.ndists)
},

# stats
Expand Down Expand Up @@ -128,7 +129,7 @@ Matdist <- R6Class("Matdist",
"*" %=% gprm(self, c("x", "pdf"))
mean <- self$mean()

vnapply(seq(nrow(pdf)), function(i) {
vnapply(seq_len(private$.ndists), function(i) {
if (mean[[i]] == Inf) {
Inf
} else {
Expand All @@ -149,7 +150,7 @@ Matdist <- R6Class("Matdist",
mean <- self$mean()
sd <- self$stdev()

vnapply(seq(nrow(pdf)), function(i) {
vnapply(seq_len(private$.ndists), function(i) {
if (mean[[i]] == Inf) {
Inf
} else {
Expand All @@ -171,7 +172,7 @@ Matdist <- R6Class("Matdist",
mean <- self$mean()
sd <- self$stdev()

kurt <- vnapply(seq(nrow(pdf)), function(i) {
kurt <- vnapply(seq_len(private$.ndists), function(i) {
if (mean[[i]] == Inf) {
Inf
} else {
Expand Down Expand Up @@ -209,8 +210,8 @@ Matdist <- R6Class("Matdist",
if (length(t) == 1) {
mgf <- apply(pdf, 1, function(.y) sum(exp(x * t) * .y))
} else {
stopifnot(length(z) == nrow(pdf))
mgf <- vnapply(seq(nrow(pdf)),
stopifnot(length(z) == private$.ndists)
mgf <- vnapply(seq_len(private$.ndists),
function(i) sum(exp(x * t[[i]]) * pdf[i, ]))
}

Expand All @@ -228,8 +229,8 @@ Matdist <- R6Class("Matdist",
if (length(t) == 1) {
cf <- apply(pdf, 1, function(.y) sum(exp(x * t * 1i) * .y))
} else {
stopifnot(length(z) == nrow(pdf))
cf <- vnapply(seq(nrow(pdf)),
stopifnot(length(z) == private$.ndists)
cf <- vnapply(seq_len(private$.ndists),
function(i) sum(exp(x * t[[i]] * 1i) * pdf[i, ]))
}

Expand All @@ -247,8 +248,8 @@ Matdist <- R6Class("Matdist",
if (length(z) == 1) {
pgf <- apply(pdf, 1, function(.y) sum((z^x) * .y))
} else {
stopifnot(length(z) == nrow(pdf))
pgf <- vnapply(seq(nrow(pdf)),
stopifnot(length(z) == private$.ndists)
pgf <- vnapply(seq_len(private$.ndists),
function(i) sum((z[[i]]^x) * pdf[i, ]))
}

Expand All @@ -271,7 +272,7 @@ Matdist <- R6Class("Matdist",
.pdf = function(x, log = FALSE) {
"pdf, data" %=% gprm(self, c("pdf", "x"))
out <- t(C_Vec_WeightedDiscretePdf(
x, matrix(data, ncol(pdf), nrow(pdf)),
x, matrix(data, ncol(pdf), private$.ndists),
t(pdf)
))
if (log) {
Expand Down Expand Up @@ -306,7 +307,8 @@ Matdist <- R6Class("Matdist",

# traits
.traits = list(valueSupport = "discrete", variateForm = "univariate"),
.improper = FALSE
.improper = FALSE,
.ndists = 0
)
)

Expand Down Expand Up @@ -392,7 +394,12 @@ c.Matdist <- function(...) {
#' m[1:2]
#' @export
"[.Matdist" <- function(md, i) {
if (length(i) == 1) {
if (is.logical(i)) {
i <- which(i)
}
if (length(i) == 0) {
stop("Can't create an empty distribution.")
} else if (length(i) == 1) {
pdf <- gprm(md, "pdf")[i, ]
dstr("WeightedDiscrete", x = as.numeric(names(pdf)), pdf = pdf,
decorators = md$decorators)
Expand Down
9 changes: 8 additions & 1 deletion tests/testthat/test-sdistribution-Matdist.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ test_that("c.Matdist", {
expect_true(all(r >= 1))
})

test_that("c.Matdist", {
test_that("[.Matdist", {
set.seed(1)
m <- as.Distribution(
t(apply(matrix(runif(200), 20, 10, FALSE,
Expand All @@ -95,9 +95,16 @@ test_that("c.Matdist", {
fun = "pdf"
)

expect_equal(m$strprint(), "Matdist(20)")

expect_error(m[logical(20)], "empty")

m1 <- m[1]
m12 <- m[1:2]

expect_distribution(m1, "WeightedDiscrete")
expect_distribution(m[!logical(20)], "Matdist")
expect_distribution(m[c(TRUE, logical(19))], "WeightedDiscrete")
expect_distribution(m12, "Matdist")

expect_equal(unname(m$cdf(0:25)[, 1]), unname(m1$cdf(0:25)))
Expand Down

0 comments on commit bb410f4

Please sign in to comment.