Skip to content

Commit

Permalink
MOre updates
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Nov 8, 2024
1 parent 160e913 commit 2a28f50
Show file tree
Hide file tree
Showing 17 changed files with 82 additions and 1,252 deletions.
72 changes: 72 additions & 0 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_resources.hpp>
#include <raft/matrix/select_k_types.hpp>
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#include <cuda_fp16.h> // __half

#include <cstdint> // uint32_t

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

namespace raft::matrix::detail {

template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
const T* in_val,
const IdxT* in_idx,
size_t batch_size,
size_t len,
int k,
T* out_val,
IdxT* out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
const IdxT* len_i = nullptr) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
extern template void raft::matrix::detail::select_k(raft::resources const& handle, \
const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
bool sorted, \
raft::matrix::SelectAlgo algo, \
const IdxT* len_i)
instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
instantiate_raft_matrix_detail_select_k(float, uint32_t);
// needed for brute force knn
instantiate_raft_matrix_detail_select_k(float, int);
// We did not have these two for double before, but there are tests for them. We
// therefore include them here.
instantiate_raft_matrix_detail_select_k(double, int64_t);
instantiate_raft_matrix_detail_select_k(double, uint32_t);

#undef instantiate_raft_matrix_detail_select_k
10 changes: 4 additions & 6 deletions cpp/include/raft/matrix/detail/select_k.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,10 +15,8 @@
*/
#pragma once

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "select_k-inl.cuh"
#endif

#ifdef RAFT_COMPILED
#include "select_k-ext.cuh"
#endif
// #ifdef RAFT_COMPILED
// #include "select_k-ext.cuh"
// #endif
96 changes: 0 additions & 96 deletions cpp/include/raft_runtime/cluster/kmeans.hpp

This file was deleted.

62 changes: 0 additions & 62 deletions cpp/include/raft_runtime/distance/fused_distance_nn.hpp

This file was deleted.

67 changes: 0 additions & 67 deletions cpp/include/raft_runtime/distance/fused_l2_nn.hpp

This file was deleted.

50 changes: 0 additions & 50 deletions cpp/include/raft_runtime/distance/pairwise_distance.hpp

This file was deleted.

Loading

0 comments on commit 2a28f50

Please sign in to comment.