Created
November 23, 2017 23:28
-
-
Save nlothian/0cd4540389f7091717ece6f4b89b6604 to your computer and use it in GitHub Desktop.
Convert a FastText model to a form suitable for viewing in Tensorboard
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.contrib.tensorboard.plugins import projector | |
import tensorflow as tf | |
import numpy as np | |
import os | |
meta_file = "g2x_metadata.tsv" | |
output_path = "./projections" | |
# read embedding file into list and get the size | |
with open('./ft_model.vec', 'r') as embedding_file: | |
embedding_content = embedding_file.readlines() | |
embedding_content = [x.strip() for x in embedding_content] | |
num_lines = len(embedding_content) - 1 # skip the header | |
num_dims = len(embedding_content[1].split()) - 1 # -1 because of the label column | |
print("Detected dimensions:", num_lines, " X ", num_dims) | |
placeholder = np.zeros((num_lines, num_dims)) | |
print(placeholder.shape) | |
z = 0 | |
with open(os.path.join(output_path, meta_file), 'w') as file_metadata: | |
i = 0 | |
for line in embedding_content[1:]: # skip the header line | |
values = line.split() | |
raw_label = values[0] | |
#print(label) | |
col = 0 | |
for val in values[1:]: # skip the label | |
placeholder[i][col] = val | |
z = i + col | |
col = col + 1 | |
i = i + 1 | |
if raw_label == '': | |
file_metadata.write("<Empty Line>\n") | |
else: | |
label = raw_label | |
file_metadata.write(label + "\n") | |
print("z = ", z) | |
# define the model without training | |
sess = tf.InteractiveSession() | |
embedding = tf.Variable(placeholder, trainable=False, name='g2x_metadata') | |
tf.global_variables_initializer().run() | |
saver = tf.train.Saver() | |
writer = tf.summary.FileWriter(output_path, sess.graph) | |
# adding into projector | |
config = projector.ProjectorConfig() | |
embed = config.embeddings.add() | |
embed.tensor_name = 'g2x_metadata' | |
embed.metadata_path = meta_file | |
# Specify the width and height of a single thumbnail. | |
projector.visualize_embeddings(writer, config) | |
saver.save(sess, os.path.join(output_path, 'g2x_metadata.ckpt')) | |
print('Num nodes: {}'.format(num_lines)) | |
print('Run `tensorboard --logdir={0}` to run visualize result on tensorboard'.format(output_path)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This script generates the tsv file for you right here: https://gist.github.com/nlothian/0cd4540389f7091717ece6f4b89b6604#file-fasttext_to_tensorboard-py-L59
It requires the trained vector file (called ft_model.vec) in the same directory as the script.
(This is really a pretty rough script I created for my own use. But I hope you find it helpful)