{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "sCe7schSzJE9" }, "source": [ "Jul/29/2023 by Yoshihisa Nitta
\n", "Mar/24/2024 by Yoshihisa Nitta
\n", "Google Colab 対応コード
" ] }, { "cell_type": "markdown", "metadata": { "id": "qaf96rA2EHFx" }, "source": [ "[本ノートブックの実行に関する注意点]\n", "
    \n", "
  1. 最初に LOAD_MODELFalse に設定してノートブック全体を実行すること。学習したマリオのモデルが Google Drive 上に保存される。
  2. \n", "
  3. 以後、LOAD_MODELTrue に設定してノートブック全体を何度も実行すること。以前に訓練したマリオのモデルをロードして追加学習し、Google Drive 上のモデルを更新する。\n", "
  4. \n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cz0X4xmbMA66" }, "outputs": [], "source": [ "LOAD_MODEL = True # Set to *False* for initial training and *True* for additional training." ] }, { "cell_type": "markdown", "metadata": { "id": "45lkAWCm6haN" }, "source": [ "
\n", "\n", "# PyTorch: Reinforcement Learning (PyTorch を用いた強化学習)\n", "# SuperMario with D-DQN (D-DQN によるスーパーマリオの強化学習)\n", "\n", "
\n", "\n", "スーパーマリオの操作を D-DQN (Double Deep Q-Network) を用いて強化学習する。\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "qY2hQD0hZJ06" }, "source": [ "

\n", "このページは、以下の URL に示す \"PyTorch の公式チュートリアルの強化学習(Mario)\" の Web ページの内容を元に、nitta が自分で改変したものである。\n", "

\n", "\n", "
\n",
    "Train A Mario-Playing RL Agent\n",
    "Authors: Yuansong Feng, Suraj Subramanian, Howard Wang, Steven Guo.\n",
    "
\n", "\n", "https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html\n", "\n", "
\n", "\n", "

\n", "モデルの保存 (checkpoint) では、項目が1変数 (mario.curr_step) 増えている。\n", "元のページの github から checkpoint ファイルをダウンロードして使う場合は、注意すること。\n", "

" ] }, { "cell_type": "markdown", "metadata": { "id": "RWhjBCI5B6qD" }, "source": [ "
\n", "\n", "# Google Colab\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 2091, "status": "ok", "timestamp": 1711991150688, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "u2Tw6NeZzlUL", "outputId": "bb347e6e-b6af-4bff-96c2-8f160cf88a79" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ] } ], "source": [ "# by nitta\n", "import os\n", "is_colab = 'google.colab' in str(get_ipython()) # for Google Colab\n", "\n", "if is_colab:\n", " from google.colab import drive\n", " drive.mount('/content/drive')\n", " SAVE_PREFIX='/content/drive/MyDrive/PyTorch/ReinforcementLearning'\n", "else:\n", " SAVE_PREFIX='.'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 11589, "status": "ok", "timestamp": 1711991162273, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "05DpcF_A9q0m", "outputId": "20b9193b-4097-4fe1-e66b-b494c18724e3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4 packages can be upgraded. Run 'apt list --upgradable' to see them.\n", "The following packages have been kept back:\n", " libcudnn8 libcudnn8-dev libnccl-dev libnccl2\n", "0 upgraded, 0 newly installed, 0 to remove and 4 not upgraded.\n", "xvfb is already the newest version (2:21.1.4-2ubuntu1.7~22.04.8).\n", "0 upgraded, 0 newly installed, 0 to remove and 4 not upgraded.\n" ] } ], "source": [ "# packages\n", "if is_colab:\n", " ! apt update -qq\n", " ! apt upgrade -qq\n", " ! apt install -qq xvfb\n", " ! pip -q install pyvirtualdisplay\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patches as patches\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MhNeJEhn0Act" }, "outputs": [], "source": [ "# by nitta\n", "# Show multiple images as an animation (Colab compatible)\n", "%matplotlib notebook\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from matplotlib import animation\n", "from IPython import display\n", "import os\n", "import matplotlib\n", "matplotlib.rcParams['animation.embed_limit'] = 2**128\n", "\n", "def display_frames_as_anim(frames, filepath=None, html5=False):\n", " \"\"\"\n", " Displays a list of frames as a gif, with controls\n", " \"\"\"\n", " H, W, _ = frames[0].shape\n", " fig, ax = plt.subplots(1, 1, figsize=(W/100.0, H/100.0))\n", " ax.axis('off')\n", " patch = plt.imshow(frames[0])\n", "\n", " def animate(i):\n", " display.clear_output(wait=True)\n", " patch.set_data(frames[i])\n", " return patch\n", "\n", " anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=30, repeat=False)\n", "\n", " if not filepath is None:\n", " dpath, fname = os.path.split(filepath)\n", " if dpath != '' and not os.path.exists(dpath):\n", " os.makedirs(dpath)\n", " anim.save(filepath)\n", "\n", " if is_colab:\n", " if html5:\n", " display.display(display.HTML(anim.to_html5_video()))\n", " else:\n", " display.display(display.HTML(anim.to_jshtml()))\n", " #plt.close()\n", " else:\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "eccT5ygFCDF7" }, "source": [ "
\n", "\n", "# OpenAI Gym: SuperMarioBros\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 1064, "status": "ok", "timestamp": 1711991163332, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "1cLWeYH28qxy", "outputId": "ac7d4cd2-a508-4fd4-b405-a67d9bd6c986" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: Could not find a version that satisfies the requirement gym-super-mario-bros== (from versions: 0.4.0, 0.4.1, 0.4.2, 0.5.0, 0.5.1, 0.6.0, 0.6.1, 0.6.2, 0.6.3, 0.6.4, 0.7.0, 0.8.0, 0.8.1, 0.9.0, 0.9.1, 0.9.2, 0.9.3, 0.9.4, 0.9.5, 0.10.0, 0.10.1, 0.10.2, 0.10.3, 0.10.4, 0.11.0, 0.11.1, 0.11.2, 0.11.3, 0.11.4, 0.11.5, 0.12.0, 0.13.0, 0.13.1, 0.14.0, 0.14.1, 0.14.2, 0.14.3, 0.15.0, 0.15.1, 0.15.2, 0.15.3, 0.15.4, 0.15.5, 0.15.6, 1.0.0, 1.0.1, 1.1.0, 1.1.1, 1.1.2, 1.1.3, 2.0.0, 2.1.0, 2.1.1, 2.1.2, 2.1.3, 2.2.0, 2.2.1, 2.3.0, 2.3.1, 3.0.0, 3.0.1, 3.0.2, 3.0.3, 3.0.4, 3.0.5, 3.0.6, 3.0.7, 3.0.8, 3.1.0, 3.1.1, 3.1.2, 4.0.0, 4.0.1, 4.0.2, 4.1.0, 5.0.0, 5.0.1, 6.0.1, 6.0.2, 6.0.3, 6.0.4, 7.0.0, 7.0.1, 7.1.0, 7.1.1, 7.1.2, 7.1.3, 7.1.5, 7.1.6, 7.2.1, 7.2.3, 7.3.0, 7.3.1, 7.3.2, 7.3.3, 7.4.0)\u001b[0m\u001b[31m\n", "\u001b[0m\u001b[31mERROR: No matching distribution found for gym-super-mario-bros==\u001b[0m\u001b[31m\n", "\u001b[0m" ] } ], "source": [ "# マリオゲームの使用可能なバージョンを調べる\n", "# エラーが表示されるが、気にしない事。\n", "! pip install gym-super-mario-bros==" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 4661, "status": "ok", "timestamp": 1711991167988, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "-jSudvF99D80", "outputId": "afeec919-b24b-4e47-b950-ba8d544964c2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: gym-super-mario-bros==7.4.0 in /usr/local/lib/python3.10/dist-packages (7.4.0)\n", "Requirement already satisfied: nes-py>=8.1.4 in /usr/local/lib/python3.10/dist-packages (from gym-super-mario-bros==7.4.0) (8.2.1)\n", "Requirement already satisfied: gym>=0.17.2 in /usr/local/lib/python3.10/dist-packages (from nes-py>=8.1.4->gym-super-mario-bros==7.4.0) (0.25.2)\n", "Requirement already satisfied: numpy>=1.18.5 in /usr/local/lib/python3.10/dist-packages (from nes-py>=8.1.4->gym-super-mario-bros==7.4.0) (1.25.2)\n", "Requirement already satisfied: pyglet<=1.5.21,>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from nes-py>=8.1.4->gym-super-mario-bros==7.4.0) (1.5.21)\n", "Requirement already satisfied: tqdm>=4.48.2 in /usr/local/lib/python3.10/dist-packages (from nes-py>=8.1.4->gym-super-mario-bros==7.4.0) (4.66.2)\n", "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gym>=0.17.2->nes-py>=8.1.4->gym-super-mario-bros==7.4.0) (2.2.1)\n", "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym>=0.17.2->nes-py>=8.1.4->gym-super-mario-bros==7.4.0) (0.0.8)\n" ] } ], "source": [ "! pip install gym-super-mario-bros==7.4.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 2036, "status": "ok", "timestamp": 1711991170017, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "eLFNng4Xr1WK", "outputId": "fa2f2e21-c2e2-45bc-db2b-e8834a56d7a7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: Could not find a version that satisfies the requirement tensordict== (from versions: 0.0.1a0, 0.0.1b0, 0.0.1rc0, 0.0.2a0, 0.0.2b0, 0.0.3, 0.1.0, 0.1.1, 0.1.2, 0.2.0, 0.2.1, 0.3.0, 0.3.1)\u001b[0m\u001b[31m\n", "\u001b[0m\u001b[31mERROR: No matching distribution found for tensordict==\u001b[0m\u001b[31m\n", "\u001b[0m" ] } ], "source": [ "# PyTorch の強化学習用ライブラリのバージョンを調べる\n", "# エラーが表示されるが、気にしない事。\n", "! pip install tensordict==" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 4324, "status": "ok", "timestamp": 1711991174339, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "u1EORSkzs2Af", "outputId": "229a9fd3-29b1-4ba2-8d40-6d445ef4a590" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tensordict==0.3.1 in /usr/local/lib/python3.10/dist-packages (0.3.1)\n", "Requirement already satisfied: torch>=2.2.1 in /usr/local/lib/python3.10/dist-packages (from tensordict==0.3.1) (2.2.1+cu121)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from tensordict==0.3.1) (1.25.2)\n", "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from tensordict==0.3.1) (2.2.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (3.13.3)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (4.10.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (3.1.3)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (2023.6.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (2.19.3)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (12.1.105)\n", "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.2.1->tensordict==0.3.1) (2.2.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.2.1->tensordict==0.3.1) (12.4.99)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.2.1->tensordict==0.3.1) (2.1.5)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.2.1->tensordict==0.3.1) (1.3.0)\n" ] } ], "source": [ "! pip install tensordict==0.3.1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 923, "status": "ok", "timestamp": 1711991175258, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "Wpg2lQ1Ls9Ui", "outputId": "df1ba561-8c2c-4867-eaa7-9baf69ad7fd2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: Could not find a version that satisfies the requirement torchrl== (from versions: 0.0.1a0, 0.0.1b0, 0.0.1rc0, 0.0.2a0, 0.0.3, 0.0.4a0, 0.0.4b0, 0.0.4, 0.0.5, 0.1.0, 0.1.1, 0.2.0, 0.2.1, 0.3.0, 0.3.1)\u001b[0m\u001b[31m\n", "\u001b[0m\u001b[31mERROR: No matching distribution found for torchrl==\u001b[0m\u001b[31m\n", "\u001b[0m" ] } ], "source": [ "# PyTorch の強化学習用ライブラリのバージョンを調べる\n", "# エラーが表示されるが、気にしない事。\n", "! pip install torchrl==" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 4069, "status": "ok", "timestamp": 1711991179324, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "jlpZI5Zbs9cr", "outputId": "8ed482a6-f702-4fe1-b577-8f56545054f7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: torchrl==0.3.1 in /usr/local/lib/python3.10/dist-packages (0.3.1)\n", "Requirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from torchrl==0.3.1) (2.2.1+cu121)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchrl==0.3.1) (1.25.2)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from torchrl==0.3.1) (24.0)\n", "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from torchrl==0.3.1) (2.2.1)\n", "Requirement already satisfied: tensordict>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from torchrl==0.3.1) (0.3.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (3.13.3)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (4.10.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (3.1.3)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (2023.6.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (2.19.3)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (12.1.105)\n", "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->torchrl==0.3.1) (2.2.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.1.0->torchrl==0.3.1) (12.4.99)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->torchrl==0.3.1) (2.1.5)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->torchrl==0.3.1) (1.3.0)\n" ] } ], "source": [ "!pip install torchrl==0.3.1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KEKYgbiBMhcg" }, "outputs": [], "source": [ "import gym\n", "import gym_super_mario_bros\n", "\n", "ENV = 'SuperMarioBros-1-1-v0'\n", "#ENV = 'SuperMarioBros-1-1-v3'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 3456, "status": "ok", "timestamp": 1711991183630, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "97D0NBlc6aYr", "outputId": "4e049059-8d6e-4450-ae9e-42c9bfc4f5b7" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n" ] } ], "source": [ "import torch\n", "import torchvision\n", "import PIL\n", "import pathlib\n", "import numpy as np\n", "import collections\n", "import random, datetime, os, copy" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 10, "status": "ok", "timestamp": 1711991183630, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "7Kobsh6RNxqp", "outputId": "1f4877ea-e2eb-43b2-baee-3f19443447a5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.25.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n" ] } ], "source": [ "print(gym.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 10, "status": "ok", "timestamp": 1711991183631, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "MXh2IReY5XZA", "outputId": "acfa7adb-8b95-4ade-d2d6-c2347f094fcb" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/gym/envs/registration.py:593: UserWarning: \u001b[33mWARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.\u001b[0m\n", " logger.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "(240, 256, 3)\n", "Discrete(5)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/gym/core.py:317: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", " deprecation(\n" ] } ], "source": [ "# 本来のマリオゲームにおけるobservation は、画面そのもの(numpy.ndarray)。\n", "# 強化学習時の observation は、以前は「画面そのもの」だったが、Mar/25/2024時点では LazyFrames クラスのオブジェクトに変更されている\n", "\n", "# NES Emulator for OpenAI Gym\n", "from nes_py.wrappers import JoypadSpace\n", "from gym_super_mario_bros.actions import RIGHT_ONLY\n", "\n", "if gym.__version__ < '0.26':\n", " env = gym_super_mario_bros.make(ENV, new_step_api=True)\n", "else:\n", " env = gym_super_mario_bros.make(ENV, render_mode='rgb', apply_api_compatibility=True)\n", "\n", "# action-space\n", "# 0: walk right\n", "# 1: jump right\n", "#env = JoypadSpace(env, [[\"right\"], [\"right\", \"A\"]])\n", "env = JoypadSpace(env, RIGHT_ONLY)\n", "\n", "observation = env.reset()\n", "\n", "print(type(observation))\n", "print(observation.shape)\n", "print(env.action_space)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wnm4yq-8bYpZ" }, "source": [ "## マリオゲームを実行する(行動はランダム選択)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aNLXWGrWaTzi" }, "outputs": [], "source": [ "np.random.seed(12345)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 8, "status": "ok", "timestamp": 1711991183631, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "xx_D4I3GbPoB", "outputId": "f8fb9dcf-73eb-4cbf-efc7-c23af831e708" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:227: DeprecationWarning: \u001b[33mWARN: Core environment is written in old step API which returns one bool instead of two. It is recommended to rewrite the environment with new step API. \u001b[0m\n", " logger.deprecation(\n", "/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", " if not isinstance(done, (bool, np.bool8)):\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "101\n" ] } ], "source": [ "observation = env.reset()\n", "frames = [observation.copy()] # [注意] copy() してから保存しないと、全部同じ画像になってしまう\n", "\n", "for _ in range(100):\n", " action = env.action_space.sample() # np.random.choice(2) # 0: walk left, 1: jump\n", " observation, reward, done, trunc, info = env.step(action) # needs action from DNN\n", "\n", " frames.append(observation.copy())\n", "\n", " if done or trunc or info[\"flag_get\"]:\n", " break;\n", "\n", "print(len(frames))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 338 }, "executionInfo": { "elapsed": 6954, "status": "ok", "timestamp": 1711991190579, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "BY-NaoIHbWFS", "outputId": "ea6d0531-abb0-4c9e-beb3-538aded5fcca" }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", "
\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# アニメーション表示 by nitta\n", "display_frames_as_anim(frames, 'mario_video0.mp4')\n", "\n", "if is_colab: # copy to google drive\n", " ! mkdir -p {SAVE_PREFIX}\n", " ! cp mario_video0.mp4 {SAVE_PREFIX} # copy to the Google Drive" ] }, { "cell_type": "markdown", "metadata": { "id": "9OizCuuUeR1M" }, "source": [ "### 実行の様子を動画で表示する (HTML)\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "6aZbLITjMX1T" }, "source": [ "
\n", "\n", "# 強化学習のクラス定義 (Reinforcement Learning Definitions)\n", "\n", "
\n", "\n", "
    \n", "
  1. 環境 (Environment) ... エージェント が相互作用し、学習をする世界
  2. \n", "
  3. 行動 (Action) $a$ ... エージェント環境 に対応する方法。すべての取り得る行動の集合は 行動空間 (action-space) と呼ばれる。
  4. \n", "
  5. 状態 (State) $s$ ... 環境 の現在の状況。環境 が取り得る全ての状態の集合は状態空間 (state-space) と呼ばれる。
  6. \n", "
  7. 報酬 (reward) $r$ ... 報酬環境エージェント にフィードバックする鍵である。これによりエージェント は学習し、将来の行動 を変更する。複数の時間ステップに渡って得られた報酬 の総和は利得 (Return) と呼ばれる。
  8. \n", "
  9. 最適行動値関数 (OPtimal Action-Value function) $Q^{*}(s, a)$ ... 状態 $s$ において任意の行動 $a$ を取ったのち、将来の時間ステップにおいて利得を最大化する行動を取るときの利得の期待値を返す。\n", "$Q$ はある状態における行動の「質」を表すものといえる。この関数を近似する。
  10. \n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "GByHsUTL-xwv" }, "source": [ "
\n", "\n", "## 環境 (Environment)\n", "\n", "### 環境: 初期化\n", "\n", "マリオでは、環境は土管 (tube), キノコ (mushroom) やその他のコンポーネントから構成されている。\n", "\n", "マリオが行動をすると、環境は変更された「次の状態」や「報酬」やその他の情報を返す。" ] }, { "cell_type": "markdown", "metadata": { "id": "jel8Jk1vX8l8" }, "source": [ "### 環境: 前処理 (Preprocess Environment)\n", "\n", "環境次の状態next_state 変数でエージェントに返される。\n", "各状態は [3, 240, 250] サイズの配列で表現されている。\n", "状態 は、\n", "パイプの色や空の色など、マリオの行動に関係ない情報も含んでいる。\n", "\n", "\n", "環境から返されたデータを エージェント に渡す前に Wrapper が前処理を行う。\n", "\n", "GrayScaleObservation は、RGB画像をグレースケール画像へ変換するときによく利用される Wrapper である。\n", "必要な情報を失うこと無しに、状態表現のサイズを減らすことができる。\n", "これにより、各状態 のサイズは [1, 240, 256] になる。\n", "\n", "ResizeObservation は各観測を正方形画像へとダウンサンプルし、新しいサイズ [1, 84, 84] に変換する。\n", "\n", "SkipFrame は独自のラッパーで、gym.Wrapper を継承して step() 関数を実装している。\n", "連続したフレームは変化があまりないので、多くの情報を失うことなく n-中間フレームをスキップできる。n番目のフレームはスキップされた各フレームの報酬を累積した報酬和を得る。\n", "\n", "FrameStack は連続したフレームを一つの観測時点にまとめて学習モデルに与えることができるラッパーである。\n", "事前の数フレームでの動作方向に基づいて、\n", "マリオがジャンプしているのか着地しているのかを\n", "識別できる。\n", "\n", "
\n",
    "(H,W,3) x 4 ---> (1, H, W) x 4 --> (1, h, w) x 4 --> (4, h, w)\n",
    "H = 240, W = 256\n",
    "h = w = 84\n",
    "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 74 }, "executionInfo": { "elapsed": 25, "status": "ok", "timestamp": 1711991190580, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "dFsfLefd4nhC", "outputId": "5bc2be0b-8c0e-4c95-acb5-5083606532ea" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n" ] } ], "source": [ "class SkipFrame(gym.Wrapper):\n", " def __init__(self, env, skip):\n", " \"\"\" skip 枚スキップした後のフレームだけを返す \"\"\"\n", " super().__init__(env)\n", " self._skip = skip\n", " self.original_observations = [] # added by nitta\n", "\n", " def step(self, action):\n", " \"\"\" 行動を繰り返し、報酬和を求める \"\"\"\n", " total_reward = 0.0\n", " self.original_observations = [] # added by nitta\n", " for i in range(self._skip):\n", " # 報酬を累積し、同じ行動を繰り返す\n", " obs, reward, done, trunc, info = self.env.step(action)\n", " self.original_observations.append(obs) # added by nitta\n", " total_reward += reward\n", " if done or trunc:\n", " break\n", " return obs, total_reward, done, trunc, info\n", "\n", " def get_observations(self): # added by nitta\n", " \"\"\" step で skip 枚の画面をまとめて処理しているが、元の画面の配列を返す。 \"\"\"\n", " return self.original_observations;\n", "\n", "class GrayScaleObservation(gym.ObservationWrapper):\n", " def __init__(self, env):\n", " super().__init__(env)\n", " obs_shape = self.observation_space.shape[:2]\n", " self.observation_space = gym.spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)\n", "\n", " def permute_orientation(self, observation):\n", " # (H, W, C) array --> (C, H, W) tensor に変換する\n", " observation = np.transpose(observation, (2, 0, 1))\n", " observation = torch.tensor(observation.copy(), dtype=torch.float)\n", " return observation\n", "\n", " def observation(self, observation):\n", " observation = self.permute_orientation(observation)\n", " transform = torchvision.transforms.Grayscale()\n", " observation = transform(observation)\n", " return observation\n", "\n", "class ResizeObservation(gym.ObservationWrapper):\n", " def __init__(self, env, shape):\n", " super().__init__(env)\n", " if isinstance(shape, int):\n", " self.shape = (shape, shape)\n", " else:\n", " self.shape = tuple(shape)\n", " obs_shape = self.shape + self.observation_space.shape[2:]\n", " self.observation_space = gym.spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)\n", "\n", " def observation(self, observation):\n", " transforms = torchvision.transforms.Compose([\n", " torchvision.transforms.Resize(self.shape, antialias=True),\n", " torchvision.transforms.Normalize(0, 255)\n", " ])\n", " observation = transforms(observation).squeeze(0)\n", " return observation" ] }, { "cell_type": "markdown", "metadata": { "id": "i8_yFq-kZEHW" }, "source": [ "### 環境: 前処理済みの環境 (4, 84, 84)\n", "\n", "上の Wrapper を環境に適用すると、最終的にラップされた状態は、4枚の連続するグレースケールフレームがひとつに積み重ねた状態から構成されている。\n", "\n", "マリオがアクションする度に、環境はこの構造の状態で応答する。\n", "その構造は [4, 84, 84] のサイズの3次元配列で表現される。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 110 }, "executionInfo": { "elapsed": 546, "status": "ok", "timestamp": 1711991191102, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "b2bpdSn656z4", "outputId": "7687b0aa-10a4-49a7-bfa6-8aa7c02f0415" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/gym/envs/registration.py:593: UserWarning: \u001b[33mWARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.10/dist-packages/gym/core.py:317: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", " deprecation(\n" ] } ], "source": [ "if gym.__version__ < '0.26':\n", " env = gym_super_mario_bros.make(ENV, new_step_api=True)\n", "else:\n", " env = gym_super_mario_bros.make(ENV, render_mode='rgb', apply_api_compatibility=True)\n", "\n", "# action-space\n", "# 0: walk right\n", "# 1: jump right\n", "#env = JoypadSpace(env, [[\"right\"], [\"right\", \"A\"]])\n", "env = JoypadSpace(env, RIGHT_ONLY)\n", "\n", "env = SkipFrame(env, skip=4) # 1つのstep()関数で複数の step() を呼び出す\n", "env = GrayScaleObservation(env) # グレースケール画像に変更する\n", "env = ResizeObservation(env, shape=84) # 画面サイズを変更する\n", "if gym.__version__ < '0.26':\n", " env = gym.wrappers.FrameStack(env, num_stack=4, new_step_api=True)\n", "else:\n", " env = gym.wrappers.FrameStack(env, num_stack=4) # 連続したフレームを1つにまとめる" ] }, { "cell_type": "markdown", "metadata": { "id": "oDAZEdzFCYzj" }, "source": [ "
\n", "\n", "## エージェント (Agent)\n", "\n", "\n", "このゲームのエージェントを表現する\n", "Mario クラスを作成した。\n", "\n", "Mario は以下のことができる。\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mAu7FMmh_wxA" }, "outputs": [], "source": [ "class Mario:\n", " def __init__():\n", " pass\n", "\n", " def act(self, state):\n", " \"状態が与えられると、ε-greedy 法にしたがって、行動を選択する\"\n", " pass\n", "\n", " def cache(self, experience):\n", " \"\"\" 経験(experience) を記憶に加える \"\"\"\n", " pass\n", "\n", " def recall(self):\n", " \"記憶から経験のサンプリングを行う\"\n", " pass\n", "\n", " def learn(self):\n", " \"経験データのバッチを用いて、オンライン行動価値 Q 関数を更新する\"\n", " pass" ] }, { "cell_type": "markdown", "metadata": { "id": "b3_i6RxlapDR" }, "source": [ "### エージェント: 行動の選択 (Act)\n", "\n", "各状態において、次の2通りの方法から行動を選択する。\n", "\n", "\n", "self.exploration_rate の値に基づく確率で、explore を選択し、ランダムに行動を選択する。exploit を選択した場合は、最適な行動を提供する MariNet に基づいて行動する。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "48PmUqwaDkUt" }, "outputs": [], "source": [ "# Act\n", "class Mario:\n", " def __init__(self, state_dim, action_dim, save_dir):\n", " self.state_dim = state_dim\n", " self.action_dim = action_dim\n", " self.save_dir = save_dir\n", "\n", " self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", " # 最適な行動を予測する Mario の DNN (Learn セクションで定義する)\n", " self.net = MarioNet(self.state_dim, self.action_dim).float() # define in Train\n", " self.net = self.net.to(device=self.device)\n", "\n", " self.exploration_rate = 1\n", " self.exploration_rate_decay = 0.99999975\n", " self.exploration_rate_min = 0.1\n", " self.curr_step = 0\n", " self.base_step = 0 # added by nitta\n", "\n", " self.save_every = 5e5 # Mario Net を保存する経験数 (save interval)\n", "\n", " def act(self, state):\n", " \"\"\"\n", " 与えられた状態において、ε-greedy 法で行動を選択し、ステップ値を更新する。\n", " 入力:\n", " state (LazyFrame) : 現在の状態に対する1つの観測。state_dim 次元。\n", " Return:\n", " action_idx (int): Marioが実行した行動のインデックス値\n", " \"\"\"\n", "\n", " if np.random.rand() < self.exploration_rate: # explore\n", " action_idx = np.random.randint(self.action_dim)\n", " else: # exploit\n", " state = state[0].__array__() if isinstance(state,tuple) else state.__array__()\n", " state = torch.tensor(state, device=self.device).unsqueeze(0)\n", " action_values = self.net(state, model=\"online\")\n", " action_idx = torch.argmax(action_values, axis=1).item()\n", "\n", " # exploration_rate を減じる\n", " self.exploration_rate *= self.exploration_rate_decay\n", " self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)\n", "\n", " # step 数を増やす\n", " self.curr_step += 1\n", " return action_idx" ] }, { "cell_type": "markdown", "metadata": { "id": "zTV1SAXBaz_M" }, "source": [ "
\n", "\n", "## 記憶と振り返り (Cache and Recall)\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2bX3DYW52q5T" }, "outputs": [], "source": [ "# Cache and Recall\n", "import torchrl\n", "import tensordict\n", "\n", "class Mario(Mario):\n", " def __init__(self, state_dim, action_dim, save_dir):\n", " super().__init__(state_dim, action_dim, save_dir)\n", " #self.memory = collections.deque(maxlen=100000)\n", " self.memory = torchrl.data.TensorDictReplayBuffer(\n", " storage=torchrl.data.LazyMemmapStorage(\n", " 100000,\n", " device=torch.device(\"cpu\")\n", " )\n", " )\n", " self.batch_size = 32\n", "\n", " def cache(self, state, next_state, action, reward, done):\n", " \"\"\"\n", " self.memory (リプレイバッファ)に経験を保存する\n", "\n", " Inputs:\n", " state (LazyFrame)\n", " next_state (LazyFrame)\n", " action (int)\n", " reward (float)\n", " done (bool)\n", " \"\"\"\n", "\n", " def first_if_tuple(x):\n", " return x[0] if isinstance(x, tuple) else x\n", "\n", " state = first_if_tuple(state).__array__()\n", " next_state = first_if_tuple(next_state).__array__()\n", "\n", " state = torch.tensor(state)\n", " next_state = torch.tensor(next_state)\n", " action = torch.tensor([action])\n", " reward = torch.tensor([reward])\n", " done = torch.tensor([done])\n", "\n", " # self.memory.append([state, next_state, action, reward, done,])\n", " self.memory.add(\n", " tensordict.TensorDict({\n", " 'state': state,\n", " 'next_state': next_state,\n", " 'action': action,\n", " 'reward': reward,\n", " \"done\" :done\n", " }, batch_size=[])\n", " )\n", "\n", " def recall(self):\n", " \"\"\"\n", " メモリから経験のバッチを取り出す\n", " \"\"\"\n", " batch = self.memory.sample(self.batch_size).to(self.device)\n", " state, next_state, action, reward, done = (batch.get(key) for key in ('state', 'next_state', 'action', 'reward', 'done'))\n", " return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()" ] }, { "cell_type": "markdown", "metadata": { "id": "DxqHSzlga7T5" }, "source": [ "
\n", "\n", "## 学習 (Learn, Train)\n", "\n", "Mario は内部で DDQN アルゴリズムを用いる。\n", "\n", "DDQN は $Q_{online}$ と $Q_{target}$ という2つの畳み込みネットワーク (ConvNets) を使用する。これらはそれぞれ独立して最適行動値関数を近似する。\n", "\n", "本実装では、$Q_{online}$ と $Q_{target}$ では同じ特徴生成器を使うが、全結合層 (FC, Fully Connected) 分類器としては別々に更新される。\n", "$Q_{target}$ のパラメータである $\\theta_{target}$ は逆伝播時には、更新されないように固定される。\n", "その代わり、\n", "定期的に $\\theta_{online}$ と同期される。" ] }, { "cell_type": "markdown", "metadata": { "id": "Tunhj5gvbw3U" }, "source": [ "### 学習: ニューラルネットワーク" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7IzcrS4t8lvZ" }, "outputs": [], "source": [ "# Learn (Train)\n", "# Neural Network\n", "\n", "class MarioNet(torch.nn.Module):\n", " \"\"\"\n", " 単純な CNN 構造。\n", " 入力 -> (conv2d + relu) * 3 -> flatten -> (dense + relu) * 2 -> 出力\n", " \"\"\"\n", " def __init__(self, input_dim, output_dim):\n", " super().__init__()\n", " c, h, w = input_dim\n", "\n", " if h != 84:\n", " raise ValueError(f'Expecting input height: 84, but got {h}')\n", " if w != 84:\n", " raise ValueError(f'Expecting input width: 84, but got {w}')\n", "\n", " self.online = self.__build_cnn(c, output_dim)\n", "\n", " self.target = self.__build_cnn(c, output_dim)\n", " self.target.load_state_dict(self.online.state_dict())\n", "\n", " # Q_target parameters are frozen.\n", " for p in self.target.parameters():\n", " p.requires_grad = False\n", "\n", " def forward(self, input, model):\n", " if model == 'online':\n", " return self.online(input)\n", " elif model == 'target':\n", " return self.target(input)\n", "\n", " def __build_cnn(self, c, output_dim):\n", " return torch.nn.Sequential(\n", " torch.nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),\n", " torch.nn.ReLU(),\n", " torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),\n", " torch.nn.ReLU(),\n", " torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),\n", " torch.nn.ReLU(),\n", " torch.nn.Flatten(),\n", " torch.nn.Linear(3136, 512),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(512, output_dim),\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "hljIczLPb-Bw" }, "source": [ "### 学習: TD Estimate & TD Target\n", "\n", "学習時には2つの値が使われる。(TD = Temporal Difference)\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cRoHyoge_mjU" }, "outputs": [], "source": [ "# TD Estimate and TD Target\n", "\n", "class Mario(Mario):\n", " def __init__(self, state_dim, action_dim, save_dir):\n", " super().__init__(state_dim, action_dim, save_dir)\n", " self.gamma = 0.9\n", "\n", " def td_estimate(self, state, action):\n", " current_Q = self.net(state, model='online')[\n", " np.arange(0, self.batch_size),\n", " action\n", " ] # Q_online(s, a)\n", " return current_Q\n", "\n", " @torch.no_grad()\n", " def td_target(self, reward, next_state, done):\n", " next_state_Q = self.net(next_state, model='online')\n", " best_action = torch.argmax(next_state_Q, axis=1)\n", " next_Q = self.net(next_state, model='target')[\n", " np.arange(0, self.batch_size),\n", " best_action\n", " ]\n", " return (reward + (1 - done.float()) * self.gamma * next_Q).float()" ] }, { "cell_type": "markdown", "metadata": { "id": "w2l164xydDCN" }, "source": [ "### 学習 : モデルを更新する (Update the model)\n", "\n", "
    \n", "
  1. \n", "Mario はリプレイバッファから入力をサンプリングしながら、$TD_t$ と $TD_{\\epsilon}$ を計算する。
  2. \n", "
  3. そして、この損失を $Q_{online}$ に逆伝播して、パラメータ $\\theta_{online}$ を更新する。
  4. \n", "
\n", "\n", "$\\alpha$ はオプティマイザに渡される学習率 $lr$ である。\n", "\n", "$\\quad \\theta_{online} \\leftarrow \\theta_{online} + \\alpha ~ \\nabla (TD_{\\epsilon} - TD_{t})$
\n", "\n", "$\\theta_{target}$ は逆伝播の間は更新されない。\n", "その代わり、定期的に$\\theta_{online}$ を $\\theta_{target}$ にコピーする。\n", "\n", "$\\quad \\theta_{target} \\leftarrow \\theta_{online}$
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7j7N33f-Hlhg" }, "outputs": [], "source": [ "# Update the model\n", "class Mario(Mario):\n", " def __init__(self, state_dim, action_dim, save_dir):\n", " super().__init__(state_dim, action_dim, save_dir)\n", " self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)\n", " self.loss_fn = torch.nn.SmoothL1Loss()\n", "\n", " def update_Q_online(self, td_estimate, td_target):\n", " loss = self.loss_fn(td_estimate, td_target)\n", " self.optimizer.zero_grad()\n", " loss.backward()\n", " self.optimizer.step()\n", " return loss.item()\n", "\n", " def sync_Q_target(self):\n", " self.net.target.load_state_dict(self.net.online.state_dict())" ] }, { "cell_type": "markdown", "metadata": { "id": "C4hdvRs_d3bk" }, "source": [ "### モデルを保存する (チェックポイント)\n", "\n", "元原稿の save() 関数は、MarioNet のモデルと、exploration_rate 変数だけを保存するものであった。\n", "\n", "以下のコードの save() 関数と load() 関数では、curr_step 変数も記録するように変更している。\n", "この変数を使うと、追加で学習したときに通算で何ステップ学習したかがわかる。\n", "\n", "元の save() 関数は save_org() 関数に変更した。また、save_org() 関数で保存したチェックポイントをロードする関数は load_org() 関数として定義した。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zY1GP7IkNaIv" }, "outputs": [], "source": [ "# Save a checkpoint\n", "class Mario(Mario):\n", " def save(self):\n", " '''\n", " save_path = (\n", " self.save_dir / f\"mario_net_{int(self.curr_step // self.save_every)}.chkpt\"\n", " )\n", " torch.save(\n", " dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),\n", " save_path,\n", " )\n", " '''\n", " # changed by nitta\n", " os.makedirs(self.save_dir, exist_ok=True)\n", " d = dict(\n", " model=self.net.state_dict(),\n", " exploration_rate=self.exploration_rate,\n", " curr_step=self.curr_step,\n", " )\n", " save_path = self.save_dir / f\"mario_net_{self.curr_step}.chkpt\"\n", " if (os.path.isfile(save_path)):\n", " os.remove(save_path)\n", " torch.save(d, save_path)\n", "\n", " save_path2 = self.save_dir / \"mario_net.chkpt\"\n", " if (os.path.isfile(save_path2)):\n", " os.remove(save_path2)\n", " torch.save(d, save_path2) # latest model\n", "\n", " print(f\"MarioNet saved to {save_path} at step {self.curr_step}\")\n", "\n", " def load(self, path): # defined by nitta\n", " checkpoint = torch.load(load_path, map_location=(\n", " 'cuda' if torch.cuda.is_available else 'cpu'\n", " ))\n", " mario.net.load_state_dict(checkpoint['model'])\n", " mario.exploration_rate = checkpoint['exploration_rate']\n", " mario.curr_step = checkpoint['curr_step']\n", " mario.base_step = mario.curr_step\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IJ_iRpydjgCz" }, "outputs": [], "source": [ "class Mario(Mario):\n", " def save_org(self): # 元の save(self) 関数\n", " save_path = (\n", " self.save_dir / f\"mario_net_{int(self.curr_step // self.save_every)}.chkpt\"\n", " )\n", " torch.save(\n", " dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),\n", " save_path,\n", " )\n", " print(f\"MarioNet saved to {save_path} at step {self.curr_step}\")\n", "\n", " def load_org(self, path): # defined by nitta\n", " checkpoint = torch.load(load_path, map_location=(\n", " 'cuda' if torch.cuda.is_available else 'cpu'\n", " ))\n", " mario.net.load_state_dict(checkpoint['model'])\n", " mario.exploration_rate = checkpoint['exploration_rate']\n", " mario.curr_step = 20000 * 200 # may be\n", " mario.base_step = mario.curr_step" ] }, { "cell_type": "markdown", "metadata": { "id": "fT2MrA97johD" }, "source": [ "### 全部を一つにまとめる" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TZcuAM6nOTQw" }, "outputs": [], "source": [ "# Gather all into one\n", "class Mario(Mario):\n", " def __init__(self, state_dim, action_dim, save_dir):\n", " super().__init__(state_dim, action_dim, save_dir)\n", " self.burnin = 1e4 # 訓練の前の経験の最小値\n", " self.learn_every = 3 # Q_online を更新する間の経験数\n", " self.sync_every = 1e4 # Q_target と Q_online を同期させる間の経験数\n", "\n", " def learn(self):\n", " if self.curr_step % self.sync_every == 0:\n", " self.sync_Q_target()\n", "\n", " if self.curr_step % self.save_every == 0:\n", " self.save()\n", "\n", " if self.curr_step - self.base_step < self.burnin: # changed by nitta (base_step)\n", " return None, None\n", "\n", " if self.curr_step % self.learn_every != 0:\n", " return None, None\n", "\n", " # sampling from memory\n", " state, next_state, action, reward, done = self.recall()\n", "\n", " # Get TD Estimate\n", " td_est = self.td_estimate(state, action)\n", "\n", " # Get TD Target\n", " td_tgt = self.td_target(reward, next_state, done)\n", "\n", " # Backpropagate loss through Q_online\n", " loss = self.update_Q_online(td_est, td_tgt)\n", "\n", " return (td_est.mean().item(), loss)" ] }, { "cell_type": "markdown", "metadata": { "id": "YZmrwigusQMN" }, "source": [ "
\n", "\n", "# Logging" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3PPR9-OzqjW_" }, "outputs": [], "source": [ "import numpy as np\n", "import time, datetime\n", "\n", "class MetricLogger:\n", " def __init__(self, save_dir):\n", " os.makedirs(save_dir, exist_ok=True)\n", " self.save_log = save_dir / \"log\"\n", " with open(self.save_log, \"w\") as f:\n", " f.write(\n", " f\"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}\"\n", " f\"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}\"\n", " f\"{'TimeDelta':>15}{'Time':>20}\\n\"\n", " )\n", " self.ep_rewards_plot = save_dir / \"reward_plot.jpg\"\n", " self.ep_lengths_plot = save_dir / \"length_plot.jpg\"\n", " self.ep_avg_losses_plot = save_dir / \"loss_plot.jpg\"\n", " self.ep_avg_qs_plot = save_dir / \"q_plot.jpg\"\n", "\n", " # history metrics\n", " self.ep_rewards = []\n", " self.ep_lengths = []\n", " self.ep_avg_losses = []\n", " self.ep_avg_qs = []\n", "\n", " # moving averages, added for every call to record()\n", " self.moving_avg_ep_rewards = []\n", " self.moving_avg_ep_lengths = []\n", " self.moving_avg_ep_avg_losses = []\n", " self.moving_avg_ep_avg_qs = []\n", "\n", " # current episode metric\n", " self.init_episode()\n", "\n", " # timing\n", " self.record_time = time.time()\n", "\n", " def log_step(self, reward, loss, q):\n", " self.curr_ep_reward += reward\n", " self.curr_ep_length += 1\n", " if loss:\n", " self.curr_ep_loss += loss\n", " self.curr_ep_q += q\n", " self.curr_ep_loss_length += 1\n", "\n", " def log_episode(self):\n", " \"Mark end of episode\"\n", " self.ep_rewards.append(self.curr_ep_reward)\n", " self.ep_lengths.append(self.curr_ep_length)\n", " if self.curr_ep_loss_length == 0:\n", " ep_avg_loss = 0\n", " ep_avg_q = 0\n", " else:\n", " ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)\n", " ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)\n", " self.ep_avg_losses.append(ep_avg_loss)\n", " self.ep_avg_qs.append(ep_avg_q)\n", "\n", " self.init_episode()\n", "\n", " def init_episode(self):\n", " self.curr_ep_reward = 0.0\n", " self.curr_ep_length = 0\n", " self.curr_ep_loss = 0.0\n", " self.curr_ep_q = 0.0\n", " self.curr_ep_loss_length = 0\n", "\n", " def record(self, episode, epsilon, step):\n", " mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)\n", " mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)\n", " mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)\n", " mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)\n", " self.moving_avg_ep_rewards.append(mean_ep_reward)\n", " self.moving_avg_ep_lengths.append(mean_ep_length)\n", " self.moving_avg_ep_avg_losses.append(mean_ep_loss)\n", " self.moving_avg_ep_avg_qs.append(mean_ep_q)\n", "\n", " last_record_time = self.record_time\n", " self.record_time = time.time()\n", " time_since_last_record = np.round(self.record_time - last_record_time, 3)\n", "\n", " print(\n", " f\"Episode {episode} - \"\n", " f\"Step {step} - \"\n", " f\"Epsilon {epsilon} - \"\n", " f\"Mean Reward {mean_ep_reward} - \"\n", " f\"Mean Length {mean_ep_length} - \"\n", " f\"Mean Loss {mean_ep_loss} - \"\n", " f\"Mean Q Value {mean_ep_q} - \"\n", " f\"Time Delta {time_since_last_record} - \"\n", " f\"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}\"\n", " )\n", " with open(self.save_log, \"a\") as f:\n", " f.write(\n", " f\"{episode:8d}{step:8d}{epsilon:10.3f}\"\n", " f\"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}\"\n", " f\"{mean_ep_loss:15.3f}{mean_ep_q:15.3f}\"\n", " f\"{time_since_last_record:15.3f}\"\n", " f\"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\\n\"\n", " )\n", "\n", " for metric in [\"ep_lengths\", \"ep_avg_losses\", \"ep_avg_qs\", \"ep_rewards\"]:\n", " plt.clf()\n", " plt.plot(getattr(self, f\"moving_avg_{metric}\"),label=f\"moving_avg_{metric}\")\n", " plt.legend()\n", " plt.savefig(getattr(self,f\"{metric}_plot\"))" ] }, { "cell_type": "markdown", "metadata": { "id": "lnTz381JHLg3" }, "source": [ "
\n", "\n", "# 強化学習: モデルを訓練(学習)する\n", "\n", "
\n", "\n", "LOAD_MODELFalse の場合は、新しくモデルを生成して Training (学習)を行う。\n", "\n", "LOAD_MODELTrue の場合は、保存されているモデルを読み込んで、追加 Training (学習)を行う。\n", "\n", "\n", "\n", "元の Web ページによれば、\n", "マリオが自分の世界でのやり方を本当に学ぶには、少なくとも 40,000エピソードを繰り返す必要があるとのこと。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 274 }, "executionInfo": { "elapsed": 291499, "status": "ok", "timestamp": 1711991482587, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "ZpzvPXqYH4B2", "outputId": "8bcebabb-3bf6-4bb0-9b59-9fd20767581a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using CUDA: True\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:227: DeprecationWarning: \u001b[33mWARN: Core environment is written in old step API which returns one bool instead of two. It is recommended to rewrite the environment with new step API. \u001b[0m\n", " logger.deprecation(\n", "/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n", " if not isinstance(done, (bool, np.bool8)):\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Episode 20 - Step 8304027 - Epsilon 0.1257214844542129 - Mean Reward 1220.5 - Mean Length 201.35 - Mean Loss 0.0 - Mean Q Value 0.0 - Time Delta 51.369 - Time 2024-04-01T17:07:17\n", "Episode 40 - Step 8308032 - Epsilon 0.1255956687991323 - Mean Reward 1338.5 - Mean Length 200.8 - Mean Loss 0.0 - Mean Q Value 0.0 - Time Delta 50.368 - Time 2024-04-01T17:08:08\n", "MarioNet saved to /content/drive/MyDrive/PyTorch/ReinforcementLearning/checkpoints/mario_net_8310000.chkpt at step 8310000\n", "Episode 60 - Step 8311633 - Epsilon 0.12548265216339902 - Mean Reward 1335.05 - Mean Length 193.883 - Mean Loss 0.696 - Mean Q Value 14.699 - Time Delta 50.085 - Time 2024-04-01T17:08:58\n", "Episode 80 - Step 8316233 - Epsilon 0.12533843003897904 - Mean Reward 1413.488 - Mean Length 202.912 - Mean Loss 1.287 - Mean Q Value 31.291 - Time Delta 70.559 - Time 2024-04-01T17:10:08\n", "MarioNet saved to /content/drive/MyDrive/PyTorch/ReinforcementLearning/checkpoints/mario_net_8320000.chkpt at step 8320000\n", "Episode 100 - Step 8321006 - Epsilon 0.12518895913444447 - Mean Reward 1425.01 - Mean Length 210.06 - Mean Loss 1.611 - Mean Q Value 41.079 - Time Delta 73.172 - Time 2024-04-01T17:11:22\n" ] } ], "source": [ "from pathlib import Path\n", "\n", "use_cuda = torch.cuda.is_available()\n", "print(f\"Using CUDA: {use_cuda}\\n\")\n", "\n", "save_dir = Path(SAVE_PREFIX + \"/checkpoints\")\n", "#os.makedirs(save_dir, exist_ok=True)\n", "\n", "mario = Mario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)\n", "logger = MetricLogger(save_dir / datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\"))\n", "\n", "### Load the pre-trained Model of Mario\n", "if LOAD_MODEL:\n", " load_path = Path(SAVE_PREFIX + \"/checkpoints\") / \"mario_net.chkpt\"\n", " mario.load(load_path)\n", "###\n", "\n", "mario.save_every = 10000 # 20000\n", "\n", "episodes = 100 # 2000\n", "for e in range(episodes):\n", " state = env.reset()\n", " # Play the game\n", " while True:\n", " # 状態の中でエージェントを動かす\n", " action = mario.act(state)\n", " # エージェントは行動する\n", " next_state, reward, done, trunc, info = env.step(action)\n", " # 記憶する\n", " mario.cache(state, next_state, action, reward, done)\n", " # 訓練する\n", " q, loss = mario.learn()\n", "\n", " # ログに記録する\n", " logger.log_step(reward, loss, q)\n", "\n", " # 状態を更新する\n", " state = next_state\n", "\n", " if done or info[\"flag_get\"] or trunc:\n", " break\n", "\n", " logger.log_episode()\n", "\n", " if ((e+1) %20 == 0) or ((e+1) == episodes):\n", " logger.record(episode=(e+1), epsilon=mario.exploration_rate, step=mario.curr_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "HEDxo373R8XG" }, "source": [ "
\n", "\n", "# 実験 by nitta\n", "\n", "
\n", "\n", "\n", "## 説明\n", "\n", "ここで、環境 env から返される観測 observationLazyFrames クラスのインスタンスであり、エージェント mario が見ているゲーム画面を表す。\n", "\n", "observation = env.reset()
\n", "observation, reward, done, trunc, info = env.step(action)
\n", "\n", "LazyFrames クラスのインスタンスは\n", "\\_\\_array\\_\\_() で配列に変換できる。\n", "上書きされないように copy() しておくと安全である。\n", "変換された配列の形式は\n", "(C, H, W) = (4, 84, 84)\n", "であり、4枚のグレースケール画像 ($84 \\times 84$ )である。\n", "C(=4) は時系列順にまとめてグループ化したものであり、ある画面に至る前の数ステップ分の画面があれば、画面内のキャラクタの動きを判定できる。\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 92 }, "executionInfo": { "elapsed": 871, "status": "ok", "timestamp": 1711991483443, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "nwqM6NRagbYC", "outputId": "5d04204f-2a3d-495d-f9a9-751b8f38ea2e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n" ] } ], "source": [ "observation = env.reset()\n", "print(type(observation))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "executionInfo": { "elapsed": 9, "status": "ok", "timestamp": 1711991483443, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "NBIPJPtgi6yG", "outputId": "f050796b-59e2-4cd2-818c-4d4876c9b4b8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 84, 84)\n" ] } ], "source": [ "print(observation.__array__().shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "jPEVzp3tvx5A" }, "source": [ "## ゲーム実行の様子の動画 (グレースケール表示)\n", "\n", "エージェントが見ている「環境」を動画にした。\n", "\n", "時刻 $t$ においてエージェントに与えられる「環境」は、\n", "$84 \\times 84$のグレースケール画像を時系列順\n", "($t-3$, $t-2$, $t-1$, $t$)\n", "に4フレーム重ねた\n", "$4 \\times 84 \\times 84$ のデータである。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 54 }, "executionInfo": { "elapsed": 3271, "status": "ok", "timestamp": 1711991510605, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "pnkgvQb0uhYw", "outputId": "8b4cf9f2-280e-473e-8496-1581e8bd7d9a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "215\n", "(84, 84, 4)\n" ] } ], "source": [ "# ゲームを実行してみる\n", "\n", "bak_exploration_rate = mario.exploration_rate\n", "mario.exploration_rate = 0.0\n", "\n", "observation = env.reset() # gym.wrappers.frame_stack.LazyFrames (4,84,84)\n", "\n", "frames = [ observation.__array__().copy().transpose(1,2,0) ] # [注意] copy() してから保存しないと、全部同じ画像になってしまう\n", "\n", "for step in range(2000):\n", " action = mario.act(observation) # np.random.choice(2) # 0: walk left, 1: jump\n", " observation, reward, done, trunc, info = env.step(action) # needs action from DNN\n", " frames.append(observation.__array__().copy().transpose(1,2,0) )\n", "\n", " if done or trunc or info[\"flag_get\"]:\n", " break\n", "\n", "print(len(frames))\n", "print(frames[0].shape)\n", "\n", "mario.exploration_rate = bak_exploration_rate" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 182 }, "executionInfo": { "elapsed": 5524, "status": "ok", "timestamp": 1711991520037, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "cEJWb5hUQpdK", "outputId": "b4da573f-4cf5-42b4-cfef-1f88b190261f" }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", "
\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# アニメーション表示 by nitta\n", "display_frames_as_anim(frames, 'mario_video1.mp4')\n", "\n", "if is_colab: # copy to google drive\n", " ! mkdir -p {SAVE_PREFIX}\n", " ! cp mario_video1.mp4 {SAVE_PREFIX} # copy to the Google Drive" ] }, { "cell_type": "markdown", "metadata": { "id": "NOM6nsxheLwz" }, "source": [ "### 実行の様子を動画で表示する (HTML)\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "WzZ5WZbKwe_F" }, "source": [ "## ゲーム実行の様子の動画 (カラー表示)\n", "\n", "「エージェント mario 」はサイズが $84\\times 84$ のグレースケール画像に変換されたゲーム画面を、$skip (=4)$ 画面毎に与えられ、行動を選択する。\n", "エージェントは\n", "$skip(=4)$ された画面の間は、同じ行動を選択し続ける。\n", "\n", "[自分へのメモ by nitta]\n", "SkipFrame クラスの step() 関数は $skip (=4)$ 枚ごとに画像を処理しているが、元の画像をリストとして保存するコードを追加した。\n", "また、 保存した元のゲーム画面を返す get_observations() 関数を追加した。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 54 }, "executionInfo": { "elapsed": 6011, "status": "ok", "timestamp": 1711991548925, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "Z2cb6kFaBA4S", "outputId": "7ae0c644-ac6d-4832-df07-9c52f3f22cd7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1579\n", "(240, 256, 3)\n" ] } ], "source": [ "bak_exploration_rate = mario.exploration_rate\n", "mario.exploration_rate = 0.0\n", "\n", "frames = []\n", "\n", "observation = env.reset()\n", "c,h,w = observation.__array__().shape\n", "\n", "for step in range(4000):\n", " action = mario.act(observation)\n", " observation, reward, done, trunc, info = env.step(action)\n", "\n", " for color_obs in env.get_observations():\n", " frames.append(color_obs.copy())\n", "\n", " if done or trunc or info[\"flag_get\"]:\n", " break;\n", "\n", "print(len(frames))\n", "print(frames[0].shape)\n", "\n", "mario.exploration_rate = bak_exploration_rate" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 262 }, "executionInfo": { "elapsed": 32521, "status": "ok", "timestamp": 1711991585389, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "AMKVZjWYtaDM", "outputId": "cf846b28-3239-46af-a353-7049c2520548" }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# アニメーション表示 with HTML5 (data is too many for jstml)\n", "display_frames_as_anim(frames, 'mario_video2.mp4', True) # save to file\n", "\n", "if is_colab: # copy to google drive\n", " ! mkdir -p {SAVE_PREFIX}\n", " ! cp mario_video2.mp4 {SAVE_PREFIX} # copy to the Google Drive" ] }, { "cell_type": "markdown", "metadata": { "id": "nv_48uIfvJL1" }, "source": [ "### 実行の様子を動画で表示する (HTML)\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "zoG_fY3W07XE" }, "source": [ "## 最終状態のモデルを保存する" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "executionInfo": { "elapsed": 671, "status": "ok", "timestamp": 1711991652075, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "P3Th7vb6icI9", "outputId": "c0a13696-93f7-4fc5-8198-9de86d29a008" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MarioNet saved to /content/drive/MyDrive/PyTorch/ReinforcementLearning/checkpoints/mario_net_8321903.chkpt at step 8321903\n" ] } ], "source": [ "# 最終状態のモデルを保存する\n", "mario.save()" ] }, { "cell_type": "markdown", "metadata": { "id": "QdOYh3f1VdZ6" }, "source": [ "# 成功例\n", "\n", "6500000 steps の訓練をしたあたりから、たまにマリオがゴールに到達できるようになってきた。\n", "\n", "約8320000 steps 訓練後したモデルでは\n", "exploration_rate = 0.0 で実行して、\n", "数回に1回の割合でゴールまでマリオが進む。\n", "1 episode で約 200 steps 平均実行するので、これは\n", "約 40000 episode 訓練した計算になる。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "gNzpysnFu8gv" }, "source": [ "### 実行の様子を動画で表示する (HTML)\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uxlLVzoNyHtk" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyPp9ylmFinFvVVdrlDYTanb", "gpuType": "T4", "machine_shape": "hm", "provenance": [ { "file_id": "1JCg9VzfqIMLRq-_Sfp-cFKWJoDdGEbV0", "timestamp": 1711989801906 }, { "file_id": "1vL9VANSm6YSUn1PWHoXJ637Z68n2c8CC", "timestamp": 1711680186719 }, { "file_id": "1FMky_9GL1kct9j0w6MC1aI0fnTAFBuzP", "timestamp": 1711546691816 }, { "file_id": "1zQTf5uiBEfSR6jsmMkobqkYqK4HgZKjj", "timestamp": 1711513173624 }, { "file_id": "1YLWD5uZkEXVaUdRFt6XyotozFRxjoOXl", "timestamp": 1711384485189 }, { "file_id": "1uLaN1cvAK6YTzdWM9kZqI1NzDrGFiG6U", "timestamp": 1711352242824 } ] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 1 }