From a62317c4308d89f7b8c130c091785411d8b9bdf8 Mon Sep 17 00:00:00 2001 From: Sergey Skorik Date: Sun, 9 Jun 2024 15:15:15 +0000 Subject: [PATCH 1/4] Init commit in ecg generation pipeline --- setup.cfg | 3 +++ src/ecglib/models/architectures/model_types.py | 5 ++++- src/ecglib/models/architectures/registred_models.py | 4 +++- src/ecglib/models/architectures/sssd/__init__.py | 0 src/ecglib/models/architectures/sssd/s4.py | 2 ++ src/ecglib/models/architectures/sssd/sssd_ecg_nle.py | 2 ++ 6 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 src/ecglib/models/architectures/sssd/__init__.py create mode 100644 src/ecglib/models/architectures/sssd/s4.py create mode 100644 src/ecglib/models/architectures/sssd/sssd_ecg_nle.py diff --git a/setup.cfg b/setup.cfg index 8960de6..de10762 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,6 +29,9 @@ install_requires = PyWavelets>=1.3.0 wfdb>=4.0.0 omegaconf>=2.3.0 + einops>=0.7.0 + opt_einsum>=3.3.0 + pytest>=7.4.2 python_requires = >=3.6 [options.packages.find] diff --git a/src/ecglib/models/architectures/model_types.py b/src/ecglib/models/architectures/model_types.py index cccf3c7..fad9843 100644 --- a/src/ecglib/models/architectures/model_types.py +++ b/src/ecglib/models/architectures/model_types.py @@ -8,7 +8,8 @@ class MType(IntEnum): DENSENET = 1 TABULAR = 2 CNN = 3 - OTHER = 4 # use to sign custom models + SSSD = 4 + OTHER = 5 # use to sign custom models @staticmethod def from_string(label: str) -> IntEnum: @@ -21,6 +22,8 @@ def from_string(label: str) -> IntEnum: return MType.TABULAR elif "cnn1d" in label: return MType.CNN + elif "sssd" in label: + return MType.SSSD elif "other" in label: return MType.OTHER else: diff --git a/src/ecglib/models/architectures/registred_models.py b/src/ecglib/models/architectures/registred_models.py index 4c5c20e..1bc083b 100644 --- a/src/ecglib/models/architectures/registred_models.py +++ b/src/ecglib/models/architectures/registred_models.py @@ -6,6 +6,7 @@ from .densenet1d import densenet121_1d, densenet201_1d from .tabular import tabular from .cnn1d import cnn1d +from .sssd.sssd_ecg_nle import sssd_ecg __all__ = ["register_model", "registred_models", "get_builder", "is_model_registred"] @@ -19,7 +20,8 @@ "resnet1d50": resnet1d50, "resnet1d101": resnet1d101, "tabular": tabular, - "cnn1d": cnn1d + "cnn1d": cnn1d, + "sssd_ecg": sssd_ecg, } diff --git a/src/ecglib/models/architectures/sssd/__init__.py b/src/ecglib/models/architectures/sssd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ecglib/models/architectures/sssd/s4.py b/src/ecglib/models/architectures/sssd/s4.py new file mode 100644 index 0000000..46b695e --- /dev/null +++ b/src/ecglib/models/architectures/sssd/s4.py @@ -0,0 +1,2 @@ +class S4: + raise NotImplementedError \ No newline at end of file diff --git a/src/ecglib/models/architectures/sssd/sssd_ecg_nle.py b/src/ecglib/models/architectures/sssd/sssd_ecg_nle.py new file mode 100644 index 0000000..46ed491 --- /dev/null +++ b/src/ecglib/models/architectures/sssd/sssd_ecg_nle.py @@ -0,0 +1,2 @@ +class SSSD_ECG_nle: + raise NotImplementedError \ No newline at end of file From dc13d4aa5df44394809ceabd3a6f4b11416fbcdb Mon Sep 17 00:00:00 2001 From: Sergey Skorik Date: Tue, 18 Jun 2024 13:53:35 +0000 Subject: [PATCH 2/4] #15: fix excessive lead selection --- src/ecglib/data/datasets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ecglib/data/datasets.py b/src/ecglib/data/datasets.py index acb69f6..90d580d 100644 --- a/src/ecglib/data/datasets.py +++ b/src/ecglib/data/datasets.py @@ -84,12 +84,12 @@ def get_name(self, index: int) -> str: return str(Path(self.get_fpath(index)).stem) def read_ecg_record( - self, file_path, data_type, leads=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - ): + self, file_path, data_type + ): if data_type == "npz": ecg_record = np.load(file_path)["arr_0"].astype("float64") elif data_type == "wfdb": - ecg_record, _ = wfdb.rdsamp(file_path, channels=leads) + ecg_record, _ = wfdb.rdsamp(file_path) ecg_record = ecg_record.T ecg_record = ecg_record.astype("float64") else: @@ -140,7 +140,7 @@ def __getitem__(self, index): file_path = self.ecg_data.iloc[index]["fpath"] # data standartization (scaling, resampling, cuts off, normalization and padding/truncation) - ecg_record = self.read_ecg_record(file_path, self.data_type, self.leads) + ecg_record = self.read_ecg_record(file_path, self.data_type) full_ecg_record_info = EcgRecord( signal=ecg_record[self.leads, :], frequency=ecg_frequency, From 14e7d18417acac8df07f02204664d8f3e3bd2e43 Mon Sep 17 00:00:00 2001 From: Sergey Skorik Date: Tue, 18 Jun 2024 15:14:37 +0000 Subject: [PATCH 3/4] #13 add sssd_ecg_nle paper code --- README.md | 6 +- notebooks/sssd_ecg_nle.ipynb | 526 ++++++++ .../models/architectures/registred_models.py | 4 +- .../architectures/sssd/extensions/__init__.py | 0 .../sssd/extensions/cauchy/__init__.py | 0 .../sssd/extensions/cauchy/cauchy.cpp | 102 ++ .../sssd/extensions/cauchy/cauchy.py | 102 ++ .../sssd/extensions/cauchy/cauchy_cuda.cu | 368 +++++ .../sssd/extensions/cauchy/map.h | 72 + .../sssd/extensions/cauchy/setup.py | 25 + src/ecglib/models/architectures/sssd/s4.py | 1190 ++++++++++++++++- .../models/architectures/sssd/sssd_ecg_nle.py | 245 +++- src/ecglib/models/architectures/sssd/util.py | 33 + src/ecglib/models/config/model_configs.py | 28 +- src/ecglib/models/config/registred_configs.py | 2 + src/ecglib/preprocessing/functional.py | 13 + src/ecglib/preprocessing/preprocess.py | 6 +- 17 files changed, 2712 insertions(+), 10 deletions(-) create mode 100644 notebooks/sssd_ecg_nle.ipynb create mode 100644 src/ecglib/models/architectures/sssd/extensions/__init__.py create mode 100644 src/ecglib/models/architectures/sssd/extensions/cauchy/__init__.py create mode 100644 src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy.cpp create mode 100644 src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy.py create mode 100644 src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy_cuda.cu create mode 100644 src/ecglib/models/architectures/sssd/extensions/cauchy/map.h create mode 100644 src/ecglib/models/architectures/sssd/extensions/cauchy/setup.py create mode 100644 src/ecglib/models/architectures/sssd/util.py diff --git a/README.md b/README.md index 5c64225..618b2b4 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ - [Models](#models) - [Preprocessing](#preprocessing) - [Predict](#predict) +- [Generation](#generation) ### Introduction @@ -69,7 +70,7 @@ Via `datasets.py` one can create class *EcgDataset* to store ECG datasets. It st from ecglib.data import EcgDataset -targets = [[0.0] if 'AFIB' in eval(ptb_xl_info.iloc[i]['scp_codes']).keys() else [1.0] +targets = [[1.0] if 'AFIB' in eval(ptb_xl_info.iloc[i]['scp_codes']).keys() else [0.0] for i in range(ptb_xl_info.shape[0])] ecg_data = EcgDataset(ecg_data=ptb_xl_info, target=targets) ``` @@ -158,3 +159,6 @@ result_df = predict.predict_directory(directory="path/to/data_to_predict", file_type="wfdb") print(predict.predict(ecg_signal, channels_first=False)) ``` + +### Generation +`ecglib` contains the architecture of the diffusion model `SSSD_ECG_nle`, with which you can obtain synthetic signals. The training and generation pipeline is presented in `notebooks/sssd_ecg_nle.ipynb`. \ No newline at end of file diff --git a/notebooks/sssd_ecg_nle.ipynb b/notebooks/sssd_ecg_nle.ipynb new file mode 100644 index 0000000..3cbc5fb --- /dev/null +++ b/notebooks/sssd_ecg_nle.ipynb @@ -0,0 +1,526 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SSSD-ECG-nle methods" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Requirements installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%cd ../\n", + "# Install ecglib package\n", + "%pip install .\n", + "# Install the cauchy CUDA kernel (it's need to fast calculate s4 kernel)\n", + "%pip install src/ecglib/models/architectures/sssd/extensions/cauchy/.\n", + "%cd src/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SSSD-ECG-nle training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from dataclasses import dataclass, field\n", + "from typing import List, Callable\n", + "import signal\n", + "import sys\n", + "\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "import pandas as pd\n", + "import wfdb\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "\n", + "\n", + "from ecglib.data import load_ptb_xl\n", + "from ecglib.data import EcgDataset\n", + "from ecglib.models.config.model_configs import SSSDConfig\n", + "from ecglib.models.model_builder import create_model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Prepare PTB-XL ECG dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "need to change a logic of downloading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "path_to_unzip = '/home/ecg_data/physionet_data/raw_data/'\n", + "SAMPLE_FREQUENCY = 100\n", + "ptb_xl_info = load_ptb_xl(path_to_unzip=path_to_unzip, frequency=SAMPLE_FREQUENCY)\n", + "ptb_xl_info['frequency'] = SAMPLE_FREQUENCY\n", + "\n", + "# Split in accordance with PTB-XL Benchmarking\n", + "val_ptbxl_info = ptb_xl_info.loc[ptb_xl_info['strat_fold'] == 9].reset_index()\n", + "test_ptbxl_info = ptb_xl_info.loc[ptb_xl_info['strat_fold'] == 10].reset_index()\n", + "train_ptbxl_info = ptb_xl_info.loc[~ptb_xl_info['strat_fold'].isin([9, 10])].reset_index()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configure parameters of dataset\n", + "@dataclass(repr=True)\n", + "class DataParams:\n", + " syndromes: List[str] = field(default_factory=lambda: ['AFIB']) # CRBBB, 1AVB, PVC or any code from `scp_statements.csv`. For multilabel-scenario append to list\n", + " normalization: str = 'identical'\n", + " leads: List[int] = field(default_factory=lambda: [0, 5, 6, 7, 8, 9, 10, 11]) # Standart sequence ['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']. \n", + " # We take only 8 LI leads ['I', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']\n", + " sample_frequency: int = SAMPLE_FREQUENCY # 500\n", + " augmentation: Callable = None # Check README.md Preprocessing for more details\n", + " batch_size: int = 8 # Train with batch_size=8 require 15Gb GPU memory\n", + "\n", + "data_config = DataParams()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# make sure the sequence of leads\n", + "wfdb.rdsamp(ptb_xl_info.iloc[0]['fpath'])[1]['sig_name']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create EcgDataset's\n", + "def dataset_from_map_file(map_file):\n", + " targets = [[1.0 if synd in eval(map_file.iloc[i]['scp_codes']).keys() else 0.0 for synd in data_config.syndromes] \n", + " for i in range(map_file.shape[0])]\n", + " return EcgDataset(ecg_data=map_file, \n", + " target=targets,\n", + " frequency=data_config.sample_frequency,\n", + " leads=data_config.leads,\n", + " norm_type=data_config.normalization,\n", + " classes=len(data_config.syndromes),\n", + " augmentation=data_config.augmentation\n", + " )\n", + "\n", + "train_dataset = dataset_from_map_file(train_ptbxl_info)\n", + "val_dataset = dataset_from_map_file(val_ptbxl_info)\n", + "test_dataset = dataset_from_map_file(test_ptbxl_info)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create torch dataloaders\n", + "def dataloader_from_dataset(dataset, train=True):\n", + " return DataLoader(\n", + " dataset,\n", + " batch_size=data_config.batch_size,\n", + " shuffle=train,\n", + " drop_last=train,\n", + " )\n", + "\n", + "train_loader = dataloader_from_dataset(train_dataset)\n", + "val_loader = dataloader_from_dataset(val_dataset, train=False)\n", + "test_loader = dataloader_from_dataset(test_dataset, train=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create SSSD-ECG-nle" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sssd_ecg_nle_config = SSSDConfig(\n", + " in_channels=len(data_config.leads),\n", + " label_embed_classes=len(data_config.syndromes)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sssd_model = create_model(\n", + " model_name='sssd_ecg',\n", + " config=sssd_ecg_nle_config,\n", + " pathology=data_config.syndromes,\n", + " leads_count=len(data_config.leads),\n", + " num_classes=len(data_config.syndromes)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Training pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configure parameters of dataset\n", + "@dataclass(repr=True)\n", + "class TrainParams:\n", + " num_viewed_samples: int = 10**6 # Number of viewed samples before model training will stop\n", + " batch_size: int = data_config.batch_size\n", + " # Diffusion Hyperparams\n", + " T: int = 200 # denoising num steps\n", + " beta_0: float = 0.0001 # first beta in markov chain\n", + " beta_T: float = 0.02 # last beta. Linear interpolation between beta_0 and beta_T\n", + " grad_norm: float = None # Maximum norm to gradient norm clipping\n", + " grad_val: float = None # Maximum norm to gradient value clipping\n", + " lr: float = 0.0002\n", + "\n", + "train_config = TrainParams()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def calc_diffusion_hyperparams(T, beta_0, beta_T):\n", + " \"\"\"\n", + " Compute diffusion process hyperparameters\n", + "\n", + " Parameters:\n", + " T (int): number of diffusion steps\n", + " beta_0 and beta_T (float): beta schedule start/end value, \n", + " where any beta_t in the middle is linearly interpolated\n", + " \n", + " Returns:\n", + " a dictionary of diffusion hyperparameters including:\n", + " T (int), Beta/Alpha/Alpha_bar/Sigma (torch.tensor on cpu, shape=(T, ))\n", + " These cpu tensors are changed to cuda tensors on each individual gpu\n", + " \"\"\"\n", + "\n", + " Beta = torch.linspace(beta_0, beta_T, T) # Linear schedule\n", + " Alpha = 1 - Beta\n", + " Alpha_bar = Alpha + 0\n", + " Beta_tilde = Beta + 0\n", + " for t in range(1, T):\n", + " Alpha_bar[t] *= Alpha_bar[t - 1] # \\bar{\\alpha}_t = \\prod_{s=1}^t \\alpha_s\n", + " Beta_tilde[t] *= (1 - Alpha_bar[t - 1]) / (\n", + " 1 - Alpha_bar[t]) # \\tilde{\\beta}_t = \\beta_t * (1-\\bar{\\alpha}_{t-1})\n", + " # / (1-\\bar{\\alpha}_t)\n", + " Sigma = torch.sqrt(Beta_tilde) # \\sigma_t^2 = \\tilde{\\beta}_t\n", + "\n", + " _dh = {}\n", + " _dh[\"T\"], _dh[\"Beta\"], _dh[\"Alpha\"], _dh[\"Alpha_bar\"], _dh[\"Sigma\"] = T, Beta, Alpha, Alpha_bar, Sigma\n", + " return _dh\n", + "\n", + "train_config.diffusion_hyperparams = calc_diffusion_hyperparams(train_config.T, train_config.beta_0, train_config.beta_T)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, optimizer, criterion, dataloader, train_config, device):\n", + " num_epochs = train_config.num_viewed_samples // (len(dataloader) * train_config.batch_size)\n", + " for epoch in range(num_epochs):\n", + " print(\"Epoch number:\", epoch)\n", + " model.train()\n", + " num_curr_samples = 0 # counter for the number of examples viewed\n", + " for _, batch in tqdm(enumerate(dataloader), total=len(dataloader)):\n", + " # Get batch\n", + " ids, (input, targets) = batch\n", + " ecg_signal = input[0].to(device)\n", + " targets = targets.long().to(device)\n", + " # Get model input\n", + " T, Alpha_bar = train_config.diffusion_hyperparams[\"T\"], train_config.diffusion_hyperparams[\"Alpha_bar\"]\n", + " Alpha_bar = Alpha_bar.to(device)\n", + " diffusion_steps = torch.randint(T, size=(ecg_signal.shape[0],1)).to(device) # randomly sample diffusion steps from 1~T\n", + " z = torch.normal(0, 1, size=ecg_signal.shape).to(device)\n", + " transformed_X = torch.sqrt(Alpha_bar[diffusion_steps]) * ecg_signal + torch.sqrt(1-Alpha_bar[diffusion_steps]) * z\n", + " optimizer.zero_grad()\n", + " # Loss propagation\n", + " epsilon_theta = model((transformed_X, targets, None, diffusion_steps,))\n", + " loss = criterion(epsilon_theta, z)\n", + " loss.backward()\n", + "\n", + " # Gradient Norm Clipping\n", + " if train_config.grad_norm is not None:\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=train_config.grad_norm)\n", + "\n", + " # Gradient Val Clipping\n", + " if train_config.grad_val is not None:\n", + " torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=train_config.grad_val)\n", + " # Optimizer step\n", + " optimizer.step()\n", + " # Interrupt process\n", + " num_curr_samples += ecg_signal.shape[0]\n", + " if num_curr_samples > train_config.num_viewed_samples:\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize params\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "criterion = torch.nn.MSELoss()\n", + "sssd_model = sssd_model.to(device)\n", + "optimizer = torch.optim.Adam(params=sssd_model.parameters(), lr=train_config.lr)\n", + "trained_sssd_model = train(sssd_model, optimizer, criterion, train_loader, train_config, device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SSSD-ECG-nle Generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configure parameters of dataset\n", + "@dataclass(repr=True)\n", + "class GenerationParams:\n", + " generation_path: str = 'afib_synthetic_ptbxl'\n", + " batch_size: int = 10\n", + " leads: List[int] = field(default_factory=lambda: data_config.leads)\n", + " ecg_length: int = 1000\n", + " T: int = train_config.T # denoising num steps\n", + " beta_0: float = train_config.beta_0 # first beta in markov chain\n", + " beta_T: float = train_config.beta_T # last beta. Linear interpolation between beta_0 and beta_T\n", + " syndromes: List[str] = field(default_factory=lambda: data_config.syndromes) # CRBBB, 1AVB, PVC or any code from `scp_statements.csv`. For multilabel-scenario append to list\n", + " pass\n", + "\n", + "gen_config = GenerationParams()\n", + "gen_config.diffusion_hyperparams = calc_diffusion_hyperparams(gen_config.T, gen_config.beta_0, gen_config.beta_T)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def sampling_label(net, size, diffusion_hyperparams, device, cond=None, metadata=None):\n", + " \"\"\"\n", + " Perform the complete sampling step according to p(x_0|x_T) = \\prod_{t=1}^T p_{\\theta}(x_{t-1}|x_t)\n", + "\n", + " Parameters:\n", + " net (torch network): the wavenet model\n", + " size (tuple): size of tensor to be generated, \n", + " usually is (number of audios to generate, channels=1, length of audio)\n", + " diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams\n", + " note, the tensors need to be cuda tensors \n", + " cond: conditioning as integer tensor\n", + " guidance_weight: weight for classifier-free guidance (if trained with conditioning_dropout>0)\n", + " \n", + " Returns:\n", + " the generated audio(s) in torch.tensor, shape=size\n", + " \"\"\"\n", + "\n", + " _dh = diffusion_hyperparams\n", + " T, Alpha, Alpha_bar, Sigma = _dh[\"T\"], _dh[\"Alpha\"], _dh[\"Alpha_bar\"], _dh[\"Sigma\"]\n", + " bs = size[0]\n", + " assert len(Alpha) == T\n", + " assert len(Alpha_bar) == T\n", + " assert len(Sigma) == T\n", + " assert len(size) == 3\n", + "\n", + " x = torch.normal(0, 1, size=size).to(device)\n", + " with torch.no_grad():\n", + " for t in range(T-1, -1, -1):\n", + " diffusion_steps = (t * torch.ones((bs, 1))).to(device) # use the corresponding reverse step\n", + " epsilon_theta = net((x, cond, metadata, diffusion_steps,)) # predict \\epsilon according to \\epsilon_\\theta \n", + " x = (x - (1-Alpha[t])/torch.sqrt(1-Alpha_bar[t]) * epsilon_theta) / torch.sqrt(Alpha[t]) # update x_{t-1} to \\mu_\\theta(x_t)\n", + " if t > 0:\n", + " x = x + Sigma[t] * torch.normal(0, 1, size=size).to(device) # add the variance term to x_{t-1}\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_four_leads(tensor):\n", + " leadI = tensor[:,0,:].unsqueeze(1) # I = LA - RA\n", + " leadavf = tensor[:,1,:].unsqueeze(1) # 3/2(LL - V_w), V_w = 1/3(RA + LA + LL)\n", + " leadschest = tensor[:,2:8,:] # Vi - V_w\n", + "\n", + " leadII = (0.5*leadI) + leadavf # II = LL - RA\n", + "\n", + " leadIII = -(0.5*leadI) + leadavf # III = LL - LA\n", + " leadavr = -(0.75*leadI) -(0.5*leadavf) # 3/2 (RA - V_w)\n", + " leadavl = (0.75*leadI) - (0.5*leadavf) # 3/2 (LA - V_w)\n", + "\n", + " leads12 = torch.cat([leadI, leadII, leadIII, leadavr, leadavl, leadavf, leadschest], dim=1)\n", + "\n", + " return leads12" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generation(model, ecg_info, gen_config, device, mode='train'):\n", + " # Main pipeline\n", + " bar = tqdm(total=len(ecg_info))\n", + " bs = gen_config.batch_size\n", + " leads_num = len(gen_config.leads)\n", + " raw_df = ecg_info.copy()\n", + " if SAMPLE_FREQUENCY == 100:\n", + " postfix = 'lr' # low-rate\n", + " else:\n", + " postfix = 'hr' # high-rate\n", + " for i in range(0, len(ecg_info), bs):\n", + " # Get batch\n", + " if i + bs > len(ecg_info):\n", + " batch_df = ecg_info.iloc[i:]\n", + " else:\n", + " batch_df = ecg_info.iloc[i:i+bs]\n", + " # Get condition\n", + " targets = [[1.0 if synd in eval(ecg_info.iloc[i+j]['scp_codes']).keys() else 0.0 for synd in gen_config.syndromes] \n", + " for j in range(len(batch_df))]\n", + " cond = torch.tensor(targets).long().to(device)\n", + " # Sampling\n", + " generated_ecg = sampling_label(\n", + " model, \n", + " (len(batch_df), leads_num, gen_config.ecg_length),\n", + " gen_config.diffusion_hyperparams,\n", + " device,\n", + " cond,\n", + " metadata=None\n", + " )\n", + " if leads_num == 8:\n", + " generated_ecg12 = generate_four_leads(generated_ecg)\n", + " else:\n", + " assert leads_num == 12\n", + " generated_ecg12 = generated_ecg\n", + " # Saving\n", + " for j, (_, row) in enumerate(batch_df.iterrows()):\n", + " file_name = '_'.join(row[f'filename_{postfix}'].split('/')[1:]) + '.npz'\n", + " fpath = os.path.join(gen_config.generation_path, 'data', file_name)\n", + " np.savez(fpath, generated_ecg12[j].detach().cpu().numpy())\n", + " raw_df.loc[i+j, 'fpath'] = fpath\n", + " # update bar\n", + " bar.update(len(batch_df))\n", + " raw_df.to_csv(os.path.join(gen_config.generation_path, f\"{gen_config.generation_path}_{mode}_map_file.csv\"), index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs(os.path.join(gen_config.generation_path, 'data'), exist_ok=True)\n", + "sssd_model.to(device)\n", + "sssd_model.eval()\n", + "generation(sssd_model, train_ptbxl_info, gen_config, device)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/ecglib/models/architectures/registred_models.py b/src/ecglib/models/architectures/registred_models.py index 1bc083b..d3d776b 100644 --- a/src/ecglib/models/architectures/registred_models.py +++ b/src/ecglib/models/architectures/registred_models.py @@ -6,7 +6,7 @@ from .densenet1d import densenet121_1d, densenet201_1d from .tabular import tabular from .cnn1d import cnn1d -from .sssd.sssd_ecg_nle import sssd_ecg +from .sssd.sssd_ecg_nle import sssd_ecg_nle __all__ = ["register_model", "registred_models", "get_builder", "is_model_registred"] @@ -21,7 +21,7 @@ "resnet1d101": resnet1d101, "tabular": tabular, "cnn1d": cnn1d, - "sssd_ecg": sssd_ecg, + "sssd_ecg": sssd_ecg_nle, } diff --git a/src/ecglib/models/architectures/sssd/extensions/__init__.py b/src/ecglib/models/architectures/sssd/extensions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ecglib/models/architectures/sssd/extensions/cauchy/__init__.py b/src/ecglib/models/architectures/sssd/extensions/cauchy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy.cpp b/src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy.cpp new file mode 100644 index 0000000..0d3f02c --- /dev/null +++ b/src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy.cpp @@ -0,0 +1,102 @@ +#include +#include +#include +#include +#include + +#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +torch::Tensor cauchy_mult_sym_fwd_cuda(torch::Tensor v, + torch::Tensor z, + torch::Tensor w); +std::tuple cauchy_mult_sym_bwd_cuda(torch::Tensor v, + torch::Tensor z, + torch::Tensor w, + torch::Tensor dout); + +namespace cauchy { + +torch::Tensor cauchy_mult_sym_fwd(torch::Tensor v, + torch::Tensor z, + torch::Tensor w) { + CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w); + const auto batch_size = v.size(0); + const auto N = v.size(1); + const auto L = z.size(0); + CHECK_SHAPE(v, batch_size, N); + CHECK_SHAPE(z, L); + CHECK_SHAPE(w, batch_size, N); + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)v.get_device()}; + return cauchy_mult_sym_fwd_cuda(v, z, w); +} + +std::tuple +cauchy_mult_sym_bwd(torch::Tensor v, + torch::Tensor z, + torch::Tensor w, + torch::Tensor dout) { + CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w); CHECK_DEVICE(dout); + const auto batch_size = v.size(0); + const auto N = v.size(1); + const auto L = z.size(0); + CHECK_SHAPE(v, batch_size, N); + CHECK_SHAPE(z, L); + CHECK_SHAPE(w, batch_size, N); + CHECK_SHAPE(dout, batch_size, L); + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)v.get_device()}; + return cauchy_mult_sym_bwd_cuda(v, z, w, dout); +} + +} // cauchy + +torch::Tensor vand_log_mult_sym_fwd_cuda(torch::Tensor v, torch::Tensor x, int L); + +std::tuple +vand_log_mult_sym_bwd_cuda(torch::Tensor v, torch::Tensor x, torch::Tensor dout); + +namespace vand { + +torch::Tensor vand_log_mult_sym_fwd(torch::Tensor v, torch::Tensor x, int L) { + CHECK_DEVICE(v); CHECK_DEVICE(x); + const auto batch_size = v.size(0); + const auto N = v.size(1); + CHECK_SHAPE(v, batch_size, N); + CHECK_SHAPE(x, batch_size, N); + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)v.get_device()}; + return vand_log_mult_sym_fwd_cuda(v, x, L); +} + +std::tuple +vand_log_mult_sym_bwd(torch::Tensor v, torch::Tensor x, torch::Tensor dout) { + CHECK_DEVICE(v); CHECK_DEVICE(x); CHECK_DEVICE(dout); + const auto batch_size = v.size(0); + const auto N = v.size(1); + const auto L = dout.size(1); + CHECK_SHAPE(v, batch_size, N); + CHECK_SHAPE(x, batch_size, N); + CHECK_SHAPE(dout, batch_size, L); + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)v.get_device()}; + return vand_log_mult_sym_bwd_cuda(v, x, dout); +} + +} // vand + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cauchy_mult_sym_fwd", &cauchy::cauchy_mult_sym_fwd, + "Cauchy multiply symmetric forward"); + m.def("cauchy_mult_sym_bwd", &cauchy::cauchy_mult_sym_bwd, + "Cauchy multiply symmetric backward"); + m.def("vand_log_mult_sym_fwd", &vand::vand_log_mult_sym_fwd, + "Log Vandermonde multiply symmetric forward"); + m.def("vand_log_mult_sym_bwd", &vand::vand_log_mult_sym_bwd, + "Log Vandermonde multiply symmetric backward"); +} \ No newline at end of file diff --git a/src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy.py b/src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy.py new file mode 100644 index 0000000..4229322 --- /dev/null +++ b/src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy.py @@ -0,0 +1,102 @@ +from pathlib import Path +import torch + +from einops import rearrange + +from structured_kernels import cauchy_mult_sym_fwd, cauchy_mult_sym_bwd +# try: +# from cauchy_mult import cauchy_mult_sym_fwd, cauchy_mult_sym_bwd +# except ImportError: +# from torch.utils.cpp_extension import load +# current_dir = Path(__file__).parent.absolute() +# cauchy_mult_extension = load( +# name='cauchy_mult', +# sources=[str(current_dir / 'cauchy.cpp'), str(current_dir / 'cauchy_cuda.cu')], +# extra_cflags=['-g', '-march=native', '-funroll-loops'], +# extra_cuda_cflags=['-O3', '-lineinfo', '--use_fast_math'], +# extra_include_paths=str(current_dir), +# build_directory=str(current_dir), +# verbose=True +# ) +# cauchy_mult_sym_fwd = cauchy_mult_extension.cauchy_mult_sym_fwd +# cauchy_mult_sym_bwd = cauchy_mult_extension.cauchy_mult_sym_bwd + + +def cauchy_mult_torch(v: torch.Tensor, z: torch.Tensor, w: torch.Tensor, + symmetric=True) -> torch.Tensor: + """ + v: (B, N) + z: (L) + w: (B, N) + symmetric: whether to assume that v and w contain complex conjugate pairs, of the form + [v_half, v_half.conj()] and [w_half, w_half.conj()] + """ + if not symmetric: + return (rearrange(v, 'b n -> b 1 n') / (rearrange(z, 'l -> l 1') - rearrange(w, 'b n -> b 1 n'))).sum(dim=-1) + else: + N = v.shape[-1] + assert N % 2 == 0 + vv = rearrange(v[:, :N // 2], 'b n -> b 1 n') + zz = rearrange(z, 'l -> l 1') + ww = rearrange(w[:, :N // 2], 'b n -> b 1 n') + # return 2 * ((zz * vv.real - vv.real * ww.real - vv.imag * ww.imag) + # / (zz * zz - 2 * zz * ww.real + ww.abs().square())).sum(dim=-1) + return (vv / (zz - ww) + vv.conj() / (zz - ww.conj())).sum(dim=-1) + + +# def cauchy_mult_keops(v, z, w): +# from pykeops.torch import LazyTensor +# v_l = LazyTensor(rearrange(v, 'b N -> b 1 N 1')) +# z_l = LazyTensor(rearrange(z, 'L -> 1 L 1 1')) +# w_l = LazyTensor(rearrange(w, 'b N -> b 1 N 1')) +# sub = z_l - w_l # (b N L 1), for some reason it doesn't display the last dimension +# div = v_l / sub +# s = div.sum(dim=2, backend='GPU') +# return s.squeeze(-1) + + +def _cauchy_mult(v, z, w): + return CauchyMultiplySymmetric.apply(v, z, w) + + +def cauchy_mult(v, z, w): + """ Wrap the cuda method to deal with shapes """ + v, w = torch.broadcast_tensors(v, w) + shape = v.shape + # z_shape = z.shape + # z = z.squeeze() + assert len(z.shape) == 1 + + v = v.contiguous() + w = w.contiguous() + z = z.contiguous() + + N = v.size(-1) + assert w.size(-1) == N + y = _cauchy_mult(v.view(-1, N), z, w.view(-1, N)) + y = y.view(*shape[:-1], z.size(-1)) + return y + + +class CauchyMultiplySymmetric(torch.autograd.Function): + + @staticmethod + def forward(ctx, v, z, w): + batch, N = v.shape + supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] + L = z.shape[-1] + if not N in supported_N_values: + raise NotImplementedError(f'Only support N values in {supported_N_values}') + max_L_value = 32 * 1024 * 64 * 1024 + if L > max_L_value: + raise NotImplementedError(f'Only support L values <= {max_L_value}') + if not (v.is_cuda and z.is_cuda and w.is_cuda): + raise NotImplementedError(f'Only support CUDA tensors') + ctx.save_for_backward(v, z, w) + return cauchy_mult_sym_fwd(v, z, w) + + @staticmethod + def backward(ctx, dout): + v, z, w = ctx.saved_tensors + dv, dw = cauchy_mult_sym_bwd(v, z, w, dout) + return dv, None, dw \ No newline at end of file diff --git a/src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy_cuda.cu b/src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy_cuda.cu new file mode 100644 index 0000000..a6d13d5 --- /dev/null +++ b/src/ecglib/models/architectures/sssd/extensions/cauchy/cauchy_cuda.cu @@ -0,0 +1,368 @@ +#include +// On pytorch 1.10 and CUDA 10.2, I get compilation errors on torch/csrc/api/include/torch/nn/cloneable.h +// So we'll only include torch/python.h instead of torch/extension.h +// Similar to https://github.com/getkeops/keops/blob/3efd428b55c724b12f23982c06de00bc4d02d903/pykeops/torch_headers.h.in#L8 +// #include +#include +#include // For getCurrentCUDAStream +#include // For atomicAdd on complex +#include +#include // For scalar_value_type +#include "map.h" // For the MAP macro, i.e. for_each over the arguments + + +#ifndef ITEMS_PER_THREAD_SYM_FWD_VALUES + #define ITEMS_PER_THREAD_SYM_FWD_VALUES {2, 4, 8, 16, 32, 32, 32, 64, 64, 64} +#endif +#ifndef MAX_BLOCK_SIZE_VALUE + #define MAX_BLOCK_SIZE_VALUE 256 +#endif +#ifndef ITEMS_PER_THREAD_SYM_BWD_VALUE + #define ITEMS_PER_THREAD_SYM_BWD_VALUE 32 +#endif + +static constexpr int ITEMS_PER_THREAD_SYM_FWD[] = ITEMS_PER_THREAD_SYM_FWD_VALUES; +static constexpr int MAX_BLOCK_SIZE = MAX_BLOCK_SIZE_VALUE; +static constexpr int ITEMS_PER_THREAD_SYM_BWD = ITEMS_PER_THREAD_SYM_BWD_VALUE; + +template +using CudaAcsr = at::GenericPackedTensorAccessor; +constexpr __host__ __device__ int div_up_const(int a, int b) { return (a + b - 1) / b; } + +__host__ __device__ static inline int div_up(int a, int b) { return (a + b - 1) / b;} + +template +__global__ void cauchy_mult_sym_fwd_cuda_kernel(CudaAcsr v, + CudaAcsr z, + CudaAcsr w, + CudaAcsr out, + int L) { + // Get the float type from the complex type + // https://github.com/pytorch/pytorch/blob/bceb1db885cafa87fe8d037d8f22ae9649a1bba0/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp#L213 + using float_t = typename at::scalar_value_type::type; + constexpr int N = 1 << log_N; + constexpr int blockDimx = div_up_const(N, items_per_thread); + constexpr int blockDimy = MAX_BLOCK_SIZE / blockDimx; + // We just want a shared array: + // __shared__ scalar_t s_b[16]; + // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270 + // So we declare a char array and cast it. + // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer + __shared__ char v_smem_char[N * sizeof(scalar_t)]; + scalar_t *v_smem = (scalar_t *)&v_smem_char; + __shared__ char w_smem_char[N * sizeof(scalar_t)]; + scalar_t *w_smem = (scalar_t *)&w_smem_char; + __shared__ char out_smem_char[blockDimy * sizeof(scalar_t)]; + scalar_t *out_smem = (scalar_t *)&out_smem_char; + int batch_idx = blockIdx.x; + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int L_idx = blockIdx.y * blockDim.y + threadIdx.y; + int L_block_start = blockIdx.y * blockDim.y; + scalar_t z_t = L_block_start + threadIdx.y < L ? z[L_block_start + threadIdx.y] : scalar_t(0.f); + for (int N_idx = threadIdx.x + threadIdx.y * blockDim.x; N_idx < N; N_idx += blockDim.x * blockDim.y) { + v_smem[N_idx] = v[batch_idx][N_idx]; + w_smem[N_idx] = w[batch_idx][N_idx]; + } + __syncthreads(); + scalar_t result = 0; + if (L_idx < L) { + // Combining the two terms (a/b + c/d = (ad + bc)/(bd)) seems to increase numerical errors. + // Using nvcc --use_fast_math yields the same speed between the two versions. + // So we don't combine the two terms. + #pragma unroll + for (int item = 0; item < items_per_thread; ++item) { + int N_idx = item * blockDimx + threadIdx.x; + scalar_t v_t = v_smem[N_idx], w_t = w_smem[N_idx]; + result += v_t / (z_t - w_t) + std::conj(v_t) / (z_t - std::conj(w_t)); + } + } + // TODO: this only works for N a power of 2 + #pragma unroll + for (int offset = blockDimx / 2; offset > 0; offset /= 2) { + result += WARP_SHFL_DOWN(result, offset); + } + if ((threadIdx.x == 0) && (L_idx < L)) { + out_smem[threadIdx.y] = result; + } + __syncthreads(); + if (tid < blockDim.y && L_block_start + tid < L) { + out[batch_idx][L_block_start + tid] = out_smem[tid]; + } +} + +torch::Tensor cauchy_mult_sym_fwd_cuda(torch::Tensor v, + torch::Tensor z, + torch::Tensor w) { + const int batch_size = v.size(0); + const int N = v.size(1); + const int L = z.size(0); + auto out = torch::empty({batch_size, L}, torch::dtype(v.dtype()).device(v.device())); + auto stream = at::cuda::getCurrentCUDAStream(); + using scalar_t = c10::complex; + const auto v_a = v.packed_accessor32(); + const auto z_a = z.packed_accessor32(); + const auto w_a = w.packed_accessor32(); + auto out_a = out.packed_accessor32(); + int log_N = int(log2((double) N)); + int block_x = div_up(N, ITEMS_PER_THREAD_SYM_FWD[log_N - 1]); + dim3 block(block_x, MAX_BLOCK_SIZE / block_x); + dim3 grid(batch_size, div_up(L, block.y)); + switch (log_N) { + #define CASE_LOG_N(log_N_val) case log_N_val: \ + cauchy_mult_sym_fwd_cuda_kernel \ + <<>>(v_a, z_a, w_a, out_a, L); break; + MAP(CASE_LOG_N, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + #undef CASE_LOG_N + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return out; +} + +template +__global__ void cauchy_mult_sym_bwd_cuda_kernel(CudaAcsr v, + CudaAcsr z, + CudaAcsr w, + CudaAcsr dout, + CudaAcsr dv, + CudaAcsr dw, + int L, + int L_chunk_size) { + // We just want a shared array: + // __shared__ scalar_t s_b[16]; + // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270 + // So we declare a char array and cast it. + // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer + __shared__ char dv_smem_char[C10_WARP_SIZE * sizeof(scalar_t)]; + scalar_t *dv_smem = (scalar_t *)&dv_smem_char; + __shared__ char dw_smem_char[C10_WARP_SIZE * sizeof(scalar_t)]; + scalar_t *dw_smem = (scalar_t *)&dw_smem_char; + int batch_idx = blockIdx.x; + int N_idx = blockIdx.y; + int L_chunk_idx = blockIdx.z; + int tid = threadIdx.x; + scalar_t w_conj_t = std::conj(w[batch_idx][N_idx]); + scalar_t dv_t = 0; + scalar_t dw_t = 0; + #pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD_SYM_BWD; ++item) { + int l = L_chunk_idx * L_chunk_size + item * blockDim.x + threadIdx.x; + scalar_t dout_t, z_t; + if (check_L_boundary) { + dout_t = l < L ? dout[batch_idx][l] : 0; + z_t = l < L ? z[l] : 1; + } else { // Not checking boundary can speed it up quite a bit, around 30%. + dout_t = dout[batch_idx][l]; + z_t = z[l]; + } + scalar_t denom_1 = std::conj(z_t) - w_conj_t; + scalar_t denom_2 = z_t - w_conj_t; + scalar_t term_1 = dout_t / denom_1; + scalar_t term_2 = std::conj(dout_t) / denom_2; + dv_t += term_1 + term_2; + dw_t += term_1 / denom_1 + term_2 / denom_2; + } + dv_t = at::native::cuda_utils::BlockReduceSum(dv_t, dv_smem); + dw_t = at::native::cuda_utils::BlockReduceSum(dw_t, dw_smem); + if (tid == 0) { + dw[batch_idx][N_idx][L_chunk_idx] = dw_t * std::conj(v[batch_idx][N_idx]); + dv[batch_idx][N_idx][L_chunk_idx] = dv_t; + } +} + +std::tuple +cauchy_mult_sym_bwd_cuda(torch::Tensor v, + torch::Tensor z, + torch::Tensor w, + torch::Tensor dout) { + const int batch_size = v.size(0); + const int N = v.size(1); + const int L = z.size(0); + constexpr int MAX_BLOCK_SIZE = 1024; + constexpr int MAX_L_CHUNK_SIZE = ITEMS_PER_THREAD_SYM_BWD * MAX_BLOCK_SIZE; + const int n_L_chunks = div_up(L, MAX_L_CHUNK_SIZE); + auto dv = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(v.dtype()).device(v.device())); + auto dw = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(w.dtype()).device(w.device())); + auto stream = at::cuda::getCurrentCUDAStream(); + using scalar_t = c10::complex; + const auto v_a = v.packed_accessor32(); + const auto z_a = z.packed_accessor32(); + const auto w_a = w.packed_accessor32(); + const auto dout_a = dout.packed_accessor32(); + auto dv_a = dv.packed_accessor32(); + auto dw_a = dw.packed_accessor32(); + // Each block need to have a multiple of 32 threads, otherwise + // at::native::cuda_utils::BlockReduceSum to produce wrong result. + // int block_x = max(div_up(L, ITEMS_PER_THREAD_SYM_BWD), C10_WARP_SIZE); + const int L_chunk_size = min(L, MAX_L_CHUNK_SIZE); + int block_x = div_up(L_chunk_size, ITEMS_PER_THREAD_SYM_BWD * C10_WARP_SIZE) * C10_WARP_SIZE; + bool check_L_boundary = L != block_x * ITEMS_PER_THREAD_SYM_BWD * n_L_chunks; + dim3 block(block_x); + dim3 grid(batch_size, N, n_L_chunks); + check_L_boundary + ? cauchy_mult_sym_bwd_cuda_kernel + <<>>(v_a, z_a, w_a, dout_a, dv_a, dw_a, L, L_chunk_size) + : cauchy_mult_sym_bwd_cuda_kernel + <<>>(v_a, z_a, w_a, dout_a, dv_a, dw_a, L, L_chunk_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return std::make_tuple(dv.sum(-1), dw.sum(-1)); +} + +template +__global__ void vand_log_mult_sym_fwd_cuda_kernel(CudaAcsr, 2> v, + CudaAcsr, 2> x, + CudaAcsr out, + int L) { + using cfloat_t = typename c10::complex; + constexpr int N = 1 << log_N; + constexpr int blockDimx = div_up_const(N, items_per_thread); + constexpr int blockDimy = MAX_BLOCK_SIZE / blockDimx; + // We just want a shared array: + // __shared__ cfloat_t s_b[16]; + // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270 + // So we declare a char array and cast it. + // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer + __shared__ char v_smem_char[N * sizeof(cfloat_t)]; + cfloat_t *v_smem = (cfloat_t *)&v_smem_char; + __shared__ char x_smem_char[N * sizeof(cfloat_t)]; + cfloat_t *x_smem = (cfloat_t *)&x_smem_char; + __shared__ float out_smem[blockDimy]; + int batch_idx = blockIdx.x; + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int L_idx = blockIdx.y * blockDim.y + threadIdx.y; + int L_block_start = blockIdx.y * blockDim.y; + for (int N_idx = threadIdx.x + threadIdx.y * blockDim.x; N_idx < N; N_idx += blockDim.x * blockDim.y) { + v_smem[N_idx] = v[batch_idx][N_idx]; + x_smem[N_idx] = x[batch_idx][N_idx]; + } + __syncthreads(); + float result = 0; + if (L_idx < L) { + #pragma unroll + for (int item = 0; item < items_per_thread; ++item) { + int N_idx = item * blockDimx + threadIdx.x; + cfloat_t v_t = v_smem[N_idx], x_t = x_smem[N_idx]; + result += (std::exp(x_t * L_idx) * v_t).real_; + } + } + // TODO: this only works for N a power of 2 + #pragma unroll + for (int offset = blockDimx / 2; offset > 0; offset /= 2) { + result += WARP_SHFL_DOWN(result, offset); + } + if ((threadIdx.x == 0) && (L_idx < L)) { + out_smem[threadIdx.y] = 2 * result; + } + __syncthreads(); + if (tid < blockDim.y && L_block_start + tid < L) { + out[batch_idx][L_block_start + tid] = out_smem[tid]; + } +} + +torch::Tensor vand_log_mult_sym_fwd_cuda(torch::Tensor v, torch::Tensor x, int L) { + const int batch_size = v.size(0); + const int N = v.size(1); + auto opts = v.options(); + auto out = torch::empty({batch_size, L}, opts.dtype(torch::kFloat32)); + auto stream = at::cuda::getCurrentCUDAStream(); + const auto v_a = v.packed_accessor32, 2, at::RestrictPtrTraits>(); + const auto x_a = x.packed_accessor32, 2, at::RestrictPtrTraits>(); + auto out_a = out.packed_accessor32(); + int log_N = int(log2((double) N)); + int block_x = div_up(N, ITEMS_PER_THREAD_SYM_FWD[log_N - 1]); + dim3 block(block_x, MAX_BLOCK_SIZE / block_x); + dim3 grid(batch_size, div_up(L, block.y)); + switch (log_N) { + #define CASE_LOG_N(log_N_val) case log_N_val: \ + vand_log_mult_sym_fwd_cuda_kernel \ + <<>>(v_a, x_a, out_a, L); break; + MAP(CASE_LOG_N, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + #undef CASE_LOG_N + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return out; +} + +template +__global__ void vand_log_mult_sym_bwd_cuda_kernel(CudaAcsr, 2> v, + CudaAcsr, 2> x, + CudaAcsr dout, + CudaAcsr, 3> dv, + CudaAcsr, 3> dx, + int L, + int L_chunk_size) { + using cfloat_t = typename c10::complex; + // We just want a shared array: + // __shared__ c10::complex s_b[16]; + // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270 + // So we declare a char array and cast it. + // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer + __shared__ char dv_smem_char[C10_WARP_SIZE * sizeof(cfloat_t)]; + cfloat_t *dv_smem = (cfloat_t *)&dv_smem_char; + __shared__ char dx_smem_char[C10_WARP_SIZE * sizeof(cfloat_t)]; + cfloat_t *dx_smem = (cfloat_t *)&dx_smem_char; + int batch_idx = blockIdx.x; + int N_idx = blockIdx.y; + int L_chunk_idx = blockIdx.z; + int tid = threadIdx.x; + cfloat_t x_t = x[batch_idx][N_idx]; + cfloat_t dv_t = 0; + cfloat_t dx_t = 0; + #pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD_SYM_BWD; ++item) { + int l = L_chunk_idx * L_chunk_size + item * blockDim.x + threadIdx.x; + float dout_t; + if (check_L_boundary) { + dout_t = l < L ? dout[batch_idx][l] : 0; + } else { // Not checking boundary can speed it up quite a bit. + dout_t = dout[batch_idx][l]; + } + // Need to conjugate as we're doing complex gradient. + cfloat_t do_exp_x_t = dout_t * std::conj(std::exp(x_t * l)); + dv_t += do_exp_x_t; + dx_t += do_exp_x_t * l; + } + dv_t = at::native::cuda_utils::BlockReduceSum(dv_t, dv_smem); + dx_t = at::native::cuda_utils::BlockReduceSum(dx_t, dx_smem); + if (tid == 0) { + dx[batch_idx][N_idx][L_chunk_idx] = 2 * dx_t * std::conj(v[batch_idx][N_idx]); + dv[batch_idx][N_idx][L_chunk_idx] = 2 * dv_t; + } +} + + +std::tuple +vand_log_mult_sym_bwd_cuda(torch::Tensor v, + torch::Tensor x, + torch::Tensor dout) { + const int batch_size = v.size(0); + const int N = v.size(1); + const int L = dout.size(1); + constexpr int MAX_BLOCK_SIZE = 1024; + constexpr int MAX_L_CHUNK_SIZE = ITEMS_PER_THREAD_SYM_BWD * MAX_BLOCK_SIZE; + const int n_L_chunks = div_up(L, MAX_L_CHUNK_SIZE); + auto dv = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(v.dtype()).device(v.device())); + auto dx = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(x.dtype()).device(x.device())); + auto stream = at::cuda::getCurrentCUDAStream(); + using cfloat_t = c10::complex; + const auto v_a = v.packed_accessor32(); + const auto x_a = x.packed_accessor32(); + const auto dout_a = dout.packed_accessor32(); + auto dv_a = dv.packed_accessor32(); + auto dx_a = dx.packed_accessor32(); + // Each block need to have a multiple of 32 threads, otherwise + // at::native::cuda_utils::BlockReduceSum to produce wrong result. + // int block_x = max(div_up(L, ITEMS_PER_THREAD_SYM_BWD), C10_WARP_SIZE); + const int L_chunk_size = min(L, MAX_L_CHUNK_SIZE); + int block_x = div_up(L_chunk_size, ITEMS_PER_THREAD_SYM_BWD * C10_WARP_SIZE) * C10_WARP_SIZE; + bool check_L_boundary = L != block_x * ITEMS_PER_THREAD_SYM_BWD * n_L_chunks; + dim3 block(block_x); + dim3 grid(batch_size, N, n_L_chunks); + check_L_boundary + ? vand_log_mult_sym_bwd_cuda_kernel + <<>>(v_a, x_a, dout_a, dv_a, dx_a, L, L_chunk_size) + : vand_log_mult_sym_bwd_cuda_kernel + <<>>(v_a, x_a, dout_a, dv_a, dx_a, L, L_chunk_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return std::make_tuple(dv.sum(-1), dx.sum(-1)); +} \ No newline at end of file diff --git a/src/ecglib/models/architectures/sssd/extensions/cauchy/map.h b/src/ecglib/models/architectures/sssd/extensions/cauchy/map.h new file mode 100644 index 0000000..a1d9c4e --- /dev/null +++ b/src/ecglib/models/architectures/sssd/extensions/cauchy/map.h @@ -0,0 +1,72 @@ +// Downloaded from https://github.com/swansontec/map-macro + +/* + * Copyright (C) 2012 William Swanson + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, copy, + * modify, merge, publish, distribute, sublicense, and/or sell copies + * of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF + * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * Except as contained in this notice, the names of the authors or + * their institutions shall not be used in advertising or otherwise to + * promote the sale, use or other dealings in this Software without + * prior written authorization from the authors. + */ + +#ifndef MAP_H_INCLUDED +#define MAP_H_INCLUDED + +#define EVAL0(...) __VA_ARGS__ +#define EVAL1(...) EVAL0(EVAL0(EVAL0(__VA_ARGS__))) +#define EVAL2(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) +#define EVAL3(...) EVAL2(EVAL2(EVAL2(__VA_ARGS__))) +#define EVAL4(...) EVAL3(EVAL3(EVAL3(__VA_ARGS__))) +#define EVAL(...) EVAL4(EVAL4(EVAL4(__VA_ARGS__))) + +#define MAP_END(...) +#define MAP_OUT +#define MAP_COMMA , + +#define MAP_GET_END2() 0, MAP_END +#define MAP_GET_END1(...) MAP_GET_END2 +#define MAP_GET_END(...) MAP_GET_END1 +#define MAP_NEXT0(test, next, ...) next MAP_OUT +#define MAP_NEXT1(test, next) MAP_NEXT0(test, next, 0) +#define MAP_NEXT(test, next) MAP_NEXT1(MAP_GET_END test, next) + +#define MAP0(f, x, peek, ...) f(x) MAP_NEXT(peek, MAP1)(f, peek, __VA_ARGS__) +#define MAP1(f, x, peek, ...) f(x) MAP_NEXT(peek, MAP0)(f, peek, __VA_ARGS__) + +#define MAP_LIST_NEXT1(test, next) MAP_NEXT0(test, MAP_COMMA next, 0) +#define MAP_LIST_NEXT(test, next) MAP_LIST_NEXT1(MAP_GET_END test, next) + +#define MAP_LIST0(f, x, peek, ...) f(x) MAP_LIST_NEXT(peek, MAP_LIST1)(f, peek, __VA_ARGS__) +#define MAP_LIST1(f, x, peek, ...) f(x) MAP_LIST_NEXT(peek, MAP_LIST0)(f, peek, __VA_ARGS__) + +/** + * Applies the function macro `f` to each of the remaining parameters. + */ +#define MAP(f, ...) EVAL(MAP1(f, __VA_ARGS__, ()()(), ()()(), ()()(), 0)) + +/** + * Applies the function macro `f` to each of the remaining parameters and + * inserts commas between the results. + */ +#define MAP_LIST(f, ...) EVAL(MAP_LIST1(f, __VA_ARGS__, ()()(), ()()(), ()()(), 0)) + +#endif \ No newline at end of file diff --git a/src/ecglib/models/architectures/sssd/extensions/cauchy/setup.py b/src/ecglib/models/architectures/sssd/extensions/cauchy/setup.py new file mode 100644 index 0000000..a126d1a --- /dev/null +++ b/src/ecglib/models/architectures/sssd/extensions/cauchy/setup.py @@ -0,0 +1,25 @@ +from setuptools import setup +import torch.cuda +from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension +from torch.utils.cpp_extension import CUDA_HOME + +ext_modules = [] +if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension( + 'structured_kernels', [ + 'cauchy.cpp', + 'cauchy_cuda.cu', + ], + extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'], + # 'nvcc': ['-O2', '-lineinfo'] + 'nvcc': ['-O2', '-lineinfo', '--use_fast_math'] + } + ) + ext_modules.append(extension) + +setup( + name='structured_kernels', + version="0.1.0", + ext_modules=ext_modules, + # cmdclass={'build_ext': BuildExtension.with_options(use_ninja=False)}) + cmdclass={'build_ext': BuildExtension}) \ No newline at end of file diff --git a/src/ecglib/models/architectures/sssd/s4.py b/src/ecglib/models/architectures/sssd/s4.py index 46b695e..4d06260 100644 --- a/src/ecglib/models/architectures/sssd/s4.py +++ b/src/ecglib/models/architectures/sssd/s4.py @@ -1,2 +1,1188 @@ -class S4: - raise NotImplementedError \ No newline at end of file +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import logging +from functools import partial +from scipy import special as ss +from einops import rearrange, repeat +import opt_einsum as oe + +contract = oe.contract +contract_expression = oe.contract_expression + + + +''' Standalone CSDI + S4 imputer for random missing, non-random missing and black-out missing. +The notebook contains CSDI and S4 functions and utilities. However the imputer is located in the last Class of +the notebook, please see more documentation of use there. Additional at this file can be added for CUDA multiplication +the cauchy kernel.''' + + + + +def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: + """Initializes multi-GPU-friendly python logger.""" + + logger = logging.getLogger(name) + logger.setLevel(level) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + # for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): + # setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger +log = get_logger(__name__) + +""" Cauchy kernel """ + +try: # Try CUDA extension + from .extensions.cauchy.cauchy import cauchy_mult + has_cauchy_extension = True +except Exception as err: + log.warn( + "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%" + ) + log.warn(f"An exception occurred: {err}") + has_cauchy_extension = False + +try: # Try pykeops + import pykeops + from pykeops.torch import Genred + has_pykeops = True + def cauchy_conj(v, z, w): + """ Pykeops version """ + expr_num = 'z * ComplexReal(v) - Real2Complex(Sum(v * w))' + expr_denom = 'ComplexMult(z-w, z-Conj(w))' + + cauchy_mult = Genred( + f'ComplexDivide({expr_num}, {expr_denom})', + # expr_num, + # expr_denom, + [ + 'v = Vj(2)', + 'z = Vi(2)', + 'w = Vj(2)', + ], + reduction_op='Sum', + axis=1, + dtype='float32' if v.dtype == torch.cfloat else 'float64', + ) + + v, z, w = _broadcast_dims(v, z, w) + v = _c2r(v) + z = _c2r(z) + w = _c2r(w) + + r = 2*cauchy_mult(v, z, w, backend='GPU') + return _r2c(r) + +except ImportError: + has_pykeops = False + if not has_cauchy_extension: + log.error( + "Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency." + ) + def cauchy_slow(v, z, w): + """ + v, w: (..., N) + z: (..., L) + returns: (..., L) + """ + cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L) + return torch.sum(cauchy_matrix, dim=-2) + +def _broadcast_dims(*tensors): + max_dim = max([len(tensor.shape) for tensor in tensors]) + tensors = [tensor.view((1,)*(max_dim-len(tensor.shape))+tensor.shape) for tensor in tensors] + return tensors + +_c2r = torch.view_as_real +_r2c = torch.view_as_complex +_conj = lambda x: torch.cat([x, x.conj()], dim=-1) +if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10): + _resolve_conj = lambda x: x.conj().resolve_conj() +else: + _resolve_conj = lambda x: x.conj() + + + +""" simple nn.Module components """ + +def Activation(activation=None, dim=-1): + if activation in [ None, 'id', 'identity', 'linear' ]: + return nn.Identity() + elif activation == 'tanh': + return nn.Tanh() + elif activation == 'relu': + return nn.ReLU() + elif activation == 'gelu': + return nn.GELU() + elif activation in ['swish', 'silu']: + return nn.SiLU() + elif activation == 'glu': + return nn.GLU(dim=dim) + elif activation == 'sigmoid': + return nn.Sigmoid() + else: + raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) + +def get_initializer(name, activation=None): + if activation in [ None, 'id', 'identity', 'linear', 'modrelu' ]: + nonlinearity = 'linear' + elif activation in ['relu', 'tanh', 'sigmoid']: + nonlinearity = activation + elif activation in ['gelu', 'swish', 'silu']: + nonlinearity = 'relu' # Close to ReLU so approximate with ReLU's gain + else: + raise NotImplementedError(f"get_initializer: activation {activation} not supported") + + if name == 'uniform': + initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity) + elif name == 'normal': + initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity) + elif name == 'xavier': + initializer = torch.nn.init.xavier_normal_ + elif name == 'zero': + initializer = partial(torch.nn.init.constant_, val=0) + elif name == 'one': + initializer = partial(torch.nn.init.constant_, val=1) + else: + raise NotImplementedError(f"get_initializer: initializer type {name} not supported") + + return initializer + +class TransposedLinear(nn.Module): + """ Linear module on the second-to-last dimension """ + + def __init__(self, d_input, d_output, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.empty(d_output, d_input)) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # nn.Linear default init + # nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent + + if bias: + self.bias = nn.Parameter(torch.empty(d_output, 1)) + bound = 1 / math.sqrt(d_input) + nn.init.uniform_(self.bias, -bound, bound) + else: + self.bias = 0.0 + + def forward(self, x): + return contract('... u l, v u -> ... v l', x, self.weight) + self.bias + +def LinearActivation( + d_input, d_output, bias=True, + zero_bias_init=False, + transposed=False, + initializer=None, + activation=None, + activate=False, # Apply activation as part of this module + weight_norm=False, + **kwargs, + ): + """ Returns a linear nn.Module with control over axes order, initialization, and activation """ + + # Construct core module + linear_cls = TransposedLinear if transposed else nn.Linear + if activation == 'glu': d_output *= 2 + linear = linear_cls(d_input, d_output, bias=bias, **kwargs) + + # Initialize weight + if initializer is not None: + get_initializer(initializer, activation)(linear.weight) + + # Initialize bias + if bias and zero_bias_init: + nn.init.zeros_(linear.bias) + + # Weight norm + if weight_norm: + linear = nn.utils.weight_norm(linear) + + if activate and activation is not None: + activation = Activation(activation, dim=-2 if transposed else -1) + linear = nn.Sequential(linear, activation) + return linear + + + +""" Misc functional utilities """ + +def krylov(L, A, b, c=None, return_power=False): + """ + Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick. + + If return_power=True, return A^{L-1} as well + """ + # TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises + + x = b.unsqueeze(-1) # (..., N, 1) + A_ = A + + AL = None + if return_power: + AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device) + _L = L-1 + + done = L == 1 + # loop invariant: _L represents how many indices left to compute + while not done: + if return_power: + if _L % 2 == 1: AL = A_ @ AL + _L //= 2 + + # Save memory on last iteration + l = x.shape[-1] + if L - l <= l: + done = True + _x = x[..., :L-l] + else: _x = x + + _x = A_ @ _x + x = torch.cat([x, _x], dim=-1) # there might be a more efficient way of ordering axes + if not done: A_ = A_ @ A_ + + assert x.shape[-1] == L + + if c is not None: + x = torch.einsum('...nl, ...n -> ...l', x, c) + x = x.contiguous() # WOW!! + if return_power: + return x, AL + else: + return x + +def power(L, A, v=None): + """ Compute A^L and the scan sum_i A^i v_i + + A: (..., N, N) + v: (..., N, L) + """ + + I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) + + powers = [A] + l = 1 + while True: + if L % 2 == 1: I = powers[-1] @ I + L //= 2 + if L == 0: break + l *= 2 + powers.append(powers[-1] @ powers[-1]) + + if v is None: return I + + # Invariants: + # powers[-1] := A^l + # l := largest po2 at most L + + # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A + # We do this reverse divide-and-conquer for efficiency reasons: + # 1) it involves fewer padding steps for non-po2 L + # 2) it involves more contiguous arrays + + # Take care of edge case for non-po2 arrays + # Note that this initial step is a no-op for the case of power of 2 (l == L) + k = v.size(-1) - l + v_ = powers.pop() @ v[..., l:] + v = v[..., :l] + v[..., :k] = v[..., :k] + v_ + + # Handle reduction for power of 2 + while v.size(-1) > 1: + v = rearrange(v, '... (z l) -> ... z l', z=2) + v = v[..., 0, :] + powers.pop() @ v[..., 1, :] + return I, v.squeeze(-1) + + +""" HiPPO utilities """ + +def embed_c2r(A): + A = rearrange(A, '... m n -> ... m () n ()') + A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + \ + np.pad(A, ((0, 0), (1, 0), (0, 0), (1,0))) + return rearrange(A, 'm x n y -> (m x) (n y)') + +def transition(measure, N, **measure_args): + """ A, B transition matrices for different measures + + measure: the type of measure + legt - Legendre (translated) + legs - Legendre (scaled) + glagt - generalized Laguerre (translated) + lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization + """ + # Laguerre (translated) + if measure == 'lagt': + b = measure_args.get('beta', 1.0) + A = np.eye(N) / 2 - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + # Generalized Laguerre + # alpha 0, beta small is most stable (limits to the 'lagt' measure) + # alpha 0, beta 1 has transition matrix A = [lower triangular 1] + elif measure == 'glagt': + alpha = measure_args.get('alpha', 0.0) + beta = measure_args.get('beta', 0.01) + A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) + B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] + + L = np.exp(.5 * (ss.gammaln(np.arange(N)+alpha+1) - ss.gammaln(np.arange(N)+1))) + A = (1./L[:, None]) * A * L[None, :] + B = (1./L[:, None]) * B * np.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2) + # Legendre (translated) + elif measure == 'legt': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1) ** .5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :] + B = R[:, None] + A = -A + # Legendre (scaled) + elif measure == 'legs': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + elif measure == 'fourier': + freqs = np.arange(N//2) + d = np.stack([freqs, np.zeros(N//2)], axis=-1).reshape(-1)[:-1] + A = 2*np.pi*(np.diag(d, 1) - np.diag(d, -1)) + A = A - embed_c2r(np.ones((N//2, N//2))) + B = embed_c2r(np.ones((N//2, 1)))[..., :1] + elif measure == 'random': + A = np.random.randn(N, N) / N + B = np.random.randn(N, 1) + elif measure == 'diagonal': + A = -np.diag(np.exp(np.random.randn(N))) + B = np.random.randn(N, 1) + else: + raise NotImplementedError + + return A, B + +def rank_correction(measure, N, rank=1, dtype=torch.float): + """ Return low-rank matrix L such that A + L is normal """ + + if measure == 'legs': + assert rank >= 1 + P = torch.sqrt(.5+torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) + elif measure == 'legt': + assert rank >= 2 + P = torch.sqrt(1+2*torch.arange(N, dtype=dtype)) # (N) + P0 = P.clone() + P0[0::2] = 0. + P1 = P.clone() + P1[1::2] = 0. + P = torch.stack([P0, P1], dim=0) # (2 N) + elif measure == 'lagt': + assert rank >= 1 + P = .5**.5 * torch.ones(1, N, dtype=dtype) + elif measure == 'fourier': + P = torch.ones(N, dtype=dtype) # (N) + P0 = P.clone() + P0[0::2] = 0. + P1 = P.clone() + P1[1::2] = 0. + P = torch.stack([P0, P1], dim=0) # (2 N) + else: raise NotImplementedError + + d = P.size(0) + if rank > d: + P = torch.cat([P, torch.zeros(rank-d, N, dtype=dtype)], dim=0) # (rank N) + return P + +def nplr(measure, N, rank=1, dtype=torch.float): + """ Return w, p, q, V, B such that + (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V + i.e. A = V[w - p q^*]V^*, B = V B + """ + assert dtype == torch.float or torch.cfloat + if measure == 'random': + dtype = torch.cfloat if dtype == torch.float else torch.cdouble + # w = torch.randn(N//2, dtype=dtype) + w = -torch.exp(torch.randn(N//2)) + 1j*torch.randn(N//2) + P = torch.randn(rank, N//2, dtype=dtype) + B = torch.randn(N//2, dtype=dtype) + V = torch.eye(N, dtype=dtype)[..., :N//2] # Only used in testing + return w, P, B, V + + A, B = transition(measure, N) + A = torch.as_tensor(A, dtype=dtype) # (N, N) + B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) + + P = rank_correction(measure, N, rank=rank, dtype=dtype) + AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3) + w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) + # V w V^{-1} = A + + # Only keep one of the conjugate pairs + w = w[..., 0::2].contiguous() + V = V[..., 0::2].contiguous() + + V_inv = V.conj().transpose(-1, -2) + + B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B + P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P + + + return w, P, B, V + + +def bilinear(dt, A, B=None): + """ + dt: (...) timescales + A: (... N N) + B: (... N) + """ + N = A.shape[-1] + I = torch.eye(N).to(A) + A_backwards = I - dt[:, None, None] / 2 * A + A_forwards = I + dt[:, None, None] / 2 * A + + if B is None: + dB = None + else: + dB = dt[..., None] * torch.linalg.solve( + A_backwards, B.unsqueeze(-1) + ).squeeze(-1) # (... N) + + dA = torch.linalg.solve(A_backwards, A_forwards) # (... N N) + return dA, dB + + + + +class SSKernelNPLR(nn.Module): + """Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR) + + The class name stands for 'State-Space SSKernel for Normal Plus Low-Rank'. + The parameters of this function are as follows. + + A: (... N N) the state matrix + B: (... N) input matrix + C: (... N) output matrix + dt: (...) timescales / discretization step size + p, q: (... P N) low-rank correction to A, such that Ap=A+pq^T is a normal matrix + + The forward pass of this Module returns: + (... L) that represents represents FFT SSKernel_L(A^dt, B^dt, C) + + """ + + @torch.no_grad() + def _setup_C(self, double_length=False): + """ Construct C~ from C + + double_length: current C is for length L, convert it to length 2L + """ + C = _r2c(self.C) + self._setup_state() + dA_L = power(self.L, self.dA) + # Multiply C by I - dA_L + C_ = _conj(C) + prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) + if double_length: prod = -prod # Multiply by I + dA_L instead + C_ = C_ - prod + C_ = C_[..., :self.N] # Take conjugate pairs again + self.C.copy_(_c2r(C_)) + + if double_length: + self.L *= 2 + self._omega(self.L, dtype=C.dtype, device=C.device, cache=True) + + def _omega(self, L, dtype, device, cache=True): + """ Calculate (and cache) FFT nodes and their "unprocessed" them with the bilinear transform + This should be called everytime the internal length self.L changes """ + omega = torch.tensor( + np.exp(-2j * np.pi / (L)), dtype=dtype, device=device + ) # \omega_{2L} + omega = omega ** torch.arange(0, L // 2 + 1, device=device) + z = 2 * (1 - omega) / (1 + omega) + if cache: + self.register_buffer("omega", _c2r(omega)) + self.register_buffer("z", _c2r(z)) + return omega, z + + def __init__( + self, + L, w, P, B, C, log_dt, + hurwitz=False, + trainable=None, + lr=None, + tie_state=False, + length_correction=True, + verbose=False, + ): + """ + L: Maximum length; this module computes an SSM kernel of length L + w: (N) + p: (r, N) low-rank correction to A + q: (r, N) + A represented by diag(w) - pq^* + + B: (N) + dt: (H) timescale per feature + C: (H, C, N) system is 1-D to c-D (channels) + + hurwitz: tie pq and ensure w has negative real part + trainable: toggle which of the parameters is trainable + lr: add hook to set lr of hippo parameters specially (everything besides C) + tie_state: tie all state parameters across the H hidden features + length_correction: multiply C by (I - dA^L) - can be turned off when L is large for slight speedup at initialization (only relevant when N large as well) + + Note: tensor shape N here denotes half the true state size, because of conjugate symmetry + """ + + super().__init__() + self.hurwitz = hurwitz + self.tie_state = tie_state + self.verbose = verbose + + # Rank of low-rank correction + self.rank = P.shape[-2] + assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = w.size(-1) + + # Broadcast everything to correct shapes + C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (H, C, N) + H = 1 if self.tie_state else self.H + B = repeat(B, 'n -> 1 h n', h=H) + P = repeat(P, 'r n -> r h n', h=H) + w = repeat(w, 'n -> h n', h=H) + + # Cache Fourier nodes every time we set up a desired length + self.L = L + if self.L is not None: + self._omega(self.L, dtype=C.dtype, device=C.device, cache=True) + + # Register parameters + # C is a regular parameter, not state + # self.C = nn.Parameter(_c2r(C.conj().resolve_conj())) + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + train = False + if trainable is None: trainable = {} + if trainable == False: trainable = {} + if trainable == True: trainable, train = {}, True + self.register("log_dt", log_dt, trainable.get('dt', train), lr, 0.0) + self.register("B", _c2r(B), trainable.get('B', train), lr, 0.0) + self.register("P", _c2r(P), trainable.get('P', train), lr, 0.0) + if self.hurwitz: + log_w_real = torch.log(-w.real + 1e-3) # Some of the HiPPO methods have real part 0 + w_imag = w.imag + self.register("log_w_real", log_w_real, trainable.get('A', 0), lr, 0.0) + self.register("w_imag", w_imag, trainable.get('A', train), lr, 0.0) + self.Q = None + else: + self.register("w", _c2r(w), trainable.get('A', train), lr, 0.0) + # self.register("Q", _c2r(P.clone().conj().resolve_conj()), trainable.get('P', train), lr, 0.0) + Q = _resolve_conj(P.clone()) + self.register("Q", _c2r(Q), trainable.get('P', train), lr, 0.0) + + if length_correction: + self._setup_C() + + def _w(self): + # Get the internal w (diagonal) parameter + if self.hurwitz: + w_real = -torch.exp(self.log_w_real) + w_imag = self.w_imag + w = w_real + 1j * w_imag + else: + w = _r2c(self.w) # (..., N) + return w + + def forward(self, state=None, rate=1.0, L=None): + """ + state: (..., s, N) extra tensor that augments B + rate: sampling rate factor + + returns: (..., c+s, L) + """ + # Handle sampling rate logic + # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) sampling rate rate + # If either are not passed in, assume we're not asked to change the scale of our kernel + assert not (rate is None and L is None) + if rate is None: + rate = self.L / L + if L is None: + L = int(self.L / rate) + + # Increase the internal length if needed + while rate * L > self.L: + self.double_length() + + dt = torch.exp(self.log_dt) * rate + B = _r2c(self.B) + C = _r2c(self.C) + P = _r2c(self.P) + Q = P.conj() if self.Q is None else _r2c(self.Q) + w = self._w() + + if rate == 1.0: + # Use cached FFT nodes + omega, z = _r2c(self.omega), _r2c(self.z) # (..., L) + else: + omega, z = self._omega(int(self.L/rate), dtype=w.dtype, device=w.device, cache=False) + + if self.tie_state: + B = repeat(B, '... 1 n -> ... h n', h=self.H) + P = repeat(P, '... 1 n -> ... h n', h=self.H) + Q = repeat(Q, '... 1 n -> ... h n', h=self.H) + + # Augment B + if state is not None: + # Have to "unbilinear" the state to put it into the same "type" as B + # Compute 1/dt * (I + dt/2 A) @ state + + # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way + s = _conj(state) if state.size(-1) == self.N else state # (B H N) + sA = ( + s * _conj(w) # (B H N) + - contract('bhm, rhm, rhn -> bhn', s, _conj(Q), _conj(P)) + ) + s = s / dt.unsqueeze(-1) + sA / 2 + s = s[..., :self.N] + + B = torch.cat([s, B], dim=-3) # (s+1, H, N) + + # Incorporate dt into A + w = w * dt.unsqueeze(-1) # (H N) + + # Stack B and p, C and q for convenient batching + B = torch.cat([B, P], dim=-3) # (s+1+r, H, N) + C = torch.cat([C, Q], dim=-3) # (c+r, H, N) + + # Incorporate B and C batch dimensions + v = B.unsqueeze(-3) * C.unsqueeze(-4) # (s+1+r, c+r, H, N) + # w = w[None, None, ...] # (1, 1, H, N) + # z = z[None, None, None, ...] # (1, 1, 1, L) + + # Calculate resolvent at omega + if has_cauchy_extension and z.dtype == torch.cfloat: + # r = cauchy_mult(v, z, w, symmetric=True) + r = cauchy_mult(v, z, w) + elif has_pykeops: + r = cauchy_conj(v, z, w) + else: + r = cauchy_slow(v, z, w) + r = r * dt[None, None, :, None] # (S+1+R, C+R, H, L) + + # Low-rank Woodbury correction + if self.rank == 1: + k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (1 + r[-1:, -1:, :, :]) + elif self.rank == 2: + r00 = r[: -self.rank, : -self.rank, :, :] + r01 = r[: -self.rank, -self.rank :, :, :] + r10 = r[-self.rank :, : -self.rank, :, :] + r11 = r[-self.rank :, -self.rank :, :, :] + det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[:1, 1:, :, :] * r11[1:, :1, :, :] + s = ( + r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] + - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] + - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] + ) + s = s / det + k_f = r00 - s + else: + r00 = r[:-self.rank, :-self.rank, :, :] + r01 = r[:-self.rank, -self.rank:, :, :] + r10 = r[-self.rank:, :-self.rank, :, :] + r11 = r[-self.rank:, -self.rank:, :, :] + r11 = rearrange(r11, "a b h n -> h n a b") + r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) + r11 = rearrange(r11, "h n a b -> a b h n") + k_f = r00 - torch.einsum("i j h n, j k h n, k l h n -> i l h n", r01, r11, r10) + + # Final correction for the bilinear transform + k_f = k_f * 2 / (1 + omega) + + # Move from frequency to coefficients + k = torch.fft.irfft(k_f) # (S+1, C, H, L) + + # Truncate to target length + k = k[..., :L] + + if state is not None: + k_state = k[:-1, :, :, :] # (S, C, H, L) + else: + k_state = None + k_B = k[-1, :, :, :] # (C H L) + return k_B, k_state + + @torch.no_grad() + def double_length(self): + if self.verbose: log.info(f"S4: Doubling length from L = {self.L} to {2*self.L}") + self._setup_C(double_length=True) + + def _setup_linear(self): + """ Create parameters that allow fast linear stepping of state """ + w = self._w() + B = _r2c(self.B) # (H N) + P = _r2c(self.P) + Q = P.conj() if self.Q is None else _r2c(self.Q) + + # Prepare Linear stepping + dt = torch.exp(self.log_dt) + D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) + R = (torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2*contract('r h n, h n, s h n -> h r s', Q, D, P).real) # (H r r) + Q_D = rearrange(Q*D, 'r h n -> h r n') + R = torch.linalg.solve(R.to(Q_D), Q_D) # (H r N) + R = rearrange(R, 'h r n -> r h n') + + self.step_params = { + "D": D, # (H N) + "R": R, # (r H N) + "P": P, # (r H N) + "Q": Q, # (r H N) + "B": B, # (1 H N) + "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) + } + + def _step_state_linear(self, u=None, state=None): + """ + Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization. + + Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster + + u: (H) input + state: (H, N/2) state with conjugate pairs + Optionally, the state can have last dimension N + Returns: same shape as state + """ + C = _r2c(self.C) # View used for dtype/device + + if u is None: # Special case used to find dA + u = torch.zeros(self.H, dtype=C.dtype, device=C.device) + if state is None: # Special case used to find dB + state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) + + step_params = self.step_params.copy() + if state.size(-1) == self.N: # Only store half of the conjugate pairs; should be true by default + # There should be a slightly faster way using conjugate symmetry + contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', _conj(p), _conj(x), _conj(y))[..., :self.N] # inner outer product + else: + assert state.size(-1) == 2*self.N + step_params = {k: _conj(v) for k, v in step_params.items()} + # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping + contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', p, x, y) # inner outer product + D = step_params["D"] # (H N) + E = step_params["E"] # (H N) + R = step_params["R"] # (r H N) + P = step_params["P"] # (r H N) + Q = step_params["Q"] # (r H N) + B = step_params["B"] # (1 H N) + + new_state = E * state - contract_fn(P, Q, state) # (B H N) + new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) + new_state = D * (new_state - contract_fn(P, R, new_state)) + + return new_state + + def _setup_state(self): + """ Construct dA and dB for discretized state equation """ + + # Construct dA and dB by using the stepping + self._setup_linear() + C = _r2c(self.C) # Just returns a view that we use for finding dtype/device + + state = torch.eye(2*self.N, dtype=C.dtype, device=C.device).unsqueeze(-2) # (N 1 N) + dA = self._step_state_linear(state=state) + dA = rearrange(dA, "n h m -> h m n") + self.dA = dA # (H N N) + + u = C.new_ones(self.H) + dB = self._step_state_linear(u=u) + dB = _conj(dB) + self.dB = rearrange(dB, '1 h n -> h n') # (H N) + + def _step_state(self, u, state): + """ Must be called after self.default_state() is used to construct an initial state! """ + next_state = self.state_contraction(self.dA, state) + self.input_contraction(self.dB, u) + return next_state + + + def setup_step(self, mode='dense'): + """ Set up dA, dB, dC discretized parameters for stepping """ + self._setup_state() + + # Calculate original C + dA_L = power(self.L, self.dA) + I = torch.eye(self.dA.size(-1)).to(dA_L) + C = _conj(_r2c(self.C)) # (H C N) + + dC = torch.linalg.solve( + I - dA_L.transpose(-1, -2), + C.unsqueeze(-1), + ).squeeze(-1) + self.dC = dC + + # Do special preprocessing for different step modes + + self._step_mode = mode + if mode == 'linear': + # Linear case: special step function for the state, we need to handle output + # use conjugate symmetry by default, which affects the output projection + self.dC = 2*self.dC[:, :, :self.N] + elif mode == 'diagonal': + # Eigendecomposition of the A matrix + L, V = torch.linalg.eig(self.dA) + V_inv = torch.linalg.inv(V) + # Check that the eigendedecomposition is correct + if self.verbose: + print("Diagonalization error:", torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA)) + + # Change the parameterization to diagonalize + self.dA = L + self.dB = contract('h n m, h m -> h n', V_inv, self.dB) + self.dC = contract('h n m, c h n -> c h m', V, self.dC) + + elif mode == 'dense': + pass + else: raise NotImplementedError("NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}") + + + def default_state(self, *batch_shape): + C = _r2c(self.C) + N = C.size(-1) + H = C.size(-2) + + # Cache the tensor contractions we will later do, for efficiency + # These are put in this function because they depend on the batch size + if self._step_mode !='linear': + N *= 2 + + if self._step_mode == 'diagonal': + self.state_contraction = contract_expression( + "h n, ... h n -> ... h n", + (H, N), + batch_shape + (H, N), + ) + else: + # Dense (quadratic) case: expand all terms + self.state_contraction = contract_expression( + "h m n, ... h n -> ... h m", + (H, N, N), + batch_shape + (H, N), + ) + + self.input_contraction = contract_expression( + "h n, ... h -> ... h n", + (H, N), # self.dB.shape + batch_shape + (H,), + ) + + self.output_contraction = contract_expression( + "c h n, ... h n -> ... c h", + (C.shape[0], H, N), # self.dC.shape + batch_shape + (H, N), + ) + + state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) + return state + + def step(self, u, state): + """ Must have called self.setup_step() and created state with self.default_state() before calling this """ + + if self._step_mode == 'linear': + new_state = self._step_state_linear(u, state) + else: + new_state = self._step_state(u, state) + y = self.output_contraction(self.dC, new_state) + return y, new_state + + def register(self, name, tensor, trainable=False, lr=None, wd=None): + """Utility method: register a tensor as a buffer or trainable parameter""" + + if trainable: + self.register_parameter(name, nn.Parameter(tensor)) + else: + self.register_buffer(name, tensor) + + optim = {} + if trainable and lr is not None: + optim["lr"] = lr + if trainable and wd is not None: + optim["weight_decay"] = wd + if len(optim) > 0: + setattr(getattr(self, name), "_optim", optim) + + +class HippoSSKernel(nn.Module): + + """Wrapper around SSKernel that generates A, B, C, dt according to HiPPO arguments. + + The SSKernel is expected to support the interface + forward() + default_state() + setup_step() + step() + """ + + def __init__( + self, + H, + N=64, + L=1, + measure="legs", + rank=1, + channels=1, # 1-dim to C-dim map; can think of C as having separate "heads" + dt_min=0.001, + dt_max=0.1, + trainable=None, # Dictionary of options to train various HiPPO parameters + lr=None, # Hook to set LR of hippo parameters differently + length_correction=True, # Multiply by I-A|^L after initialization; can be turned off for initialization speed + hurwitz=False, + tie_state=False, # Tie parameters of HiPPO ODE across the H features + precision=1, # 1 (single) or 2 (double) for the kernel + resample=False, # If given inputs of different lengths, adjust the sampling rate. Note that L should always be provided in this case, as it assumes that L is the true underlying length of the continuous signal + verbose=False, + ): + super().__init__() + self.N = N + self.H = H + L = L or 1 + self.precision = precision + dtype = torch.double if self.precision == 2 else torch.float + cdtype = torch.cfloat if dtype == torch.float else torch.cdouble + self.rate = None if resample else 1.0 + self.channels = channels + + # Generate dt + log_dt = torch.rand(self.H, dtype=dtype) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) + + w, p, B, _ = nplr(measure, self.N, rank, dtype=dtype) + C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) + self.kernel = SSKernelNPLR( + L, w, p, B, C, + log_dt, + hurwitz=hurwitz, + trainable=trainable, + lr=lr, + tie_state=tie_state, + length_correction=length_correction, + verbose=verbose, + ) + + def forward(self, L=None): + k, _ = self.kernel(rate=self.rate, L=L) + return k.float() + + def step(self, u, state, **kwargs): + u, state = self.kernel.step(u, state, **kwargs) + return u.float(), state + + def default_state(self, *args, **kwargs): + return self.kernel.default_state(*args, **kwargs) + + + + + + +def get_torch_trans(heads=8, layers=1, channels=64): + encoder_layer = nn.TransformerEncoderLayer( + d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu") + return nn.TransformerEncoder(encoder_layer, num_layers=layers) + + + + +class S4(nn.Module): + + def __init__( + self, + d_model, + d_state=64, + l_max=1, # Maximum length of sequence. Fine if not provided: the kernel will keep doubling in length until longer than sequence. However, this can be marginally slower if the true length is not a power of 2 + channels=1, # maps 1-dim to C-dim + bidirectional=False, + # Arguments for FF + activation='gelu', # activation in between SS and FF + postact=None, # activation after FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + hyper_act=None, # Use a "hypernetwork" multiplication + dropout=0.0, + transposed=True, # axis ordering (B, L, D) or (B, D, L) + verbose=False, + # SSM Kernel arguments + **kernel_args, + ): + + + """ + d_state: the dimension of the state, also denoted by N + l_max: the maximum sequence length, also denoted by L + if this is not known at model creation, set l_max=1 + channels: can be interpreted as a number of "heads" + bidirectional: bidirectional + dropout: standard dropout argument + transposed: choose backbone axis ordering of (B, L, H) or (B, H, L) [B=batch size, L=sequence length, H=hidden dimension] + + Other options are all experimental and should not need to be configured + """ + + + super().__init__() + + self.h = d_model + self.n = d_state + self.bidirectional = bidirectional + self.channels = channels + self.transposed = transposed + + # optional multiplicative modulation GLU-style + # https://arxiv.org/abs/2002.05202 + self.hyper = hyper_act is not None + if self.hyper: + channels *= 2 + self.hyper_activation = Activation(hyper_act) + + self.D = nn.Parameter(torch.randn(channels, self.h)) + + if self.bidirectional: + channels *= 2 + + + # SSM Kernel + self.kernel = HippoSSKernel(self.h, N=self.n, L=l_max, channels=channels, verbose=verbose, **kernel_args) + + # Pointwise + self.activation = Activation(activation) + dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + + + # position-wise output transform to mix features + self.output_linear = LinearActivation( + self.h*self.channels, + self.h, + transposed=self.transposed, + initializer=initializer, + activation=postact, + activate=True, + weight_norm=weight_norm, + ) + + #self.time_transformer = get_torch_trans(heads=8, layers=1, channels=self.h) + + + + + def forward(self, u, **kwargs): # absorbs return_output and transformer src mask + """ + u: (B H L) if self.transposed else (B L H) + state: (H N) never needed unless you know what you're doing + + Returns: same shape as u + """ + if not self.transposed: u = u.transpose(-1, -2) + L = u.size(-1) + + # Compute SS Kernel + k = self.kernel(L=L) # (C H L) (B C H L) + + # Convolution + if self.bidirectional: + k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2) + k = F.pad(k0, (0, L)) \ + + F.pad(k1.flip(-1), (L, 0)) \ + + k_f = torch.fft.rfft(k, n=2*L) # (C H L) + u_f = torch.fft.rfft(u, n=2*L) # (B H L) + y_f = contract('bhl,chl->bchl', u_f, k_f) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) + y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L) + + + # Compute D term in state space equation - essentially a skip connection + y = y + contract('bhl,ch->bchl', u, self.D) # u.unsqueeze(-3) * self.D.unsqueeze(-1) + + # Optional hyper-network multiplication + if self.hyper: + y, yh = rearrange(y, 'b (s c) h l -> s b c h l', s=2) + y = self.hyper_activation(yh) * y + + # Reshape to flatten channels + y = rearrange(y, '... c h l -> ... (c h) l') + + y = self.dropout(self.activation(y)) + + if not self.transposed: y = y.transpose(-1, -2) + + y = self.output_linear(y) + + # ysize = b, k, l, requieres l, b, k + #y = self.time_transformer(y.permute(2,0,1)).permute(1,2,0) + + + return y, None + + + def step(self, u, state): + """ Step one time step as a recurrent model. Intended to be used during validation. + + u: (B H) + state: (B H N) + Returns: output (B H), state (B H N) + """ + assert not self.training + + y, next_state = self.kernel.step(u, state) # (B C H) + y = y + u.unsqueeze(-2) * self.D + y = rearrange(y, '... c h -> ... (c h)') + y = self.activation(y) + if self.transposed: + y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) + else: + y = self.output_linear(y) + return y, next_state + + def default_state(self, *batch_shape, device=None): + return self.kernel.default_state(*batch_shape) + + @property + def d_state(self): + return self.h * self.n + + @property + def d_output(self): + return self.h + + @property + def state_to_tensor(self): + return lambda state: rearrange('... h n -> ... (h n)', state) + + + +class S4Layer(nn.Module): + #S4 Layer that can be used as a drop-in replacement for a TransformerEncoder + def __init__(self, features, lmax, N=64, dropout=0.0, bidirectional=True, layer_norm=True): + super().__init__() + self.s4_layer = S4(d_model=features, + d_state=N, + l_max=lmax, + bidirectional=bidirectional) + + self.norm_layer = nn.LayerNorm(features) if layer_norm else nn.Identity() + self.dropout = nn.Dropout2d(dropout) if dropout>0 else nn.Identity() + + def forward(self, x): + #x has shape seq, batch, feature + x = x.permute((1,2,0)) #batch, feature, seq (as expected from S4 with transposed=True) + xout, _ = self.s4_layer(x) #batch, feature, seq + xout = self.dropout(xout) + xout = xout + x # skip connection # batch, feature, seq + xout = xout.permute((2,0,1)) # seq, batch, feature + return self.norm_layer(xout) diff --git a/src/ecglib/models/architectures/sssd/sssd_ecg_nle.py b/src/ecglib/models/architectures/sssd/sssd_ecg_nle.py index 46ed491..d741281 100644 --- a/src/ecglib/models/architectures/sssd/sssd_ecg_nle.py +++ b/src/ecglib/models/architectures/sssd/sssd_ecg_nle.py @@ -1,2 +1,243 @@ -class SSSD_ECG_nle: - raise NotImplementedError \ No newline at end of file +import math +import torch +import torch.nn as nn + +from .s4 import S4Layer +from .util import swish, calc_diffusion_step_embedding + + +class Conv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): + super(Conv, self).__init__() + self.padding = dilation * (kernel_size - 1) // 2 + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding) + self.conv = nn.utils.parametrizations.weight_norm(self.conv) + nn.init.kaiming_normal_(self.conv.weight) + + def forward(self, x): + out = self.conv(x) + return out + + +class ZeroConv1d(nn.Module): + def __init__(self, in_channel, out_channel): + super(ZeroConv1d, self).__init__() + self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0) + self.conv.weight.data.zero_() + self.conv.bias.data.zero_() + + def forward(self, x): + out = self.conv(x) + return out + + +class Residual_block(nn.Module): + def __init__(self, res_channels, skip_channels, + diffusion_step_embed_dim_out, in_channels, + s4_lmax, + s4_d_state, + s4_dropout, + s4_bidirectional, + s4_layernorm, + label_embed_dim=None, + gender_embed_dim=None): + super(Residual_block, self).__init__() + self.res_channels = res_channels + + + self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels) + + self.S41 = S4Layer(features=2*self.res_channels, + lmax=s4_lmax, + N=s4_d_state, + dropout=s4_dropout, + bidirectional=s4_bidirectional, + layer_norm=s4_layernorm) + + self.conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3) + + self.S42 = S4Layer(features=2*self.res_channels, + lmax=s4_lmax, + N=s4_d_state, + dropout=s4_dropout, + bidirectional=s4_bidirectional, + layer_norm=s4_layernorm) + + self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1) + self.res_conv = nn.utils.parametrizations.weight_norm(self.res_conv) + nn.init.kaiming_normal_(self.res_conv.weight) + + + self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1) + self.skip_conv = nn.utils.parametrizations.weight_norm(self.skip_conv) + nn.init.kaiming_normal_(self.skip_conv.weight) + + + #the layer-specific fc for conditional embeddings (conditional case) + self.fc_label = nn.Linear(label_embed_dim, 2 * self.res_channels) if label_embed_dim is not None else None + self.fc_gender = nn.Linear(gender_embed_dim, 2 * self.res_channels) if gender_embed_dim is not None else None + + def forward(self, input_data): + x, label_embed, gender_embed, diffusion_step_embed = input_data + h = x + B, C, L = x.shape + assert C == self.res_channels + + part_t = self.fc_t(diffusion_step_embed) + part_t = part_t.view([B, self.res_channels, 1]) + h = h + part_t + + h = self.conv_layer(h) + h = self.S41(h.permute(2,0,1)).permute(1,2,0) + + # process label embedding + if(self.fc_label is not None): + label_embed = self.fc_label(label_embed).unsqueeze(2) #output B, 2C, 1 + h = h + label_embed + + # process gender embedding + if self.fc_gender is not None: + gender_embed = self.fc_gender(gender_embed).unsqueeze(2) + h = h + gender_embed + + # Doubled S4 feeding only if h added with label/gender embedding + if self.fc_label is not None or self.fc_gender is not None: + h = self.S42(h.permute(2,0,1)).permute(1,2,0) + + out = torch.tanh(h[:,:self.res_channels,:]) * torch.sigmoid(h[:,self.res_channels:,:]) + + res = self.res_conv(out) + assert x.shape == res.shape + skip = self.skip_conv(out) + + return (x + res) * math.sqrt(0.5), skip # normalize for training stability + + +class Residual_group(nn.Module): + def __init__(self, res_channels, skip_channels, num_res_layers, + diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out, + in_channels, + s4_lmax, + s4_d_state, + s4_dropout, + s4_bidirectional, + s4_layernorm, + label_embed_dim=None, + gender_embed_dim=None): + super(Residual_group, self).__init__() + self.num_res_layers = num_res_layers + self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in + + self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) + self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) + + self.residual_blocks = nn.ModuleList() + for n in range(self.num_res_layers): + self.residual_blocks.append(Residual_block(res_channels, skip_channels, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + in_channels=in_channels, + s4_lmax=s4_lmax, + s4_d_state=s4_d_state, + s4_dropout=s4_dropout, + s4_bidirectional=s4_bidirectional, + s4_layernorm=s4_layernorm, + label_embed_dim=label_embed_dim, + gender_embed_dim=gender_embed_dim)) + + + def forward(self, input_data): + noise, label_embed, gender_embed, diffusion_steps = input_data + + diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in) + diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) + diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) + + h = noise + skip = 0 + for n in range(self.num_res_layers): + h, skip_n = self.residual_blocks[n]((h, label_embed, gender_embed, diffusion_step_embed)) + skip += skip_n + + return skip * math.sqrt(1.0 / self.num_res_layers) + + +class SSSD_ECG_nle(nn.Module): + def __init__(self, in_channels, res_channels, skip_channels, out_channels, + num_res_layers, + diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out, + s4_lmax, + s4_d_state, + s4_dropout, + s4_bidirectional, + s4_layernorm, + label_embed_classes=0, + label_embed_dim=128, + gender_embed_classes=0, + gender_embed_dim=128, + new_label_embed=False, + ): + super(SSSD_ECG_nle, self).__init__() + + self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU()) + + # embedding for global conditioning + self.new_label_embed = new_label_embed + if not self.new_label_embed: + self.label_embedding = nn.Embedding(label_embed_classes, label_embed_dim) if label_embed_classes>0 else None + else: + # Now have `label_embed_classes` pairs of embeddings + # so, doesn't need multiplication on embedding weights + # It useful because we define 0 as negative class and when multiply 0 we don't transform information + self.label_embedding = nn.ModuleList([nn.Embedding(2, label_embed_dim, padding_idx=-1) if label_embed_classes > 0 else None for _ in range(label_embed_classes)]) + self.label_embedding_conv = nn.Conv1d(in_channels=label_embed_classes, out_channels=1, kernel_size=1) if label_embed_classes > 1 else None + self.gender_embedding = nn.Embedding(gender_embed_classes, gender_embed_dim) if gender_embed_classes > 0 else None + + self.residual_layer = Residual_group(res_channels=res_channels, + skip_channels=skip_channels, + num_res_layers=num_res_layers, + diffusion_step_embed_dim_in=diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + in_channels=in_channels, + s4_lmax=s4_lmax, + s4_d_state=s4_d_state, + s4_dropout=s4_dropout, + s4_bidirectional=s4_bidirectional, + s4_layernorm=s4_layernorm, + label_embed_dim=label_embed_dim if label_embed_classes > 0 else None, + gender_embed_dim=gender_embed_dim if gender_embed_classes > 0 else None) + + self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1), + nn.ReLU(), + ZeroConv1d(skip_channels, out_channels)) + + def forward(self, input_data): + + noise, label, gender, diffusion_steps = input_data + + if not self.new_label_embed: + # Multiplication on weights + label_embed = label @ self.label_embedding.weight if self.label_embedding is not None else None + else: + # Choose embedding in each pair + label_embed = torch.stack([embedding(label[:, i]).squeeze(1) for i, embedding in enumerate(self.label_embedding)]).permute(1, 0, 2) + if self.label_embedding_conv is not None: + label_embed = self.label_embedding_conv(label_embed).squeeze(1) + else: + label_embed = label_embed.squeeze(1) + gender_embed = self.gender_embedding(gender).squeeze(1) if self.gender_embedding is not None else None + + x = noise + x = self.init_conv(x) + x = self.residual_layer((x, label_embed, gender_embed, diffusion_steps)) + y = self.final_conv(x) + + return y + + +def sssd_ecg_nle(**kwargs): + return SSSD_ECG_nle(**kwargs) diff --git a/src/ecglib/models/architectures/sssd/util.py b/src/ecglib/models/architectures/sssd/util.py new file mode 100644 index 0000000..8ee5924 --- /dev/null +++ b/src/ecglib/models/architectures/sssd/util.py @@ -0,0 +1,33 @@ +import torch +import numpy as np + +def swish(x): + return x * torch.sigmoid(x) + + +def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): + """ + Embed a diffusion step $t$ into a higher dimensional space + E.g. the embedding vector in the 128-dimensional space is + [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))] + + Parameters: + diffusion_steps (torch.long tensor, shape=(batchsize, 1)): + diffusion steps for batch data + diffusion_step_embed_dim_in (int, default=128): + dimensionality of the embedding space for discrete diffusion steps + + Returns: + the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)): + """ + + assert diffusion_step_embed_dim_in % 2 == 0 + + half_dim = diffusion_step_embed_dim_in // 2 + _embed = np.log(10000) / (half_dim - 1) + _embed = torch.exp(torch.arange(half_dim) * -_embed).cuda() + _embed = diffusion_steps * _embed + diffusion_step_embed = torch.cat((torch.sin(_embed), + torch.cos(_embed)), 1) + + return diffusion_step_embed \ No newline at end of file diff --git a/src/ecglib/models/config/model_configs.py b/src/ecglib/models/config/model_configs.py index 3b756d2..0e1e895 100644 --- a/src/ecglib/models/config/model_configs.py +++ b/src/ecglib/models/config/model_configs.py @@ -10,7 +10,8 @@ "ResNetConfig", "TabularNetConfig", "DenseNetConfig", - "CNN1dConfig" + "CNN1dConfig", + "SSSDConfig" ] @@ -110,3 +111,28 @@ class CNN1dConfig(BaseConfig): inp_channels: int = 12 inp_features: int = 1 cnn_ftrs: list = field(default_factory=lambda: [64, 32, 16]) + + +@dataclass(repr=True, eq=True) +class SSSDConfig(BaseConfig): + """ + Default parameters correspond SSSD_ECG model + """ + in_channels: int = 8 + res_channels: int = 256 + skip_channels: int = 256 + out_channels: int = 8 + num_res_layers: int = 36 + diffusion_step_embed_dim_in: int = 128 + diffusion_step_embed_dim_mid: int = 512 + diffusion_step_embed_dim_out: int = 512 + s4_lmax: int = 1000 + s4_d_state: int = 64 + s4_dropout: float = 0.0 + s4_bidirectional: bool = True + s4_layernorm: bool = True + label_embed_dim: int = 128 + label_embed_classes: int = 40 + gender_embed_classes: int = 0 + gender_embed_dim: int = 128 + new_label_embed: bool = True diff --git a/src/ecglib/models/config/registred_configs.py b/src/ecglib/models/config/registred_configs.py index c0224d1..0abdae9 100644 --- a/src/ecglib/models/config/registred_configs.py +++ b/src/ecglib/models/config/registred_configs.py @@ -3,6 +3,7 @@ ResNetConfig, DenseNetConfig, TabularNetConfig, + SSSDConfig, ) from ..architectures.model_types import MType @@ -16,6 +17,7 @@ MType.RESNET: ResNetConfig, MType.DENSENET: DenseNetConfig, MType.TABULAR: TabularNetConfig, + MType.SSSD: SSSDConfig } diff --git a/src/ecglib/preprocessing/functional.py b/src/ecglib/preprocessing/functional.py index 6a0f704..05e53d9 100644 --- a/src/ecglib/preprocessing/functional.py +++ b/src/ecglib/preprocessing/functional.py @@ -191,6 +191,19 @@ def z_normalization( s_norm[same_values] = 0 return s_norm +def identical_nomralization( + s: np.ndarray, +) -> np.ndarray: + """function to identical normalization + + Args: + s (np.ndarray): signal + + Returns: + np.ndarray: preprocessed (identical) signal + """ + return s + def DWT_filter( s: np.ndarray, diff --git a/src/ecglib/preprocessing/preprocess.py b/src/ecglib/preprocessing/preprocess.py index 520b2d1..6f065f7 100644 --- a/src/ecglib/preprocessing/preprocess.py +++ b/src/ecglib/preprocessing/preprocess.py @@ -155,7 +155,7 @@ def __call__(self, x): class Normalization: """ Apply normalization - :param norm_type: type of normalization ('z_norm', 'z_norm_constant_handle', and 'min_max') + :param norm_type: type of normalization ('z_norm', 'z_norm_constant_handle', 'min_max' and 'identical') :return: preprocessed data """ @@ -169,9 +169,11 @@ def __init__( self.func = F.minmax_normalization elif norm_type == "z_norm" or norm_type == "z_norm_constant_handle": self.func = F.z_normalization + elif norm_type == "identical": + self.func = F.identical_nomralization else: raise ValueError( - "norm_type must be one of [min_max, z_norm, z_norm_constant_handle]" + "norm_type must be one of [min_max, z_norm, z_norm_constant_handle, identical]" ) def apply_normalization(self, x): From 6a054f15b95039e5155f1c6abfebccf91d19aa65 Mon Sep 17 00:00:00 2001 From: Sergey Skorik Date: Fri, 12 Jul 2024 08:44:44 +0000 Subject: [PATCH 4/4] update the functionality of loading ptb-xl --- notebooks/sssd_ecg_nle.ipynb | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/notebooks/sssd_ecg_nle.ipynb b/notebooks/sssd_ecg_nle.ipynb index 3cbc5fb..84c42ab 100644 --- a/notebooks/sssd_ecg_nle.ipynb +++ b/notebooks/sssd_ecg_nle.ipynb @@ -77,10 +77,16 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "need to change a logic of downloading" + "path_to_unzip = '' # If you already download ptb_xl dataset, put unzip path here. \n", + " # Make sure that path to map file has a structure \n", + " # f'{path_to_unzip}/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3/ptbxl_database.csv\"\n", + "download = True if path_to_unzip else False\n", + "SAMPLE_FREQUENCY = 100" ] }, { @@ -89,9 +95,7 @@ "metadata": {}, "outputs": [], "source": [ - "path_to_unzip = '/home/ecg_data/physionet_data/raw_data/'\n", - "SAMPLE_FREQUENCY = 100\n", - "ptb_xl_info = load_ptb_xl(path_to_unzip=path_to_unzip, frequency=SAMPLE_FREQUENCY)\n", + "ptb_xl_info = load_ptb_xl(path_to_unzip=path_to_unzip, frequency=SAMPLE_FREQUENCY, download=download)\n", "ptb_xl_info['frequency'] = SAMPLE_FREQUENCY\n", "\n", "# Split in accordance with PTB-XL Benchmarking\n",