Skip to content

Commit

Permalink
tensordot and reshape bug (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
christopherzimmerman authored Jul 19, 2020
1 parent b734c00 commit 5a7632a
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 45 deletions.
94 changes: 94 additions & 0 deletions src/tensor/linalg.cr
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,100 @@ class Tensor(T)
dest
end

# Compute tensor dot product along specified axes.
#
# Given two tensors, a and b, and an array_like object containing two
# array_like objects, (a_axes, b_axes), sum the products of a’s and b’s
# elements (components) over the axes specified by a_axes and b_axes.
# The third argument can be a single non-negative integer_like scalar,
# N; if it is such, then the last N dimensions of a and the first N
# dimensions of b are summed over.
#
# Arguments
# ---------
# *b* : Tensor
# Right hand side of dot products
# *axes* : Array(Array(Int)) | Array(Int) | Int
# Axes of summation
#
# Examples
# --------
# ```
# a = Tensor.range(60.0).reshape(3, 4, 5)
# b = Tensor.range(24.0).reshape(4, 3, 2)
# puts a.tensordot(b, axes: [[1, 0], [0, 1]])
#
# # [[4400, 4730],
# # [4532, 4874],
# # [4664, 5018],
# # [4796, 5162],
# # [4928, 5306]]
# ```
def tensordot(b : Tensor(T), axes : Array(Array(Int)))
axes_a, axes_b = axes
na = axes_a.size
nb = axes_b.size
as_ = self.shape
nda = self.rank
bs = b.shape
ndb = b.rank
equal = na == nb
na.times do |k|
if as_[axes_a[k]] != bs[axes_b[k]]
equal = false
break
end
if axes_a[k] < 0
axes_a[k] += nda
end
if axes_b[k] < 0
axes_b[k] += ndb
end
end
unless equal
raise Num::Internal::ShapeError.new("Shape mismatch for sum")
end
notin = (0...nda).select do |k|
!axes_a.includes?(k)
end
newaxes_a = notin + axes_a
n2 = 1
axes_a.each do |axis|
n2 *= as_[axis]
end
newshape_a = [(notin.map { |ax| as_[ax] }).product, n2]
olda = notin.map { |ax| as_[ax] }

notin = (0...ndb).select do |k|
!axes_b.includes?(k)
end
newaxes_b = axes_b + notin
n2 = 1
axes_b.each do |axis|
n2 *= bs[axis]
end
newshape_b = [n2, (notin.map { |ax| bs[ax] }).product]
oldb = notin.map { |ax| bs[ax] }

at = self.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = at.matmul(bt)
res.reshape(olda + oldb)
end

# :ditto:
def tensordot(b : Tensor(T), axes : Int)
axes_a = (-axes...0).to_a
axes_b = (0...axes).to_a
self.tensordot(b, [axes_a, axes_b])
end

# :ditto:
def tensordot(b : Tensor(T), axes : Array(Int))
axes_a, axes_b = axes
self.tensordot(b, [[axes_a], [axes_b]])
end

# :nodoc:
def is_matrix
unless self.rank == 2
Expand Down
67 changes: 22 additions & 45 deletions src/tensor/tensor.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1217,64 +1217,41 @@ class Tensor(T)
# # [3, 4]]
# ```
def reshape(new_shape : Array(Int))
result_shape = new_shape.map &.to_i

if result_shape == @shape
newshape = new_shape.map &.to_i
if newshape == shape
return self.view
end

n = 1
c = @size
auto = -1

result_shape.each_with_index do |v, i|
if v < 0
if auto >= 0
raise Num::Internal::ValueError.new(
"Only a single dimension can be inferred"
)
newsize = 1
cur_size = size
autosize = -1
newshape.each_with_index do |val, i|
if val < 0
if autosize >= 0
raise Num::Internal::ValueError.new("Only shape dimension can be automatic")
end
auto = i
autosize = i
else
n *= v
newsize *= val
end
end

if auto >= 0
result_shape = result_shape.dup
result_shape[auto] = c // n
n *= result_shape[auto]
if autosize >= 0
newshape = newshape.dup
newshape[autosize] = cur_size // newsize
newsize *= newshape[autosize]
end

if n != c
raise Num::Internal::ShapeError.new(
"Shape #{@shape} cannot be reshaped to #{result_shape}"
)
if newsize != cur_size
raise Num::Internal::ShapeError.new "Shapes #{shape} cannot be reshaped to #{newshape}"
end

newstrides = Num::Internal.shape_to_strides(newshape, Num::RowMajor)

if @flags.contiguous?
new_strides = Num::Internal.shape_to_strides(
result_shape,
Num::RowMajor
)
t = Tensor(T).new(@buffer, result_shape, new_strides)
t.flags &= ~Num::ArrayFlags::OwnData
t
elsif @flags.fortran?
new_strides = Num::Internal.shape_to_strides(
result_shape,
Num::ColMajor
)
t = Tensor(T).new(@buffer, result_shape, new_strides)
t.flags &= ~Num::ArrayFlags::OwnData
t
self.class.new(@buffer, newshape, newstrides)
else
t = dup(Num::ColMajor)
new_strides = Num::Internal.shape_to_strides(
result_shape,
Num::ColMajor
)
t
tmp = self.dup(Num::RowMajor)
self.class.new(tmp.to_unsafe, newshape, newstrides)
end
end

Expand Down

0 comments on commit 5a7632a

Please sign in to comment.