Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement uncount() #352

Open
markfairbanks opened this issue Apr 22, 2022 · 8 comments
Open

Implement uncount() #352

markfairbanks opened this issue Apr 22, 2022 · 8 comments
Labels
feature a feature request or enhancement

Comments

@markfairbanks
Copy link
Collaborator

library(tidyr)

df <- tibble(x = c("a", "b"), n = c(1, 2))

df %>%
  uncount(n)
#> # A tibble: 3 × 1
#>   x    
#>   <chr>
#> 1 a    
#> 2 b    
#> 3 b
@markfairbanks
Copy link
Collaborator Author

Looks like we would need to make uncount() a generic in tidyr first.

@mgirlich
Copy link
Collaborator

There is already a tidyr issue tidyverse/tidyr#1071 for this. In dbplyr we simply added dbplyr_uncount() for now...

@markfairbanks
Copy link
Collaborator Author

Should we add dtplyr_uncount()? Or maybe we should wait for tidyverse/tidyr#1101 to be merged?

@mgirlich
Copy link
Collaborator

It might take some time for tidyverse/tidyr#1101 to be merged (the interface of most of these functions is not completely stable yet). Therefore, I created a separate PR tidyverse/tidyr#1358 for uncount(). This might be merged faster.

@markfairbanks
Copy link
Collaborator Author

markfairbanks commented May 6, 2022

Here's an initial implementation that I think covers everything. Once tidyverse/tidyr#1358 is merged we can create this as an uncount.dtplyr_step method.

Edit: I'm not sure if there's a better way to deal with creating the .id column. Having to pull() the weights would be pretty expensive in a long pipe chain.

library(data.table)
library(dtplyr)
library(dplyr, warn.conflicts = FALSE)
library(tidyr, warn.conflicts = FALSE)

dtplyr_uncount <- function(data, weights, ..., .remove = TRUE, .id = NULL) {
  weights <- enquo(weights)
  
  needs_id <- !is.null(.id)
  
  if (needs_id) {
    .reps <- pull(data, !!weights)
  }
  
  out <- slice(data, rep(1:.N, !!weights))
  
  if (needs_id) {
    out <- mutate(out, !!.id := sequence(!!.reps))
  }
  
  if (.remove) {
    out <- select(out, -!!weights)
  }
  
  out
}

df <- data.table(x = c("a", "b"), n = c(1, 2))

df %>%
  dtplyr_uncount(n)
#> Source: local data table [3 x 1]
#> Call:   `_DT1`[rep(1:.N, n)[between(rep(1:.N, n), -.N, .N)], .(x)]
#> 
#>   x    
#>   <chr>
#> 1 a    
#> 2 b    
#> 3 b    
#> 
#> # Use as.data.table()/as.data.frame()/as_tibble() to access results

df %>%
  dtplyr_uncount(n, .id = "id", .remove = FALSE)
#> Source: local data table [3 x 3]
#> Call:   `_DT2`[rep(1:.N, n)[between(rep(1:.N, n), -.N, .N)]][, `:=`(id = sequence(c(1, 
#> 2)))]
#> 
#>   x         n    id
#>   <chr> <dbl> <int>
#> 1 a         1     1
#> 2 b         2     1
#> 3 b         2     2
#> 
#> # Use as.data.table()/as.data.frame()/as_tibble() to access results

@markfairbanks
Copy link
Collaborator Author

markfairbanks commented May 12, 2022

I think this might be a better way - it avoids the pull() issue I mentioned above.

library(data.table)
library(dtplyr)
library(dplyr, warn.conflicts = FALSE)
library(tidyr, warn.conflicts = FALSE)
library(purrr)

step_subset_j <- dtplyr:::step_subset_j
step_subset <- dtplyr:::step_subset

dtplyr_uncount <- function(data, weights, ..., .remove = TRUE, .id = NULL) {
  weights <- quo_squash(enquo(weights))
  
  groups <- group_vars(data)
  has_groups <- length(groups) > 0
  if (has_groups) {
    data <- ungroup(data)
  }
  
  if (is.null(.id)) {
    if (.remove) {
      vars <- setdiff(data$vars, as_name(weights))
      j <- call2(".", !!!syms(vars))
    } else {
      vars <- data$vars
      j <- NULL
    }
    out <- step_subset(
      data,
      vars = vars,
      i = expr(rep(1:.N, !!weights)),
      j = j
    )
  } else {
    vars_names <- data$vars
    if (.remove) {
      vars_names <- setdiff(vars_names, as_name(weights))
    }
    vars <- map(syms(vars_names), ~ expr(rep(!!.x, !!weights)))
    names(vars) <- vars_names
    
    vars <- append(vars, exprs(!!.id := sequence(!!weights)))

    out <- step_subset_j(
      data,
      vars = vars_names,
      j = call2(".", !!!vars)
    )
  }
  
  if (has_groups) {
    out <- group_by(out, !!!syms(groups))
  }
  
  out
}

df <- lazy_dt(data.table(x = c("a", "b"), n = c(1, 2)))

df %>%
  dtplyr_uncount(n)
#> Source: local data table [3 x 1]
#> Call:   `_DT1`[rep(1:.N, n), .(x)]
#> 
#>   x    
#>   <chr>
#> 1 a    
#> 2 b    
#> 3 b    
#> 
#> # Use as.data.table()/as.data.frame()/as_tibble() to access results

df %>%
  dtplyr_uncount(n, .id = "id", .remove = FALSE)
#> Source: local data table [3 x 3]
#> Call:   `_DT1`[, .(x = rep(x, n), n = rep(n, n), id = sequence(n))]
#> 
#>   x         n    id
#>   <chr> <dbl> <int>
#> 1 a         1     1
#> 2 b         2     1
#> 3 b         2     2
#> 
#> # Use as.data.table()/as.data.frame()/as_tibble() to access results

@eutwt
Copy link
Collaborator

eutwt commented May 13, 2022

Would it be faster to just always use the second branch of your function, where you rep each column? (obviously it is in the example below, but I haven't thought about the other possible scenarios)

devtools::load_all('/Users/mbp/Documents/GitHub/dtp')
#> ℹ Loading dtplyr
#> Warning: package 'dplyr' was built under R version 4.1.2


# dtplyr_uncount <- [as above]

dtplyr_uncount2 <- function(data, weights, ..., .remove = TRUE, .id = NULL) {
  weights <- quo_squash(enquo(weights))
  
  groups <- group_vars(data)
  has_groups <- length(groups) > 0
  if (has_groups) {
    data <- ungroup(data)
  }

  vars_names <- data$vars
  if (.remove) {
    vars_names <- setdiff(vars_names, as_name(weights))
  }
  vars <- map(syms(vars_names), ~ expr(rep(!!.x, !!weights)))
  names(vars) <- vars_names
  
  if (!is.null(.id)) {
    vars <- append(vars, exprs(!!.id := sequence(!!weights)))
  }

  out <- step_subset_j(
    data,
    vars = vars_names,
    j = call2(".", !!!vars)
  )
  
  if (has_groups) {
    out <- group_by(out, !!!syms(groups))
  }
  
  out
}

df <- data.table(x = c("a", "b"), n = c(1, 2)) %>% 
  mutate(n = c(1, 2)*1e6) %>% 
  lazy_dt()


library(bench)

mark(
  a =  df %>% dtplyr_uncount(n) %>% collect,
  b =  df %>% dtplyr_uncount2(n) %>% collect)
#> Warning: Some expressions had a GC in every iteration; so filtering is disabled.
#> # A tibble: 2 × 6
#>   expression      min   median `itr/sec` mem_alloc `gc/sec`
#>   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
#> 1 a            55.1ms   60.4ms      13.9    69.4MB     21.8
#> 2 b            19.2ms   31.7ms      25.7    46.1MB     19.8

Created on 2022-05-12 by the reprex package (v2.0.1)

@markfairbanks
Copy link
Collaborator Author

Would it be faster to just always use the second branch of your function

Looks like it. I hadn’t gotten around to testing it yet. I’m a bit surprised to be honest - I would have assumed the simple slice/select would be much more efficient.

@markfairbanks markfairbanks added the feature a feature request or enhancement label Jun 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature a feature request or enhancement
Projects
None yet
Development

No branches or pull requests

3 participants