forked from tencent-quantum-lab/tensorcircuit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathomeinsum_contractor.py
155 lines (135 loc) · 5.23 KB
/
omeinsum_contractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import json
from typing import List, Set, Dict, Tuple
import tempfile
import warnings
import cotengra as ctg
# Prerequisites for running this example:
# Step 1: install julia, see https://julialang.org/download/,
# Please install julia >= 1.8.5, the 1.6.7 LTS version raises:
# `Error in python: free(): invalid pointer`
# Step 2: add julia path to the PATH env variable so that juliacall can find it
# Step 3: install juliacall via `pip install juliacall`, this example was tested with juliacall 0.9.9
# Step 4: install julia package `OMEinsum`, this example was tested with OMEinsum v0.7.2,
# see https://docs.julialang.org/en/v1/stdlib/Pkg/ for more details on julia's package manager
# Step 5: for julia multi-threading, set env variable PYTHON_JULIACALL_THREADS=<N|auto>.
# However, in order to use julia multi-threading in juliacall,
# we have to turn off julia GC at the risk of OOM.
# See see https://github.com/cjdoris/PythonCall.jl/issues/219 for more details.
from juliacall import Main as jl
jl.seval("using OMEinsum")
import tensorcircuit as tc
tc.set_backend("tensorflow")
class OMEinsumTreeSAOptimizer(object):
def __init__(
self,
sc_target: float = 20,
betas: Tuple[float, float, float] = (0.01, 0.01, 15),
ntrials: int = 10,
niters: int = 50,
sc_weight: float = 1.0,
rw_weight: float = 0.2,
):
self.sc_target = sc_target
self.betas = betas
self.ntrials = ntrials
self.niters = niters
self.sc_weight = sc_weight
self.rw_weight = rw_weight
def _contraction_tree_to_contraction_path(self, ei, queue, path, idx):
if ei["isleaf"]:
# OMEinsum provide 1-based index
# but in contraction path we want 0-based index
ei["tensorindex"] -= 1
return idx
assert len(ei["args"]) == 2, "must be a binary tree"
for child in ei["args"]:
idx = self._contraction_tree_to_contraction_path(child, queue, path, idx)
assert "tensorindex" in child
lhs_args = sorted(
[queue.index(child["tensorindex"]) for child in ei["args"]], reverse=True
)
for arg in lhs_args:
queue.pop(arg)
ei["tensorindex"] = idx
path.append(lhs_args)
queue.append(idx)
return idx + 1
def __call__(
self,
inputs: List[Set[str]],
output: Set[str],
size: Dict[str, int],
memory_limit=None,
) -> List[Tuple[int, int]]:
inputs_omeinsum = tuple(map(tuple, inputs))
output_omeinsum = tuple(output)
eincode = jl.OMEinsum.EinCode(inputs_omeinsum, output_omeinsum)
size_dict = jl.OMEinsum.uniformsize(eincode, 2)
for k, v in size.items():
size_dict[k] = v
algorithm = jl.OMEinsum.TreeSA(
sc_target=self.sc_target,
βs=jl.range(self.betas[0], step=self.betas[1], stop=self.betas[2]),
ntrials=self.ntrials,
niters=self.niters,
sc_weight=self.sc_weight,
rw_weight=self.rw_weight,
)
nthreads = jl.Threads.nthreads()
if nthreads > 1:
warnings.warn(
"Julia receives Threads.nthreads()={0}. "
"However, in order to use julia multi-threading in juliacall, "
"we have to turn off julia GC at the risk of OOM. "
"That means you may need a large memory machine. "
"Please see https://github.com/cjdoris/PythonCall.jl/issues/219 "
"for more details.".format(nthreads)
)
jl.GC.enable(False)
optcode = jl.OMEinsum.optimize_code(eincode, size_dict, algorithm)
if nthreads > 1:
jl.GC.enable(True)
# jl.println("time and space complexity computed by OMEinsum: ",
# jl.OMEinsum.timespace_complexity(optcode, size_dict))
fp = tempfile.NamedTemporaryFile(suffix=".json", delete=False)
fp.close()
jl.OMEinsum.writejson(fp.name, optcode)
with open(fp.name, "r") as f:
contraction_tree = json.load(f)
os.unlink(fp.name)
num_tensors = len(contraction_tree["inputs"])
assert num_tensors == len(
inputs
), "should have the same number of input tensors"
queue = list(range(num_tensors))
path = []
self._contraction_tree_to_contraction_path(
contraction_tree["tree"], queue, path, num_tensors
)
return path
# For more random circuits, please refer to
# https://datadryad.org/stash/dataset/doi:10.5061/dryad.k6t1rj8
c = tc.Circuit.from_qsim_file("circuit_n12_m14_s0_e0_pEFGH.qsim")
opt = ctg.ReusableHyperOptimizer(
methods=["greedy", "kahypar"],
parallel=True,
minimize="flops",
max_repeats=1024,
progbar=False,
)
print("cotengra contractor")
tc.set_contractor(
"custom", optimizer=opt, preprocessing=True, contraction_info=True, debug_level=2
)
c.expectation_ps(z=[0], reuse=False)
print("OMEinsum contractor")
opt_treesa = OMEinsumTreeSAOptimizer(sc_target=30, sc_weight=0.0, rw_weight=0.0)
tc.set_contractor(
"custom",
optimizer=opt_treesa,
preprocessing=True,
contraction_info=True,
debug_level=2,
)
c.expectation_ps(z=[0], reuse=False)