Skip to content

Commit

Permalink
use new stx in scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
kach committed Jul 3, 2024
1 parent 30ee4fa commit 55717ad
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions demo-scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,12 @@
U = [0, 1, 2] # utterance: {none, some, all} of the people are nice

@jax.jit
def meaning(n, u):
# (none) (some) (all)
def meaning(n, u): # (none) (some) (all)
return np.array([ n == 0, n > 0, n == 3 ])[u]

@memo
def scalar():
def scalar[n: N, u: U]():
cast: [speaker, listener]
forall: n in N
forall: u in U

listener: thinks[
speaker: given(n in N, wpp=1),
speaker: chooses(u in U, wpp=imagine[
Expand All @@ -26,7 +22,6 @@ def scalar():
]
listener: hears [speaker.u] is u
listener: chooses(n in N, wpp=E[speaker.n == n])

return E[listener.n == n]

print(scalar())

0 comments on commit 55717ad

Please sign in to comment.