Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added toy dataset and linear cde with linear regression #1

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Project Specific
data/*.npy
outputs/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ Theoretical Foundations of Deep Selective State-Space Models: The experiments
```angular2html
conda create -n jax_cde python=3.11
conda activate jax_cde
conda install pre-commit
conda install pre-commit numpy scikit-learn matplotlib
pip install -U "jax[cuda12]"
pip install diffrax signax==0.1.1
pre-commit install
```

Expand Down
34 changes: 34 additions & 0 deletions data/generate_toy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
from signax.signature import signature


key = jr.PRNGKey(1234)
depth = 4

for dim in [2, 3]:

data = jr.normal(key, shape=(100000, 100, dim))
data = jnp.round(data)
data = jnp.cumsum(data, axis=1)
data = data / jnp.max(jnp.abs(data))

vmap_calc_sig = jax.vmap(signature, in_axes=(0, None))
labels = vmap_calc_sig(data, depth)

if dim == 2:
labels = labels[1][:, 0, 1]
elif dim == 3:
labels = labels[2][:, 0, 1, 2]

labels = labels / jnp.max(jnp.abs(labels))

data = np.array(data)
labels = np.array(labels)

with open(f"data/data_{dim}.npy", "wb") as f:
np.save(f, data)
with open(f"data/labels_{dim}.npy", "wb") as f:
np.save(f, labels)
69 changes: 69 additions & 0 deletions data/jax_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import jax.numpy as jnp
import jax.random as jr
import numpy as np


class InMemoryDataloader:
def __init__(self, num_train, num_test, train_x, train_y, test_x, test_y):
self.num_train = num_train
self.num_test = num_test
self.train_x = train_x
self.train_y = train_y
self.test_x = test_x
self.test_y = test_y

def loop(self, batch_size, data, labels, *, key):
size = data.shape[0]
indices = jnp.arange(size)
while True:
perm_key, key = jr.split(key, 2)
perm = jr.permutation(perm_key, indices)
for X in self.loop_epoch(batch_size, data[perm], labels[perm]):
yield X

def loop_epoch(self, batch_size, data, labels):
size = data.shape[0]
indices = jnp.arange(size)
start = 0
end = batch_size
while end < size:
batch_indices = indices[start:end]
yield data[batch_indices], labels[batch_indices]
start = end
end = start + batch_size
batch_indices = indices[start:]
yield data[batch_indices], labels[batch_indices]

def train_loop(self, batch_size, epoch=False, *, key):
if epoch:
return self.loop_epoch(batch_size, self.train_x, self.train_y)
else:
return self.loop(batch_size, self.train_x, self.train_y, key=key)

def test_loop(self, batch_size, epoch=False, *, key):
if epoch:
return self.loop_epoch(batch_size, self.test_x, self.test_y)
else:
return self.loop(batch_size, self.test_x, self.test_y, key=key)


class ToyDataloader(InMemoryDataloader):
def __init__(self, num):
with open(f"data/data_{num}.npy", "rb") as f:
data = jnp.array(np.load(f))
with open(f"data/labels_{num}.npy", "rb") as f:
labels = jnp.array(np.load(f))
N = data.shape[0]
train_x = data[: int(0.8 * N)]
train_y = labels[: int(0.8 * N)]
test_x = data[int(0.8 * N) :]
test_y = labels[int(0.8 * N) :]

super().__init__(
train_x.shape[0],
test_x.shape[0],
train_x,
train_y,
test_x,
test_y,
)
150 changes: 150 additions & 0 deletions linear_cde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os

import diffrax as dfx
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import sklearn.linear_model

from data.jax_dataloader import ToyDataloader


class LinearCDE:
hidden_dim: int
data_dim: int
omega_dim: int
xi_dim: int
label_dim: int
vf_A: jnp.array
vf_B: jnp.array

def __init__(
self,
hidden_dim,
data_dim,
omega_dim,
xi_dim,
label_dim,
*,
key,
):
init_matrix_key, init_bias_key, vf_A_key, vf_B_key = jr.split(key, 4)
self.hidden_dim = hidden_dim
self.data_dim = data_dim
self.label_dim = label_dim
self.init_matrix = jr.normal(init_matrix_key, (hidden_dim, data_dim))
self.init_bias = jr.normal(init_bias_key, (hidden_dim,))
self.omega_dim = omega_dim
self.xi_dim = xi_dim
self.vf_A = jr.normal(vf_A_key, (hidden_dim, omega_dim, hidden_dim)) / (
hidden_dim**0.5
)
self.vf_B = jr.normal(vf_B_key, (hidden_dim, xi_dim))

def __call__(self, ts, omega_path, xi_path, x0):
control_path = jnp.concatenate((omega_path, xi_path), axis=-1)
control = dfx.LinearInterpolation(ts=ts, ys=control_path)
y0 = self.init_matrix @ x0 + self.init_bias

def func(t, y, args):
return jnp.concatenate((jnp.dot(self.vf_A, y), self.vf_B), axis=-1)

term = dfx.ControlTerm(func, control).to_ode()
saveat = dfx.SaveAt(t1=True)
solution = dfx.diffeqsolve(
term,
dfx.Tsit5(),
0,
1,
0.01,
y0,
stepsize_controller=dfx.PIDController(atol=1e-2, rtol=1e-2, jump_ts=ts),
saveat=saveat,
)
return solution.ys[-1]


def obtain_features_from_model(model, dataloader, batch_size, num_samples, label_dim):
features = np.zeros((num_samples, model.hidden_dim))
labels = np.zeros((num_samples, label_dim))
start = 0
end = start
i = 0
vmap_model = jax.vmap(model)
for data in dataloader:
i += 1
print(f"Batch {i}")
X, y = data
ts = jnp.repeat(jnp.linspace(0.0, 1.0, X.shape[1])[None, :], batch_size, axis=0)
input = jnp.concatenate((ts[..., None], X), axis=-1)
out = vmap_model(ts, input, input, X[:, 0, :])
end += len(out)
features[start:end] = out
labels[start:end] = y[:, None]
start = end
return features, labels


def train_linear(
model,
dataloader,
num_train,
num_test,
label_dim,
batch_size,
*,
key,
):

featkey_train, featkey_test, key = jr.split(key, 3)

features_train, labels_train = obtain_features_from_model(
model,
dataloader.train_loop(batch_size, epoch=True, key=key),
batch_size,
num_train,
label_dim,
)
features_test, labels_test = obtain_features_from_model(
model,
dataloader.test_loop(batch_size, epoch=True, key=key),
batch_size,
num_test,
label_dim,
)

clf = sklearn.linear_model.LinearRegression()
clf.fit(features_train, labels_train)
predictions = clf.predict(features_test)
mse = jnp.mean((labels_test - predictions) ** 2)
return mse


if __name__ == "__main__":
key = jr.PRNGKey(2345)
if not os.path.isdir("outputs"):
os.mkdir("outputs")
model_key, train_key = jr.split(key)
hidden_dim = 256
label_dim = 1
batch_size = 4000
for data_dim in [2, 3]:
omega_dim = data_dim + 1
xi_dim = data_dim + 1
model = LinearCDE(
hidden_dim, data_dim, omega_dim, xi_dim, label_dim, key=model_key
)
dataset = ToyDataloader(num=data_dim)
mse = train_linear(
model,
dataset,
num_train=dataset.num_train,
num_test=dataset.num_test,
label_dim=label_dim,
batch_size=batch_size,
key=train_key,
)
mse = np.array(mse)
np.save(f"outputs/lin_cde_mse_{data_dim}.npy", mse)
print(f"Data dim: {data_dim}, MSE: {mse}")
Loading