-
Notifications
You must be signed in to change notification settings - Fork 4
/
rsa.wppl
78 lines (60 loc) · 2.03 KB
/
rsa.wppl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
///first load and cleanup the data.
//TODO
var data = [2, 2, 2]
///next the RSA model, with a parameter for alternative weights.
var worldPrior = function() {
var num_nice_people = randomInteger(4) //3 people.. 0-3 can be nice.
return num_nice_people
}
var utterancePrior = function(altWeights) {
var utterances = ["some of the people are nice",
"all of the people are nice",
"none of the people are nice"]
var i = discrete(altWeights)
return utterances[i]
}
var meaning = function(utt,world) {
return utt=="some of the people are nice"? world>0 :
utt=="all of the people are nice"? world==3 :
utt=="none of the people are nice"? world==0 :
true
}
var literalListener = cache(function(utterance) {
Enumerate(function(){
var world = worldPrior()
var m = meaning(utterance, world)
factor(m?0:-Infinity)
return world
})
})
var speaker = cache(function(world, altWeights) {
Enumerate(function(){
var utterance = utterancePrior(altWeights)
var L = literalListener(utterance)
factor(L.score([],world))
return utterance
})
})
var listener = function(altWeights, utterance) {
display("run listener, altWeights: "+altWeights)
Enumerate(function(){
var world = worldPrior()
var S = speaker(world, altWeights)
factor(S.score([],utterance))
return world
})
}
var listenerScore = gpcache.cache(function(altWeights, utterance, world) {
return listener(altWeights, utterance).score([],world)
}, 1.0)
///now analyse the data against the RSA model. infer posterior on alternative weights.
var allWeight = MH(function(){
// var altWeights = [1, uniform(0,1), 1]
var altWeights = [1, uniformDraw([0.5, 1, 2]), 1]
var utterance = "some of the people are nice"
//TODO: response linking function?
map(function(r){factor(listenerScore(altWeights, utterance, r))}, data)
return altWeights[1]
},
100)
expectation(allWeight, function(x){return x})