Skip to content

Instantly share code, notes, and snippets.

@vincentchu
Created January 23, 2018 18:04
Show Gist options
  • Save vincentchu/b8484a97d0f5c080c613463554757390 to your computer and use it in GitHub Desktop.
Save vincentchu/b8484a97d0f5c080c613463554757390 to your computer and use it in GitHub Desktop.
#!/usr/bin/python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import tensorflow as tf
import numpy as np
IMG_INPUT_TENSOR = "import/Preprocessor/sub:0"
ANCHOR_BOXES_TENSOR = "import/MultipleGridAnchorGenerator/Identity:0"
EXPECTED_ANCHORS = 1917
def code_gen(anchors):
print("/**")
print(" * Anchors")
print(" *")
print(" * SSD Anchor boxes for SSD/Mobilenet architectures")
print(" * num_layers = 6")
print(" * min_scale = 0.2")
print(" * max_scale = 0.95")
print(" * aspect_ratios = [1.0, 2.0, 0.5, 3.0, 0.3333]")
print(" *")
print(" * See: https://github.com/tensorflow/models/blob/master/research/object_detection/anchor_generators/multiple_grid_anchor_generator.py#L248")
print(" */")
print("struct Anchors {")
print(" static let numAnchors = %d" % (EXPECTED_ANCHORS))
print(" static var ssdAnchors: [[Float32]] {")
print(" var arr: [[Float32]] = Array(repeating: Array(repeating: 0.0, count: 4), count: numAnchors)")
print("")
for i in range(0, EXPECTED_ANCHORS):
box_str = ", ".join(["% .8f" % (pt) for pt in anchors[i, :]])
print(" arr[%d] = [ %s ]" % (i, box_str))
print("")
print(" return arr")
print(" }")
print("}")
pass
def main(_):
if len(sys.argv) != 2:
print("Must specify an input graph file! Usage: %s [Graph PB file]" % (sys.argv[0]))
graph_file = sys.argv[1]
with open(graph_file, 'rb') as f:
serialized = f.read()
tf.reset_default_graph()
original_gdef = tf.GraphDef()
original_gdef.ParseFromString(serialized)
graph = tf.import_graph_def(original_gdef)
img_placeholder = np.zeros((1, 300, 300, 3))
with tf.Session(graph=graph) as sess:
image_input_tensor = sess.graph.get_tensor_by_name(IMG_INPUT_TENSOR)
anchors_tensor = sess.graph.get_tensor_by_name(ANCHOR_BOXES_TENSOR)
anchors = sess.run(anchors_tensor, feed_dict = { image_input_tensor: img_placeholder })
assert(anchors.shape == (EXPECTED_ANCHORS, 4))
code_gen(anchors)
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment