Skip to content

Commit

Permalink
Code review and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
plietar committed Mar 26, 2024
1 parent a5fc679 commit ec584db
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 18 deletions.
9 changes: 5 additions & 4 deletions R/event.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@ EventBase <- R6Class(
Event <- R6Class(
'Event',
inherit = EventBase,
private = list(
should_restore = FALSE
),
public = list(
.should_restore = FALSE,

#' @description Initialise an Event.
#' @param restore if true, the schedule of this event is restored when
#' restoring from a saved simulation.
initialize = function(restore = TRUE) {
self$.event <- create_event()
self$.should_restore = restore
private$should_restore = restore
},

#' @description Schedule this event to occur in the future.
Expand Down Expand Up @@ -91,7 +92,7 @@ Event <- R6Class(
#' simulation in which this variable did not exist.
restore_state = function(timestep, state) {
event_base_set_timestep(self$.event, timestep)
if (self$.should_restore && !is.null(state)) {
if (private$should_restore && !is.null(state)) {
event_restore(self$.event, state)
}
}
Expand Down
25 changes: 18 additions & 7 deletions R/simulation.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ simulation_loop <- function(
#' @param timesteps the number of time steps that have already been simulated
#' @param variables the list of Variables
#' @param events the list of Events
#' @return the saved simulation state
#' @return the saved simulation state.
save_simulation_state <- function(timesteps, variables, events) {
random_state <- .GlobalEnv$.Random.seed
list(
Expand All @@ -99,9 +99,12 @@ save_simulation_state <- function(timesteps, variables, events) {
}

#' @title Save the state of a simulation object or set of objects.
#' @param objects a simulation object (ie. a variable or event), or list
#' thereof.
#' @return the saved states of the objects
#' @param objects a simulation object (eg. a variable or event) or an
#' arbitrarily nested list structure of such objects.
#' @return the saved states of the objects. This has the same shape as the given
#' \code{objects}: if a list was passed as an argument, this returns the
#' corresponding list of saved states. If a singular object was passed, this
#' returns just that particular object's state.
#' @export
save_object_state <- function(objects) {
if (is.list(objects)) {
Expand All @@ -116,7 +119,7 @@ save_object_state <- function(objects) {
#' The state of passed events and variables is overwritten to match the state
#' they had when the simulation was checkpointed.
#' @param state the simulation state to restore, as returned by
#' \code{\link[individual]{restore_simulation_state}}.
#' \code{\link[individual]{save_simulation_state}}.
#' @param variables the list of Variables
#' @param events the list of Events
#' @param restore_random_state if TRUE, restore R's global random number
Expand Down Expand Up @@ -153,8 +156,15 @@ is_uniquely_named <- function(x) {
#' extended with more features upon resuming. In this case, the
#' \code{restore_state} method is called with a \code{NULL} argument.
#'
#' @param objects a simulation object (ie. a variable or event), or list
#' thereof.
#' @param timesteps the number of time steps that have already been simulated
#' @param objects a simulation object (eg. a variable or event) or an
#' arbitrarily nested list structure of such objects.
#' @param state a saved simulation state for the given objects, as returned by
#' \code{\link[individual]{save_object_state}}. This should have the same shape
#' as the \code{objects} argument: if a list of objects is given, then
#' \code{state} should be a list of corresponding states. If NULL is passed,
#' then each object's \code{restore_state} method is called with NULL as
#' its argument.
#' @export
restore_object_state <- function(timesteps, objects, state) {
if (is.list(objects)) {
Expand All @@ -176,6 +186,7 @@ restore_object_state <- function(timesteps, objects, state) {
} else {
stop("Saved state does not match resumed objects")
}

for (k in keys) {
restore_object_state(timesteps, objects[[k]], state[[k]])
}
Expand Down
13 changes: 11 additions & 2 deletions man/restore_object_state.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/restore_simulation_state.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions man/save_object_state.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/save_simulation_state.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

166 changes: 166 additions & 0 deletions tests/testthat/test-checkpoint.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,169 @@ test_that("cannot resume with smaller timesteps", {
simulation_loop(timesteps = 10, state = state),
"Restored state is already longer than timesteps")
})

MockState <- R6Class(
'MockState',
private = list(value = NULL),
public = list(
save_state = function() private$value,
restore_state = NULL,
initialize = function(value = NULL) {
private$value <- value
self$restore_state <- mockery::mock()
}
)
)

test_that("saved object's state is returned", {
o <- MockState$new("foo")
state <- save_object_state(o)
expect_identical(state, "foo")
})

test_that("saved objects' state is returned as list", {
o1 <- MockState$new("foo")
o2 <- MockState$new("bar")
state <- save_object_state(list(o1, o2))
expect_identical(state, list("foo", "bar"))
})

test_that("saved objects' state preserves names", {
o1 <- MockState$new("foo")
o2 <- MockState$new("bar")
state <- save_object_state(list(x=o1, y=o2))
expect_identical(state, list(x="foo", y="bar"))
})

test_that("saved objects' state preserves nested structure", {
o1 <- MockState$new("foo")
o2 <- MockState$new("bar")
o3 <- MockState$new("baz")

state <- save_object_state(list(x=o1, y=list(o2, o3)))
expect_identical(state, list(x="foo", y=list("bar", "baz")))
})

test_that("empty list is returned for empty list of objects", {
state <- save_object_state(list())
expect_identical(state, list())
})

test_that("restore_state method is called on object", {
o <- MockState$new()
restore_object_state(123, o, "state")
mockery::expect_called(o$restore_state, 1)
mockery::expect_args(o$restore_state, 1, 123, "state")
})

test_that("restore_state method is called on object list", {
o1 <- MockState$new()
o2 <- MockState$new()
restore_object_state(123, list(o1, o2), list("hello", "world"))

mockery::expect_called(o1$restore_state, 1)
mockery::expect_args(o1$restore_state, 1, 123, "hello")

mockery::expect_called(o2$restore_state, 1)
mockery::expect_args(o2$restore_state, 1, 123, "world")
})

test_that("restore_state method is called on named object list", {
o1 <- MockState$new()
o2 <- MockState$new()

# Lists get paired up by name, even if the order is different
restore_object_state(123, list(x=o1, y=o2), list(y="world", x="hello"))

mockery::expect_called(o1$restore_state, 1)
mockery::expect_args(o1$restore_state, 1, 123, "hello")

mockery::expect_called(o2$restore_state, 1)
mockery::expect_args(o2$restore_state, 1, 123, "world")
})

test_that("restore_state method is called on nested object list", {
o1 <- MockState$new()
o2 <- MockState$new()
o3 <- MockState$new()

restore_object_state(
123,
list(list(o1, o2), o3),
list(list("foo", "bar"), "baz"))

mockery::expect_called(o1$restore_state, 1)
mockery::expect_args(o1$restore_state, 1, 123, "foo")

mockery::expect_called(o2$restore_state, 1)
mockery::expect_args(o2$restore_state, 1, 123, "bar")

mockery::expect_called(o3$restore_state, 1)
mockery::expect_args(o3$restore_state, 1, 123, "baz")
})

test_that("restore_state method is called with NULL for new objects", {
o1 <- MockState$new()
o2 <- MockState$new()
o3 <- MockState$new()

restore_object_state(
123,
list(x=o1, y=o2, z=o3),
list(x="foo", z="baz"))

mockery::expect_called(o1$restore_state, 1)
mockery::expect_args(o1$restore_state, 1, 123, "foo")

mockery::expect_called(o2$restore_state, 1)
mockery::expect_args(o2$restore_state, 1, 123, NULL)

mockery::expect_called(o3$restore_state, 1)
mockery::expect_args(o3$restore_state, 1, 123, "baz")
})

test_that("cannot restore objects with partial unnamed list", {
o1 <- MockState$new()
o2 <- MockState$new()

expect_error(
restore_object_state(123, list(o1, o2), list("foo")),
"Saved state does not match resumed objects")
})

test_that("restore_state method is called with NULL for new list of objects", {
o1 <- MockState$new()
o2 <- MockState$new()
o3 <- MockState$new()

restore_object_state(
123,
list(x=o1, y=list(o2, o3)),
list(x="foo"))

mockery::expect_called(o1$restore_state, 1)
mockery::expect_args(o1$restore_state, 1, 123, "foo")

mockery::expect_called(o2$restore_state, 1)
mockery::expect_args(o2$restore_state, 1, 123, NULL)

mockery::expect_called(o3$restore_state, 1)
mockery::expect_args(o3$restore_state, 1, 123, NULL)
})

test_that("restore_state method is called with NULL for all objects", {
o1 <- MockState$new()
o2 <- MockState$new()
o3 <- MockState$new()

restore_object_state(123, list(x=o1, y=o2, z=o3), NULL)

mockery::expect_called(o1$restore_state, 1)
mockery::expect_args(o1$restore_state, 1, 123, NULL)

mockery::expect_called(o2$restore_state, 1)
mockery::expect_args(o2$restore_state, 1, 123, NULL)

mockery::expect_called(o3$restore_state, 1)
mockery::expect_args(o3$restore_state, 1, 123, NULL)
})

0 comments on commit ec584db

Please sign in to comment.