From 7ddf4472bedc807d4de0afd1ec41a677d11bb594 Mon Sep 17 00:00:00 2001 From: Chris Nuernberger Date: Sun, 10 Mar 2024 08:17:18 -0600 Subject: [PATCH] Fixes for #94 --- src/tech/v3/tensor_api.clj | 13 ++++++++----- test/tech/v3/tensor/integration_test.clj | 9 +++++++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/tech/v3/tensor_api.clj b/src/tech/v3/tensor_api.clj index e1c54630..ba1817a2 100644 --- a/src/tech/v3/tensor_api.clj +++ b/src/tech/v3/tensor_api.clj @@ -146,7 +146,6 @@ (apply dims/select (.dimensions t) select-args) buffer (or (.buffer t) (.bufferIO t)) buf-offset (long buf-offset) - ;; _ (println "buflen" buf-len) new-buffer (if (and buf-len (== 0 (long buf-len))) buffer (if-not (and (== buf-offset 0) @@ -215,10 +214,14 @@ (reify ObjectReader (lsize [rdr] n-offsets) (readObject [rdr idx] - (construct-tensor (dtype-base/sub-buffer - tens-buf - (.readLong offsets idx) - buf-ecount) + (construct-tensor (if buf-ecount + (dtype-base/sub-buffer + tens-buf + (.readLong offsets idx) + buf-ecount) + (dtype-base/sub-buffer + tens-buf + (.readLong offsets idx))) dimensions))))))) (mget [t idx-seq] (.ndReadObjectIter t idx-seq)) diff --git a/test/tech/v3/tensor/integration_test.clj b/test/tech/v3/tensor/integration_test.clj index bc7867f1..eb4feb98 100644 --- a/test/tech/v3/tensor/integration_test.clj +++ b/test/tech/v3/tensor/integration_test.clj @@ -289,3 +289,12 @@ (is (= [0 10] (dtype/shape (dtt/select a 2 :lla)))) (is (= [0] (dtype/shape (dtt/select a 2 :all 4)))) (is (= [2 0 3] (dtype/shape (dtt/select a (range 2 4) :all (range 6 9))))))) + + +(deftest issue-94 + (let [n-neurons 10 + weights (dtt/clone (dtt/compute-tensor [n-neurons n-neurons] (fn [_ _] (< (double (rand)) 0.1)) :boolean)) + activations (dtt/->tensor (range n-neurons)) + ww (dtt/select weights (range n-neurons)) + answer (dtt/reduce-axis (dtt/select weights (dtt/->tensor (range n-neurons))) dfn/sum 0 :float32)] + (is (not (nil? (.toString ^Object answer))))))