diff --git a/.gitignore b/.gitignore index 68bc17f..c7efa6f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# Project Specific +data/*.npy +outputs/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index dc45504..84cc629 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,9 @@ Theoretical Foundations of Deep Selective State-Space Models: The experiments ```angular2html conda create -n jax_cde python=3.11 conda activate jax_cde -conda install pre-commit +conda install pre-commit numpy scikit-learn matplotlib +pip install -U "jax[cuda12]" +pip install diffrax signax==0.1.1 pre-commit install ``` diff --git a/data/generate_toy_dataset.py b/data/generate_toy_dataset.py new file mode 100644 index 0000000..8f4dc35 --- /dev/null +++ b/data/generate_toy_dataset.py @@ -0,0 +1,34 @@ +import jax +import jax.numpy as jnp +import jax.random as jr +import numpy as np +from signax.signature import signature + + +key = jr.PRNGKey(1234) +depth = 4 + +for dim in [2, 3]: + + data = jr.normal(key, shape=(100000, 100, dim)) + data = jnp.round(data) + data = jnp.cumsum(data, axis=1) + data = data / jnp.max(jnp.abs(data)) + + vmap_calc_sig = jax.vmap(signature, in_axes=(0, None)) + labels = vmap_calc_sig(data, depth) + + if dim == 2: + labels = labels[1][:, 0, 1] + elif dim == 3: + labels = labels[2][:, 0, 1, 2] + + labels = labels / jnp.max(jnp.abs(labels)) + + data = np.array(data) + labels = np.array(labels) + + with open(f"data/data_{dim}.npy", "wb") as f: + np.save(f, data) + with open(f"data/labels_{dim}.npy", "wb") as f: + np.save(f, labels) diff --git a/data/jax_dataloader.py b/data/jax_dataloader.py new file mode 100644 index 0000000..4a719c9 --- /dev/null +++ b/data/jax_dataloader.py @@ -0,0 +1,69 @@ +import jax.numpy as jnp +import jax.random as jr +import numpy as np + + +class InMemoryDataloader: + def __init__(self, num_train, num_test, train_x, train_y, test_x, test_y): + self.num_train = num_train + self.num_test = num_test + self.train_x = train_x + self.train_y = train_y + self.test_x = test_x + self.test_y = test_y + + def loop(self, batch_size, data, labels, *, key): + size = data.shape[0] + indices = jnp.arange(size) + while True: + perm_key, key = jr.split(key, 2) + perm = jr.permutation(perm_key, indices) + for X in self.loop_epoch(batch_size, data[perm], labels[perm]): + yield X + + def loop_epoch(self, batch_size, data, labels): + size = data.shape[0] + indices = jnp.arange(size) + start = 0 + end = batch_size + while end < size: + batch_indices = indices[start:end] + yield data[batch_indices], labels[batch_indices] + start = end + end = start + batch_size + batch_indices = indices[start:] + yield data[batch_indices], labels[batch_indices] + + def train_loop(self, batch_size, epoch=False, *, key): + if epoch: + return self.loop_epoch(batch_size, self.train_x, self.train_y) + else: + return self.loop(batch_size, self.train_x, self.train_y, key=key) + + def test_loop(self, batch_size, epoch=False, *, key): + if epoch: + return self.loop_epoch(batch_size, self.test_x, self.test_y) + else: + return self.loop(batch_size, self.test_x, self.test_y, key=key) + + +class ToyDataloader(InMemoryDataloader): + def __init__(self, num): + with open(f"data/data_{num}.npy", "rb") as f: + data = jnp.array(np.load(f)) + with open(f"data/labels_{num}.npy", "rb") as f: + labels = jnp.array(np.load(f)) + N = data.shape[0] + train_x = data[: int(0.8 * N)] + train_y = labels[: int(0.8 * N)] + test_x = data[int(0.8 * N) :] + test_y = labels[int(0.8 * N) :] + + super().__init__( + train_x.shape[0], + test_x.shape[0], + train_x, + train_y, + test_x, + test_y, + ) diff --git a/linear_cde.py b/linear_cde.py new file mode 100644 index 0000000..a801251 --- /dev/null +++ b/linear_cde.py @@ -0,0 +1,150 @@ +import os + +import diffrax as dfx +import jax +import jax.numpy as jnp +import jax.random as jr +import numpy as np +import sklearn.linear_model + +from data.jax_dataloader import ToyDataloader + + +class LinearCDE: + hidden_dim: int + data_dim: int + omega_dim: int + xi_dim: int + label_dim: int + vf_A: jnp.array + vf_B: jnp.array + + def __init__( + self, + hidden_dim, + data_dim, + omega_dim, + xi_dim, + label_dim, + *, + key, + ): + init_matrix_key, init_bias_key, vf_A_key, vf_B_key = jr.split(key, 4) + self.hidden_dim = hidden_dim + self.data_dim = data_dim + self.label_dim = label_dim + self.init_matrix = jr.normal(init_matrix_key, (hidden_dim, data_dim)) + self.init_bias = jr.normal(init_bias_key, (hidden_dim,)) + self.omega_dim = omega_dim + self.xi_dim = xi_dim + self.vf_A = jr.normal(vf_A_key, (hidden_dim, omega_dim, hidden_dim)) / ( + hidden_dim**0.5 + ) + self.vf_B = jr.normal(vf_B_key, (hidden_dim, xi_dim)) + + def __call__(self, ts, omega_path, xi_path, x0): + control_path = jnp.concatenate((omega_path, xi_path), axis=-1) + control = dfx.LinearInterpolation(ts=ts, ys=control_path) + y0 = self.init_matrix @ x0 + self.init_bias + + def func(t, y, args): + return jnp.concatenate((jnp.dot(self.vf_A, y), self.vf_B), axis=-1) + + term = dfx.ControlTerm(func, control).to_ode() + saveat = dfx.SaveAt(t1=True) + solution = dfx.diffeqsolve( + term, + dfx.Tsit5(), + 0, + 1, + 0.01, + y0, + stepsize_controller=dfx.PIDController(atol=1e-2, rtol=1e-2, jump_ts=ts), + saveat=saveat, + ) + return solution.ys[-1] + + +def obtain_features_from_model(model, dataloader, batch_size, num_samples, label_dim): + features = np.zeros((num_samples, model.hidden_dim)) + labels = np.zeros((num_samples, label_dim)) + start = 0 + end = start + i = 0 + vmap_model = jax.vmap(model) + for data in dataloader: + i += 1 + print(f"Batch {i}") + X, y = data + ts = jnp.repeat(jnp.linspace(0.0, 1.0, X.shape[1])[None, :], batch_size, axis=0) + input = jnp.concatenate((ts[..., None], X), axis=-1) + out = vmap_model(ts, input, input, X[:, 0, :]) + end += len(out) + features[start:end] = out + labels[start:end] = y[:, None] + start = end + return features, labels + + +def train_linear( + model, + dataloader, + num_train, + num_test, + label_dim, + batch_size, + *, + key, +): + + featkey_train, featkey_test, key = jr.split(key, 3) + + features_train, labels_train = obtain_features_from_model( + model, + dataloader.train_loop(batch_size, epoch=True, key=key), + batch_size, + num_train, + label_dim, + ) + features_test, labels_test = obtain_features_from_model( + model, + dataloader.test_loop(batch_size, epoch=True, key=key), + batch_size, + num_test, + label_dim, + ) + + clf = sklearn.linear_model.LinearRegression() + clf.fit(features_train, labels_train) + predictions = clf.predict(features_test) + mse = jnp.mean((labels_test - predictions) ** 2) + return mse + + +if __name__ == "__main__": + key = jr.PRNGKey(2345) + if not os.path.isdir("outputs"): + os.mkdir("outputs") + model_key, train_key = jr.split(key) + hidden_dim = 256 + label_dim = 1 + batch_size = 4000 + for data_dim in [2, 3]: + omega_dim = data_dim + 1 + xi_dim = data_dim + 1 + model = LinearCDE( + hidden_dim, data_dim, omega_dim, xi_dim, label_dim, key=model_key + ) + dataset = ToyDataloader(num=data_dim) + mse = train_linear( + model, + dataset, + num_train=dataset.num_train, + num_test=dataset.num_test, + label_dim=label_dim, + batch_size=batch_size, + key=train_key, + ) + mse = np.array(mse) + np.save(f"outputs/lin_cde_mse_{data_dim}.npy", mse) + print(f"Data dim: {data_dim}, MSE: {mse}")