Skip to content

Commit

Permalink
Fixing remaining python test
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Nov 15, 2024
1 parent adda5fc commit c4b8590
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions python/pylibraft/pylibraft/test/test_handle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,27 +17,37 @@
import pytest

from pylibraft.common import DeviceResources, Stream, device_ndarray
from pylibraft.distance import pairwise_distance
from pylibraft.random import rmat

cupy = pytest.importorskip("cupy")


@pytest.mark.parametrize("stream", [cupy.cuda.Stream().ptr, Stream()])
def test_handle_external_stream(stream):
def generate_theta(r_scale, c_scale):
max_scale = max(r_scale, c_scale)
theta = np.random.random_sample(max_scale * 4)
for i in range(max_scale):
a = theta[4 * i]
b = theta[4 * i + 1]
c = theta[4 * i + 2]
d = theta[4 * i + 3]
total = a + b + c + d
theta[4 * i] = a / total
theta[4 * i + 1] = b / total
theta[4 * i + 2] = c / total
theta[4 * i + 3] = d / total
theta_device = device_ndarray(theta)
return theta, theta_device

input1 = np.random.random_sample((50, 3))
input1 = np.asarray(input1, order="F").astype("float")

output = np.zeros((50, 50), dtype="float")
@pytest.mark.parametrize("stream", [cupy.cuda.Stream().ptr, Stream()])
def test_handle_external_stream(stream):

input1_device = device_ndarray(input1)
output_device = device_ndarray(output)
theta, theta_device = generate_theta(16, 16)
out_buff = np.empty((1000, 2), dtype=np.int32)
output_device = device_ndarray(out_buff)

# We are just testing that this doesn't segfault
handle = DeviceResources(stream)
pairwise_distance(
input1_device, input1_device, output_device, "euclidean", handle=handle
)
handle = DeviceResources()
rmat(output_device, theta_device, 16, 16, 12345, handle=handle)
handle.sync()

with pytest.raises(ValueError):
Expand Down

0 comments on commit c4b8590

Please sign in to comment.