Created
April 23, 2023 03:06
-
-
Save chunhualiao/0dec705a10814b3603f20bd6e4fe5a62 to your computer and use it in GitHub Desktop.
DeepSpeed-Chat.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"machine_shape": "hm", | |
"authorship_tag": "ABX9TyO7g/S92fuy0+K0kosLkokF", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU", | |
"gpuClass": "premium" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/chunhualiao/0dec705a10814b3603f20bd6e4fe5a62/deepspeed-chat.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "kRZ2fVVQBg5e" | |
}, | |
"outputs": [], | |
"source": [ | |
"!nvidia-smi" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install deepspeed>=0.9.0" | |
], | |
"metadata": { | |
"id": "3ofrU86GB0eP" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pwd" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "LhH3UmPVCAXA", | |
"outputId": "dc8c04c6-3dc9-4231-f222-cae8772ce2c5" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!git clone https://github.com/microsoft/DeepSpeedExamples.git" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "dlt8fMLZCESB", | |
"outputId": "d02f3285-e5d4-4463-be73-acc3eba3d3e0" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Cloning into 'DeepSpeedExamples'...\n", | |
"remote: Enumerating objects: 6134, done.\u001b[K\n", | |
"remote: Counting objects: 100% (220/220), done.\u001b[K\n", | |
"remote: Compressing objects: 100% (133/133), done.\u001b[K\n", | |
"remote: Total 6134 (delta 105), reused 142 (delta 58), pack-reused 5914\u001b[K\n", | |
"Receiving objects: 100% (6134/6134), 21.83 MiB | 16.27 MiB/s, done.\n", | |
"Resolving deltas: 100% (3292/3292), done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!cd DeepSpeedExamples/applications/DeepSpeed-Chat/" | |
], | |
"metadata": { | |
"id": "I-sh1NkjCTPG" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pwd\n", | |
"!ls\n", | |
"%cd ./DeepSpeedExamples/applications/DeepSpeed-Chat/\n", | |
"!pwd" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ktUfF6-XCgCg", | |
"outputId": "adb00e2f-7dc0-4b9e-fd13-95dac9b414cc" | |
}, | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content\n", | |
"'=0.9.0' DeepSpeedExamples sample_data\n", | |
"/content/DeepSpeedExamples/applications/DeepSpeed-Chat\n", | |
"/content/DeepSpeedExamples/applications/DeepSpeed-Chat\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install -r requirements.txt" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "i9Ht-joDCWv1", | |
"outputId": "1002b46c-e6c3-42ae-ee61-cd003d353a95" | |
}, | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Collecting git+https://github.com/huggingface/transformers (from -r requirements.txt (line 7))\n", | |
" Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-05wo4zds\n", | |
" Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-05wo4zds\n", | |
" Resolved https://github.com/huggingface/transformers to commit d04ec99bec8a0b432fc03ed60cea9a1a20ebaf3c\n", | |
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", | |
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", | |
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", | |
"Collecting datasets>=2.8.0\n", | |
" Downloading datasets-2.11.0-py3-none-any.whl (468 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m468.7/468.7 kB\u001b[0m \u001b[31m35.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting sentencepiece>=0.1.97\n", | |
" Downloading sentencepiece-0.1.98-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m77.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: protobuf==3.20.3 in /usr/local/lib/python3.9/dist-packages (from -r requirements.txt (line 3)) (3.20.3)\n", | |
"Collecting accelerate>=0.15.0\n", | |
" Downloading accelerate-0.18.0-py3-none-any.whl (215 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m215.3/215.3 kB\u001b[0m \u001b[31m30.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: torch>=1.12.0 in /usr/local/lib/python3.9/dist-packages (from -r requirements.txt (line 5)) (2.0.0+cu118)\n", | |
"Requirement already satisfied: deepspeed>=0.9.0 in /usr/local/lib/python3.9/dist-packages (from -r requirements.txt (line 6)) (0.9.1)\n", | |
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (1.22.4)\n", | |
"Collecting responses<0.19\n", | |
" Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n", | |
"Collecting multiprocess\n", | |
" Downloading multiprocess-0.70.14-py39-none-any.whl (132 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.9/132.9 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (2023.4.0)\n", | |
"Collecting aiohttp\n", | |
" Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m75.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting huggingface-hub<1.0.0,>=0.11.0\n", | |
" Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m200.1/200.1 kB\u001b[0m \u001b[31m29.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (2.27.1)\n", | |
"Collecting dill<0.3.7,>=0.3.0\n", | |
" Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m19.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (9.0.0)\n", | |
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (4.65.0)\n", | |
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (6.0)\n", | |
"Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (1.5.3)\n", | |
"Collecting xxhash\n", | |
" Downloading xxhash-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m33.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from datasets>=2.8.0->-r requirements.txt (line 1)) (23.1)\n", | |
"Requirement already satisfied: psutil in /usr/local/lib/python3.9/dist-packages (from accelerate>=0.15.0->-r requirements.txt (line 4)) (5.9.5)\n", | |
"Requirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (3.1)\n", | |
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (3.1.2)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (4.5.0)\n", | |
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (2.0.0)\n", | |
"Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (3.11.0)\n", | |
"Requirement already satisfied: sympy in /usr/local/lib/python3.9/dist-packages (from torch>=1.12.0->-r requirements.txt (line 5)) (1.11.1)\n", | |
"Requirement already satisfied: lit in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.12.0->-r requirements.txt (line 5)) (16.0.1)\n", | |
"Requirement already satisfied: cmake in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.12.0->-r requirements.txt (line 5)) (3.25.2)\n", | |
"Requirement already satisfied: ninja in /usr/local/lib/python3.9/dist-packages (from deepspeed>=0.9.0->-r requirements.txt (line 6)) (1.11.1)\n", | |
"Requirement already satisfied: hjson in /usr/local/lib/python3.9/dist-packages (from deepspeed>=0.9.0->-r requirements.txt (line 6)) (3.1.0)\n", | |
"Requirement already satisfied: py-cpuinfo in /usr/local/lib/python3.9/dist-packages (from deepspeed>=0.9.0->-r requirements.txt (line 6)) (9.0.0)\n", | |
"Requirement already satisfied: pydantic<2.0.0 in /usr/local/lib/python3.9/dist-packages (from deepspeed>=0.9.0->-r requirements.txt (line 6)) (1.10.7)\n", | |
"Collecting tokenizers!=0.11.3,<0.14,>=0.11.1\n", | |
" Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m102.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers==4.29.0.dev0->-r requirements.txt (line 7)) (2022.10.31)\n", | |
"Collecting yarl<2.0,>=1.0\n", | |
" Downloading yarl-1.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (269 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m269.3/269.3 kB\u001b[0m \u001b[31m36.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.8.0->-r requirements.txt (line 1)) (23.1.0)\n", | |
"Collecting multidict<7.0,>=4.5\n", | |
" Downloading multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.2/114.2 kB\u001b[0m \u001b[31m18.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hCollecting async-timeout<5.0,>=4.0.0a3\n", | |
" Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n", | |
"Collecting aiosignal>=1.1.2\n", | |
" Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n", | |
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.8.0->-r requirements.txt (line 1)) (2.0.12)\n", | |
"Collecting frozenlist>=1.1.1\n", | |
" Downloading frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (158 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.8/158.8 kB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets>=2.8.0->-r requirements.txt (line 1)) (2022.12.7)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets>=2.8.0->-r requirements.txt (line 1)) (3.4)\n", | |
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets>=2.8.0->-r requirements.txt (line 1)) (1.26.15)\n", | |
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->torch>=1.12.0->-r requirements.txt (line 5)) (2.1.2)\n", | |
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets>=2.8.0->-r requirements.txt (line 1)) (2022.7.1)\n", | |
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets>=2.8.0->-r requirements.txt (line 1)) (2.8.2)\n", | |
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.9/dist-packages (from sympy->torch>=1.12.0->-r requirements.txt (line 5)) (1.3.0)\n", | |
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.8.1->pandas->datasets>=2.8.0->-r requirements.txt (line 1)) (1.16.0)\n", | |
"Building wheels for collected packages: transformers\n", | |
" Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for transformers: filename=transformers-4.29.0.dev0-py3-none-any.whl size=6973407 sha256=00fd0ea5744e942e9044b8160d46fb0a8b4d578472a10b74ce7a28d74ac7ad25\n", | |
" Stored in directory: /tmp/pip-ephem-wheel-cache-h62gc5dw/wheels/14/a0/7b/8f6b25ba4110aa215fcb8d6aedd6cd4f9b9b6619190999ac2b\n", | |
"Successfully built transformers\n", | |
"Installing collected packages: tokenizers, sentencepiece, xxhash, multidict, frozenlist, dill, async-timeout, yarl, responses, multiprocess, huggingface-hub, aiosignal, transformers, aiohttp, datasets, accelerate\n", | |
"Successfully installed accelerate-0.18.0 aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 datasets-2.11.0 dill-0.3.6 frozenlist-1.3.3 huggingface-hub-0.13.4 multidict-6.0.4 multiprocess-0.70.14 responses-0.18.0 sentencepiece-0.1.98 tokenizers-0.13.3 transformers-4.29.0.dev0 xxhash-3.2.0 yarl-1.9.1\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!python train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --deployment-type single_gpu" | |
], | |
"metadata": { | |
"id": "YeqzSbJKETkG" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from google.colab import drive\n", | |
"drive.mount('/content/drive')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "HOB1p9IOB9dG", | |
"outputId": "38a23feb-2c77-4e90-82af-41d959ca7e23" | |
}, | |
"execution_count": 33, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Mounted at /content/drive\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I was able to train DeepSpeed Chat using a single A100 GPU from Google Colab (I am a paying Colab Pro with access to premium GPUs). There were some minor issues I had to deal with. Mostly they were related to CUDA out of memory errors.
To apply the following patch, save it into a text file like my_changes.patch.
Then
cd DeepSpeedExamples
patch -p0 <my_changes.patch