Skip to content

Commit

Permalink
Rename
Browse files Browse the repository at this point in the history
  • Loading branch information
infrub committed Jun 28, 2019
1 parent d4bd1fe commit f72ed97
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions netcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import itertools


class HistTensorFrame:
class TensorFrame:
"""Tensor class for netcon.
Attributes:
Expand All @@ -32,20 +32,20 @@ def __init__(self,rpn=[],bits=0,bonds=[],cost=0.0,is_new=True):
self.is_new = is_new

def __repr__(self):
return "HistTensorFrame({0}, bonds={1}, cost={2:.6e}, bits={3}, is_new={4})".format(
return "TensorFrame({0}, bonds={1}, cost={2:.6e}, bits={3}, is_new={4})".format(
self.rpn, self.bonds, self.cost, self.bits, self.is_new)

def __str__(self):
return "{0} : bonds={1} cost={2:.6e} bits={3} new={4}".format(
self.rpn, self.bonds, self.cost, self.bits, self.is_new)


class NetconClass:
class NetconOptimizer:
def __init__(self, prime_tensors, bond_dims):
self.prime_tensors = prime_tensors
self.BOND_DIMS = bond_dims[:]

def calc(self):
def optimize(self):
"""Find optimal contraction sequence.
Args:
Expand Down Expand Up @@ -106,7 +106,7 @@ def init_tensordict_of_size(self):
if i>=0: bits += (1<<i)
bonds = frozenset(t.bonds)
cost = 0.0
tensordict_of_size[1].update({bits:HistTensorFrame(rpn,bits,bonds,cost)})
tensordict_of_size[1].update({bits:TensorFrame(rpn,bits,bonds,cost)})
return tensordict_of_size


Expand All @@ -126,7 +126,7 @@ def contract(self,t1,t2):
bits = t1.bits ^ t2.bits # XOR
bonds = frozenset(t1.bonds ^ t2.bonds)
cost = self.get_contracting_cost(t1,t2)
return HistTensorFrame(rpn,bits,bonds,cost)
return TensorFrame(rpn,bits,bonds,cost)


def are_direct_product(self,t1,t2):
Expand Down
2 changes: 1 addition & 1 deletion tdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def parse_args():
logging.basicConfig(format="%(levelname)s:%(message)s", level=config.LOGGING_LEVEL)

tn.output_log("input")
rpn, cpu = netcon.NetconClass(tn.tensors, BOND_DIMS).calc()
rpn, cpu = netcon.NetconOptimizer(tn.tensors, BOND_DIMS).optimize()
mem = get_memory(tn, rpn)

TENSOR_MATH_NAMES = TENSOR_NAMES[:]
Expand Down

0 comments on commit f72ed97

Please sign in to comment.