-
-
Save arafatkatze/c063bddb9b8d17a037695d748db4f592 to your computer and use it in GitHub Desktop.
# This file is useful for reading the contents of the ops generated by ruby. | |
# You can read any graph defination in pb/pbtxt format generated by ruby | |
# or by python and then convert it back and forth from human readable to binary format. | |
import tensorflow as tf | |
from google.protobuf import text_format | |
from tensorflow.python.platform import gfile | |
def pbtxt_to_graphdef(filename): | |
with open(filename, 'r') as f: | |
graph_def = tf.GraphDef() | |
file_content = f.read() | |
text_format.Merge(file_content, graph_def) | |
tf.import_graph_def(graph_def, name='') | |
tf.train.write_graph(graph_def, 'pbtxt/', 'protobuf.pb', as_text=False) | |
def graphdef_to_pbtxt(filename): | |
with gfile.FastGFile(filename,'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
tf.import_graph_def(graph_def, name='') | |
tf.train.write_graph(graph_def, 'pbtxt/', 'protobuf.pbtxt', as_text=True) | |
return | |
graphdef_to_pbtxt('graph.pb') # here you can write the name of the file to be converted | |
# and then a new file will be made in pbtxt directory. |
For pb
to pbtxt
it's not necessary to actually import the graph_def in tensorflow (which is what causes errors in TF2).
Simply converting the graph_def
to string will return the text representation of the protobuf message:
def graphdef_to_pbtxt(filename):
with open(filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with open('protobuf.txt', 'w') as fp:
fp.write(str(graph_def))
Also, using GFiles may be faster, but makes it more difficult to keep track of where in the TF API they are defined, so normal python files work just fine in my case.
For
pb
topbtxt
it's not necessary to actually import the graph_def in tensorflow (which is what causes errors in TF2).
Simply converting thegraph_def
to string will return the text representation of the protobuf message:def graphdef_to_pbtxt(filename): with open(filename,'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with open('protobuf.txt', 'w') as fp: fp.write(str(graph_def))
Also, using GFiles may be faster, but makes it more difficult to keep track of where in the TF API they are defined, so normal python files work just fine in my case.
Hi, @GPhilo, I'm trying to use pbtxt_to_graphdef
on tensorflow 1.15, the error message says:
Traceback (most recent call last):
File "/Users/kiddo/Library/Preferences/PyCharm2019.3/scratches/scratch_1.py", line 32, in
pbtxt_to_graphdef('/Users/kiddo/Desktop/kiddo_model/saved_model.pbtxt.txt')
File "/Users/kiddo/Library/Preferences/PyCharm2019.3/scratches/scratch_1.py", line 20, in pbtxt_to_graphdef
text_format.Merge(file_content, graph_def)
File "/Users/kiddo/anaconda/lib/python3.6/site-packages/google/protobuf/text_format.py", line 693, in Merge
allow_unknown_field=allow_unknown_field)
File "/Users/kiddo/anaconda/lib/python3.6/site-packages/google/protobuf/text_format.py", line 760, in MergeLines
return parser.MergeLines(lines, message)
File "/Users/kiddo/anaconda/lib/python3.6/site-packages/google/protobuf/text_format.py", line 785, in MergeLines
self._ParseOrMerge(lines, message)
File "/Users/kiddo/anaconda/lib/python3.6/site-packages/google/protobuf/text_format.py", line 807, in _ParseOrMerge
self._MergeField(tokenizer, message)
File "/Users/kiddo/anaconda/lib/python3.6/site-packages/google/protobuf/text_format.py", line 899, in _MergeField
(message_descriptor.full_name, name))
google.protobuf.text_format.ParseError: 1:1 : Message type "tensorflow.GraphDef" has no field named "meta_info_def".
Could you please help me, find out what's wrong with the parsing?
Random speculation here, but are you sure you're passing a GraphDef file? pb and pbtxt files are generic protobuf extensions, these methods only work to transform GraphDef messages between binary and text formats. GraphDef are, for example, the result of freeze_graph.
Thanx for your fast reply. So, let me explain my actual goal. I currently re-train a new custom BERT from scratch (https://github.com/google-research/bert). BERT comes also as a tensorflow Hub module (https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1), which makes it really easy to load this model in both Tensorflow and Keras.
As, you may speculate the module contains a SavedModel
which means it provides the trained variables and the saved_model.pb
. The saved_model.pb
contain the actual graph definition which in case of the google's default BERT has a vocabulary of 30522, which actually means the graph includes such specifications. In order to be able to load my own custom BERT with a different vocab and the same specifications for the rest of the layers, I need to amend saved_model.pb
, so I used this piece of code:
with tf.Session(graph=tf.Graph()) as sess:
graph_def = tf.saved_model.loader.load(sess, ["train"], "bert")
tf.train.write_graph(graph_def, "t", "saved_model.pbtxt", as_text=True)
Given the saved_model.pbtxt
, I can now actually see the specifications of the model in raw text format and amend all the definitions of 30522 to 32000 in my case. Now, I need to do the reverse part and wrap it up again as a .pb file, because this is what Tensorflow Hub reads... I still struggle with this step and that's how I found your code :)
I would appreciate any help in order to find a working solution!
The SavedModel protobuf message is not a GraphDef, hence your error. There definitely is a way to work with protobuf messages directly without having to actually interpret them as valid Tensorflow objects - at least when conversion between binary and text formats. I can't remember right now what module it was, I'll try to have a look and post a further comment if I find it.
Ok, I think I actually found it. Look up the google.protobuf
module and/or see if you find a "saved_model_pb2" file you could import (that would be the generated python wrapper for the SavedModel message definition, via which I think it should be possible to load the file and convert it between the text/binary format)
The SavedModel protobuf messag is not a GraphDef, hence your error. There definitely is a way to work with protobuf messages directly wthout having to actually interpret them as valid Tensorflow objects - at least when conversion between binary and text formats. I can't remember right now what module it was, I'll try to have a look and post a further comment if I find it.
Any solution that will lead me to saved_model.pb -> saved_model.pbtxt -> saved_model.pb, or just amending the saved_model.pb anyhow is welcome. Thanks for your help, appreciate it!
Ok, I think I actually found it. Look up the
google.protobuf
module and/or see if you find a "saved_model_pb2" file you could import (that would be the generated python wrapper for the SavedModel message definition, via which I think it should be possible to load the file and convert it between the text/binary format)
Could you please provide a minimal example on how to check if the google.protobuf
module has a "saved_model_pb2" and load this, in order to save it back to .pbtxt. Sorry, but even the terminology google.protobuf
is vague for me...
I tried to implement the same but I am getting the following error
in graphdef_to_pbtxt(filename)
7 with open(filename,'rb') as f:
8 graph_def = tf.compat.v1.GraphDef()
----> 9 graph_def.ParseFromString(f.read())
10 with open('protobuf.txt', 'w') as fp:
11 fp.write(str(graph_def))
DecodeError: Error parsing message
The code I used is as follows
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile
def graphdef_to_pbtxt(filename):
with open(filename,'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with open('protobuf.txt', 'w') as fp:
fp.write(str(graph_def))
graphdef_to_pbtxt('saved_model.pb')
Can anybody help me on this?
This function obtained from here will do the trick:
import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
model_filename ='saved_model.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
it doesn't run in tf2 any more :|