Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanfeng97 committed Jan 29, 2024
1 parent c4bbc17 commit a8daf38
Show file tree
Hide file tree
Showing 41 changed files with 10,066 additions and 2 deletions.
58 changes: 56 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,56 @@
# HGM2R
Source code for IEEE TPAMI 2024 "Hypergraph-Based Multi-Modal Representation for Open-Set 3D Object Retrieval"
# Hypergraph-Based Multi-Modal Representation for Open-Set 3D Object Retrieval
This repository contains the source code for the paper "Hypergraph-Based Multi-Modal Representation for Open-Set 3D Object Retrieval" published in IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI) 2024 by [Yifan Feng](https://fengyifan.site/), Shuyi Ji, Yu-Shen Liu, Shaoyi Du, Qionghai Dai, Yue Gao*. This paper is available at [here](https://ieeexplore.ieee.org/abstract/document/10319392/).

![framework](doc/fw.jpg)

## Introduction
In this repository, we provide our implementation of Hypergraph-Based Multi-Modal Representation (HGM$^2$R), which is based on the following environments:
* [python 3.9](https://www.python.org/): basic programming language.
* [torch 1.12.1](https://pytorch.org/): for computation.
* [hydra-core 1.3.2](https://hydra.cc/docs/intro/): for configuration and multi-run management.

## Installation
1. Clone this repository.
2. Install the required libraries.
``` bash
pip install -r requirements.txt
```

## Downloads
In this paper, we release four datasets (OS-ESB-core, OS-NTU-core, OS-MN40-core, and OS-ABO-core) for Open-Set Retrieval task, which can be download in [here](https://moon-lab.tech/os3dor). Our dataset splitting files of the four datasets can be download in [here](https://pan.baidu.com/s/1uvrT9xBps6DENndxyu8V_A?pwd=evby). And those pre-extracted features of the four datasets can be download in [here](https://pan.baidu.com/s/1gylhaVO9dXmTbfIqV3DMBw?pwd=y774). The pre-extracted features should be placed in the `feature` folder.

## Usage
First, you should compress the voxel features with the following command:
```bash
python pre_vox_ft_compress.py
```
Then, you can train the HGM$^2$R model with the following command:
```bash
python train_hgm2r.py
```
To change the dataset, you can modify the `line 272` of `train_hgm2r.py` with names (esb, ntu, mn40, abo):
```python
dataset = "esb" # esb, ntu, mn40, abo
```

## Citation
If you find this repository useful in your research, please cite our following papers:
```
@article{feng2023hypergraph,
title={Hypergraph-Based Multi-Modal Representation for Open-Set 3D Object Retrieval},
author={Feng, Yifan and Ji, Shuyi and Liu, Yu-Shen and Du, Shaoyi and Dai, Qionghai and Gao, Yue},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2023},
publisher={IEEE}
}
@inproceedings{feng2019hypergraph,
title={Hypergraph neural networks},
author={Feng, Yifan and You, Haoxuan and Zhang, Zizhao and Ji, Rongrong and Gao, Yue},
booktitle={Proceedings of the AAAI conference on artificial intelligence},
volume={33},
number={01},
pages={3558--3565},
year={2019}
}
```
Binary file added doc/fw.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions feature/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
More pre-extraced features can be downloaded from [here](https://pan.baidu.com/s/1gylhaVO9dXmTbfIqV3DMBw?pwd=y774).
Binary file added feature/esb__level_all__t2r8__mv12_resnet18.npy
Binary file not shown.
Binary file not shown.
Binary file added feature/esb__level_all__t2r8__vox32_voxnet.npy
Binary file not shown.
28 changes: 28 additions & 0 deletions feature_extraction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Mult-modal 3D Object Feature Extraction
This repository contains the code for extracting multi-modal features from 3D objects. The features are extracted from the following modalities:
- Multi-view: MVCNN
- Pointclouds: PointNet, DGCNN
- Voxel: 3DShapeNets

## Requirements
- Python 3.6
- Pytorch 1.12
- Hydra-core 1.3.2

## Installation
1. Clone this repository
2. Install dependencies
```bash
pip install -r requirements.txt
```

## Usage
### Train Model
```bash
python mv_resnet18_train.py
```
### Extract Features
```bash
python mv_resnet18_gen_ft.py
```

Empty file added feature_extraction/__init__.py
Empty file.
216 changes: 216 additions & 0 deletions feature_extraction/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import json
import torch
import numpy as np
import open3d as o3d
from PIL import Image
from pathlib import Path
from functools import partial
import torchvision.transforms as T
from typing import List, Union, Dict
from torch.utils.data import Dataset, DataLoader
# from torchdata.datapipes.iter import IterableWrapper, Mapper

def fetch_img_list(path: Union[str, Path], n_view, real=False):
if not real:
all_filenames = sorted(list(Path(path).glob('image/h_*.jpg')))
else:
all_filenames = sorted(list(Path(path).glob('real_image/main_*.jpg')))
all_view = len(all_filenames)
filenames = all_filenames[::all_view//n_view][:n_view]
return filenames

def read_image(path_list: Union[List[str], List[Path]], augment=False, img_size=224):
if augment:
transform = T.Compose([
T.RandomResizedCrop((img_size, img_size)),
T.RandomHorizontalFlip(),
T.ToTensor(),
])
else:
transform = T.Compose([
T.Resize((img_size, img_size)),
T.ToTensor(),
])
imgs = [transform(Image.open(v).convert("RGB")) for v in path_list]
imgs = torch.stack(imgs)
return imgs

def fetch_pt_path(path: Union[str, Path], n_pt):
p = Path(path) / 'pointcloud' / f'pt_{n_pt}.pts'
if p.exists():
return p
else:
return Path(path).with_suffix('.pts')

def read_pointcloud(path: Union[str, Path], augment=False):
pt = np.asarray(o3d.io.read_point_cloud(str(path)).points)
pt = pt - np.expand_dims(np.mean(pt, axis=0), 0)
dist = np.max(np.sqrt(np.sum(pt ** 2, axis=1)), 0)
pt = pt / dist
if augment:
xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
pt = np.add(np.multiply(pt, xyz1), xyz2).astype('float32')
pt = torch.from_numpy(pt.astype(np.float32))
return pt.transpose(0, 1)

def fetch_vox_path(path: Union[str, Path], d_vox):
return Path(path) / 'voxel' / f'vox_{d_vox}.ply'

def read_voxel(path: Union[str, Path], d_vox, augment=False):
vox_3d = o3d.io.read_voxel_grid(str(path))
vox_idx = torch.from_numpy(np.array([v.grid_index for v in vox_3d.get_voxels()])).float()
vox_idx = vox_idx/vox_idx.max()
# if False and augment:
if False and augment:
vox_idx = vox_idx * 2 - 1
theta = torch.rand(1) * 2 * torch.pi
rot_mat = torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
[0, 0, 1]])
vox_idx = torch.mm(vox_idx, rot_mat)
vox_idx = (vox_idx - vox_idx.min(0)[0]) / 2
vox_idx = vox_idx * d_vox
vox_idx = torch.clamp(vox_idx, 0, d_vox - 1).long()
vox = torch.zeros((d_vox, d_vox, d_vox))
vox[vox_idx[:,0], vox_idx[:,1], vox_idx[:,2]] = 1
return vox.unsqueeze(0)

class VOMM_dataset(Dataset):
def __init__(self, data_root, sample_list, modality_cfg, augment=False, real=False):
super().__init__()
data_root = Path(data_root)
self.real=real
self.augment = augment
self.cfg = modality_cfg
self.path_list = [data_root / sample['path'] for sample in sample_list]
self.label_idx_list = [sample['label_idx'] for sample in sample_list]
self.name_list = [s['path'] for s in sample_list]
self.label_name_list = [s['label'] for s in sample_list]
self.n_class = len(set(self.label_idx_list))

def __getitem__(self, index):
cur_path = self.path_list[index]
ret = [self.label_idx_list[index], ]
for m, m_cfg in self.cfg.items():
if m == 'image':
img_list = fetch_img_list(cur_path, **m_cfg, real=self.real)
data = read_image(img_list, augment=self.augment)
elif m == 'pointcloud':
path = fetch_pt_path(cur_path, **m_cfg)
data = read_pointcloud(path, augment=self.augment)
elif m == 'voxel':
path = fetch_vox_path(cur_path, **m_cfg)
data = read_voxel(path, **m_cfg, augment=self.augment)
ret.append(data)
return ret

def __len__(self):
return len(self.path_list)

def VOMM_Shot_Data(data_root, split_file, modality_cfg):
with open(split_file, 'r') as f:
split = json.load(f)
train_set = VOMM_dataset(data_root, split['train'], modality_cfg, augment=True)
val_set = VOMM_dataset(data_root, split['validation'], modality_cfg, augment=False)
test_set = VOMM_dataset(data_root, split['test'], modality_cfg, augment=False)
return train_set, val_set, test_set

def VOMM_OSR_Data(data_root, split_file, modality_cfg, data_ret_root=None, train_real=False, real=False):
with open(split_file, 'r') as f:
split = json.load(f)
if data_ret_root is None:
data_ret_root = data_root
if train_real:
train_data = VOMM_dataset(data_root, split['train']['sample'], modality_cfg, augment=True, real=True)
else:
train_data = VOMM_dataset(data_root, split['train']['sample'], modality_cfg, augment=True)
query_data = VOMM_dataset(data_ret_root, split['retrieval']['query'], modality_cfg, augment=False, real=real)
target_data = VOMM_dataset(data_ret_root, split['retrieval']['target'], modality_cfg, augment=False, real=real)
return train_data, query_data, target_data

# def VOMM_datapipe(data_root: Union[str, Path], sample_list: List, modality_cfg: Dict[str, int], augment=False):
# data_root = Path(data_root)
# if isinstance(modality_cfg, str):
# modality_cfg = {modality_cfg: None}
# for m in modality_cfg.keys():
# if m not in ['image', 'pointcloud', 'voxel']:
# raise ValueError(f'Unknown modality: {m}')

# path_list = [data_root / sample['path'] for sample in sample_list]
# lbl_idx_list = [sample['label_idx'] for sample in sample_list]

# # for different modality
# data_dp_list = []
# for m, m_cfg in modality_cfg.items():
# if m == 'image':
# cur_dp = IterableWrapper(path_list)
# _fetch_img_list = partial(fetch_img_list, **m_cfg)
# cur_dp = Mapper(cur_dp, _fetch_img_list)
# _read_img = partial(read_image, augment=augment)
# cur_dp = Mapper(cur_dp, _read_img)
# elif m == 'pointcloud':
# cur_dp = IterableWrapper(path_list)
# _fetch_pt_path = partial(fetch_pt_path, **m_cfg)
# cur_dp = Mapper(cur_dp, _fetch_pt_path)
# _read_pt = partial(read_pointcloud, augment=augment)
# cur_dp = Mapper(cur_dp, _read_pt)
# elif m == 'voxel':
# cur_dp = IterableWrapper(path_list)
# _fetch_vox_path = partial(fetch_vox_path, **m_cfg)
# cur_dp = Mapper(cur_dp, _fetch_vox_path)
# _read_vox = partial(read_voxel, augment=augment)
# cur_dp = Mapper(cur_dp, _read_vox)
# data_dp_list.append(cur_dp)
# # for label
# dp_lbl = IterableWrapper(lbl_idx_list)
# return dp_lbl.zip(*data_dp_list)

# def _get_datapipe(data_root, sample_list, modality_cfg, label_map, augment=False, detail=False):
# cur_data = VOMM_datapipe(data_root, sample_list, modality_cfg, augment=augment)
# cur_dataset = {'data': cur_data, 'n_class': len(label_map)}
# if detail:
# cur_dataset['name'] = [s['path'] for s in sample_list]
# cur_dataset['label_name'] = [s['label'] for s in sample_list]
# return cur_dataset

# def VOMM_Shot_Data(data_root, split_file, modality_cfg, detail=False):
# with open(split_file, 'r') as f:
# split = json.load(f)
# train_set = _get_datapipe(data_root, split['train'], modality_cfg, split['label_map'], augment=False, detail=detail)
# val_set = _get_datapipe(data_root, split['validation'], modality_cfg, split['label_map'], augment=False, detail=detail)
# test_set = _get_datapipe(data_root, split['test'], modality_cfg, split['label_map'], augment=False, detail=detail)
# return train_set, val_set, test_set


if __name__ == '__main__':
data_root = '/home2/fengyifan/data/modelnet/40/ModelNet40_MM'
osr_json_path = '/home2/fengyifan/code/OSR/Extract-Feature/splits/mn40__level_all__t2r8.json'
shot_json_path = '/home2/fengyifan/code/OSR/Extract-Feature/splits/mn40__level_0__t20v20.json'
modality_cfg = {
'image': {'n_view':8},
# 'pointcloud': {'n_pt': 1024},
# 'voxel': {'n_vox': 32}
}
# img_root = Path('/media/fengyifan/本地磁盘/NTU/NTU_2000_MM/chess/Y3813_pawn/image')
# img_list = sorted([str(p) for p in img_root.glob('*.jpg')])[::4]
# imgs = read_image(img_list, augment=True)
# print(imgs.shape)

# pt = read_pointcloud('/media/fengyifan/本地磁盘/NTU/NTU_2000_MM/chess/Y3813_pawn/pointcloud/pt_1024.pts')
# print(pt.shape)

# vox = read_voxel('/media/fengyifan/本地磁盘/NTU/NTU_2000_MM/chess/Y3813_pawn/voxel/vox_32.ply')
# print(vox.shape)

# import json
# with open(json_path, 'r') as fp:
# data = json.load(fp)
# dp = VOMM_datapipe(data_root, data['train']['sample'], modality_cfg, augment=True)
# batch = next(iter(dp))
# print(batch[1].shape)

train_set, val_set, test_set = VOMM_Shot_Data(data_root, shot_json_path, modality_cfg)
train_dataloader = DataLoader(train_set, batch_size=32, shuffle=True)
for idx, (lbl, sample) in enumerate(train_dataloader):
print(lbl)
44 changes: 44 additions & 0 deletions feature_extraction/mn40_shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

mark: mn40_shot_10

arch:
pretrained: true
n_view: 12
img_size: 224
ft_dim: 512

# 1080Ti
# n_view bz
# 12 32
batch_size: 64

# train
lr: 0.001
momentum: 0.9
weight_decay: 5e-4
max_epoch: 80
val_interval: 5
cos:
T_max: 40
eta_min: 1e-6
es:
patience: 10
threshold: 0.005

n_class: 40
n_worker: 16

path:
wd: /home2/fengyifan/code/OSR/Extract-Feature
data_root: /home2/fengyifan/data/modelnet/40/ModelNet40_MM
split: /home2/fengyifan/code/OSR/Extract-Feature/splits/mn40__level_0__t10v10.json
# outputs
out: ${path.wd}/cache
res: ${path.out}/${mark}

hydra:
run:
dir: ${path.res}/${now:%Y-%m-%d}_${now:%H-%M-%S}
sweep:
dir: ${path.res}/${now:%Y-%m-%d}_${now:%H-%M-%S}
subdir: ${hydra.job.num}
Loading

0 comments on commit a8daf38

Please sign in to comment.