forked from yaroslavvb/stuff
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdanjar_peek.py
74 lines (57 loc) · 2.13 KB
/
danjar_peek.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import tensorflow as tf
from tensorflow.python.client import timeline
class Queue(tf.FIFOQueue):
def __init__(self, capacity):
s = ()
d = tf.int32
super().__init__(capacity - 1, [d], [s])
self._first = tf.get_variable(name="var1",
initializer=tf.ones_initializer(),
shape=s, dtype=d, use_resource=False)
self._size = tf.get_variable(name="size", shape=(),
initializer=tf.zeros_initializer(),
dtype=tf.int32, use_resource=False)
def peek(self):
return self._first.read_value()
def enqueue(self, element):
super_ = super()
def first():
assigns = [self._first.assign(element)]
with tf.control_dependencies(assigns):
return tf.constant(0)
def other():
with tf.control_dependencies([super_.enqueue(element)]):
return tf.constant(0)
with tf.control_dependencies([self._size.assign_add(1)]):
dummy = tf.cond(tf.equal(self._size, 0), first, other)
return tf.identity(dummy)
queue = Queue(10)
queue_peek = queue.peek()
print("Peek op is "+str(queue_peek))
queue_init = queue.enqueue(tf.constant(-2))
print(tf.get_default_graph().as_graph_def())
for i in range(20):
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(queue_init)
print("queue size", sess.run(queue.size()))
sess.run(queue.close())
# print("Printing queue")
# while True:
# print(sess.run(queue.dequeue()))
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_options.output_partition_graphs = True
run_metadata = tf.RunMetadata()
#import pdb; pdb.set_trace()
# queue_peek,
result = sess.run(queue_peek, run_metadata=run_metadata,
options=run_options)
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('timeline-%d.json'%(i,), 'w') as f:
f.write(ctf)
with open('stepstats-%d.json'%(i,), 'w') as f:
f.write(str(run_metadata))
print(result, end=' ')
# Expected: 1 1 1 1 1 1 1 1 1 1
# Actual: 0 1 0 0 1 1 0 0 0 1