Skip to content

Commit

Permalink
Merge pull request #3 from mini-sora/update_doc
Browse files Browse the repository at this point in the history
Add doc for reproduction
  • Loading branch information
PeterH0323 authored Mar 25, 2024
2 parents a6cd2ed + cf20bb1 commit 7aa6083
Showing 1 changed file with 43 additions and 2 deletions.
45 changes: 43 additions & 2 deletions README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,56 @@ English | [简体中文](./README_CN.md)

## 数据集

...
- ImageNet-1K

可以在 OpenDatalab 进行下载 [ImageNet-1K](https://opendatalab.org.cn/OpenDataLab/ImageNet-1K)

```shell
pip install openxlab #安装
pip install -U openxlab #版本升级
openxlab login #进行登录,输入对应的AK/SK

cd ${dataset_dir}
openxlab dataset get --dataset-repo OpenDataLab/ImageNet-1K #数据集下载
```

## 复现步骤

1. 数据集预处理

因为在原版 Meta 的 [DiT](https://github.com/facebookresearch/DiT) 中,每个 iter 都会对数据进行重复计算,为了节省训练的时间,可以先对图片进行预处理,在训练的时候可以节省这部分的时间

详见 dev 分支中的 [extract_features.py#L163](https://github.com/mini-sora/MiniSora-DiT/blob/ad13c58370842db333c77253709e3fbbc1e9a092/extract_features.py#L163-L177) ,处理需要时间较久,大概 1~2小时。

```python
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
# Map input images to latent space + normalize latents:
x = vae.encode(x).latent_dist.sample().mul_(0.18215)

x = x.detach().cpu().numpy() # (1, 4, 32, 32)
np.save(f'{args.features_path}/imagenet256_features/{train_steps}.npy', x)

y = y.detach().cpu().numpy() # (1,)
np.save(f'{args.features_path}/imagenet256_labels/{train_steps}.npy', y)

train_steps += 1
print(train_steps)
```

2. 使用 mmengine 重写数据流
3. 重写 loss 计算
4. 使用 xtuner 调训练 pipeline

## 模型架构

...

## 算力需求

...
论文原版是:`8 x A100` ,但是使用混合精度,可以使用 `2 x A100`

## 其他项目

Expand Down

0 comments on commit 7aa6083

Please sign in to comment.