Created
July 25, 2017 08:35
-
-
Save shubhamagarwal92/731f471d5a88a8bdf4d3313b9cfb570b to your computer and use it in GitHub Desktop.
Generate top k predictions from beam search (tf-seq2seq)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
# based on https://github.com/google/seq2seq/blob/master/bin/tools/generate_beam_viz.py | |
import numpy as np | |
import networkx as nx | |
import pickle | |
import argparse | |
import os | |
def _add_graph_level(graph, level, parent_ids, names, scores): | |
"""Adds a levelto the passed graph""" | |
for i, parent_id in enumerate(parent_ids): | |
new_node = (level, i) | |
parent_node = (level - 1, parent_id) | |
graph.add_node(new_node) | |
graph.node[new_node]["name"] = names[i] | |
graph.node[new_node]["score"] = str(scores[i]) | |
graph.node[new_node]["size"] = 100 | |
# Add an edge to the parent | |
graph.add_edge(parent_node, new_node) | |
def create_graph(predicted_ids, parent_ids, scores, vocab=None): | |
def get_node_name(pred): | |
return vocab[pred] if vocab else str(pred) | |
seq_length = predicted_ids.shape[0] | |
graph = nx.DiGraph() | |
for level in range(seq_length): | |
names = [get_node_name(pred) for pred in predicted_ids[level]] | |
_add_graph_level(graph, level + 1, parent_ids[level], names, scores[level]) | |
graph.node[(0, 0)]["name"] = "START" | |
return graph | |
def get_path_to_root(graph, node): | |
p = graph.predecessors(node) | |
assert len(p) <= 1 | |
self_seq = [graph.node[node]['name'].split('\t')[0]] | |
if len(p) == 0: | |
return self_seq | |
else: | |
return self_seq + get_path_to_root(graph, p[0]) | |
def main(data, vocab, top_k,nBestPath,scoresPath,sourceReadFilePath,sourceWriteFilePath,indexFilePath): | |
beam_data = np.load(data) | |
with open(vocab) as file: | |
vocab = file.readlines() | |
vocab = [v.replace("\n", "") for v in vocab] | |
vocab += ["UNK", "SEQUENCE_START", "SEQUENCE_END"] | |
data_len = len(beam_data["predicted_ids"]) | |
# print(data_len) | |
data_iterator = zip(beam_data["predicted_ids"], | |
beam_data["beam_parent_ids"], | |
beam_data["scores"]) | |
def _tree_node_predecessor(pos): | |
return graph.node[graph.predecessors(pos)[0]] | |
sourceReadFile=open(sourceReadFilePath,'r') | |
sourceWriteFile=open(sourceWriteFilePath,'w') | |
scoreFile=open(scoresPath,'w') | |
nBestFile=open(nBestPath,'w') | |
indexFile=open(indexFilePath,'w') | |
for row_i, (predicted_ids, parent_ids, scores) in enumerate(data_iterator): | |
graph = create_graph( | |
predicted_ids=predicted_ids, | |
parent_ids=parent_ids, | |
scores=scores, | |
vocab=vocab) | |
pred_end_node_names = {pos for pos, d in graph.node.items() | |
if d['name'] == 'SEQUENCE_END' | |
and len(graph.predecessors(pos)) > 0 | |
and _tree_node_predecessor(pos)['name'] != 'SEQUENCE_END'} | |
result = [(tuple(get_path_to_root(graph, pos)[1:-1][::-1]), | |
float(graph.node[pos]['score'])) | |
for pos in pred_end_node_names] | |
if len(result) == 0: | |
continue | |
filtered_result = filter(lambda x: 'SEQUENCE_END' not in x[0], result) | |
s_result = sorted(filtered_result, key=lambda x: x[1], reverse=True) | |
probs = np.exp(np.array(list(zip(*s_result))[1])) | |
# probs = nn_probs / np.sum(nn_probs) | |
result_w_prob = [(path, score, prob) for (path, score), prob in zip(s_result, probs)] | |
indexFile.write(str(len(result_w_prob[:top_k]))) | |
indexFile.write('\n') | |
sourceLine = sourceReadFile.readline() | |
print(row_i) | |
for path, score, prob in result_w_prob[:top_k]: | |
targetLine = "".join(path) | |
# print("\t".join((str(row_i), path, str(score), str(prob)))) | |
nBestFile.write(targetLine) | |
nBestFile.write('\n') | |
scoreFile.write(str(score)+'\t'+str(prob)) | |
scoreFile.write('\n') | |
sourceWriteFile.write(sourceLine) | |
sourceReadFile.close() | |
sourceWriteFile.close() | |
scoreFile.close() | |
nBestFile.close() | |
indexFile.close() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser( | |
description="Generate beam search top k") | |
parser.add_argument( | |
"-d", "--data", type=str, required=True, | |
help="path to the beam search data file") | |
parser.add_argument( | |
"-k", "--top_k", type=int, required=True, | |
help="number of top k to take") | |
parser.add_argument( | |
"-v", "--vocab", type=str, required=False, | |
help="path to the vocabulary file") | |
parser.add_argument( | |
"-b", "--nBestPath", type=str, required=False, | |
help="path to the nBest file") | |
parser.add_argument( | |
"-sc", "--scoresPath", type=str, required=False, | |
help="path to the score file") | |
parser.add_argument( | |
"-sr", "--sourceReadFilePath", type=str, required=False, | |
help="path to the source file") | |
parser.add_argument( | |
"-sw", "--sourceWriteFilePath", type=str, required=False, | |
help="path to the source write file") | |
parser.add_argument( | |
"-n", "--indexFilePath", type=str, required=False, | |
help="path to the source write file") | |
args = parser.parse_args() | |
main(args.data, args.vocab, args.top_k,args.nBestPath,args.scoresPath,args.sourceReadFilePath,args.sourceWriteFilePath,args.indexFilePath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
source activate venv | |
export CODE_DIR= | |
export PROJECT_HOME= | |
export DATA_HOME=$PROJECT_HOME/data/ | |
export VOCAB_DIR=${DATA_HOME}/Vocab | |
export OUTPUT_DIR= | |
export BEAM_FILE=beams_20.npz | |
# vocab target file | |
export VOCAB_FILE=$VOCAB_DIR/vocab.target.txt | |
export TOP_K=10 | |
# target file | |
export SOURCE_READ_FILE=$DATA_HOME/source.txt | |
export NBEST_FILE=$OUTPUT_DIR/predictions_nbest.txt | |
export SOURCE_WRITE_FILE=$OUTPUT_DIR/source_repeated.txt | |
export SCORES_FILE=$OUTPUT_DIR/logScores.txt | |
export INDEX_FILE=$OUTPUT_DIR/indices.txt | |
python generate_beam_topk.py --data=$BEAM_FILE --vocab=$VOCAB_FILE --top_k=$TOP_K --nBestPath=$NBEST_FILE --scoresPath=$SCORES_FILE --sourceReadFilePath=$SOURCE_READ_FILE --sourceWriteFilePath=$SOURCE_WRITE_FILE --indexFilePath=$INDEX_FILE | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment