Skip to content

Instantly share code, notes, and snippets.

@chunhualiao
Created April 23, 2023 03:06
Show Gist options
  • Save chunhualiao/0dec705a10814b3603f20bd6e4fe5a62 to your computer and use it in GitHub Desktop.
Save chunhualiao/0dec705a10814b3603f20bd6e4fe5a62 to your computer and use it in GitHub Desktop.
DeepSpeed-Chat.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"machine_shape": "hm",
"authorship_tag": "ABX9TyO7g/S92fuy0+K0kosLkokF",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "premium"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/chunhualiao/0dec705a10814b3603f20bd6e4fe5a62/deepspeed-chat.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kRZ2fVVQBg5e"
},
"outputs": [],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"source": [
"!pip install deepspeed>=0.9.0"
],
"metadata": {
"id": "3ofrU86GB0eP"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pwd"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LhH3UmPVCAXA",
"outputId": "dc8c04c6-3dc9-4231-f222-cae8772ce2c5"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!git clone https://github.com/microsoft/DeepSpeedExamples.git"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dlt8fMLZCESB",
"outputId": "d02f3285-e5d4-4463-be73-acc3eba3d3e0"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'DeepSpeedExamples'...\n",
"remote: Enumerating objects: 6134, done.\u001b[K\n",
"remote: Counting objects: 100% (220/220), done.\u001b[K\n",
"remote: Compressing objects: 100% (133/133), done.\u001b[K\n",
"remote: Total 6134 (delta 105), reused 142 (delta 58), pack-reused 5914\u001b[K\n",
"Receiving objects: 100% (6134/6134), 21.83 MiB | 16.27 MiB/s, done.\n",
"Resolving deltas: 100% (3292/3292), done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!cd DeepSpeedExamples/applications/DeepSpeed-Chat/"
],
"metadata": {
"id": "I-sh1NkjCTPG"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pwd\n",
"!ls\n",
"%cd ./DeepSpeedExamples/applications/DeepSpeed-Chat/\n",
"!pwd"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ktUfF6-XCgCg",
"outputId": "adb00e2f-7dc0-4b9e-fd13-95dac9b414cc"
},
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content\n",
"'=0.9.0' DeepSpeedExamples sample_data\n",
"/content/DeepSpeedExamples/applications/DeepSpeed-Chat\n",
"/content/DeepSpeedExamples/applications/DeepSpeed-Chat\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install -r requirements.txt"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "i9Ht-joDCWv1",
"outputId": "1002b46c-e6c3-42ae-ee61-cd003d353a95"
},
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting git+https://github.com/huggingface/transformers (from -r requirements.txt (line 7))\n",
" Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-05wo4zds\n",
" Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-05wo4zds\n",
" Resolved https://github.com/huggingface/transformers to commit d04ec99bec8a0b432fc03ed60cea9a1a20ebaf3c\n",
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Collecting datasets>=2.8.0\n",
" Downloading datasets-2.11.0-py3-none-any.whl (468 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m468.7/468.7 kB\u001b[0m \u001b[31m35.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting sentencepiece>=0.1.97\n",
" Downloading sentencepiece-0.1.98-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m77.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: protobuf==3.20.3 in /usr/local/lib/python3.9/dist-packages (from -r requirements.txt (line 3)) (3.20.3)\n",
"Collecting accelerate>=0.15.0\n",
" Downloading accelerate-0.18.0-py3-none-any.whl (215 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m215.3/215.3 kB\u001b[0m \u001b[31m30.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: torch>=1.12.0 in /usr/local/lib/python3.9/dist-packages (from -r requirements.txt (line 5)) (2.0.0+cu118)\n",
"Requirement already satisfied: deepspeed>=0.9.0 in /usr/local/lib/python3.9/dist-packages (from -r requirements.txt (line 6)) (0.9.1)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (1.22.4)\n",
"Collecting responses<0.19\n",
" Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n",
"Collecting multiprocess\n",
" Downloading multiprocess-0.70.14-py39-none-any.whl (132 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.9/132.9 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (2023.4.0)\n",
"Collecting aiohttp\n",
" Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m75.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting huggingface-hub<1.0.0,>=0.11.0\n",
" Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m200.1/200.1 kB\u001b[0m \u001b[31m29.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (2.27.1)\n",
"Collecting dill<0.3.7,>=0.3.0\n",
" Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m19.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (9.0.0)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (4.65.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (6.0)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (1.5.3)\n",
"Collecting xxhash\n",
" Downloading xxhash-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m33.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (23.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.9/dist-packages (from accelerate>=0.15.0->-r requirements.txt (line 4)) (5.9.5)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (3.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (3.1.2)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (4.5.0)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (2.0.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (3.11.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (1.11.1)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.12.0->-r requirements.txt (line 5)) (16.0.1)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.12.0->-r requirements.txt (line 5)) (3.25.2)\n",
"Requirement already satisfied: ninja in /usr/local/lib/python3.9/dist-packages (from deepspeed>=0.9.0->-r requirements.txt (line 6)) (1.11.1)\n",
"Requirement already satisfied: hjson in /usr/local/lib/python3.9/dist-packages (from deepspeed>=0.9.0->-r requirements.txt (line 6)) (3.1.0)\n",
"Requirement already satisfied: py-cpuinfo in /usr/local/lib/python3.9/dist-packages (from deepspeed>=0.9.0->-r requirements.txt (line 6)) (9.0.0)\n",
"Requirement already satisfied: pydantic<2.0.0 in /usr/local/lib/python3.9/dist-packages (from deepspeed>=0.9.0->-r requirements.txt (line 6)) (1.10.7)\n",
"Collecting tokenizers!=0.11.3,<0.14,>=0.11.1\n",
" Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m102.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers==4.29.0.dev0->-r requirements.txt (line 7)) (2022.10.31)\n",
"Collecting yarl<2.0,>=1.0\n",
" Downloading yarl-1.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (269 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m269.3/269.3 kB\u001b[0m \u001b[31m36.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.8.0->-r requirements.txt (line 1)) (23.1.0)\n",
"Collecting multidict<7.0,>=4.5\n",
" Downloading multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.2/114.2 kB\u001b[0m \u001b[31m18.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting async-timeout<5.0,>=4.0.0a3\n",
" Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n",
"Collecting aiosignal>=1.1.2\n",
" Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.8.0->-r requirements.txt (line 1)) (2.0.12)\n",
"Collecting frozenlist>=1.1.1\n",
" Downloading frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (158 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.8/158.8 kB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets>=2.8.0->-r requirements.txt (line 1)) (2022.12.7)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets>=2.8.0->-r requirements.txt (line 1)) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets>=2.8.0->-r requirements.txt (line 1)) (1.26.15)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->torch>=1.12.0->-r requirements.txt (line 5)) (2.1.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets>=2.8.0->-r requirements.txt (line 1)) (2022.7.1)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets>=2.8.0->-r requirements.txt (line 1)) (2.8.2)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.9/dist-packages (from sympy->torch>=1.12.0->-r requirements.txt (line 5)) (1.3.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.8.1->pandas->datasets>=2.8.0->-r requirements.txt (line 1)) (1.16.0)\n",
"Building wheels for collected packages: transformers\n",
" Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for transformers: filename=transformers-4.29.0.dev0-py3-none-any.whl size=6973407 sha256=00fd0ea5744e942e9044b8160d46fb0a8b4d578472a10b74ce7a28d74ac7ad25\n",
" Stored in directory: /tmp/pip-ephem-wheel-cache-h62gc5dw/wheels/14/a0/7b/8f6b25ba4110aa215fcb8d6aedd6cd4f9b9b6619190999ac2b\n",
"Successfully built transformers\n",
"Installing collected packages: tokenizers, sentencepiece, xxhash, multidict, frozenlist, dill, async-timeout, yarl, responses, multiprocess, huggingface-hub, aiosignal, transformers, aiohttp, datasets, accelerate\n",
"Successfully installed accelerate-0.18.0 aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 datasets-2.11.0 dill-0.3.6 frozenlist-1.3.3 huggingface-hub-0.13.4 multidict-6.0.4 multiprocess-0.70.14 responses-0.18.0 sentencepiece-0.1.98 tokenizers-0.13.3 transformers-4.29.0.dev0 xxhash-3.2.0 yarl-1.9.1\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!python train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --deployment-type single_gpu"
],
"metadata": {
"id": "YeqzSbJKETkG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HOB1p9IOB9dG",
"outputId": "38a23feb-2c77-4e90-82af-41d959ca7e23"
},
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
}
]
}
@chunhualiao
Copy link
Author

chunhualiao commented Apr 23, 2023

I was able to train DeepSpeed Chat using a single A100 GPU from Google Colab (I am a paying Colab Pro with access to premium GPUs). There were some minor issues I had to deal with. Mostly they were related to CUDA out of memory errors.

To apply the following patch, save it into a text file like my_changes.patch.

Then
cd DeepSpeedExamples
patch -p0 <my_changes.patch

diff --git applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/single_gpu/run_1.3b.sh applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/single_gpu/run_1.3b.sh
index 8d2865c..3cb36cd 100644
--- applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/single_gpu/run_1.3b.sh
+++ applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/single_gpu/run_1.3b.sh
@@ -16,5 +16,5 @@ fi
 mkdir -p $OUTPUT
 
 deepspeed --num_gpus 1 main.py --model_name_or_path facebook/opt-1.3b \
-   --gradient_accumulation_steps 8 --lora_dim 128 --zero_stage $ZERO_STAGE \
+   --gradient_accumulation_steps 8 --gradient_checkpointing --lora_dim 128 --zero_stage $ZERO_STAGE \
    --deepspeed --output_dir $OUTPUT &> $OUTPUT/training.log
diff --git applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/single_gpu/run_350m.sh applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/single_gpu/run_350m.sh
index 435de2c..35ea226 100644
--- applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/single_gpu/run_350m.sh
+++ applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/single_gpu/run_350m.sh
@@ -14,5 +14,5 @@ fi
 mkdir -p $OUTPUT
 
 deepspeed --num_gpus 1 main.py --model_name_or_path facebook/opt-350m \
-   --num_padding_at_beginning 1 --weight_decay 0.1 --disable_dropout --gradient_accumulation_steps 4 --zero_stage $ZERO_STAGE \
+   --num_padding_at_beginning 1 --weight_decay 0.1 --disable_dropout --gradient_checkpointing --gradient_accumulation_steps 4 --zero_stage $ZERO_STAGE \
    --deepspeed --output_dir $OUTPUT &> $OUTPUT/training.log
diff --git applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/single_gpu/run_1.3b.sh applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/single_gpu/run_1.3b.sh
index b33e3ad..d061e6a 100644
--- applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/single_gpu/run_1.3b.sh
+++ applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/single_gpu/run_1.3b.sh
@@ -22,6 +22,6 @@ mkdir -p $OUTPUT
 deepspeed main.py \
    --actor_model_name_or_path $ACTOR_MODEL_PATH --critic_model_name_or_path $CRITIC_MODEL_PATH \
    --actor_zero_stage $ACTOR_ZERO_STAGE --critic_zero_stage $CRITIC_ZERO_STAGE \
-   --num_padding_at_beginning 1 --gradient_accumulation_steps 2 \
+   --num_padding_at_beginning 1 --gradient_accumulation_steps 2 --actor_gradient_checkpointing --critic_gradient_checkpointing \
    --deepspeed --actor_lora_dim 128 --enable_hybrid_engine --actor_gradient_checkpointing --disable_actor_dropout \
    --output_dir $OUTPUT &> $OUTPUT/training.log

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment