-
Notifications
You must be signed in to change notification settings - Fork 2
/
distributions.scm
95 lines (87 loc) · 2.67 KB
/
distributions.scm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
;;; Discrete distributions
;; Simulates a fair coin flip: returns 1 with probability 1/2 and 0
;; with probability 1/2.
(define flip
(make-operator
(lambda () (random 2))
(lambda ()
(lambda (x)
(mass->logmass
(cond ((= x 1) 1/2)
((= x 0) 1/2)
(else 0)))))))
;; Bernoulli distribution: returns 1 with probability p and 0 with
;; probability 1-p.
(define bernoulli
(make-operator
(lambda (p)
(if (<= (random 1.0) p)
1
0))
(lambda (p)
(lambda (x)
(mass->logmass
(cond ((= x 1) p)
((= x 0) (- 1 p))
(else 0)))))))
;;; "Continuous" distributions
;; These distributions are in fact implemented as discrete
;; distributions, based on the machine's precision.
(define (logdensity->logmass logdensity x)
(let ((epsilon
;; This calculation adapted from gjs's scmutils: find the
;; smallest number that makes a difference when added to x
(if (= x (+ 1. x))
(let loop ((e 1.0))
(if (= x (+ e x))
(loop (* e 2))
e))
(let loop ((e 1.0))
(if (= x (+ (/ e 2) x))
e
(loop (/ e 2)))))))
(+ logdensity (log epsilon))))
(define density->logdensity mass->logmass)
;; Uniform distribution: returns a value between a inclusive and b
;; exclusive. Assumes a < b.
(define cont-uniform
(make-operator
(lambda (a b)
(+ (* (random 1.0)
(- b a))
a))
(lambda (a b)
(lambda (x)
(if (and (>= x a) (< x b))
(logdensity->logmass
(- (log (- b a)))
x)
-inf)))))
;; Normal distribution with mean of mean and standard deviation of
;; stdev.
(define normal
(make-operator
(lambda (mean stdev)
;; draw from standard normal distribution using marsaglia polar
;; method
;; note: if it is desired to draw as few random numbers as
;; possible, this can be rewritten to generate two standard
;; normal deviates instead of one, the other being
;; (* v (sqrt (/ (* -2 (log s)) s)))
(let lp ((u (- 1 (random 2.0)))
(v (- 1 (random 2.0))))
(let ((s (+ (square u) (square v))))
(if (>= s 1.)
(lp (- 1 (random 2.0)) (- 1 (random 2.0)))
(let ((z (* u (sqrt (/ (* -2 (log s)) s)))))
(+ mean (* z stdev)))))))
(lambda (mean stdev)
(lambda (x)
(logdensity->logmass
(density->logdensity
(/
(exp (-
(/ (square (- x mean))
(* 2 (square stdev)))))
(* stdev (sqrt (* 2 *pi*)))))
x)))))