Skip to content

Instantly share code, notes, and snippets.

@shubhamagarwal92
Created July 25, 2017 08:35
Show Gist options
  • Save shubhamagarwal92/731f471d5a88a8bdf4d3313b9cfb570b to your computer and use it in GitHub Desktop.
Save shubhamagarwal92/731f471d5a88a8bdf4d3313b9cfb570b to your computer and use it in GitHub Desktop.
Generate top k predictions from beam search (tf-seq2seq)
#! /usr/bin/env python
# based on https://github.com/google/seq2seq/blob/master/bin/tools/generate_beam_viz.py
import numpy as np
import networkx as nx
import pickle
import argparse
import os
def _add_graph_level(graph, level, parent_ids, names, scores):
"""Adds a levelto the passed graph"""
for i, parent_id in enumerate(parent_ids):
new_node = (level, i)
parent_node = (level - 1, parent_id)
graph.add_node(new_node)
graph.node[new_node]["name"] = names[i]
graph.node[new_node]["score"] = str(scores[i])
graph.node[new_node]["size"] = 100
# Add an edge to the parent
graph.add_edge(parent_node, new_node)
def create_graph(predicted_ids, parent_ids, scores, vocab=None):
def get_node_name(pred):
return vocab[pred] if vocab else str(pred)
seq_length = predicted_ids.shape[0]
graph = nx.DiGraph()
for level in range(seq_length):
names = [get_node_name(pred) for pred in predicted_ids[level]]
_add_graph_level(graph, level + 1, parent_ids[level], names, scores[level])
graph.node[(0, 0)]["name"] = "START"
return graph
def get_path_to_root(graph, node):
p = graph.predecessors(node)
assert len(p) <= 1
self_seq = [graph.node[node]['name'].split('\t')[0]]
if len(p) == 0:
return self_seq
else:
return self_seq + get_path_to_root(graph, p[0])
def main(data, vocab, top_k,nBestPath,scoresPath,sourceReadFilePath,sourceWriteFilePath,indexFilePath):
beam_data = np.load(data)
with open(vocab) as file:
vocab = file.readlines()
vocab = [v.replace("\n", "") for v in vocab]
vocab += ["UNK", "SEQUENCE_START", "SEQUENCE_END"]
data_len = len(beam_data["predicted_ids"])
# print(data_len)
data_iterator = zip(beam_data["predicted_ids"],
beam_data["beam_parent_ids"],
beam_data["scores"])
def _tree_node_predecessor(pos):
return graph.node[graph.predecessors(pos)[0]]
sourceReadFile=open(sourceReadFilePath,'r')
sourceWriteFile=open(sourceWriteFilePath,'w')
scoreFile=open(scoresPath,'w')
nBestFile=open(nBestPath,'w')
indexFile=open(indexFilePath,'w')
for row_i, (predicted_ids, parent_ids, scores) in enumerate(data_iterator):
graph = create_graph(
predicted_ids=predicted_ids,
parent_ids=parent_ids,
scores=scores,
vocab=vocab)
pred_end_node_names = {pos for pos, d in graph.node.items()
if d['name'] == 'SEQUENCE_END'
and len(graph.predecessors(pos)) > 0
and _tree_node_predecessor(pos)['name'] != 'SEQUENCE_END'}
result = [(tuple(get_path_to_root(graph, pos)[1:-1][::-1]),
float(graph.node[pos]['score']))
for pos in pred_end_node_names]
if len(result) == 0:
continue
filtered_result = filter(lambda x: 'SEQUENCE_END' not in x[0], result)
s_result = sorted(filtered_result, key=lambda x: x[1], reverse=True)
probs = np.exp(np.array(list(zip(*s_result))[1]))
# probs = nn_probs / np.sum(nn_probs)
result_w_prob = [(path, score, prob) for (path, score), prob in zip(s_result, probs)]
indexFile.write(str(len(result_w_prob[:top_k])))
indexFile.write('\n')
sourceLine = sourceReadFile.readline()
print(row_i)
for path, score, prob in result_w_prob[:top_k]:
targetLine = "".join(path)
# print("\t".join((str(row_i), path, str(score), str(prob))))
nBestFile.write(targetLine)
nBestFile.write('\n')
scoreFile.write(str(score)+'\t'+str(prob))
scoreFile.write('\n')
sourceWriteFile.write(sourceLine)
sourceReadFile.close()
sourceWriteFile.close()
scoreFile.close()
nBestFile.close()
indexFile.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Generate beam search top k")
parser.add_argument(
"-d", "--data", type=str, required=True,
help="path to the beam search data file")
parser.add_argument(
"-k", "--top_k", type=int, required=True,
help="number of top k to take")
parser.add_argument(
"-v", "--vocab", type=str, required=False,
help="path to the vocabulary file")
parser.add_argument(
"-b", "--nBestPath", type=str, required=False,
help="path to the nBest file")
parser.add_argument(
"-sc", "--scoresPath", type=str, required=False,
help="path to the score file")
parser.add_argument(
"-sr", "--sourceReadFilePath", type=str, required=False,
help="path to the source file")
parser.add_argument(
"-sw", "--sourceWriteFilePath", type=str, required=False,
help="path to the source write file")
parser.add_argument(
"-n", "--indexFilePath", type=str, required=False,
help="path to the source write file")
args = parser.parse_args()
main(args.data, args.vocab, args.top_k,args.nBestPath,args.scoresPath,args.sourceReadFilePath,args.sourceWriteFilePath,args.indexFilePath)
source activate venv
export CODE_DIR=
export PROJECT_HOME=
export DATA_HOME=$PROJECT_HOME/data/
export VOCAB_DIR=${DATA_HOME}/Vocab
export OUTPUT_DIR=
export BEAM_FILE=beams_20.npz
# vocab target file
export VOCAB_FILE=$VOCAB_DIR/vocab.target.txt
export TOP_K=10
# target file
export SOURCE_READ_FILE=$DATA_HOME/source.txt
export NBEST_FILE=$OUTPUT_DIR/predictions_nbest.txt
export SOURCE_WRITE_FILE=$OUTPUT_DIR/source_repeated.txt
export SCORES_FILE=$OUTPUT_DIR/logScores.txt
export INDEX_FILE=$OUTPUT_DIR/indices.txt
python generate_beam_topk.py --data=$BEAM_FILE --vocab=$VOCAB_FILE --top_k=$TOP_K --nBestPath=$NBEST_FILE --scoresPath=$SCORES_FILE --sourceReadFilePath=$SOURCE_READ_FILE --sourceWriteFilePath=$SOURCE_WRITE_FILE --indexFilePath=$INDEX_FILE
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment