Skip to content

Commit

Permalink
fix various bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Dec 5, 2024
1 parent 4587bd9 commit 0786495
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 13 deletions.
5 changes: 3 additions & 2 deletions python/cugraph-pyg/cugraph_pyg/loader/link_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def __init__(
)

# Note reverse of standard convention here
edge_label_index[0] += data[1]._vertex_offsets[input_type[0]]
edge_label_index[1] += data[1]._vertex_offsets[input_type[2]]
if input_type is not None:
edge_label_index[0] += data[1]._vertex_offsets[input_type[0]]
edge_label_index[1] += data[1]._vertex_offsets[input_type[2]]

self.__input_data = torch_geometric.sampler.EdgeSamplerInput(
input_id=torch.arange(
Expand Down
4 changes: 2 additions & 2 deletions python/cugraph-pyg/cugraph_pyg/loader/node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def __init__(
input_nodes,
input_id,
)

input_nodes += data[1]._vertex_offsets[input_type]
if input_type is not None:
input_nodes += data[1]._vertex_offsets[input_type]

self.__input_data = torch_geometric.sampler.NodeSamplerInput(
input_id=torch.arange(len(input_nodes), dtype=torch.int64, device="cuda")
Expand Down
18 changes: 9 additions & 9 deletions python/cugraph-pyg/cugraph_pyg/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ def __next__(self):
self.__raw_sample_data, start_inclusive, end_inclusive = next(
self.__base_reader
)
print(self.__raw_sample_data)

lho_name = (
"label_type_hop_offsets"
if "label_type_hop_offsets" in self.__raw_sample_data
else "label_type_hop_offsets"
else "label_hop_offsets"
)

self.__raw_sample_data["input_offsets"] -= self.__raw_sample_data[
Expand Down Expand Up @@ -279,9 +279,6 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
for etype in range(num_edge_types):
pyg_can_etype = self.__edge_types[etype]

print(raw_sample_data["map"])
print(raw_sample_data["renumber_map_offsets"])

jx = self.__src_types[etype] + index * self.__num_vertex_types
map_ptr_src_beg = raw_sample_data["renumber_map_offsets"][jx]
map_ptr_src_end = raw_sample_data["renumber_map_offsets"][jx + 1]
Expand All @@ -306,20 +303,23 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
edge_ptr_end = (
index * num_edge_types * fanout_length + (etype + 1) * fanout_length
)
lho = raw_sample_data["label_type_hop_offsets"][edge_ptr_beg:edge_ptr_end]
lho = raw_sample_data["label_type_hop_offsets"][
edge_ptr_beg : edge_ptr_end + 1
]

num_sampled_edges[pyg_can_etype] = (lho).diff().cpu()

eid_i = raw_sample_data["edge_id"][edge_ptr_beg:edge_ptr_end]
eid_i = raw_sample_data["edge_id"][lho[0] : lho[-1]]

eirx = (index * num_edge_types) + etype
edge_id_ptr_beg = raw_sample_data["edge_renumber_map_offsets"][eirx]
edge_id_ptr_end = raw_sample_data["edge_renumber_map_offsets"][eirx + 1]

emap = raw_sample_data["edge_renumber_map"][edge_id_ptr_beg:edge_id_ptr_end]
edge[pyg_can_etype] = emap[eid_i]

col[pyg_can_etype] = raw_sample_data["majors"][edge_ptr_beg:edge_ptr_end]
row[pyg_can_etype] = raw_sample_data["minors"][edge_ptr_beg:edge_ptr_end]
col[pyg_can_etype] = raw_sample_data["majors"][lho[0] : lho[-1]]
row[pyg_can_etype] = raw_sample_data["minors"][lho[0] : lho[-1]]

num_sampled_nodes = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,64 @@ def test_link_neighbor_loader_negative_sampling_uneven(batch_size):
elx = torch.tensor_split(elx, eix.numel() // batch_size, dim=1)
for i, batch in enumerate(loader):
assert batch.edge_label[0] == 1.0


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.sg
def test_neighbor_loader_hetero_basic():
src = torch.tensor([0, 1, 2, 4, 3, 4, 5, 5]) # paper
dst = torch.tensor([4, 5, 4, 3, 2, 1, 0, 1]) # paper

asrc = torch.tensor([0, 1, 2, 3, 3, 0]) # author
adst = torch.tensor([0, 1, 2, 3, 4, 5]) # paper

graph_store = GraphStore()
feature_store = TensorDictFeatureStore()

graph_store[("paper", "cites", "paper"), "coo"] = [src, dst]
graph_store[("author", "writes", "paper"), "coo"] = [asrc, adst]

from cugraph_pyg.loader import NeighborLoader

loader = NeighborLoader(
(feature_store, graph_store),
num_neighbors=[1, 1, 1, 1],
input_nodes=("paper", torch.tensor([0, 1])),
batch_size=2,
)

out = next(iter(loader))

assert sorted(out["paper"].n_id.tolist()) == [0, 1, 4, 5]
assert sorted(out["author"].n_id.tolist()) == [0, 1, 3]


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.sg
def test_neighbor_loader_hetero_single_etype():
src = torch.tensor([0, 1, 2, 4, 3, 4, 5, 5]) # paper
dst = torch.tensor([4, 5, 4, 3, 2, 1, 0, 1]) # paper

asrc = torch.tensor([0, 1, 2, 3, 3, 0]) # author
adst = torch.tensor([0, 1, 2, 3, 4, 5]) # paper

graph_store = GraphStore()
feature_store = TensorDictFeatureStore()

graph_store[("paper", "cites", "paper"), "coo"] = [src, dst]
graph_store[("author", "writes", "paper"), "coo"] = [asrc, adst]

from cugraph_pyg.loader import NeighborLoader

loader = NeighborLoader(
(feature_store, graph_store),
num_neighbors=[0, 1, 0, 1],
input_nodes=("paper", torch.tensor([0, 1])),
batch_size=2,
)

out = next(iter(loader))

assert out["author"].n_id.numel() == 0
assert out["author", "writes", "paper"].edge_index.numel() == 0
assert out["author", "writes", "paper"].num_sampled_edges.tolist() == [0, 0]

0 comments on commit 0786495

Please sign in to comment.