Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed May 20, 2024
1 parent a1a80a1 commit 3a9244c
Showing 1 changed file with 32 additions and 30 deletions.
62 changes: 32 additions & 30 deletions NDTensors/ext/NDTensorsCUDAExt/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,41 @@ using CUDA: CuArray
## In this function we convert the DiagTensor to a dense tensor and
## Feed it back into contract
function NDTensors.contract!(
output_tensor::Exposed{<:CuArray, <:DenseTensor},
output_tensor::Exposed{<:CuArray,<:DenseTensor},
labelsoutput_tensor,
tensor1::Exposed{<:Any,<:DiagTensor},
labelstensor1,
tensor2::Exposed{<:CuArray,<:DenseTensor},
labelstensor2,
α::Number=one(Bool),
β::Number=zero(Bool),
)
tensor1 = unexpose(tensor1)
## convert tensor1 to a dense
tensor1 = adapt(parenttype(typeof(tensor2)), dense(tensor1))
return contract!(
output_tensor,
labelsoutput_tensor,
tensor1::Exposed{<:Any, <:DiagTensor},
expose(tensor1),
labelstensor1,
tensor2::Exposed{<:CuArray, <:DenseTensor},
tensor2,
labelstensor2,
α::Number=one(Bool),
β::Number=zero(Bool),
α,
β,
)
tensor1 = unexpose(tensor1)
## convert tensor1 to a dense
tensor1 = adapt(parenttype(typeof(tensor2)), dense(tensor1))
return contract!(
output_tensor,
labelsoutput_tensor,
expose(tensor1),
labelstensor1,
tensor2,
labelstensor2,
α,
β,
)
end
end

function NDTensors.contract!(
output_tensor::Exposed{<:CuArray, <:DenseTensor},
labelsoutput_tensor,
tensor1::Exposed{<:CuArray, <:DenseTensor},
labelstensor1,
tensor2::Exposed{<:Any, <:DiagTensor},
labelstensor2,
α::Number=one(Bool),
β::Number=zero(Bool),
function NDTensors.contract!(
output_tensor::Exposed{<:CuArray,<:DenseTensor},
labelsoutput_tensor,
tensor1::Exposed{<:CuArray,<:DenseTensor},
labelstensor1,
tensor2::Exposed{<:Any,<:DiagTensor},
labelstensor2,
α::Number=one(Bool),
β::Number=zero(Bool),
)
return contract!(
output_tensor, labelsoutput_tensor, tensor2, labelstensor2, tensor1, labelstensor1, α, β
)
contract!(output_tensor, labelsoutput_tensor, tensor2, labelstensor2, tensor1, labelstensor1, α, β)
end
end

0 comments on commit 3a9244c

Please sign in to comment.