Skip to content

Commit

Permalink
Support for neanderthal 0.49
Browse files Browse the repository at this point in the history
  • Loading branch information
cnuernber committed Sep 11, 2024
1 parent 36a8cb7 commit 33487be
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 44 deletions.
1 change: 1 addition & 0 deletions deps.edn
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
{:extra-deps {org.clojure/clojure {:mvn/version "1.11.1"}
criterium/criterium {:mvn/version "0.4.5"}
net.java.dev.jna/jna {:mvn/version "5.12.1"}
;; uncomplicate/neanderthal {:mvn/version "0.45.0"}
uncomplicate/neanderthal {:mvn/version "0.49.1"}
org.bytedeco/mkl {:mvn/version "2024.0-1.5.10"}
com.taoensso/nippy {:mvn/version "3.2.0"}
Expand Down
27 changes: 18 additions & 9 deletions src/tech/v3/datatype/native_buffer.clj
Original file line number Diff line number Diff line change
Expand Up @@ -720,15 +720,24 @@
(defn set-native-datatype
"Set the datatype of a native buffer. n-elems will be recalculated."
^NativeBuffer [item datatype]
(let [nb (as-native-buffer item)
original-size (.n-elems nb)
n-bytes (* original-size (casting/numeric-byte-width
(dtype-proto/elemwise-datatype item)))
new-byte-width (casting/numeric-byte-width
(casting/un-alias-datatype datatype))]
(NativeBuffer. (.address nb) (quot n-bytes new-byte-width)
datatype (.endianness nb)
(.resource-type nb) (meta nb) nil item)))
(if (= datatype (dtype-proto/elemwise-datatype item))
item
(let [nb (as-native-buffer item)
original-size (.n-elems nb)
n-bytes (* original-size (casting/numeric-byte-width
(dtype-proto/elemwise-datatype item)))
new-byte-width (casting/numeric-byte-width
(casting/un-alias-datatype datatype))]
(NativeBuffer. (.address nb) (quot n-bytes new-byte-width)
datatype (.endianness nb)
(.resource-type nb) (meta nb) nil item))))


(defn set-gc-obj
^NativeBuffer [^NativeBuffer nb gc-obj]
(NativeBuffer. (.address nb) (.n-elems nb)
(.elemwise-datatype nb) (.endianness nb)
(.resource-type nb) (meta nb) nil gc-obj))


(defn set-parent
Expand Down
41 changes: 34 additions & 7 deletions src/tech/v3/libs/neanderthal.clj
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
(ns tech.v3.neanderthal)
(ns tech.v3.neanderthal
(:require [uncomplicate.neanderthal.native :as n-native]
[tech.v3.datatype.errors :as errors]
[tech.v3.tensor :as dtt]
[tech.v3.datatype.copy-make-container :as dt-cmc]
[tech.v3.datatype.base :as dtype-base]
[clojure.tools.logging :as log]))


(def ^:private impl
{:tensor->matrix (requiring-resolve 'tech.v3.libs.neanderthal-pre-48/tensor->matrix)
:dtype->native-factory (requiring-resolve 'tech.v3.libs.neanderthal-pre-48/datatype->native-factory)})
(try
;;test for uncomplicate javacpp support
(require '[uncomplicate.clojure-cpp])
(require '[tech.v3.libs.neanderthal-post-48])
(catch Exception e
(require '[tech.v3.libs.neanderthal-pre-48])))


(defn tensor->matrix
([tens] ((get impl :tensor->matrix) tens))
([tens layout datatype]
((get impl :tensor->matrix) tens layout datatype)))
(let [tshape (dtype-base/shape tens)
_ (errors/when-not-errorf (== 2 (count tshape))
"Only 2D tensors can transform to neanderthal matrix")
[n-rows n-cols] tshape
layout (or layout :column)
nmat (case (or datatype (dtype-base/elemwise-datatype tens))
:float64
(n-native/dge n-rows n-cols {:layout layout})
:float32
(n-native/fge n-rows n-cols {:layout layout}))
ntens (dtt/as-tensor nmat)]
(dt-cmc/copy! tens ntens)
nmat))
([tens]
(tensor->matrix tens nil nil)))


(defn datatype->native-factory
[dtype]
((get impl :datatype->native-factory) dtype))
(case dtype
:float64
(n-native/factory-by-type Double/TYPE)
:float32
(n-native/factory-by-type Float/TYPE)))
28 changes: 0 additions & 28 deletions src/tech/v3/libs/neanderthal_pre_48.clj
Original file line number Diff line number Diff line change
Expand Up @@ -92,31 +92,3 @@
item)))]
(-> (dtype-base/sub-buffer ptr-val item-offset)
(resource/chain-resources item)))))


(defn tensor->matrix
([tens layout datatype]
(let [tshape (dtype-base/shape tens)
_ (errors/when-not-errorf (== 2 (count tshape))
"Only 2D tensors can transform to neanderthal matrix")
[n-rows n-cols] tshape
layout (or layout :column)
nmat (case (or datatype (dtype-base/elemwise-datatype tens))
:float64
(n-native/dge n-rows n-cols {:layout layout})
:float32
(n-native/fge n-rows n-cols {:layout layout}))
ntens (dtt/as-tensor nmat)]
(dt-cmc/copy! tens ntens)
nmat))
([tens]
(tensor->matrix tens nil nil)))


(defn datatype->native-factory
[dtype]
(case dtype
:float64
(n-native/factory-by-type Double/TYPE)
:float32
(n-native/factory-by-type Float/TYPE)))

0 comments on commit 33487be

Please sign in to comment.