Created
January 20, 2021 22:16
-
-
Save petewarden/927b11914b905d10f50894453f3fbf7e to your computer and use it in GitHub Desktop.
lstm_quantization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ============================================================================== | |
"""LSTM quantization with python.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import pathlib | |
from absl import app | |
from absl import flags | |
import numpy as np | |
import tensorflow as tf | |
flags.DEFINE_integer('train_steps', 2, 'Number of steps in training.') | |
flags.DEFINE_string('tflite_dir', '/tmp/lstm/tflite', | |
'Directory to save/restore float tflite model.') | |
def load_data(training_data_points, test_data_points): | |
"""Load mnist data, down sample and transform.""" | |
tf.print('Loading data...\n') | |
# Load MNIST dataset | |
mnist = tf.keras.datasets.mnist | |
(train_images, train_labels), (test_images, test_labels) = mnist.load_data() | |
# Down sampling. | |
train_images = train_images[0:training_data_points] | |
train_labels = train_labels[0:training_data_points] | |
test_images = test_images[0:test_data_points] | |
test_labels = test_labels[0:test_data_points] | |
# Normalize the input image so that each pixel value is between 0 to 1. | |
train_images = train_images.astype(np.float32) / 255.0 | |
test_images = test_images.astype(np.float32) / 255.0 | |
# Retrun. | |
return (train_images, train_labels, test_images, test_labels) | |
def train(model, train_images, train_labels, test_images, test_labels, steps): | |
"""Train the model.""" | |
tf.print('Training model...\n') | |
# Default batch is 32 so 960 runs 30 iterations. | |
model.fit( | |
train_images, | |
train_labels, | |
epochs=steps, | |
validation_data=(test_images, test_labels)) | |
def build_model(): | |
"""Build LSTM model.""" | |
tf.print('Building LSTM model.\n') | |
model = tf.keras.models.Sequential([ | |
tf.keras.layers.Input(shape=(28, 28), name='input'), | |
tf.keras.layers.LSTM(20, time_major=False, return_sequences=True), | |
tf.keras.layers.Flatten(), | |
tf.keras.layers.Dense(10, activation=tf.nn.softmax, name='output') | |
]) | |
model.compile( | |
optimizer='adam', | |
loss='sparse_categorical_crossentropy', | |
metrics=['accuracy']) | |
model.summary() | |
return model | |
def generate_data(): | |
"""Generator for calibation data.""" | |
tf.print('Generating calibration data...\n') | |
mnist = tf.keras.datasets.mnist | |
(images, _), (_, _) = mnist.load_data() | |
images = images[0:64] | |
images = images.astype(np.float32) / 255.0 | |
for image in images: | |
# Resize. [28, 28] to [1, 28, 28] for tflite. | |
image = np.expand_dims(image, axis=0) | |
yield [image] | |
def convert_and_quantize_model(model): | |
"""Convert and Quantize LSTM model.""" | |
tf.print('Quatizing fused LSTM model...\n') | |
run_model = tf.function(lambda x: model(x)) | |
# Resize input. | |
batch_size = 1 | |
steps = 28 | |
input_size = 28 | |
concrete_func = run_model.get_concrete_function( | |
tf.TensorSpec([batch_size, steps, input_size], model.inputs[0].dtype)) | |
# Save to model directory. | |
model_dir = '/tmp/keras_lstm' | |
model.save(model_dir, save_format='tf', signatures=concrete_func) | |
# Quantize from saved model. | |
converter = tf.lite.TFLiteConverter.from_saved_model(model_dir) | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
converter.representative_dataset = generate_data | |
tflite_model = converter.convert() | |
return tflite_model | |
def tflite_float_inference(model, images, expected_labels): | |
"""Run the model.""" | |
tf.print('Running tflite_float_inference...\n') | |
# Initialize the interpreter | |
interpreter = tf.lite.Interpreter(model_content=model) | |
interpreter.allocate_tensors() | |
# Interpreter details. | |
input_details = interpreter.get_input_details()[0] | |
output_details = interpreter.get_output_details()[0] | |
print('input_details:') | |
print(input_details) | |
print('output_details:') | |
print(output_details) | |
# Expand dimension. | |
image = np.expand_dims(images[0], axis=0).astype(input_details['dtype']) | |
expected_label = expected_labels[0] | |
interpreter.set_tensor(input_details['index'], image) | |
interpreter.invoke() | |
output = interpreter.get_tensor(output_details['index'])[0] # [0] for batch. | |
prediction = output.argmax() | |
print('output:') | |
print(interpreter.get_tensor(output_details['index'])) | |
print('Expected', expected_label, ' and predicted', prediction, '\n') | |
def save_tflite_files(model, path, name): | |
"""Save the TFLite model.""" | |
tf.print('Saving tflite model', path, '/', name, '...\n') | |
tflite_models_dir = pathlib.Path(path) | |
tflite_models_dir.mkdir(exist_ok=True, parents=True) | |
tflite_model_file = tflite_models_dir / name | |
tflite_model_file.write_bytes(model) | |
def main(_): | |
# Load data. | |
(train_images, train_labels, test_images, test_labels) = load_data(960, 320) | |
# Build model. | |
model = build_model() | |
# Train model. | |
train(model, train_images, train_labels, test_images, test_labels, | |
flags.FLAGS.train_steps) | |
# Convert model. | |
tflite_model_quantized_fused = convert_and_quantize_model(model) | |
# Run tflite inference. | |
tflite_float_inference(tflite_model_quantized_fused, test_images, test_labels) | |
# Save tflite model. | |
location = flags.FLAGS.tflite_dir | |
save_tflite_files(tflite_model_quantized_fused, location, 'lstm_quant.tflite') | |
if __name__ == '__main__': | |
app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment