-
Notifications
You must be signed in to change notification settings - Fork 18
/
place_cells.py
122 lines (95 loc) · 4.51 KB
/
place_cells.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
import numpy as np
import torch
import scipy
class PlaceCells(object):
def __init__(self, options, us=None):
self.Np = options.Np
self.sigma = options.place_cell_rf
self.surround_scale = options.surround_scale
self.box_width = options.box_width
self.box_height = options.box_height
self.is_periodic = options.periodic
self.DoG = options.DoG
self.device = options.device
self.softmax = torch.nn.Softmax(dim=-1)
# Randomly tile place cell centers across environment
np.random.seed(0)
usx = np.random.uniform(-self.box_width/2, self.box_width/2, (self.Np,))
usy = np.random.uniform(-self.box_width/2, self.box_width/2, (self.Np,))
self.us = torch.tensor(np.vstack([usx, usy]).T)
# If using a GPU, put on GPU
self.us = self.us.to(self.device)
# self.us = torch.tensor(np.load('models/example_pc_centers.npy')).cuda()
def get_activation(self, pos):
'''
Get place cell activations for a given position.
Args:
pos: 2d position of shape [batch_size, sequence_length, 2].
Returns:
outputs: Place cell activations with shape [batch_size, sequence_length, Np].
'''
d = torch.abs(pos[:, :, None, :] - self.us[None, None, ...]).float()
if self.is_periodic:
dx = d[:,:,:,0]
dy = d[:,:,:,1]
dx = torch.minimum(dx, self.box_width - dx)
dy = torch.minimum(dy, self.box_height - dy)
d = torch.stack([dx,dy], axis=-1)
norm2 = (d**2).sum(-1)
# Normalize place cell outputs with prefactor alpha=1/2/np.pi/self.sigma**2,
# or, simply normalize with softmax, which yields same normalization on
# average and seems to speed up training.
outputs = self.softmax(-norm2/(2*self.sigma**2))
if self.DoG:
# Again, normalize with prefactor
# beta=1/2/np.pi/self.sigma**2/self.surround_scale, or use softmax.
outputs -= self.softmax(-norm2/(2*self.surround_scale*self.sigma**2))
# Shift and scale outputs so that they lie in [0,1].
min_output,_ = outputs.min(-1,keepdims=True)
outputs += torch.abs(min_output)
outputs /= outputs.sum(-1, keepdims=True)
return outputs
def get_nearest_cell_pos(self, activation, k=3):
'''
Decode position using centers of k maximally active place cells.
Args:
activation: Place cell activations of shape [batch_size, sequence_length, Np].
k: Number of maximally active place cells with which to decode position.
Returns:
pred_pos: Predicted 2d position with shape [batch_size, sequence_length, 2].
'''
_, idxs = torch.topk(activation, k=k)
pred_pos = self.us[idxs].mean(-2)
return pred_pos
def grid_pc(self, pc_outputs, res=32):
''' Interpolate place cell outputs onto a grid'''
coordsx = np.linspace(-self.box_width/2, self.box_width/2, res)
coordsy = np.linspace(-self.box_height/2, self.box_height/2, res)
grid_x, grid_y = np.meshgrid(coordsx, coordsy)
grid = np.stack([grid_x.ravel(), grid_y.ravel()]).T
# Convert to numpy
pc_outputs = pc_outputs.reshape(-1, self.Np)
T = pc_outputs.shape[0] #T vs transpose? What is T? (dim's?)
pc = np.zeros([T, res, res])
for i in range(len(pc_outputs)):
gridval = scipy.interpolate.griddata(self.us.cpu(), pc_outputs[i], grid)
pc[i] = gridval.reshape([res, res])
return pc
def compute_covariance(self, res=30):
'''Compute spatial covariance matrix of place cell outputs'''
pos = np.array(np.meshgrid(np.linspace(-self.box_width/2, self.box_width/2, res),
np.linspace(-self.box_height/2, self.box_height/2, res))).T
pos = torch.tensor(pos)
# Put on GPU if available
pos = pos.to(self.device)
#Maybe specify dimensions here again?
pc_outputs = self.get_activation(pos).reshape(-1,self.Np).cpu()
C = pc_outputs@pc_outputs.T
Csquare = C.reshape(res,res,res,res)
Cmean = np.zeros([res,res])
for i in range(res):
for j in range(res):
Cmean += np.roll(np.roll(Csquare[i,j], -i, axis=0), -j, axis=1)
Cmean = np.roll(np.roll(Cmean, res//2, axis=0), res//2, axis=1)
return Cmean