{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "sCe7schSzJE9"
},
"source": [
"Jul/29/2023 by Yoshihisa Nitta
\n",
"Mar/24/2024 by Yoshihisa Nitta
\n",
"Google Colab 対応コード
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qaf96rA2EHFx"
},
"source": [
"[本ノートブックの実行に関する注意点]\n",
"
LOAD_MODEL
を False
に設定してノートブック全体を実行すること。学習したマリオのモデルが Google Drive 上に保存される。LOAD_MODEL
を True
に設定してノートブック全体を何度も実行すること。以前に訓練したマリオのモデルをロードして追加学習し、Google Drive 上のモデルを更新する。\n",
"\n", "このページは、以下の URL に示す \"PyTorch の公式チュートリアルの強化学習(Mario)\" の Web ページの内容を元に、nitta が自分で改変したものである。\n", "
\n", "\n", "\n", "Train A Mario-Playing RL Agent\n", "Authors: Yuansong Feng, Suraj Subramanian, Howard Wang, Steven Guo.\n", "\n", "\n", "https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html\n", "\n", "
\n", "モデルの保存 (checkpoint) では、項目が1変数 (mario.curr_step) 増えている。\n", "元のページの github から checkpoint ファイルをダウンロードして使う場合は、注意すること。\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "RWhjBCI5B6qD" }, "source": [ "next_state
変数でエージェントに返される。\n",
"各状態は [3, 240, 250] サイズの配列で表現されている。\n",
"状態 は、\n",
"パイプの色や空の色など、マリオの行動に関係ない情報も含んでいる。\n",
"\n",
"\n",
"環境から返されたデータを エージェント に渡す前に Wrapper が前処理を行う。\n",
"\n",
"GrayScaleObservation は、RGB画像をグレースケール画像へ変換するときによく利用される Wrapper である。\n",
"必要な情報を失うこと無しに、状態表現のサイズを減らすことができる。\n",
"これにより、各状態 のサイズは [1, 240, 256] になる。\n",
"\n",
"ResizeObservation は各観測を正方形画像へとダウンサンプルし、新しいサイズ [1, 84, 84] に変換する。\n",
"\n",
"SkipFrame は独自のラッパーで、gym.Wrapper
を継承して step()
関数を実装している。\n",
"連続したフレームは変化があまりないので、多くの情報を失うことなく n-中間フレームをスキップできる。n番目のフレームはスキップされた各フレームの報酬を累積した報酬和を得る。\n",
"\n",
"FrameStack は連続したフレームを一つの観測時点にまとめて学習モデルに与えることができるラッパーである。\n",
"事前の数フレームでの動作方向に基づいて、\n",
"マリオがジャンプしているのか着地しているのかを\n",
"識別できる。\n",
"\n",
"\n", "(H,W,3) x 4 ---> (1, H, W) x 4 --> (1, h, w) x 4 --> (4, h, w)\n", "H = 240, W = 256\n", "h = w = 84\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 74 }, "executionInfo": { "elapsed": 25, "status": "ok", "timestamp": 1711991190580, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "dFsfLefd4nhC", "outputId": "5bc2be0b-8c0e-4c95-acb5-5083606532ea" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", " and should_run_async(code)\n" ] } ], "source": [ "class SkipFrame(gym.Wrapper):\n", " def __init__(self, env, skip):\n", " \"\"\" skip 枚スキップした後のフレームだけを返す \"\"\"\n", " super().__init__(env)\n", " self._skip = skip\n", " self.original_observations = [] # added by nitta\n", "\n", " def step(self, action):\n", " \"\"\" 行動を繰り返し、報酬和を求める \"\"\"\n", " total_reward = 0.0\n", " self.original_observations = [] # added by nitta\n", " for i in range(self._skip):\n", " # 報酬を累積し、同じ行動を繰り返す\n", " obs, reward, done, trunc, info = self.env.step(action)\n", " self.original_observations.append(obs) # added by nitta\n", " total_reward += reward\n", " if done or trunc:\n", " break\n", " return obs, total_reward, done, trunc, info\n", "\n", " def get_observations(self): # added by nitta\n", " \"\"\" step で skip 枚の画面をまとめて処理しているが、元の画面の配列を返す。 \"\"\"\n", " return self.original_observations;\n", "\n", "class GrayScaleObservation(gym.ObservationWrapper):\n", " def __init__(self, env):\n", " super().__init__(env)\n", " obs_shape = self.observation_space.shape[:2]\n", " self.observation_space = gym.spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)\n", "\n", " def permute_orientation(self, observation):\n", " # (H, W, C) array --> (C, H, W) tensor に変換する\n", " observation = np.transpose(observation, (2, 0, 1))\n", " observation = torch.tensor(observation.copy(), dtype=torch.float)\n", " return observation\n", "\n", " def observation(self, observation):\n", " observation = self.permute_orientation(observation)\n", " transform = torchvision.transforms.Grayscale()\n", " observation = transform(observation)\n", " return observation\n", "\n", "class ResizeObservation(gym.ObservationWrapper):\n", " def __init__(self, env, shape):\n", " super().__init__(env)\n", " if isinstance(shape, int):\n", " self.shape = (shape, shape)\n", " else:\n", " self.shape = tuple(shape)\n", " obs_shape = self.shape + self.observation_space.shape[2:]\n", " self.observation_space = gym.spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)\n", "\n", " def observation(self, observation):\n", " transforms = torchvision.transforms.Compose([\n", " torchvision.transforms.Resize(self.shape, antialias=True),\n", " torchvision.transforms.Normalize(0, 255)\n", " ])\n", " observation = transforms(observation).squeeze(0)\n", " return observation" ] }, { "cell_type": "markdown", "metadata": { "id": "i8_yFq-kZEHW" }, "source": [ "### 環境: 前処理済みの環境 (4, 84, 84)\n", "\n", "上の Wrapper を環境に適用すると、最終的にラップされた状態は、4枚の連続するグレースケールフレームがひとつに積み重ねた状態から構成されている。\n", "\n", "マリオがアクションする度に、環境はこの構造の状態で応答する。\n", "その構造は [4, 84, 84] のサイズの3次元配列で表現される。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 110 }, "executionInfo": { "elapsed": 546, "status": "ok", "timestamp": 1711991191102, "user": { "displayName": "Yoshihisa Nitta", "userId": "15888006800030996813" }, "user_tz": -540 }, "id": "b2bpdSn656z4", "outputId": "7687b0aa-10a4-49a7-bfa6-8aa7c02f0415" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/gym/envs/registration.py:593: UserWarning: \u001b[33mWARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.10/dist-packages/gym/core.py:317: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", " deprecation(\n" ] } ], "source": [ "if gym.__version__ < '0.26':\n", " env = gym_super_mario_bros.make(ENV, new_step_api=True)\n", "else:\n", " env = gym_super_mario_bros.make(ENV, render_mode='rgb', apply_api_compatibility=True)\n", "\n", "# action-space\n", "# 0: walk right\n", "# 1: jump right\n", "#env = JoypadSpace(env, [[\"right\"], [\"right\", \"A\"]])\n", "env = JoypadSpace(env, RIGHT_ONLY)\n", "\n", "env = SkipFrame(env, skip=4) # 1つのstep()関数で複数の step() を呼び出す\n", "env = GrayScaleObservation(env) # グレースケール画像に変更する\n", "env = ResizeObservation(env, shape=84) # 画面サイズを変更する\n", "if gym.__version__ < '0.26':\n", " env = gym.wrappers.FrameStack(env, num_stack=4, new_step_api=True)\n", "else:\n", " env = gym.wrappers.FrameStack(env, num_stack=4) # 連続したフレームを1つにまとめる" ] }, { "cell_type": "markdown", "metadata": { "id": "oDAZEdzFCYzj" }, "source": [ "
act()
... 方策にしたがい、環境の現在の状態に応じた行動を選択し、実行する。Remember
... 経験を記憶する。\n",
"経験は (current_state, current_action, reward, next_state) である。\n",
"Marioは action policy を更新するために、経験を記憶 ( cache
)し、後で振り返る( recall
)learn()
... より良い方策を学ぶ。self.exploration_rate
の値に基づく確率で、explore
を選択し、ランダムに行動を選択する。exploit
を選択した場合は、最適な行動を提供する MariNet
に基づいて行動する。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "48PmUqwaDkUt"
},
"outputs": [],
"source": [
"# Act\n",
"class Mario:\n",
" def __init__(self, state_dim, action_dim, save_dir):\n",
" self.state_dim = state_dim\n",
" self.action_dim = action_dim\n",
" self.save_dir = save_dir\n",
"\n",
" self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
" # 最適な行動を予測する Mario の DNN (Learn セクションで定義する)\n",
" self.net = MarioNet(self.state_dim, self.action_dim).float() # define in Train\n",
" self.net = self.net.to(device=self.device)\n",
"\n",
" self.exploration_rate = 1\n",
" self.exploration_rate_decay = 0.99999975\n",
" self.exploration_rate_min = 0.1\n",
" self.curr_step = 0\n",
" self.base_step = 0 # added by nitta\n",
"\n",
" self.save_every = 5e5 # Mario Net を保存する経験数 (save interval)\n",
"\n",
" def act(self, state):\n",
" \"\"\"\n",
" 与えられた状態において、ε-greedy 法で行動を選択し、ステップ値を更新する。\n",
" 入力:\n",
" state (LazyFrame) : 現在の状態に対する1つの観測。state_dim 次元。\n",
" Return:\n",
" action_idx (int): Marioが実行した行動のインデックス値\n",
" \"\"\"\n",
"\n",
" if np.random.rand() < self.exploration_rate: # explore\n",
" action_idx = np.random.randint(self.action_dim)\n",
" else: # exploit\n",
" state = state[0].__array__() if isinstance(state,tuple) else state.__array__()\n",
" state = torch.tensor(state, device=self.device).unsqueeze(0)\n",
" action_values = self.net(state, model=\"online\")\n",
" action_idx = torch.argmax(action_values, axis=1).item()\n",
"\n",
" # exploration_rate を減じる\n",
" self.exploration_rate *= self.exploration_rate_decay\n",
" self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)\n",
"\n",
" # step 数を増やす\n",
" self.curr_step += 1\n",
" return action_idx"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zTV1SAXBaz_M"
},
"source": [
"cache()
... 行動を選択する度に、その経験(state, next_state, action, reward, done)を記憶する。recall()
... メモリからランダムに過去の経験をバッチサイズ個サンプリングし、学習する。\n",
"現在の報酬と、次の状態 $s'$ における推測された $Q^{*}$ の合計は
\n",
"$\\quad a' = \\mbox{argmax}_{a} Q_{online}(s',a)$
\n",
"$\\quad \\displaystyle TD_{t} = r + \\gamma ~ Q^{*}_{target}(s',a')$
\n",
"
\n", "次の行動 $a'$ はわからないので、次状態 $s'$ において $Q_{online}$ を最大化する行動 $a'$ を使う。\n", "
\n", "\n",
"ここで勾配計算を無効にするために、\n",
"td_target()
に対して\n",
"@torch_grad()
修飾子を使うことに注意する必要がある。\n",
"($\\theta_{target}$ に対して逆伝播する必要がないので)\n",
"
save()
関数は、MarioNet
のモデルと、exploration_rate
変数だけを保存するものであった。\n",
"\n",
"以下のコードの save()
関数と load()
関数では、curr_step
変数も記録するように変更している。\n",
"この変数を使うと、追加で学習したときに通算で何ステップ学習したかがわかる。\n",
"\n",
"元の save()
関数は save_org()
関数に変更した。また、save_org()
関数で保存したチェックポイントをロードする関数は load_org()
関数として定義した。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zY1GP7IkNaIv"
},
"outputs": [],
"source": [
"# Save a checkpoint\n",
"class Mario(Mario):\n",
" def save(self):\n",
" '''\n",
" save_path = (\n",
" self.save_dir / f\"mario_net_{int(self.curr_step // self.save_every)}.chkpt\"\n",
" )\n",
" torch.save(\n",
" dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),\n",
" save_path,\n",
" )\n",
" '''\n",
" # changed by nitta\n",
" os.makedirs(self.save_dir, exist_ok=True)\n",
" d = dict(\n",
" model=self.net.state_dict(),\n",
" exploration_rate=self.exploration_rate,\n",
" curr_step=self.curr_step,\n",
" )\n",
" save_path = self.save_dir / f\"mario_net_{self.curr_step}.chkpt\"\n",
" if (os.path.isfile(save_path)):\n",
" os.remove(save_path)\n",
" torch.save(d, save_path)\n",
"\n",
" save_path2 = self.save_dir / \"mario_net.chkpt\"\n",
" if (os.path.isfile(save_path2)):\n",
" os.remove(save_path2)\n",
" torch.save(d, save_path2) # latest model\n",
"\n",
" print(f\"MarioNet saved to {save_path} at step {self.curr_step}\")\n",
"\n",
" def load(self, path): # defined by nitta\n",
" checkpoint = torch.load(load_path, map_location=(\n",
" 'cuda' if torch.cuda.is_available else 'cpu'\n",
" ))\n",
" mario.net.load_state_dict(checkpoint['model'])\n",
" mario.exploration_rate = checkpoint['exploration_rate']\n",
" mario.curr_step = checkpoint['curr_step']\n",
" mario.base_step = mario.curr_step\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IJ_iRpydjgCz"
},
"outputs": [],
"source": [
"class Mario(Mario):\n",
" def save_org(self): # 元の save(self) 関数\n",
" save_path = (\n",
" self.save_dir / f\"mario_net_{int(self.curr_step // self.save_every)}.chkpt\"\n",
" )\n",
" torch.save(\n",
" dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),\n",
" save_path,\n",
" )\n",
" print(f\"MarioNet saved to {save_path} at step {self.curr_step}\")\n",
"\n",
" def load_org(self, path): # defined by nitta\n",
" checkpoint = torch.load(load_path, map_location=(\n",
" 'cuda' if torch.cuda.is_available else 'cpu'\n",
" ))\n",
" mario.net.load_state_dict(checkpoint['model'])\n",
" mario.exploration_rate = checkpoint['exploration_rate']\n",
" mario.curr_step = 20000 * 200 # may be\n",
" mario.base_step = mario.curr_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fT2MrA97johD"
},
"source": [
"### 全部を一つにまとめる"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TZcuAM6nOTQw"
},
"outputs": [],
"source": [
"# Gather all into one\n",
"class Mario(Mario):\n",
" def __init__(self, state_dim, action_dim, save_dir):\n",
" super().__init__(state_dim, action_dim, save_dir)\n",
" self.burnin = 1e4 # 訓練の前の経験の最小値\n",
" self.learn_every = 3 # Q_online を更新する間の経験数\n",
" self.sync_every = 1e4 # Q_target と Q_online を同期させる間の経験数\n",
"\n",
" def learn(self):\n",
" if self.curr_step % self.sync_every == 0:\n",
" self.sync_Q_target()\n",
"\n",
" if self.curr_step % self.save_every == 0:\n",
" self.save()\n",
"\n",
" if self.curr_step - self.base_step < self.burnin: # changed by nitta (base_step)\n",
" return None, None\n",
"\n",
" if self.curr_step % self.learn_every != 0:\n",
" return None, None\n",
"\n",
" # sampling from memory\n",
" state, next_state, action, reward, done = self.recall()\n",
"\n",
" # Get TD Estimate\n",
" td_est = self.td_estimate(state, action)\n",
"\n",
" # Get TD Target\n",
" td_tgt = self.td_target(reward, next_state, done)\n",
"\n",
" # Backpropagate loss through Q_online\n",
" loss = self.update_Q_online(td_est, td_tgt)\n",
"\n",
" return (td_est.mean().item(), loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YZmrwigusQMN"
},
"source": [
"LOAD_MODEL
が False
の場合は、新しくモデルを生成して Training (学習)を行う。\n",
"\n",
"LOAD_MODEL
が True
の場合は、保存されているモデルを読み込んで、追加 Training (学習)を行う。\n",
"\n",
"\n",
"\n",
"元の Web ページによれば、\n",
"マリオが自分の世界でのやり方を本当に学ぶには、少なくとも 40,000エピソードを繰り返す必要があるとのこと。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 274
},
"executionInfo": {
"elapsed": 291499,
"status": "ok",
"timestamp": 1711991482587,
"user": {
"displayName": "Yoshihisa Nitta",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "ZpzvPXqYH4B2",
"outputId": "8bcebabb-3bf6-4bb0-9b59-9fd20767581a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using CUDA: True\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:227: DeprecationWarning: \u001b[33mWARN: Core environment is written in old step API which returns one bool instead of two. It is recommended to rewrite the environment with new step API. \u001b[0m\n",
" logger.deprecation(\n",
"/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)\n",
" if not isinstance(done, (bool, np.bool8)):\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode 20 - Step 8304027 - Epsilon 0.1257214844542129 - Mean Reward 1220.5 - Mean Length 201.35 - Mean Loss 0.0 - Mean Q Value 0.0 - Time Delta 51.369 - Time 2024-04-01T17:07:17\n",
"Episode 40 - Step 8308032 - Epsilon 0.1255956687991323 - Mean Reward 1338.5 - Mean Length 200.8 - Mean Loss 0.0 - Mean Q Value 0.0 - Time Delta 50.368 - Time 2024-04-01T17:08:08\n",
"MarioNet saved to /content/drive/MyDrive/PyTorch/ReinforcementLearning/checkpoints/mario_net_8310000.chkpt at step 8310000\n",
"Episode 60 - Step 8311633 - Epsilon 0.12548265216339902 - Mean Reward 1335.05 - Mean Length 193.883 - Mean Loss 0.696 - Mean Q Value 14.699 - Time Delta 50.085 - Time 2024-04-01T17:08:58\n",
"Episode 80 - Step 8316233 - Epsilon 0.12533843003897904 - Mean Reward 1413.488 - Mean Length 202.912 - Mean Loss 1.287 - Mean Q Value 31.291 - Time Delta 70.559 - Time 2024-04-01T17:10:08\n",
"MarioNet saved to /content/drive/MyDrive/PyTorch/ReinforcementLearning/checkpoints/mario_net_8320000.chkpt at step 8320000\n",
"Episode 100 - Step 8321006 - Epsilon 0.12518895913444447 - Mean Reward 1425.01 - Mean Length 210.06 - Mean Loss 1.611 - Mean Q Value 41.079 - Time Delta 73.172 - Time 2024-04-01T17:11:22\n"
]
}
],
"source": [
"from pathlib import Path\n",
"\n",
"use_cuda = torch.cuda.is_available()\n",
"print(f\"Using CUDA: {use_cuda}\\n\")\n",
"\n",
"save_dir = Path(SAVE_PREFIX + \"/checkpoints\")\n",
"#os.makedirs(save_dir, exist_ok=True)\n",
"\n",
"mario = Mario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)\n",
"logger = MetricLogger(save_dir / datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\"))\n",
"\n",
"### Load the pre-trained Model of Mario\n",
"if LOAD_MODEL:\n",
" load_path = Path(SAVE_PREFIX + \"/checkpoints\") / \"mario_net.chkpt\"\n",
" mario.load(load_path)\n",
"###\n",
"\n",
"mario.save_every = 10000 # 20000\n",
"\n",
"episodes = 100 # 2000\n",
"for e in range(episodes):\n",
" state = env.reset()\n",
" # Play the game\n",
" while True:\n",
" # 状態の中でエージェントを動かす\n",
" action = mario.act(state)\n",
" # エージェントは行動する\n",
" next_state, reward, done, trunc, info = env.step(action)\n",
" # 記憶する\n",
" mario.cache(state, next_state, action, reward, done)\n",
" # 訓練する\n",
" q, loss = mario.learn()\n",
"\n",
" # ログに記録する\n",
" logger.log_step(reward, loss, q)\n",
"\n",
" # 状態を更新する\n",
" state = next_state\n",
"\n",
" if done or info[\"flag_get\"] or trunc:\n",
" break\n",
"\n",
" logger.log_episode()\n",
"\n",
" if ((e+1) %20 == 0) or ((e+1) == episodes):\n",
" logger.record(episode=(e+1), epsilon=mario.exploration_rate, step=mario.curr_step)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HEDxo373R8XG"
},
"source": [
"env
から返される観測 observation
は LazyFrames
クラスのインスタンスであり、エージェント mario
が見ているゲーム画面を表す。\n",
"\n",
"observation = env.reset()
observation, reward, done, trunc, info = env.step(action)
LazyFrames
クラスのインスタンスは\n",
"\\_\\_array\\_\\_()
で配列に変換できる。\n",
"上書きされないように copy()
しておくと安全である。\n",
"変換された配列の形式は\n",
"(C, H, W) = (4, 84, 84)\n",
"であり、4枚のグレースケール画像 ($84 \\times 84$ )である。\n",
"C(=4) は時系列順にまとめてグループ化したものであり、ある画面に至る前の数ステップ分の画面があれば、画面内のキャラクタの動きを判定できる。\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 92
},
"executionInfo": {
"elapsed": 871,
"status": "ok",
"timestamp": 1711991483443,
"user": {
"displayName": "Yoshihisa Nitta",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "nwqM6NRagbYC",
"outputId": "5d04204f-2a3d-495d-f9a9-751b8f38ea2e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SkipFrame
クラスの step()
関数は $skip (=4)$ 枚ごとに画像を処理しているが、元の画像をリストとして保存するコードを追加した。\n",
"また、 保存した元のゲーム画面を返す get_observations()
関数を追加した。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
},
"executionInfo": {
"elapsed": 6011,
"status": "ok",
"timestamp": 1711991548925,
"user": {
"displayName": "Yoshihisa Nitta",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "Z2cb6kFaBA4S",
"outputId": "7ae0c644-ac6d-4832-df07-9c52f3f22cd7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1579\n",
"(240, 256, 3)\n"
]
}
],
"source": [
"bak_exploration_rate = mario.exploration_rate\n",
"mario.exploration_rate = 0.0\n",
"\n",
"frames = []\n",
"\n",
"observation = env.reset()\n",
"c,h,w = observation.__array__().shape\n",
"\n",
"for step in range(4000):\n",
" action = mario.act(observation)\n",
" observation, reward, done, trunc, info = env.step(action)\n",
"\n",
" for color_obs in env.get_observations():\n",
" frames.append(color_obs.copy())\n",
"\n",
" if done or trunc or info[\"flag_get\"]:\n",
" break;\n",
"\n",
"print(len(frames))\n",
"print(frames[0].shape)\n",
"\n",
"mario.exploration_rate = bak_exploration_rate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 262
},
"executionInfo": {
"elapsed": 32521,
"status": "ok",
"timestamp": 1711991585389,
"user": {
"displayName": "Yoshihisa Nitta",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "AMKVZjWYtaDM",
"outputId": "cf846b28-3239-46af-a353-7049c2520548"
},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
"