From 33487be85af451943b28cfc22ac1d6d423e9eba0 Mon Sep 17 00:00:00 2001 From: Chris Nuernberger Date: Wed, 11 Sep 2024 11:38:01 -0600 Subject: [PATCH] Support for neanderthal 0.49 --- deps.edn | 1 + src/tech/v3/datatype/native_buffer.clj | 27 ++++++++++------ src/tech/v3/libs/neanderthal.clj | 41 ++++++++++++++++++++----- src/tech/v3/libs/neanderthal_pre_48.clj | 28 ----------------- 4 files changed, 53 insertions(+), 44 deletions(-) diff --git a/deps.edn b/deps.edn index 60b9bd0e..9249f969 100644 --- a/deps.edn +++ b/deps.edn @@ -74,6 +74,7 @@ {:extra-deps {org.clojure/clojure {:mvn/version "1.11.1"} criterium/criterium {:mvn/version "0.4.5"} net.java.dev.jna/jna {:mvn/version "5.12.1"} + ;; uncomplicate/neanderthal {:mvn/version "0.45.0"} uncomplicate/neanderthal {:mvn/version "0.49.1"} org.bytedeco/mkl {:mvn/version "2024.0-1.5.10"} com.taoensso/nippy {:mvn/version "3.2.0"} diff --git a/src/tech/v3/datatype/native_buffer.clj b/src/tech/v3/datatype/native_buffer.clj index dbbf8121..ecedc1f4 100644 --- a/src/tech/v3/datatype/native_buffer.clj +++ b/src/tech/v3/datatype/native_buffer.clj @@ -720,15 +720,24 @@ (defn set-native-datatype "Set the datatype of a native buffer. n-elems will be recalculated." ^NativeBuffer [item datatype] - (let [nb (as-native-buffer item) - original-size (.n-elems nb) - n-bytes (* original-size (casting/numeric-byte-width - (dtype-proto/elemwise-datatype item))) - new-byte-width (casting/numeric-byte-width - (casting/un-alias-datatype datatype))] - (NativeBuffer. (.address nb) (quot n-bytes new-byte-width) - datatype (.endianness nb) - (.resource-type nb) (meta nb) nil item))) + (if (= datatype (dtype-proto/elemwise-datatype item)) + item + (let [nb (as-native-buffer item) + original-size (.n-elems nb) + n-bytes (* original-size (casting/numeric-byte-width + (dtype-proto/elemwise-datatype item))) + new-byte-width (casting/numeric-byte-width + (casting/un-alias-datatype datatype))] + (NativeBuffer. (.address nb) (quot n-bytes new-byte-width) + datatype (.endianness nb) + (.resource-type nb) (meta nb) nil item)))) + + +(defn set-gc-obj + ^NativeBuffer [^NativeBuffer nb gc-obj] + (NativeBuffer. (.address nb) (.n-elems nb) + (.elemwise-datatype nb) (.endianness nb) + (.resource-type nb) (meta nb) nil gc-obj)) (defn set-parent diff --git a/src/tech/v3/libs/neanderthal.clj b/src/tech/v3/libs/neanderthal.clj index 27924c3d..05528618 100644 --- a/src/tech/v3/libs/neanderthal.clj +++ b/src/tech/v3/libs/neanderthal.clj @@ -1,16 +1,43 @@ -(ns tech.v3.neanderthal) +(ns tech.v3.neanderthal + (:require [uncomplicate.neanderthal.native :as n-native] + [tech.v3.datatype.errors :as errors] + [tech.v3.tensor :as dtt] + [tech.v3.datatype.copy-make-container :as dt-cmc] + [tech.v3.datatype.base :as dtype-base] + [clojure.tools.logging :as log])) -(def ^:private impl - {:tensor->matrix (requiring-resolve 'tech.v3.libs.neanderthal-pre-48/tensor->matrix) - :dtype->native-factory (requiring-resolve 'tech.v3.libs.neanderthal-pre-48/datatype->native-factory)}) +(try + ;;test for uncomplicate javacpp support + (require '[uncomplicate.clojure-cpp]) + (require '[tech.v3.libs.neanderthal-post-48]) + (catch Exception e + (require '[tech.v3.libs.neanderthal-pre-48]))) (defn tensor->matrix - ([tens] ((get impl :tensor->matrix) tens)) ([tens layout datatype] - ((get impl :tensor->matrix) tens layout datatype))) + (let [tshape (dtype-base/shape tens) + _ (errors/when-not-errorf (== 2 (count tshape)) + "Only 2D tensors can transform to neanderthal matrix") + [n-rows n-cols] tshape + layout (or layout :column) + nmat (case (or datatype (dtype-base/elemwise-datatype tens)) + :float64 + (n-native/dge n-rows n-cols {:layout layout}) + :float32 + (n-native/fge n-rows n-cols {:layout layout})) + ntens (dtt/as-tensor nmat)] + (dt-cmc/copy! tens ntens) + nmat)) + ([tens] + (tensor->matrix tens nil nil))) + (defn datatype->native-factory [dtype] - ((get impl :datatype->native-factory) dtype)) + (case dtype + :float64 + (n-native/factory-by-type Double/TYPE) + :float32 + (n-native/factory-by-type Float/TYPE))) diff --git a/src/tech/v3/libs/neanderthal_pre_48.clj b/src/tech/v3/libs/neanderthal_pre_48.clj index 302a56d1..43acce43 100644 --- a/src/tech/v3/libs/neanderthal_pre_48.clj +++ b/src/tech/v3/libs/neanderthal_pre_48.clj @@ -92,31 +92,3 @@ item)))] (-> (dtype-base/sub-buffer ptr-val item-offset) (resource/chain-resources item))))) - - -(defn tensor->matrix - ([tens layout datatype] - (let [tshape (dtype-base/shape tens) - _ (errors/when-not-errorf (== 2 (count tshape)) - "Only 2D tensors can transform to neanderthal matrix") - [n-rows n-cols] tshape - layout (or layout :column) - nmat (case (or datatype (dtype-base/elemwise-datatype tens)) - :float64 - (n-native/dge n-rows n-cols {:layout layout}) - :float32 - (n-native/fge n-rows n-cols {:layout layout})) - ntens (dtt/as-tensor nmat)] - (dt-cmc/copy! tens ntens) - nmat)) - ([tens] - (tensor->matrix tens nil nil))) - - -(defn datatype->native-factory - [dtype] - (case dtype - :float64 - (n-native/factory-by-type Double/TYPE) - :float32 - (n-native/factory-by-type Float/TYPE)))