Skip to content

Commit

Permalink
Merge pull request #56 from JoeyT1994/tcifix
Browse files Browse the repository at this point in the history
Fix TCI Extension
  • Loading branch information
JoeyT1994 authored Dec 18, 2024
2 parents 0e9b061 + f3a9f86 commit e4919ec
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions ext/ITensorNumericalAnalysisTCIExt/tci_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ using ITensorNumericalAnalysis:
const_itn,
ITensorNetworkFunction,
itensornetwork,
base
base,
dimension

random_initial_pivot(s::IndsNetworkMap) = [v => rand(1:dim(v)) for v in vertices(s)]
random_initial_pivot(s::IndsNetworkMap) = [v => rand(1:dim(s[v])) for v in vertices(s)]

#f should be an ndimensional function that maps a vector of scalars of length ndimensional to a scalar
function ITensorTCI.interpolate(
Expand All @@ -37,11 +38,12 @@ function ITensorTCI.interpolate(
initial_pivot = random_initial_pivot(s_renamed)
else
# manually rename
initial_pivot = [forward_dict[v] => initial_pivot[v] for v in vertices(s)]
# assuming from calculate_ind_values
initial_pivot = [forward_dict[v] => initial_pivot[only(s[v])] + 1 for v in vertices(s)]
end

tn = ITensorTCI.interpolate(
input -> f(input_to_scalars(input; base=float(base(s)))),
input -> f(input_to_scalars(input; ndims=dimension(s), base=float(base(s)))),
ttn(itensornetwork(tn));
initial_pivot,
kwargs...,
Expand All @@ -53,8 +55,7 @@ function ITensorTCI.interpolate(
end

#Takes a vector of [(dimension, digit) => bit] and converts to vector of scalars
function input_to_scalars(input; base=2.0)
ndims = maximum(first.(input))
function input_to_scalars(input; ndims, base=2.0)
x = zeros(ndims)
for pair in input
(i, j) = pair[1]
Expand Down

0 comments on commit e4919ec

Please sign in to comment.