-
Notifications
You must be signed in to change notification settings - Fork 0
/
PointProcess.py
135 lines (96 loc) · 4.21 KB
/
PointProcess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from Utils.loss import MaxLogLike, fill_triu
import torch.nn.functional as F
class PointProcessModel(torch.nn.Module):
"""
The class of generalized Hawkes process model
contains most of necessary function.
"""
def __init__(self, num_type: int):
"""
Initialize generalized Hawkes process
:param num_type: int, the number of event types.
"""
super().__init__()
self.model_name = 'A Poisson Process'
self.num_type = num_type
self.mu = torch.nn.Parameter(torch.randn(self.num_type) * 0.5 - 2.0)
self.loss_function = MaxLogLike()
class MultiVariateHawkesProcessModel(PointProcessModel):
"""
The class of generalized Hawkes process model
contains most of necessary function.
"""
def __init__(self, num_type: int, num_decay : int):
"""
Initialize generalized Hawkes process
:param num_type: int, the number of event types.
:param num_decay: int, the number of decay functions
"""
super().__init__(num_type)
self.model_name = 'A MultiVariate Hawkes Process'
self.num_decay = num_decay
self.alpha = torch.nn.Parameter(torch.randn(self.num_type, self.num_type, self.num_decay) * 0.5 - 3.0)
self.beta = torch.nn.Parameter(torch.randn(self.num_decay) * 0.5)
def forward(self, event_times, event_types, input_mask, t0, t1):
"""
:param event_times: B x N
:param input_mask: B x N
:param t0: starting time
:param t1: ending time
:return: loglikelihood
"""
mhat = F.softplus(self.mu)
Ahat = F.softplus(self.alpha)
omega = F.softplus(self.beta)
B, N = event_times.shape
dim = mhat.shape[0]
# compute m_{u_i}
mu = mhat[event_types] # B x N
# diffs[i,j] = t_i - t_j for j < i (o.w. zero)
dt = event_times[:, :, None] - event_times[:, None] # (N, T, T)
dt = fill_triu(dt, 0)
# kern[i,j] = omega* torch.exp(-omega*dt[i,j])
kern = omega * torch.exp(-omega * dt)
colidx = event_types.unsqueeze(1).repeat(1, N, 1)
rowidx = event_types.unsqueeze(2).repeat(1, 1, N)
Auu = Ahat[rowidx, colidx].squeeze(dim=3)
ag = Auu * kern
ag = fill_triu(ag, 0)
# compute total rates of u_i at time i
rates = mu + torch.sum(ag, dim=2)
#baseline \sum_i^dim \int_0^T \mu_i
compensator_baseline = (t1 - t0) * torch.sum(mhat)
# \int_{t_i}^T \omega \exp{ -\omega (t - t_i ) }
log_kernel = -omega * (t1[:, None] - event_times)
Int_kernel = (1 - torch.exp(log_kernel)).unsqueeze(1)
Au = Ahat[:, event_types].permute(1, 0, 2, 3).squeeze(3)
Au_Int_kernel = (Au * Int_kernel).sum(dim=1) * input_mask
compensator = compensator_baseline + Au_Int_kernel.sum(dim=1)
loglik = torch.log(rates + 1e-8).mul(input_mask).sum(-1) #
return (loglik, compensator)
class HawkesPointProcess(torch.nn.Module):
def __init__(self):
super().__init__()
self.mu = torch.nn.Parameter(torch.zeros(1)) #torch.nn.Parameter(torch.randn(1) * 0.5 - 2.0)
self.alpha = torch.nn.Parameter(torch.zeros(1)) #torch.nn.Parameter(torch.randn(1) * 0.5 - 3.0)
self.beta = torch.nn.Parameter(torch.zeros(1)) #torch.nn.Parameter(torch.randn(1) * 0.5)
def logprob(self, event_times, input_mask, t0, t1):
"""
:param event_times:
:param input_mask:
:param t0:
:param t1:
:return:
"""
mu = F.softplus(self.mu)
alpha = F.softplus(self.alpha)
beta = F.softplus(self.beta)
dt = event_times[:, :, None] - event_times[:, None] # (N, T, T)
dt = fill_triu(-dt * beta, -1e20)
lamb = torch.exp(torch.logsumexp(dt, dim=-1)) * alpha * beta + mu # (N, T)
loglik = torch.log(lamb + 1e-8).mul(input_mask).sum(-1) # (N,)
log_kernel = -beta * (t1[:, None] - event_times) * input_mask + (1.0 - input_mask) * -1e20
compensator = (t1 - t0) * mu
compensator = compensator - alpha * (torch.exp(torch.logsumexp(log_kernel, dim=-1)) - input_mask.sum(-1))
return (loglik- compensator)