Created
December 8, 2019 23:11
-
-
Save FeepingCreature/8a8ace4e3b3c69607d250906fce72e9b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py | |
index d9e2131..366c411 100755 | |
--- a/src/generate_unconditional_samples.py | |
+++ b/src/generate_unconditional_samples.py | |
@@ -9,7 +9,7 @@ import tensorflow as tf | |
import model, sample, encoder | |
def sample_model( | |
- model_name='117M', | |
+ model_name='1558N', | |
seed=None, | |
nsamples=0, | |
batch_size=1, | |
@@ -69,7 +69,7 @@ def sample_model( | |
out = sess.run(output) | |
for i in range(batch_size): | |
generated += batch_size | |
- text = enc.decode(out[i]) | |
+ text = enc.decode(out[i]).encode('utf-8') | |
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) | |
print(text) | |
diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py | |
index c1650bb..193f5c8 100755 | |
--- a/src/interactive_conditional_samples.py | |
+++ b/src/interactive_conditional_samples.py | |
@@ -9,7 +9,7 @@ import tensorflow as tf | |
import model, sample, encoder | |
def interact_model( | |
- model_name='117M', | |
+ model_name='1558M', | |
seed=None, | |
nsamples=1, | |
batch_size=1, | |
@@ -64,7 +64,7 @@ def interact_model( | |
) | |
saver = tf.train.Saver() | |
- ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) | |
+ ckpt = tf.train.latest_checkpoint('checkpoint/run1') | |
saver.restore(sess, ckpt) | |
while True: | |
@@ -80,7 +80,7 @@ def interact_model( | |
})[:, len(context_tokens):] | |
for i in range(batch_size): | |
generated += 1 | |
- text = enc.decode(out[i]) | |
+ text = enc.decode(out[i]).encode('utf-8') | |
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) | |
print(text) | |
print("=" * 80) | |
diff --git a/src/memory_saving_gradients.py b/src/memory_saving_gradients.py | |
index 659691f..9b46e89 100644 | |
--- a/src/memory_saving_gradients.py | |
+++ b/src/memory_saving_gradients.py | |
@@ -108,6 +108,7 @@ def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): | |
ts_all = [t for t in ts_all if 'dropout' not in t.name] | |
# DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 | |
ts_all = [t for t in ts_all if 'Cast' not in t.name] | |
+ ts_all = [t for t in ts_all if 'SparseSoftmaxCrossEntropyWithLogits' not in t.name] | |
# filter out all tensors that are inputs of the backward graph | |
with util.capture_ops() as bwd_ops: | |
@@ -120,11 +121,12 @@ def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): | |
# try two slightly different ways of getting bottlenecks tensors | |
# to checkpoint | |
- for ts in [ts_filtered, ts_all]: | |
+ # for ts in [ts_filtered, ts_all]: | |
+ for ts in [ts_filtered]: | |
# get all bottlenecks in the graph | |
bottleneck_ts = [] | |
- for t in ts: | |
+ for i, t in enumerate(ts): | |
b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) | |
f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) | |
# check that there are not shortcuts | |
@@ -133,7 +135,12 @@ def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): | |
if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): | |
bottleneck_ts.append(t) # we have a bottleneck! | |
else: | |
- debug_print("Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp))) | |
+ debug_print("%s/%s: Rejected bottleneck candidate and ops %s; %s found of %s", | |
+ i, len(ts), | |
+ [t] + list(set(ts_all) - set(b_inp) - set(f_inp)), | |
+ len(bottleneck_ts), | |
+ np.sqrt(len(ts_filtered)) * (i / len(ts)) | |
+ ) | |
# success? or try again without filtering? | |
if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found! | |
diff --git a/src/model.py b/src/model.py | |
index 4e942d8..71092bc 100644 | |
--- a/src/model.py | |
+++ b/src/model.py | |
@@ -124,10 +124,10 @@ def block(x, scope, *, past, hparams): | |
with tf.variable_scope(scope): | |
nx = x.shape[-1].value | |
a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams) | |
- x = x + a | |
+ x = x1 = x + a | |
m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams) | |
x = x + m | |
- return x, present | |
+ return x, present, x1 | |
def past_shape(*, hparams, batch_size=None, sequence=None): | |
return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head] | |
@@ -161,9 +161,9 @@ def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE): | |
pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer | |
assert len(pasts) == hparams.n_layer | |
for layer, past in enumerate(pasts): | |
- h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) | |
- if layer == 10: | |
- tf.add_to_collection('checkpoints', h) | |
+ h, present, x1 = block(h, 'h%d' % layer, past=past, hparams=hparams) | |
+ if layer < 48: | |
+ tf.add_to_collection('checkpoints', x1) | |
presents.append(present) | |
results['present'] = tf.stack(presents, axis=1) | |
h = norm(h, 'ln_f') | |
diff --git a/train.py b/train.py | |
index 57e4ef9..0cb705e 100755 | |
--- a/train.py | |
+++ b/train.py | |
@@ -118,9 +118,9 @@ def main(): | |
train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars | |
if args.optimizer == 'adam': | |
- opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) | |
+ opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate / args.batch_size) | |
elif args.optimizer == 'sgd': | |
- opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) | |
+ opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate / args.batch_size) | |
else: | |
exit('Bad optimizer:', args.optimizer) | |
@@ -136,7 +136,7 @@ def main(): | |
summary_loss = tf.summary.scalar('loss', opt_apply) | |
else: | |
if args.memory_saving_gradients: | |
- opt_grads = memory_saving_gradients.gradients(loss, train_vars) | |
+ opt_grads = memory_saving_gradients.gradients(loss, train_vars, checkpoints='collection') | |
else: | |
opt_grads = tf.gradients(loss, train_vars) | |
opt_grads = list(zip(opt_grads, train_vars)) | |
@@ -219,7 +219,7 @@ def main(): | |
tf_sample, | |
feed_dict={context: args.batch_size * [context_tokens]}) | |
for i in range(min(args.sample_num - index, args.batch_size)): | |
- text = enc.decode(out[i]) | |
+ text = enc.decode(out[i]).encode('utf-8') | |
text = '======== SAMPLE {} ========\n{}\n'.format( | |
index + 1, text) | |
all_text.append(text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment