forked from gcross/numpy-contractor-utils
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
237 lines (213 loc) · 10.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
#@+leo-ver=4-thin
#@+node:gcross.20100923134429.1858:@thin utils.py
#@@language Python
#@<< Import needed modules >>
#@+node:gcross.20100923134429.1899:<< Import needed modules >>
import __builtin__
from numpy import tensordot, multiply
from numpy.random import rand
#@-node:gcross.20100923134429.1899:<< Import needed modules >>
#@nl
#@+others
#@+node:gcross.20100923134429.1900:Functions
#@+node:gcross.20100923134429.1869:n2l
#@+at
# Utility function converting numbers to letters.
#@-at
#@@c
n2l = map(chr,range(ord('A'),ord('Z')+1))
#@-node:gcross.20100923134429.1869:n2l
#@+node:gcross.20100923134429.1870:make_contractor
def make_contractor(tensor_index_labels,index_join_pairs,result_index_labels,name="f"): # pre-process parameters
tensor_index_labels = list(map(list,tensor_index_labels))
index_join_pairs = list(index_join_pairs)
result_index_labels = list([list(index_group) if hasattr(index_group,"__getitem__") else [index_group] for index_group in result_index_labels])
assert sum(len(index_group) for index_group in tensor_index_labels) == (sum(len(index_group) for index_group in result_index_labels)+2*len(index_join_pairs))
function_definition_statements = ["def %s(%s):" % (name,",".join(n2l[:len(tensor_index_labels)]))]
#@ << def build_statements >>
#@+node:gcross.20100923134429.1871:<< def build_statements >>
def build_statements(tensor_index_labels,index_join_pairs,result_index_labels):
#@+at
# This routine recursively builds a list of statements which performs the
# full tensor contraction.
#
# First, if there is only one tensor left, then transpose and reshape it
# to match the result_index_labels.
#@-at
#@@c
if len(tensor_index_labels) == 1:
if len(result_index_labels) == 0:
return ["return A"]
else:
final_index_labels = tensor_index_labels[0]
result_indices = [[final_index_labels.index(index) for index in index_group] for index_group in result_index_labels]
transposed_indices = __builtin__.sum(result_indices,[])
assert type(transposed_indices) == list
assert len(final_index_labels) == len(transposed_indices)
new_shape = ",".join(["(%s)" % "*".join(["shape[%i]"%index for index in index_group]) for index_group in result_indices])
return ["shape=A.shape","return A.transpose(%s).reshape(%s)" % (transposed_indices,new_shape)]
#@+at
# Second, if all joins have finished, then take outer products to combine
# all remaining tensors into one.
#@-at
#@@c
elif len(index_join_pairs) == 0:
if tensor_index_labels[-1] is None:
return build_statements(tensor_index_labels[:-1],index_join_pairs,result_index_labels)
elif len(tensor_index_labels[-1]) == 0:
v = n2l[len(tensor_index_labels)-1]
return ["A*=%s" % v, "del %s" % v] + build_statements(tensor_index_labels[:-1],index_join_pairs,result_index_labels)
else:
v = n2l[len(tensor_index_labels)-1]
tensor_index_labels[0] += tensor_index_labels[-1]
return ["A = multiply.outer(A,%s)" % v, "del %s" % v] + build_statements(tensor_index_labels[:-1],index_join_pairs,result_index_labels)
#@+at
# Otherwise, do the first join, walking through index_join_pairs to find
# any other pairs which connect the same two tensors.
#@-at
#@@c
else:
#@ << Search for all joins between these tensors >>
#@+node:gcross.20100923134429.1872:<< Search for all joins between these tensors >>
#@+at
# This function searches for the tensors which are joined, and
# reorders the indices in the join so that the index corresponding
# to the tensor appearing first in the list of tensors appears
# first in the join.
#@-at
#@@c
def find_tensor_ids(join):
reordered_join = [None,None]
tensor_ids = [0,0]
join = list(join)
while tensor_ids[0] < len(tensor_index_labels):
index_labels = tensor_index_labels[tensor_ids[0]]
if index_labels is None:
tensor_ids[0] += 1
elif join[0] in index_labels:
reordered_join[0] = index_labels.index(join[0])
del join[0]
break
elif join[1] in index_labels:
reordered_join[0] = index_labels.index(join[1])
del join[1]
break
else:
tensor_ids[0] += 1
assert len(join) == 1 # otherwise index was not found in any tensor
tensor_ids[1] = tensor_ids[0] + 1
while tensor_ids[1] < len(tensor_index_labels):
index_labels = tensor_index_labels[tensor_ids[1]]
if index_labels is None:
tensor_ids[1] += 1
elif join[0] in index_labels:
reordered_join[reordered_join.index(None)] = index_labels.index(join[0])
del join[0]
break
else:
tensor_ids[1] += 1
assert len(join) == 0 # otherwise index was not found in any tensor
return tensor_ids, reordered_join
join_indices = [0]
tensor_ids,reordered_join = find_tensor_ids(index_join_pairs[0])
indices = [[],[]]
for j in xrange(2):
indices[j].append(reordered_join[j])
# Search for other joins between these tensors
for i in xrange(1,len(index_join_pairs)):
tensor_ids_,reordered_join = find_tensor_ids(index_join_pairs[i])
if tensor_ids == tensor_ids_:
join_indices.append(i)
for j in xrange(2):
indices[j].append(reordered_join[j])
#@-node:gcross.20100923134429.1872:<< Search for all joins between these tensors >>
#@nl
#@ << Build tensor contraction statements >>
#@+node:gcross.20100923134429.1873:<< Build tensor contraction statements >>
tensor_vars = [n2l[id] for id in tensor_ids]
statements = [
"try:",
" %s = tensordot(%s,%s,%s)" % (tensor_vars[0],tensor_vars[0],tensor_vars[1],indices),
" del %s" % tensor_vars[1],
"except ValueError:",
" raise ValueError('indices %%s do not match for tensor %%i, shape %%s, and tensor %%i, shape %%s.' %% (%s,%i,%s.shape,%i,%s.shape))" % (indices,tensor_ids[0],tensor_vars[0],tensor_ids[1],tensor_vars[1])
]
#@-node:gcross.20100923134429.1873:<< Build tensor contraction statements >>
#@nl
#@ << Delete joins from list and update tensor specifications >>
#@+node:gcross.20100923134429.1874:<< Delete joins from list and update tensor specifications >>
join_indices.reverse()
for join_index in join_indices:
del index_join_pairs[join_index]
new_tensor_index_labels_0 = list(tensor_index_labels[tensor_ids[0]])
indices[0].sort(reverse=True)
for index in indices[0]:
del new_tensor_index_labels_0[index]
new_tensor_index_labels_1 = list(tensor_index_labels[tensor_ids[1]])
indices[1].sort(reverse=True)
for index in indices[1]:
del new_tensor_index_labels_1[index]
tensor_index_labels[tensor_ids[0]] = new_tensor_index_labels_0+new_tensor_index_labels_1
tensor_index_labels[tensor_ids[1]] = None
#@-node:gcross.20100923134429.1874:<< Delete joins from list and update tensor specifications >>
#@nl
return statements + build_statements(tensor_index_labels,index_join_pairs,result_index_labels)
#@-node:gcross.20100923134429.1871:<< def build_statements >>
#@nl
function_definition_statements += ["\t" + statement for statement in build_statements(tensor_index_labels,index_join_pairs,result_index_labels)]
function_definition = "\n".join(function_definition_statements)+"\n"
f_globals = {"tensordot":tensordot,"multiply":multiply}
f_locals = {}
exec function_definition in f_globals, f_locals
f = f_locals[name]
f.source = function_definition
return f
#@nonl
#@-node:gcross.20100923134429.1870:make_contractor
#@+node:gcross.20100923134429.1875:make_contractor_from_implicit_joins
def make_contractor_from_implicit_joins(tensor_index_labels,result_index_labels,name="f"):
tensor_index_labels = list(map(list,tensor_index_labels))
found_indices = {}
index_join_pairs = []
for i in xrange(len(tensor_index_labels)):
for index_position, index in enumerate(tensor_index_labels[i]):
if index in found_indices:
other_tensor = found_indices[index]
if other_tensor is None:
raise ValueError("index label %s found in more than two tensors" % index)
else:
# rename this instance of the index and add to the list of join pairs
tensor_index_labels[i][index_position] = (i,index)
index_join_pairs.append((index,(i,index)))
# mark that we have found two instances of this index for
# error-checking purposes
found_indices[index] = None
else:
found_indices[index] = i
return make_contractor(tensor_index_labels,index_join_pairs,result_index_labels,name)
#@nonl
#@-node:gcross.20100923134429.1875:make_contractor_from_implicit_joins
#@+node:gcross.20100923134429.1877:crand
def crand(*shape):
return rand(*shape)*2-1+rand(*shape)*2j-1j
#@-node:gcross.20100923134429.1877:crand
#@+node:gcross.20100923134429.1879:form_contractor
def form_contractor(input_tensors,edges,output_tensor):
e = {}
for index, (v1, v2) in enumerate(edges):
if v1 in e:
raise ValueError("vertex {0} appears twice".format(v1))
if v2 in e:
raise ValueError("vertex {0} appears twice".format(v2))
e[v1] = index
e[v2] = index
output_name, output_size = output_tensor
return make_contractor_from_implicit_joins(
[[e[input_name+str(index)] for index in xrange(input_size)] for (input_name,input_size) in input_tensors],
[e[output_name+str(index)] for index in xrange(output_size)],
)
#@-node:gcross.20100923134429.1879:form_contractor
#@-node:gcross.20100923134429.1900:Functions
#@-others
#@-node:gcross.20100923134429.1858:@thin utils.py
#@-leo