Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Built-in caching for functions #1604

Open
samuelchassot opened this issue Nov 27, 2024 · 2 comments
Open

Built-in caching for functions #1604

samuelchassot opened this issue Nov 27, 2024 · 2 comments
Assignees
Labels

Comments

@samuelchassot
Copy link
Collaborator

No description provided.

@vkuncak
Copy link
Collaborator

vkuncak commented Nov 29, 2024

After discussion with @mbovel, consider adding this to stainless library and introduce a rewriting as sketched below to enable specification and verification of such functions in case of recursion.

import stainless.annotation.*
import stainless.lang.*
import stainless.lang.StaticChecks.*
import stainless.proof.check

@library
object Memo:
  case class MemoFun[A,B](f: A => B):
    @ignore @extern
    val cache: scala.collection.mutable.Map[A, B] = scala.collection.mutable.Map.empty
    @extern @pure
    def apply(a: A): B = { 
      cache.getOrElseUpdate(a, {println(f"Cache miss: f(${a}) = ${f(a)}, cache.size = ${cache.size}"); f(a)})
    }.ensuring(_ == f(a))
    @extern @pure
    def evict(a: A): Unit = cache.remove(a)
    @extern @pure
    def evictWhere(p: A => Boolean): Unit = 
      cache.filterInPlace((k, _) => !p(k))
    @extern @pure
    def evictAll(): Unit = cache.clear()

  val f: MemoFun[BigInt, BigInt] = MemoFun((n:BigInt) => {
    // require(0 <= n) // should be allowed
    // decreases(n)    // should be allowed
    if n <= 1 then n
              else f(n - 1) + f(n - 2)
  }) // .ensuring(_ >= 0) // should be allowed

  // the above should become, where fdef(x) should be used in the rest of code instead of `f(x)`:
  def fdef(n:BigInt): BigInt = {
    require(0 <= n)
    decreases(n)
    if n <= 1 then n
    else fdef(n - 1) + fdef(n - 2)
  }.ensuring(_ >= 0)

  // uses to f.evict, f.evictWhere, f.evicAll should be erased for verification

  @main @extern
  def main =
    val res1 = f(40)

    f.evict(6)
    f.evictWhere(_ > 4)

@vkuncak
Copy link
Collaborator

vkuncak commented Dec 6, 2024

@samuelchassot @mbovel did we conclude that using inline def could introduce memoization while appearing to stainless like recursion through function values,

val f : A => B = (x:A) => E(x, f(x))

?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants