{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ND3MyDtqk4hX"
},
"source": [
"Updated 19/Nov/2021 by Yoshihisa Nitta \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fQVCsdk0mByf"
},
"source": [
"# AutoEncoder Training for MNIST dataset with Tensorflow 2 on Google Colab\n",
"## MNISTデータセットに対して AutoEncoder を Google Colab 上で Tensorflow 2 で訓練する"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hHAN1RoasvZb"
},
"outputs": [],
"source": [
"#! pip install tensorflow==2.7.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 2265,
"status": "ok",
"timestamp": 1637562083627,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "MKpI5MGclKv9",
"outputId": "98d2c9ca-77aa-480e-c785-2cc212eeb3f4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.7.0\n"
]
}
],
"source": [
"%tensorflow_version 2.x\n",
"\n",
"import tensorflow as tf\n",
"print(tf.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d9Eo7tUQVTGj"
},
"source": [
"# AutoEncoder\n",
"\n",
"
\n", "「難しい事後分布を持つ連続潜在変数が存在し、さらにデータセットが大きい場合に、有効確率モデルを用いて効果的な推論と学習をどうやって行えばよいのだろうか」という問題がある。\n", "この論文では、ある穏やかな(mild)微分可能条件下では適用できる、確率的変分推論と学習アルゴリズムを紹介する。\n", "貢献する点は2点である。\n", "「変分下限の再パラメータ化により、普通のSGDを用いてtraining可能な下限推定量が得られる」\n", "「提案する下限推定器を用いて近似推論モデルを学習することによって、データポイント毎の連続潜在変数を持つ i.i.d データセットに対して事後推定が効率的に実行できる。」\n", "
\n", "gdown
from Google Drive. Download from nw.tsuda.ac.jp above only if the specifications of Google Drive change and you cannot download from Google Drive.\n",
"\n",
"## Google Drive または nw.tsuda.ac.jp からファイルをダウンロードする\n",
"\n",
"基本的に Google Drive から gdown
してください。 Google Drive の仕様が変わってダウンロードができない場合にのみ、nw.tsuda.ac.jp からダウンロードしてください。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 2587,
"status": "ok",
"timestamp": 1637562115282,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "E_pTWqexKLEK",
"outputId": "9928a675-7968-46a3-a542-551bc363435a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading...\n",
"From: https://drive.google.com/uc?id=1ZDgWE7wmVwG_ZuQVUjuh_XHeIO-7Yn63\n",
"To: /content/nw/AutoEncoder.py\n",
"\r",
" 0% 0.00/13.9k [00:00, ?B/s]\r",
"100% 13.9k/13.9k [00:00<00:00, 21.5MB/s]\n"
]
}
],
"source": [
"# Download source file\n",
"nw_path = './nw'\n",
"! rm -rf {nw_path}\n",
"! mkdir -p {nw_path}\n",
"\n",
"if True: # from Google Drive\n",
" url_model = 'https://drive.google.com/uc?id=1ZDgWE7wmVwG_ZuQVUjuh_XHeIO-7Yn63'\n",
" ! (cd {nw_path}; gdown {url_model})\n",
"else: # from nw.tsuda.ac.jp\n",
" URL_NW = 'https://nw.tsuda.ac.jp/lec/GoogleColab/pub'\n",
" url_model = f'{URL_NW}/models/AutoEncoder.py'\n",
" ! wget -nd {url_model} -P {nw_path} # download to './nw/AutoEncoder.py'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 413,
"status": "ok",
"timestamp": 1637562115693,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "r4iGaqqgbx9s",
"outputId": "dfbe9b22-bc54-4e98-cc44-21e6b4114169"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import os\n",
"import pickle\n",
"import datetime\n",
"\n",
"class AutoEncoder():\n",
" def __init__(self, \n",
" input_dim,\n",
" encoder_conv_filters,\n",
" encoder_conv_kernel_size,\n",
" encoder_conv_strides,\n",
" decoder_conv_t_filters,\n",
" decoder_conv_t_kernel_size,\n",
" decoder_conv_t_strides,\n",
" z_dim,\n",
" use_batch_norm = False,\n",
" use_dropout = False,\n",
" epoch = 0\n",
" ):\n",
" self.name = 'autoencoder'\n",
" self.input_dim = input_dim\n",
" self.encoder_conv_filters = encoder_conv_filters\n",
" self.encoder_conv_kernel_size = encoder_conv_kernel_size\n",
" self.encoder_conv_strides = encoder_conv_strides\n",
" self.decoder_conv_t_filters = decoder_conv_t_filters\n",
" self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size\n",
" self.decoder_conv_t_strides = decoder_conv_t_strides\n",
" self.z_dim = z_dim\n",
" \n",
" self.use_batch_norm = use_batch_norm\n",
" self.use_dropout = use_dropout\n",
"\n",
" self.epoch = epoch\n",
" \n",
" self.n_layers_encoder = len(encoder_conv_filters)\n",
" self.n_layers_decoder = len(decoder_conv_t_filters)\n",
" \n",
" self._build()\n",
" \n",
"\n",
" def _build(self):\n",
" ### THE ENCODER\n",
" encoder_input = tf.keras.layers.Input(shape=self.input_dim, name='encoder_input')\n",
" x = encoder_input\n",
" \n",
" for i in range(self.n_layers_encoder):\n",
" x = tf.keras.layers.Conv2D(\n",
" filters = self.encoder_conv_filters[i],\n",
" kernel_size = self.encoder_conv_kernel_size[i],\n",
" strides = self.encoder_conv_strides[i],\n",
" padding = 'same',\n",
" name = 'encoder_conv_' + str(i)\n",
" )(x)\n",
" x = tf.keras.layers.LeakyReLU()(x)\n",
" if self.use_batch_norm:\n",
" x = tf.keras.layers.BatchNormalization()(x)\n",
" if self.use_dropout:\n",
" x = tf.keras.layers.Dropout(rate = 0.25)(x)\n",
" \n",
" shape_before_flattening = tf.keras.backend.int_shape(x)[1:] # shape for 1 data\n",
" \n",
" x = tf.keras.layers.Flatten()(x)\n",
" encoder_output = tf.keras.layers.Dense(self.z_dim, name='encoder_output')(x)\n",
" \n",
" self.encoder = tf.keras.models.Model(encoder_input, encoder_output)\n",
" \n",
" ### THE DECODER\n",
" decoder_input = tf.keras.layers.Input(shape=(self.z_dim,), name='decoder_input')\n",
" x = tf.keras.layers.Dense(np.prod(shape_before_flattening))(decoder_input)\n",
" x = tf.keras.layers.Reshape(shape_before_flattening)(x)\n",
" \n",
" for i in range(self.n_layers_decoder):\n",
" x = tf.keras.layers.Conv2DTranspose(\n",
" filters = self.decoder_conv_t_filters[i],\n",
" kernel_size = self.decoder_conv_t_kernel_size[i],\n",
" strides = self.decoder_conv_t_strides[i],\n",
" padding = 'same',\n",
" name = 'decoder_conv_t_' + str(i)\n",
" )(x)\n",
" \n",
" if i < self.n_layers_decoder - 1:\n",
" x = tf.keras.layers.LeakyReLU()(x)\n",
" if self.use_batch_norm:\n",
" x = tf.keras.layers.BatchNormalization()(x)\n",
" if self.use_dropout:\n",
" x = tf.keras.layers.Dropout(rate=0.25)(x)\n",
" else:\n",
" x = tf.keras.layers.Activation('sigmoid')(x)\n",
" \n",
" decoder_output = x\n",
" self.decoder = tf.keras.models.Model(decoder_input, decoder_output)\n",
" \n",
" ### THE FULL AUTOENCODER\n",
" model_input = encoder_input\n",
" model_output = self.decoder(encoder_output)\n",
" \n",
" self.model = tf.keras.models.Model(model_input, model_output)\n",
"\n",
"\n",
" def save(self, folder):\n",
" self.save_params(os.path.join(folder, 'params.pkl'))\n",
" self.save_weights(os.path.join(folder, 'weights/weights.h5'))\n",
"\n",
"\n",
" @staticmethod\n",
" def load(folder, epoch=None): # AutoEncoder.load(folder)\n",
" params = AutoEncoder.load_params(os.path.join(folder, 'params.pkl'))\n",
" AE = AutoEncoder(*params)\n",
" if epoch is None:\n",
" AE.model.load_weights(os.path.join(folder, 'weights/weights.h5'))\n",
" else:\n",
" AE.model.load_weights(os.path.join(folder, f'weights/weights_{epoch-1}.h5'))\n",
" AE.epoch = epoch\n",
"\n",
" return AE\n",
"\n",
"\n",
" def save_params(self, filepath):\n",
" dpath, fname = os.path.split(filepath)\n",
" if dpath != '' and not os.path.exists(dpath):\n",
" os.makedirs(dpath)\n",
" with open(filepath, 'wb') as f:\n",
" pickle.dump([\n",
" self.input_dim,\n",
" self.encoder_conv_filters,\n",
" self.encoder_conv_kernel_size,\n",
" self.encoder_conv_strides,\n",
" self.decoder_conv_t_filters,\n",
" self.decoder_conv_t_kernel_size,\n",
" self.decoder_conv_t_strides,\n",
" self.z_dim,\n",
" self.use_batch_norm,\n",
" self.use_dropout,\n",
" self.epoch\n",
" ], f)\n",
"\n",
"\n",
" @staticmethod\n",
" def load_params(filepath):\n",
" with open(filepath, 'rb') as f:\n",
" params = pickle.load(f)\n",
" return params\n",
"\n",
"\n",
" def save_weights(self, filepath):\n",
" dpath, fname = os.path.split(filepath)\n",
" if dpath != '' and not os.path.exists(dpath):\n",
" os.makedirs(dpath)\n",
" self.model.save_weights(filepath)\n",
" \n",
" \n",
" def load_weights(self, filepath):\n",
" self.model.load_weights(filepath)\n",
"\n",
"\n",
" def save_images(self, imgs, filepath):\n",
" z_points = self.encoder.predict(imgs)\n",
" reconst_imgs = self.decoder.predict(z_points)\n",
" txts = [ f'{p[0]:.3f}, {p[1]:.3f}' for p in z_points ]\n",
" AutoEncoder.showImages(imgs, reconst_imgs, txts, 1.4, 1.4, 0.5, filepath)\n",
" \n",
"\n",
" @staticmethod\n",
" def r_loss(y_true, y_pred):\n",
" return tf.keras.backend.mean(tf.keras.backend.square(y_true - y_pred), axis=[1,2,3])\n",
"\n",
"\n",
" def compile(self, learning_rate):\n",
" self.learning_rate = learning_rate\n",
" optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
" self.model.compile(optimizer=optimizer, loss = AutoEncoder.r_loss)\n",
"\n",
" \n",
" def train_with_fit(self,\n",
" x_train,\n",
" y_train,\n",
" batch_size,\n",
" epochs,\n",
" run_folder='run/',\n",
" validation_data=None\n",
" ):\n",
" history= self.model.fit(\n",
" x_train,\n",
" y_train,\n",
" batch_size = batch_size,\n",
" shuffle = True,\n",
" initial_epoch = self.epoch,\n",
" epochs = epochs,\n",
" validation_data = validation_data\n",
" )\n",
" if self.epoch < epochs:\n",
" self.epoch = epochs\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(os.path.join(run_folder,f'weights/weights_{self.epoch-1}.h5'))\n",
" #idxs = np.random.choice(len(x_train), 10)\n",
" #self.save_images(x_train[idxs], os.path.join(run_folder, f'images/image_{self.epoch-1}.png'))\n",
"\n",
" return history\n",
" \n",
" \n",
" def train(self,\n",
" x_train,\n",
" y_train,\n",
" batch_size = 32,\n",
" epochs = 10,\n",
" shuffle=False,\n",
" run_folder='run/',\n",
" optimizer=None,\n",
" save_epoch_interval=100,\n",
" validation_data = None\n",
" ):\n",
" start_time = datetime.datetime.now()\n",
" steps = x_train.shape[0] // batch_size\n",
"\n",
" losses = []\n",
" val_losses = []\n",
"\n",
" for epoch in range(self.epoch, epochs):\n",
" epoch_loss = 0\n",
" indices = tf.range(x_train.shape[0], dtype=tf.int32)\n",
" if shuffle:\n",
" indices = tf.random.shuffle(indices)\n",
" x_ = x_train[indices]\n",
" y_ = y_train[indices]\n",
" \n",
" for step in range(steps):\n",
" start = batch_size * step\n",
" end = start + batch_size\n",
"\n",
" with tf.GradientTape() as tape:\n",
" outputs = self.model(x_[start:end])\n",
" tmp_loss = AutoEncoder.r_loss(y_[start:end], outputs)\n",
"\n",
" grads = tape.gradient(tmp_loss, self.model.trainable_variables)\n",
" optimizer.apply_gradients(zip(grads, self.model.trainable_variables))\n",
"\n",
" epoch_loss = np.mean(tmp_loss)\n",
" losses.append(epoch_loss)\n",
"\n",
" val_str = ''\n",
" if validation_data != None:\n",
" x_val, y_val = validation_data\n",
" outputs_val = self.model(x_val)\n",
" val_loss = np.mean(AutoEncoder.r_loss(y_val, outputs_val))\n",
" val_str = f'val loss: {val_loss:.4f} '\n",
" val_losses.append(val_loss)\n",
"\n",
"\n",
" if (epoch+1) % save_epoch_interval == 0 and run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(os.path.join(run_folder,f'weights/weights_{self.epoch}.h5'))\n",
" #idxs = np.random.choice(len(x_train), 10)\n",
" #self.save_images(x_train[idxs], os.path.join(run_folder, f'images/image_{self.epoch}.png'))\n",
"\n",
" elapsed_time = datetime.datetime.now() - start_time\n",
" print(f'{epoch+1}/{epochs} {steps} loss: {epoch_loss:.4f} {val_str}{elapsed_time}')\n",
"\n",
" self.epoch += 1\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(os.path.join(run_folder,f'weights/weights_{self.epoch-1}.h5'))\n",
" #idxs = np.random.choice(len(x_train), 10)\n",
" #self.save_images(x_train[idxs], os.path.join(run_folder, f'images/image_{self.epoch-1}.png'))\n",
"\n",
" return losses, val_losses\n",
"\n",
" @staticmethod\n",
" @tf.function\n",
" def compute_loss_and_grads(model,x,y):\n",
" with tf.GradientTape() as tape:\n",
" outputs = model(x)\n",
" tmp_loss = AutoEncoder.r_loss(y,outputs)\n",
" grads = tape.gradient(tmp_loss, model.trainable_variables)\n",
" return tmp_loss, grads\n",
"\n",
"\n",
" def train_tf(self,\n",
" x_train,\n",
" y_train,\n",
" batch_size = 32,\n",
" epochs = 10,\n",
" shuffle=False,\n",
" run_folder='run/',\n",
" optimizer=None,\n",
" save_epoch_interval=100,\n",
" validation_data = None\n",
" ):\n",
" start_time = datetime.datetime.now()\n",
" steps = x_train.shape[0] // batch_size\n",
"\n",
" losses = []\n",
" val_losses = []\n",
"\n",
" for epoch in range(self.epoch, epochs):\n",
" epoch_loss = 0\n",
" indices = tf.range(x_train.shape[0], dtype=tf.int32)\n",
" if shuffle:\n",
" indices = tf.random.shuffle(indices)\n",
" x_ = x_train[indices]\n",
" y_ = y_train[indices]\n",
"\n",
" step_losses = []\n",
" for step in range(steps):\n",
" start = batch_size * step\n",
" end = start + batch_size\n",
"\n",
" tmp_loss, grads = AutoEncoder.compute_loss_and_grads(self.model, x_[start:end], y_[start:end])\n",
" optimizer.apply_gradients(zip(grads, self.model.trainable_variables))\n",
"\n",
" step_losses.append(np.mean(tmp_loss))\n",
"\n",
" epoch_loss = np.mean(step_losses)\n",
" losses.append(epoch_loss)\n",
"\n",
" val_str = ''\n",
" if validation_data != None:\n",
" x_val, y_val = validation_data\n",
" outputs_val = self.model(x_val)\n",
" val_loss = np.mean(AutoEncoder.r_loss(y_val, outputs_val))\n",
" val_str = f'val loss: {val_loss:.4f} '\n",
" val_losses.append(val_loss)\n",
"\n",
"\n",
" if (epoch+1) % save_epoch_interval == 0 and run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(os.path.join(run_folder,f'weights/weights_{self.epoch}.h5'))\n",
" #idxs = np.random.choice(len(x_train), 10)\n",
" #self.save_images(x_train[idxs], os.path.join(run_folder, f'images/image_{self.epoch}.png'))\n",
"\n",
" elapsed_time = datetime.datetime.now() - start_time\n",
" print(f'{epoch+1}/{epochs} {steps} loss: {epoch_loss:.4f} {val_str}{elapsed_time}')\n",
"\n",
" self.epoch += 1\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(os.path.join(run_folder,f'weights/weights_{self.epoch-1}.h5'))\n",
" #idxs = np.random.choice(len(x_train), 10)\n",
" #self.save_images(x_train[idxs], os.path.join(run_folder, f'images/image_{self.epoch-1}.png'))\n",
"\n",
" return losses, val_losses\n",
"\n",
"\n",
" @staticmethod\n",
" def showImages(imgs1, imgs2, txts, w, h, vskip=0.5, filepath=None):\n",
" n = len(imgs1)\n",
" fig, ax = plt.subplots(2, n, figsize=(w * n, (2+vskip) * h))\n",
" for i in range(n):\n",
" if n == 1:\n",
" axis = ax[0]\n",
" else:\n",
" axis = ax[0][i]\n",
" img = imgs1[i].squeeze()\n",
" axis.imshow(img, cmap='gray_r')\n",
" axis.axis('off')\n",
"\n",
" axis.text(0.5, -0.35, txts[i], fontsize=10, ha='center', transform=axis.transAxes)\n",
"\n",
" if n == 1:\n",
" axis = ax[1]\n",
" else:\n",
" axis = ax[1][i]\n",
" img2 = imgs2[i].squeeze()\n",
" axis.imshow(img2, cmap='gray_r')\n",
" axis.axis('off')\n",
"\n",
" if not filepath is None:\n",
" dpath, fname = os.path.split(filepath)\n",
" if dpath != '' and not os.path.exists(dpath):\n",
" os.makedirs(dpath)\n",
" fig.savefig(filepath, dpi=600)\n",
" plt.close()\n",
" else:\n",
" plt.show()\n",
"\n",
" @staticmethod\n",
" def plot_history(vals, labels):\n",
" colors = ['red', 'blue', 'green', 'orange', 'black']\n",
" n = len(vals)\n",
" fig, ax = plt.subplots(1, 1, figsize=(9,4))\n",
" for i in range(n):\n",
" ax.plot(vals[i], c=colors[i], label=labels[i])\n",
" ax.legend(loc='upper right')\n",
" ax.set_xlabel('epochs')\n",
" # ax[0].set_ylabel('loss')\n",
" \n",
" plt.show()\n"
]
}
],
"source": [
"!cat {nw_path}/AutoEncoder.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WUuOKXf9wIrn"
},
"source": [
"# Preparing the MNIST datasets\n",
"## MNIST データセットを用意する"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ERKfpLv7UPvU"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 259,
"status": "ok",
"timestamp": 1637562115945,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "2POVZ4obViuQ",
"outputId": "611f5642-e101-4a9f-d2c5-6509cdd66993"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
"11493376/11490434 [==============================] - 0s 0us/step\n",
"11501568/11490434 [==============================] - 0s 0us/step\n",
"(60000, 28, 28)\n",
"(60000,)\n",
"(10000, 28, 28)\n",
"(10000,)\n"
]
}
],
"source": [
"# MNIST datasets\n",
"(x_train_raw, y_train_raw), (x_test_raw, y_test_raw) = tf.keras.datasets.mnist.load_data()\n",
"print(x_train_raw.shape)\n",
"print(y_train_raw.shape)\n",
"print(x_test_raw.shape)\n",
"print(y_test_raw.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 4,
"status": "ok",
"timestamp": 1637562115945,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "uI_CQRWMvxHB",
"outputId": "09da4335-e9ff-44b8-824d-3a2feaa36100"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(60000, 28, 28, 1)\n",
"(10000, 28, 28, 1)\n"
]
}
],
"source": [
"x_train = x_train_raw.reshape(x_train_raw.shape+(1,)).astype('float32') / 255.0\n",
"x_test = x_test_raw.reshape(x_test_raw.shape+(1,)).astype('float32') / 255.0\n",
"print(x_train.shape)\n",
"print(x_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "86P5y0nqzvBU"
},
"source": [
"# Define the Neural Network Model\n",
"\n",
"Use the AutoEncoder
class downloaded from nw.tsuda.ac.jp.\n",
"\n",
"## ニューラルネットワーク・モデル の定義\n",
"\n",
"nw.tsuda.ac.jp からダウンロードした AutoEncoder
クラスを使う。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tLCR0nt-zGOE"
},
"outputs": [],
"source": [
"from nw.AutoEncoder import AutoEncoder\n",
"\n",
"AE = AutoEncoder(\n",
" input_dim = (28, 28, 1),\n",
" encoder_conv_filters = [32, 64, 64, 64],\n",
" encoder_conv_kernel_size = [3, 3, 3, 3],\n",
" encoder_conv_strides = [1, 2, 2, 1],\n",
" decoder_conv_t_filters = [64, 64, 32, 1],\n",
" decoder_conv_t_kernel_size = [3, 3, 3, 3],\n",
" decoder_conv_t_strides = [1, 2, 2, 1],\n",
" z_dim = 2\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 14,
"status": "ok",
"timestamp": 1637562119696,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "jhSugt50koOt",
"outputId": "8e071c5c-eec9-4aa6-d02e-95321bf540d0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" encoder_input (InputLayer) [(None, 28, 28, 1)] 0 \n",
" \n",
" encoder_conv_0 (Conv2D) (None, 28, 28, 32) 320 \n",
" \n",
" leaky_re_lu (LeakyReLU) (None, 28, 28, 32) 0 \n",
" \n",
" encoder_conv_1 (Conv2D) (None, 14, 14, 64) 18496 \n",
" \n",
" leaky_re_lu_1 (LeakyReLU) (None, 14, 14, 64) 0 \n",
" \n",
" encoder_conv_2 (Conv2D) (None, 7, 7, 64) 36928 \n",
" \n",
" leaky_re_lu_2 (LeakyReLU) (None, 7, 7, 64) 0 \n",
" \n",
" encoder_conv_3 (Conv2D) (None, 7, 7, 64) 36928 \n",
" \n",
" leaky_re_lu_3 (LeakyReLU) (None, 7, 7, 64) 0 \n",
" \n",
" flatten (Flatten) (None, 3136) 0 \n",
" \n",
" encoder_output (Dense) (None, 2) 6274 \n",
" \n",
"=================================================================\n",
"Total params: 98,946\n",
"Trainable params: 98,946\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"AE.encoder.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 12,
"status": "ok",
"timestamp": 1637562119697,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "zG6ifGO2SjTD",
"outputId": "520660c4-5149-4070-be4e-82374fed726f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model_1\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" decoder_input (InputLayer) [(None, 2)] 0 \n",
" \n",
" dense (Dense) (None, 3136) 9408 \n",
" \n",
" reshape (Reshape) (None, 7, 7, 64) 0 \n",
" \n",
" decoder_conv_t_0 (Conv2DTra (None, 7, 7, 64) 36928 \n",
" nspose) \n",
" \n",
" leaky_re_lu_4 (LeakyReLU) (None, 7, 7, 64) 0 \n",
" \n",
" decoder_conv_t_1 (Conv2DTra (None, 14, 14, 64) 36928 \n",
" nspose) \n",
" \n",
" leaky_re_lu_5 (LeakyReLU) (None, 14, 14, 64) 0 \n",
" \n",
" decoder_conv_t_2 (Conv2DTra (None, 28, 28, 32) 18464 \n",
" nspose) \n",
" \n",
" leaky_re_lu_6 (LeakyReLU) (None, 28, 28, 32) 0 \n",
" \n",
" decoder_conv_t_3 (Conv2DTra (None, 28, 28, 1) 289 \n",
" nspose) \n",
" \n",
" activation (Activation) (None, 28, 28, 1) 0 \n",
" \n",
"=================================================================\n",
"Total params: 102,017\n",
"Trainable params: 102,017\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"AE.decoder.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z3DuY-9Z9Yf-"
},
"source": [
"# Training the Neural Model\n",
"\n",
"\n",
"Try the training in 3 ways.\n",
"\n",
"\n",
"\n",
"With each way, you first train a few times and save the state to some files.\n",
"Then, after loading the saved states, further training proceeds.\n",
"\n",
"## ニューラルモデルを学習する\n",
"\n",
"3通りの方法で学習を試みる。\n",
"どの方法においても、まず数回学習を進めて、状態をファイルに保存する。\n",
"そして、保存した状態をロードしてから、さらに学習を進める。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i4olTqB1xt0m"
},
"outputs": [],
"source": [
"MAX_EPOCHS = 200"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nA-X-4s1Z7q6"
},
"outputs": [],
"source": [
"learning_rate = 0.0005"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P7ae_ydlhZPn"
},
"source": [
"# (1) Simple Training with fit()\n",
"\n",
"\n",
"Instead of using callbacks, simply train using fit() function.\n",
"\n",
"## (1) fit() 関数を使った単純なTraining\n",
"\n",
"callbackは使わずに、単純にfit()を使ってtrainingしてみる。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uS1trkRkvTLz"
},
"outputs": [],
"source": [
"save_path1 = '/content/drive/MyDrive/ColabRun/AE01'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FSy_oa-qiFTV"
},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
"AE.model.compile(optimizer=optimizer, loss=AutoEncoder.r_loss)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 84826,
"status": "ok",
"timestamp": 1637562204883,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "d14V6JZUvOhs",
"outputId": "89346613-692f-4c64-97f2-5ef7642701a7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"1875/1875 [==============================] - 27s 6ms/step - loss: 0.0550 - val_loss: 0.0487\n",
"Epoch 2/3\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0464 - val_loss: 0.0449\n",
"Epoch 3/3\n",
"1875/1875 [==============================] - 10s 6ms/step - loss: 0.0444 - val_loss: 0.0439\n"
]
}
],
"source": [
"# At first, train for a few epochs.\n",
"# まず、少ない回数 training してみる\n",
"\n",
"history=AE.train_with_fit(\n",
" x_train,\n",
" x_train,\n",
" batch_size=32,\n",
" epochs = 3,\n",
" run_folder = save_path1,\n",
" validation_data = (x_test, x_test)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 6,
"status": "ok",
"timestamp": 1637562204884,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "49j-U8xx9kzZ",
"outputId": "932b4ab5-abbd-4f22-c980-a6b4e8d32930"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'loss': [0.05495461821556091, 0.0464060977101326, 0.04438251256942749], 'val_loss': [0.04873434454202652, 0.04490825906395912, 0.043926868587732315]}\n"
]
}
],
"source": [
"print(history.history)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 3,
"status": "ok",
"timestamp": 1637562204884,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "K0Msc4koyR-J",
"outputId": "b6a066ea-9e18-4f9c-b519-964469f196f5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
}
],
"source": [
"# Load the trained states saved before\n",
"# 保存されている学習結果をロードする\n",
"\n",
"AE_work = AutoEncoder.load(save_path1)\n",
"\n",
"# display the epoch count of training\n",
"# training のepoch回数を表示する\n",
"print(AE_work.epoch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 1972373,
"status": "ok",
"timestamp": 1637564177734,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "fiPR6yjwi2_1",
"outputId": "a2f8a1f0-8b7d-4a28-8202-2cb9be708e95"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/200\n",
"1875/1875 [==============================] - 11s 5ms/step - loss: 0.0439 - val_loss: 0.0428\n",
"Epoch 5/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0422 - val_loss: 0.0421\n",
"Epoch 6/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0417 - val_loss: 0.0414\n",
"Epoch 7/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0412 - val_loss: 0.0412\n",
"Epoch 8/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0409 - val_loss: 0.0409\n",
"Epoch 9/200\n",
"1875/1875 [==============================] - 10s 6ms/step - loss: 0.0406 - val_loss: 0.0411\n",
"Epoch 10/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0404 - val_loss: 0.0405\n",
"Epoch 11/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0402 - val_loss: 0.0403\n",
"Epoch 12/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0401 - val_loss: 0.0402\n",
"Epoch 13/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0399 - val_loss: 0.0402\n",
"Epoch 14/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0398 - val_loss: 0.0399\n",
"Epoch 15/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0396 - val_loss: 0.0401\n",
"Epoch 16/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0395 - val_loss: 0.0406\n",
"Epoch 17/200\n",
"1875/1875 [==============================] - 10s 6ms/step - loss: 0.0394 - val_loss: 0.0399\n",
"Epoch 18/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0393 - val_loss: 0.0400\n",
"Epoch 19/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0392 - val_loss: 0.0395\n",
"Epoch 20/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0391 - val_loss: 0.0393\n",
"Epoch 21/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0390 - val_loss: 0.0393\n",
"Epoch 22/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0390 - val_loss: 0.0397\n",
"Epoch 23/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0389 - val_loss: 0.0394\n",
"Epoch 24/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0388 - val_loss: 0.0395\n",
"Epoch 25/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0388 - val_loss: 0.0393\n",
"Epoch 26/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0387 - val_loss: 0.0398\n",
"Epoch 27/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0387 - val_loss: 0.0395\n",
"Epoch 28/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0386 - val_loss: 0.0395\n",
"Epoch 29/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0385 - val_loss: 0.0390\n",
"Epoch 30/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0385 - val_loss: 0.0391\n",
"Epoch 31/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0384 - val_loss: 0.0395\n",
"Epoch 32/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0384 - val_loss: 0.0391\n",
"Epoch 33/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0383 - val_loss: 0.0394\n",
"Epoch 34/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0383 - val_loss: 0.0390\n",
"Epoch 35/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0382 - val_loss: 0.0393\n",
"Epoch 36/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0382 - val_loss: 0.0391\n",
"Epoch 37/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0381 - val_loss: 0.0390\n",
"Epoch 38/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0381 - val_loss: 0.0391\n",
"Epoch 39/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0381 - val_loss: 0.0388\n",
"Epoch 40/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0380 - val_loss: 0.0392\n",
"Epoch 41/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0380 - val_loss: 0.0394\n",
"Epoch 42/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0379 - val_loss: 0.0389\n",
"Epoch 43/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0379 - val_loss: 0.0392\n",
"Epoch 44/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0379 - val_loss: 0.0393\n",
"Epoch 45/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0378 - val_loss: 0.0390\n",
"Epoch 46/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0378 - val_loss: 0.0389\n",
"Epoch 47/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0378 - val_loss: 0.0392\n",
"Epoch 48/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0377 - val_loss: 0.0390\n",
"Epoch 49/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0377 - val_loss: 0.0386\n",
"Epoch 50/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0376 - val_loss: 0.0391\n",
"Epoch 51/200\n",
"1875/1875 [==============================] - 10s 6ms/step - loss: 0.0377 - val_loss: 0.0392\n",
"Epoch 52/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0376 - val_loss: 0.0391\n",
"Epoch 53/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0376 - val_loss: 0.0385\n",
"Epoch 54/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0375 - val_loss: 0.0386\n",
"Epoch 55/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0375 - val_loss: 0.0386\n",
"Epoch 56/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0375 - val_loss: 0.0387\n",
"Epoch 57/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0375 - val_loss: 0.0385\n",
"Epoch 58/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0375 - val_loss: 0.0386\n",
"Epoch 59/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0374 - val_loss: 0.0389\n",
"Epoch 60/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0374 - val_loss: 0.0387\n",
"Epoch 61/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0374 - val_loss: 0.0387\n",
"Epoch 62/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0373 - val_loss: 0.0386\n",
"Epoch 63/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0373 - val_loss: 0.0387\n",
"Epoch 64/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0373 - val_loss: 0.0387\n",
"Epoch 65/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0373 - val_loss: 0.0384\n",
"Epoch 66/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0372 - val_loss: 0.0385\n",
"Epoch 67/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0372 - val_loss: 0.0386\n",
"Epoch 68/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0372 - val_loss: 0.0386\n",
"Epoch 69/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0372 - val_loss: 0.0389\n",
"Epoch 70/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0371 - val_loss: 0.0385\n",
"Epoch 71/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0372 - val_loss: 0.0384\n",
"Epoch 72/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0371 - val_loss: 0.0387\n",
"Epoch 73/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0371 - val_loss: 0.0386\n",
"Epoch 74/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0371 - val_loss: 0.0384\n",
"Epoch 75/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0371 - val_loss: 0.0387\n",
"Epoch 76/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0370 - val_loss: 0.0387\n",
"Epoch 77/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0370 - val_loss: 0.0383\n",
"Epoch 78/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0370 - val_loss: 0.0385\n",
"Epoch 79/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0370 - val_loss: 0.0384\n",
"Epoch 80/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0385\n",
"Epoch 81/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0384\n",
"Epoch 82/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0383\n",
"Epoch 83/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0385\n",
"Epoch 84/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0385\n",
"Epoch 85/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0386\n",
"Epoch 86/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0386\n",
"Epoch 87/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0384\n",
"Epoch 88/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0383\n",
"Epoch 89/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0385\n",
"Epoch 90/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0384\n",
"Epoch 91/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0384\n",
"Epoch 92/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0386\n",
"Epoch 93/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0385\n",
"Epoch 94/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0382\n",
"Epoch 95/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0383\n",
"Epoch 96/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0383\n",
"Epoch 97/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0384\n",
"Epoch 98/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0383\n",
"Epoch 99/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0385\n",
"Epoch 100/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0385\n",
"Epoch 101/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0383\n",
"Epoch 102/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0384\n",
"Epoch 103/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0384\n",
"Epoch 104/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0385\n",
"Epoch 105/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0386\n",
"Epoch 106/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0382\n",
"Epoch 107/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0383\n",
"Epoch 108/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0384\n",
"Epoch 109/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0381\n",
"Epoch 110/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0385\n",
"Epoch 111/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0385\n",
"Epoch 112/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0385\n",
"Epoch 113/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0383\n",
"Epoch 114/200\n",
"1875/1875 [==============================] - 10s 6ms/step - loss: 0.0364 - val_loss: 0.0384\n",
"Epoch 115/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0381\n",
"Epoch 116/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0385\n",
"Epoch 117/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0382\n",
"Epoch 118/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0386\n",
"Epoch 119/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0383\n",
"Epoch 120/200\n",
"1875/1875 [==============================] - 10s 6ms/step - loss: 0.0364 - val_loss: 0.0383\n",
"Epoch 121/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0385\n",
"Epoch 122/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0389\n",
"Epoch 123/200\n",
"1875/1875 [==============================] - 10s 6ms/step - loss: 0.0363 - val_loss: 0.0385\n",
"Epoch 124/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0383\n",
"Epoch 125/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0385\n",
"Epoch 126/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0384\n",
"Epoch 127/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384\n",
"Epoch 128/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0384\n",
"Epoch 129/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384\n",
"Epoch 130/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0388\n",
"Epoch 131/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384\n",
"Epoch 132/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384\n",
"Epoch 133/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384\n",
"Epoch 134/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384\n",
"Epoch 135/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0386\n",
"Epoch 136/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0385\n",
"Epoch 137/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0383\n",
"Epoch 138/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0383\n",
"Epoch 139/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0383\n",
"Epoch 140/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0384\n",
"Epoch 141/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0385\n",
"Epoch 142/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0385\n",
"Epoch 143/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0386\n",
"Epoch 144/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0382\n",
"Epoch 145/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0384\n",
"Epoch 146/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0387\n",
"Epoch 147/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0384\n",
"Epoch 148/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0385\n",
"Epoch 149/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0382\n",
"Epoch 150/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0381\n",
"Epoch 151/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0385\n",
"Epoch 152/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0382\n",
"Epoch 153/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0385\n",
"Epoch 154/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0384\n",
"Epoch 155/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0384\n",
"Epoch 156/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0382\n",
"Epoch 157/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0384\n",
"Epoch 158/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0384\n",
"Epoch 159/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0383\n",
"Epoch 160/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0387\n",
"Epoch 161/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0383\n",
"Epoch 162/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0383\n",
"Epoch 163/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0384\n",
"Epoch 164/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0386\n",
"Epoch 165/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0383\n",
"Epoch 166/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0382\n",
"Epoch 167/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0384\n",
"Epoch 168/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0382\n",
"Epoch 169/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0387\n",
"Epoch 170/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0389\n",
"Epoch 171/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0381\n",
"Epoch 172/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0390\n",
"Epoch 173/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0386\n",
"Epoch 174/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0382\n",
"Epoch 175/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0382\n",
"Epoch 176/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0388\n",
"Epoch 177/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0383\n",
"Epoch 178/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0383\n",
"Epoch 179/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0382\n",
"Epoch 180/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0384\n",
"Epoch 181/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383\n",
"Epoch 182/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383\n",
"Epoch 183/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383\n",
"Epoch 184/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0385\n",
"Epoch 185/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383\n",
"Epoch 186/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0386\n",
"Epoch 187/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383\n",
"Epoch 188/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383\n",
"Epoch 189/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0386\n",
"Epoch 190/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383\n",
"Epoch 191/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383\n",
"Epoch 192/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0382\n",
"Epoch 193/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0386\n",
"Epoch 194/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0383\n",
"Epoch 195/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0385\n",
"Epoch 196/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0383\n",
"Epoch 197/200\n",
"1875/1875 [==============================] - 10s 6ms/step - loss: 0.0356 - val_loss: 0.0382\n",
"Epoch 198/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0385\n",
"Epoch 199/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0384\n",
"Epoch 200/200\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0384\n"
]
}
],
"source": [
"# Then, train for more epochs. The training continues from the current self.epoch to the epoches specified.\n",
"# 追加でtrainingする。保存されている現在のepoch数から始めて、指定したepochs までtrainingが進む。\n",
"\n",
"AE_work.model.compile(optimizer, loss=AutoEncoder.r_loss)\n",
"\n",
"history_work = AE_work.train_with_fit(\n",
" x_train,\n",
" x_train,\n",
" batch_size=32,\n",
" epochs=MAX_EPOCHS,\n",
" run_folder = save_path1,\n",
" validation_data=(x_test, x_test)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 17,
"status": "ok",
"timestamp": 1637564177735,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "134mgICN_4X3",
"outputId": "9991cf88-3cda-4027-8a54-d6e4bc2e9860"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"197\n"
]
}
],
"source": [
"# the return value contains the loss values in the additional training. \n",
"# 追加で行ったtraining時のlossが返り値に含まれる\n",
"print(len(history_work.history['loss']))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NSTlPaDhBSC6"
},
"outputs": [],
"source": [
"loss1_1 = history.history['loss']\n",
"vloss1_1 = history.history['val_loss']\n",
"\n",
"loss1_2 = history_work.history['loss']\n",
"vloss1_2 = history_work.history['val_loss']\n",
"\n",
"loss1 = np.concatenate([loss1_1, loss1_2], axis=0)\n",
"val_loss1 = np.concatenate([vloss1_1, vloss1_2], axis=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 279
},
"executionInfo": {
"elapsed": 8,
"status": "ok",
"timestamp": 1637564177736,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "0Y20iXDaCB-x",
"outputId": "bb7e7449-485a-424a-ed74-63770b1bfead"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"tf.GradientTape()
function.\n",
"\n",
"\n",
"Instead of using fit()
, calculate the loss in your own train()
function, find the gradients, and apply them to the variables.\n",
"\n",
"The train_tf()
function is speeding up by declaring @tf.function
the compute_loss_and_grads()
function.\n",
"\n",
"\n",
"## (2) tf.GradientTape()
関数を使った学習\n",
"\n",
"\n",
"fit()
関数を使わずに、自分で記述した train()
関数内で loss を計算し、gradients を求めて、変数に適用する。\n",
"\n",
"train_tf()
関数では、lossとgradientsの計算を行う compute_loss_and_grads()
関数を @tf.function
宣言することで高速化を図っている。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L-s-ylxkD9O6"
},
"outputs": [],
"source": [
"save_path2 = '/content/drive/MyDrive/ColabRun/AE02/'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YpFibxw-CVzk"
},
"outputs": [],
"source": [
"from nw.AutoEncoder import AutoEncoder\n",
"\n",
"AE2 = AutoEncoder(\n",
" input_dim = (28, 28, 1),\n",
" encoder_conv_filters = [32, 64, 64, 64],\n",
" encoder_conv_kernel_size = [3, 3, 3, 3],\n",
" encoder_conv_strides = [1, 2, 2, 1],\n",
" decoder_conv_t_filters = [64, 64, 32, 1],\n",
" decoder_conv_t_kernel_size = [3, 3, 3, 3],\n",
" decoder_conv_t_strides = [1, 2, 2, 1],\n",
" z_dim = 2\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xYyyVaOw_CeI"
},
"outputs": [],
"source": [
"optimizer2 = tf.keras.optimizers.Adam(learning_rate=learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 117923,
"status": "ok",
"timestamp": 1637564297137,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "JmYIPWvfCd6E",
"outputId": "2c1d1f30-afb3-4ee9-d75b-f0994c82a923"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/3 1875 loss: 0.0406 val loss: 0.0481 0:00:39.580740\n",
"2/3 1875 loss: 0.0391 val loss: 0.0448 0:01:18.183180\n",
"3/3 1875 loss: 0.0529 val loss: 0.0432 0:01:56.549291\n"
]
}
],
"source": [
"# At first, train for a few epochs.\n",
"# まず、少ない回数 training してみる\n",
"\n",
"loss2_1, vloss2_1 = AE2.train(\n",
" x_train,\n",
" x_train,\n",
" batch_size=32,\n",
" epochs = 3, \n",
" shuffle=True,\n",
" run_folder= save_path2,\n",
" optimizer = optimizer2,\n",
" save_epoch_interval=50,\n",
" validation_data=(x_test, x_test)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 5,
"status": "ok",
"timestamp": 1637564297137,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "ECMO7RjRDx9d",
"outputId": "d4e7ad93-5757-4f1e-f1f4-bfdedd92dd52"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
}
],
"source": [
"# Load the parameters and the weights saved before.\n",
"# 保存したパラメータと、重みを読み込む。\n",
"\n",
"AE2_work = AutoEncoder.load(save_path2)\n",
"print(AE2_work.epoch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 3066556,
"status": "ok",
"timestamp": 1637567363691,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "O6zeeIiBFEXv",
"outputId": "0e2a95ce-c41d-4a3f-8291-41b32fff9f84"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4/200 1875 loss: 0.0441 val loss: 0.0430 0:00:16.042304\n",
"5/200 1875 loss: 0.0425 val loss: 0.0423 0:00:31.745769\n",
"6/200 1875 loss: 0.0419 val loss: 0.0420 0:00:47.550041\n",
"7/200 1875 loss: 0.0415 val loss: 0.0414 0:01:03.278043\n",
"8/200 1875 loss: 0.0412 val loss: 0.0411 0:01:18.886151\n",
"9/200 1875 loss: 0.0409 val loss: 0.0408 0:01:34.403325\n",
"10/200 1875 loss: 0.0406 val loss: 0.0404 0:01:49.879346\n",
"11/200 1875 loss: 0.0404 val loss: 0.0406 0:02:05.370021\n",
"12/200 1875 loss: 0.0403 val loss: 0.0406 0:02:21.087446\n",
"13/200 1875 loss: 0.0401 val loss: 0.0401 0:02:36.769709\n",
"14/200 1875 loss: 0.0399 val loss: 0.0401 0:02:52.391754\n",
"15/200 1875 loss: 0.0398 val loss: 0.0401 0:03:07.901212\n",
"16/200 1875 loss: 0.0397 val loss: 0.0400 0:03:23.392766\n",
"17/200 1875 loss: 0.0396 val loss: 0.0397 0:03:38.888386\n",
"18/200 1875 loss: 0.0394 val loss: 0.0396 0:03:54.470497\n",
"19/200 1875 loss: 0.0393 val loss: 0.0397 0:04:10.006842\n",
"20/200 1875 loss: 0.0392 val loss: 0.0396 0:04:25.572727\n",
"21/200 1875 loss: 0.0391 val loss: 0.0395 0:04:41.166492\n",
"22/200 1875 loss: 0.0391 val loss: 0.0395 0:04:56.777322\n",
"23/200 1875 loss: 0.0390 val loss: 0.0393 0:05:12.350585\n",
"24/200 1875 loss: 0.0389 val loss: 0.0394 0:05:28.061553\n",
"25/200 1875 loss: 0.0388 val loss: 0.0400 0:05:43.522782\n",
"26/200 1875 loss: 0.0387 val loss: 0.0391 0:05:59.371244\n",
"27/200 1875 loss: 0.0387 val loss: 0.0394 0:06:15.072191\n",
"28/200 1875 loss: 0.0386 val loss: 0.0394 0:06:30.590397\n",
"29/200 1875 loss: 0.0385 val loss: 0.0389 0:06:46.193198\n",
"30/200 1875 loss: 0.0385 val loss: 0.0393 0:07:01.744139\n",
"31/200 1875 loss: 0.0385 val loss: 0.0392 0:07:17.398188\n",
"32/200 1875 loss: 0.0384 val loss: 0.0391 0:07:33.097819\n",
"33/200 1875 loss: 0.0383 val loss: 0.0388 0:07:48.744927\n",
"34/200 1875 loss: 0.0382 val loss: 0.0388 0:08:04.346553\n",
"35/200 1875 loss: 0.0382 val loss: 0.0389 0:08:19.798364\n",
"36/200 1875 loss: 0.0381 val loss: 0.0390 0:08:35.371745\n",
"37/200 1875 loss: 0.0381 val loss: 0.0386 0:08:51.082935\n",
"38/200 1875 loss: 0.0380 val loss: 0.0388 0:09:06.822892\n",
"39/200 1875 loss: 0.0380 val loss: 0.0385 0:09:22.394035\n",
"40/200 1875 loss: 0.0379 val loss: 0.0389 0:09:37.953852\n",
"41/200 1875 loss: 0.0379 val loss: 0.0387 0:09:53.514515\n",
"42/200 1875 loss: 0.0378 val loss: 0.0387 0:10:09.035459\n",
"43/200 1875 loss: 0.0378 val loss: 0.0386 0:10:24.666275\n",
"44/200 1875 loss: 0.0378 val loss: 0.0388 0:10:40.257557\n",
"45/200 1875 loss: 0.0377 val loss: 0.0385 0:10:55.745129\n",
"46/200 1875 loss: 0.0377 val loss: 0.0388 0:11:11.432226\n",
"47/200 1875 loss: 0.0376 val loss: 0.0386 0:11:27.144229\n",
"48/200 1875 loss: 0.0376 val loss: 0.0390 0:11:42.676218\n",
"49/200 1875 loss: 0.0375 val loss: 0.0388 0:11:58.487087\n",
"50/200 1875 loss: 0.0375 val loss: 0.0383 0:12:14.850839\n",
"51/200 1875 loss: 0.0375 val loss: 0.0390 0:12:30.468638\n",
"52/200 1875 loss: 0.0375 val loss: 0.0388 0:12:46.049661\n",
"53/200 1875 loss: 0.0374 val loss: 0.0384 0:13:01.636156\n",
"54/200 1875 loss: 0.0374 val loss: 0.0383 0:13:17.325480\n",
"55/200 1875 loss: 0.0374 val loss: 0.0385 0:13:32.836645\n",
"56/200 1875 loss: 0.0374 val loss: 0.0388 0:13:48.441919\n",
"57/200 1875 loss: 0.0373 val loss: 0.0384 0:14:03.917869\n",
"58/200 1875 loss: 0.0373 val loss: 0.0388 0:14:19.634660\n",
"59/200 1875 loss: 0.0372 val loss: 0.0389 0:14:35.261167\n",
"60/200 1875 loss: 0.0372 val loss: 0.0384 0:14:50.896159\n",
"61/200 1875 loss: 0.0372 val loss: 0.0390 0:15:06.445663\n",
"62/200 1875 loss: 0.0372 val loss: 0.0381 0:15:22.134292\n",
"63/200 1875 loss: 0.0372 val loss: 0.0382 0:15:37.757501\n",
"64/200 1875 loss: 0.0371 val loss: 0.0384 0:15:53.316315\n",
"65/200 1875 loss: 0.0371 val loss: 0.0382 0:16:08.820412\n",
"66/200 1875 loss: 0.0371 val loss: 0.0385 0:16:24.565601\n",
"67/200 1875 loss: 0.0370 val loss: 0.0384 0:16:40.101123\n",
"68/200 1875 loss: 0.0370 val loss: 0.0383 0:16:55.609609\n",
"69/200 1875 loss: 0.0370 val loss: 0.0382 0:17:11.264953\n",
"70/200 1875 loss: 0.0370 val loss: 0.0383 0:17:26.949355\n",
"71/200 1875 loss: 0.0370 val loss: 0.0381 0:17:42.623016\n",
"72/200 1875 loss: 0.0369 val loss: 0.0381 0:17:58.321779\n",
"73/200 1875 loss: 0.0369 val loss: 0.0382 0:18:13.832138\n",
"74/200 1875 loss: 0.0369 val loss: 0.0381 0:18:29.598127\n",
"75/200 1875 loss: 0.0369 val loss: 0.0383 0:18:45.208392\n",
"76/200 1875 loss: 0.0368 val loss: 0.0385 0:19:00.743062\n",
"77/200 1875 loss: 0.0368 val loss: 0.0381 0:19:16.186948\n",
"78/200 1875 loss: 0.0368 val loss: 0.0381 0:19:31.760451\n",
"79/200 1875 loss: 0.0368 val loss: 0.0385 0:19:47.388234\n",
"80/200 1875 loss: 0.0367 val loss: 0.0383 0:20:02.935055\n",
"81/200 1875 loss: 0.0367 val loss: 0.0385 0:20:18.402500\n",
"82/200 1875 loss: 0.0367 val loss: 0.0381 0:20:33.940910\n",
"83/200 1875 loss: 0.0367 val loss: 0.0384 0:20:49.569920\n",
"84/200 1875 loss: 0.0367 val loss: 0.0385 0:21:05.242798\n",
"85/200 1875 loss: 0.0366 val loss: 0.0382 0:21:20.880114\n",
"86/200 1875 loss: 0.0367 val loss: 0.0381 0:21:36.641503\n",
"87/200 1875 loss: 0.0366 val loss: 0.0381 0:21:52.095492\n",
"88/200 1875 loss: 0.0366 val loss: 0.0379 0:22:07.601546\n",
"89/200 1875 loss: 0.0366 val loss: 0.0381 0:22:23.401748\n",
"90/200 1875 loss: 0.0366 val loss: 0.0387 0:22:39.066528\n",
"91/200 1875 loss: 0.0366 val loss: 0.0387 0:22:54.610725\n",
"92/200 1875 loss: 0.0365 val loss: 0.0385 0:23:10.169099\n",
"93/200 1875 loss: 0.0365 val loss: 0.0385 0:23:25.674254\n",
"94/200 1875 loss: 0.0365 val loss: 0.0381 0:23:41.366783\n",
"95/200 1875 loss: 0.0365 val loss: 0.0382 0:23:56.902391\n",
"96/200 1875 loss: 0.0365 val loss: 0.0382 0:24:12.496421\n",
"97/200 1875 loss: 0.0365 val loss: 0.0383 0:24:28.063963\n",
"98/200 1875 loss: 0.0364 val loss: 0.0384 0:24:43.599283\n",
"99/200 1875 loss: 0.0365 val loss: 0.0381 0:24:59.157835\n",
"100/200 1875 loss: 0.0364 val loss: 0.0379 0:25:15.526026\n",
"101/200 1875 loss: 0.0364 val loss: 0.0387 0:25:31.212898\n",
"102/200 1875 loss: 0.0364 val loss: 0.0383 0:25:46.802330\n",
"103/200 1875 loss: 0.0364 val loss: 0.0382 0:26:02.178094\n",
"104/200 1875 loss: 0.0364 val loss: 0.0382 0:26:17.746102\n",
"105/200 1875 loss: 0.0363 val loss: 0.0382 0:26:33.309578\n",
"106/200 1875 loss: 0.0363 val loss: 0.0384 0:26:49.121648\n",
"107/200 1875 loss: 0.0363 val loss: 0.0381 0:27:04.702489\n",
"108/200 1875 loss: 0.0363 val loss: 0.0382 0:27:20.170574\n",
"109/200 1875 loss: 0.0363 val loss: 0.0379 0:27:35.856174\n",
"110/200 1875 loss: 0.0363 val loss: 0.0381 0:27:51.299808\n",
"111/200 1875 loss: 0.0362 val loss: 0.0384 0:28:06.870872\n",
"112/200 1875 loss: 0.0362 val loss: 0.0381 0:28:22.438025\n",
"113/200 1875 loss: 0.0362 val loss: 0.0383 0:28:37.875336\n",
"114/200 1875 loss: 0.0362 val loss: 0.0385 0:28:53.328504\n",
"115/200 1875 loss: 0.0362 val loss: 0.0382 0:29:08.972971\n",
"116/200 1875 loss: 0.0362 val loss: 0.0379 0:29:24.502631\n",
"117/200 1875 loss: 0.0362 val loss: 0.0382 0:29:39.941896\n",
"118/200 1875 loss: 0.0362 val loss: 0.0381 0:29:55.477538\n",
"119/200 1875 loss: 0.0362 val loss: 0.0384 0:30:11.112526\n",
"120/200 1875 loss: 0.0361 val loss: 0.0381 0:30:26.374847\n",
"121/200 1875 loss: 0.0361 val loss: 0.0380 0:30:41.861327\n",
"122/200 1875 loss: 0.0361 val loss: 0.0383 0:30:57.370377\n",
"123/200 1875 loss: 0.0361 val loss: 0.0381 0:31:12.900791\n",
"124/200 1875 loss: 0.0361 val loss: 0.0380 0:31:28.312363\n",
"125/200 1875 loss: 0.0361 val loss: 0.0380 0:31:43.843139\n",
"126/200 1875 loss: 0.0361 val loss: 0.0385 0:31:59.553265\n",
"127/200 1875 loss: 0.0361 val loss: 0.0385 0:32:14.916876\n",
"128/200 1875 loss: 0.0361 val loss: 0.0381 0:32:30.487089\n",
"129/200 1875 loss: 0.0360 val loss: 0.0380 0:32:45.878726\n",
"130/200 1875 loss: 0.0360 val loss: 0.0382 0:33:01.336908\n",
"131/200 1875 loss: 0.0360 val loss: 0.0377 0:33:16.793144\n",
"132/200 1875 loss: 0.0360 val loss: 0.0383 0:33:32.367575\n",
"133/200 1875 loss: 0.0360 val loss: 0.0383 0:33:47.764421\n",
"134/200 1875 loss: 0.0360 val loss: 0.0381 0:34:03.307962\n",
"135/200 1875 loss: 0.0360 val loss: 0.0383 0:34:18.773369\n",
"136/200 1875 loss: 0.0360 val loss: 0.0380 0:34:34.307721\n",
"137/200 1875 loss: 0.0360 val loss: 0.0382 0:34:49.981894\n",
"138/200 1875 loss: 0.0360 val loss: 0.0384 0:35:05.470105\n",
"139/200 1875 loss: 0.0359 val loss: 0.0383 0:35:20.803749\n",
"140/200 1875 loss: 0.0359 val loss: 0.0379 0:35:36.185748\n",
"141/200 1875 loss: 0.0359 val loss: 0.0382 0:35:51.533243\n",
"142/200 1875 loss: 0.0359 val loss: 0.0380 0:36:06.931450\n",
"143/200 1875 loss: 0.0359 val loss: 0.0381 0:36:22.431496\n",
"144/200 1875 loss: 0.0359 val loss: 0.0381 0:36:37.869902\n",
"145/200 1875 loss: 0.0359 val loss: 0.0384 0:36:53.547983\n",
"146/200 1875 loss: 0.0359 val loss: 0.0383 0:37:09.217082\n",
"147/200 1875 loss: 0.0359 val loss: 0.0382 0:37:24.778358\n",
"148/200 1875 loss: 0.0358 val loss: 0.0379 0:37:40.239433\n",
"149/200 1875 loss: 0.0359 val loss: 0.0381 0:37:55.704042\n",
"150/200 1875 loss: 0.0358 val loss: 0.0381 0:38:12.101171\n",
"151/200 1875 loss: 0.0358 val loss: 0.0380 0:38:27.662723\n",
"152/200 1875 loss: 0.0358 val loss: 0.0379 0:38:43.202257\n",
"153/200 1875 loss: 0.0358 val loss: 0.0385 0:38:58.810277\n",
"154/200 1875 loss: 0.0358 val loss: 0.0380 0:39:14.231378\n",
"155/200 1875 loss: 0.0358 val loss: 0.0381 0:39:29.652152\n",
"156/200 1875 loss: 0.0358 val loss: 0.0379 0:39:45.085332\n",
"157/200 1875 loss: 0.0358 val loss: 0.0380 0:40:00.572288\n",
"158/200 1875 loss: 0.0358 val loss: 0.0381 0:40:16.141797\n",
"159/200 1875 loss: 0.0357 val loss: 0.0381 0:40:31.634852\n",
"160/200 1875 loss: 0.0357 val loss: 0.0381 0:40:47.056919\n",
"161/200 1875 loss: 0.0357 val loss: 0.0383 0:41:02.554172\n",
"162/200 1875 loss: 0.0358 val loss: 0.0380 0:41:18.121788\n",
"163/200 1875 loss: 0.0357 val loss: 0.0379 0:41:33.599777\n",
"164/200 1875 loss: 0.0357 val loss: 0.0385 0:41:49.118886\n",
"165/200 1875 loss: 0.0357 val loss: 0.0378 0:42:04.560262\n",
"166/200 1875 loss: 0.0357 val loss: 0.0381 0:42:20.288644\n",
"167/200 1875 loss: 0.0357 val loss: 0.0381 0:42:35.660883\n",
"168/200 1875 loss: 0.0357 val loss: 0.0383 0:42:51.115505\n",
"169/200 1875 loss: 0.0357 val loss: 0.0380 0:43:06.762465\n",
"170/200 1875 loss: 0.0356 val loss: 0.0383 0:43:22.257651\n",
"171/200 1875 loss: 0.0356 val loss: 0.0383 0:43:37.670103\n",
"172/200 1875 loss: 0.0357 val loss: 0.0380 0:43:53.056826\n",
"173/200 1875 loss: 0.0357 val loss: 0.0381 0:44:08.524716\n",
"174/200 1875 loss: 0.0356 val loss: 0.0381 0:44:24.027149\n",
"175/200 1875 loss: 0.0356 val loss: 0.0379 0:44:39.346028\n",
"176/200 1875 loss: 0.0356 val loss: 0.0381 0:44:54.734347\n",
"177/200 1875 loss: 0.0356 val loss: 0.0384 0:45:10.213102\n",
"178/200 1875 loss: 0.0356 val loss: 0.0379 0:45:25.773002\n",
"179/200 1875 loss: 0.0356 val loss: 0.0382 0:45:41.326772\n",
"180/200 1875 loss: 0.0356 val loss: 0.0380 0:45:56.666135\n",
"181/200 1875 loss: 0.0356 val loss: 0.0382 0:46:11.978621\n",
"182/200 1875 loss: 0.0356 val loss: 0.0384 0:46:27.301725\n",
"183/200 1875 loss: 0.0356 val loss: 0.0381 0:46:42.745618\n",
"184/200 1875 loss: 0.0356 val loss: 0.0380 0:46:58.128569\n",
"185/200 1875 loss: 0.0355 val loss: 0.0380 0:47:13.711115\n",
"186/200 1875 loss: 0.0355 val loss: 0.0379 0:47:29.307111\n",
"187/200 1875 loss: 0.0356 val loss: 0.0382 0:47:44.756529\n",
"188/200 1875 loss: 0.0355 val loss: 0.0381 0:48:00.214915\n",
"189/200 1875 loss: 0.0355 val loss: 0.0383 0:48:15.668996\n",
"190/200 1875 loss: 0.0355 val loss: 0.0381 0:48:31.229319\n",
"191/200 1875 loss: 0.0355 val loss: 0.0382 0:48:46.675617\n",
"192/200 1875 loss: 0.0355 val loss: 0.0380 0:49:02.254153\n",
"193/200 1875 loss: 0.0355 val loss: 0.0382 0:49:17.595616\n",
"194/200 1875 loss: 0.0355 val loss: 0.0380 0:49:32.985089\n",
"195/200 1875 loss: 0.0355 val loss: 0.0381 0:49:48.470253\n",
"196/200 1875 loss: 0.0354 val loss: 0.0382 0:50:03.960498\n",
"197/200 1875 loss: 0.0355 val loss: 0.0383 0:50:19.343814\n",
"198/200 1875 loss: 0.0355 val loss: 0.0382 0:50:34.860656\n",
"199/200 1875 loss: 0.0355 val loss: 0.0381 0:50:50.302304\n",
"200/200 1875 loss: 0.0355 val loss: 0.0380 0:51:06.366335\n"
]
}
],
"source": [
"# Additional Training.\n",
"# 追加でtrainingする。\n",
"\n",
"# Compiles the part for loss and gradients fo train_tf() function into a graph of Tensorflow 2, so it is a little over twice as fast as train(). However, it is still nearly twice as slow as fit().\n",
"# train_tf() は loss と gradients を求める部分を tf のgraphにコンパイルしているので、train()よりも2倍強高速になっている。しかし、それでもfit()よりは2倍近く遅い。\n",
"\n",
"loss2_2, vloss2_2 = AE2_work.train_tf(\n",
" x_train,\n",
" x_train,\n",
" batch_size=32,\n",
" epochs = MAX_EPOCHS, \n",
" shuffle=True,\n",
" run_folder= save_path2,\n",
" optimizer = optimizer2,\n",
" save_epoch_interval=50,\n",
" validation_data=(x_test, x_test)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 279
},
"executionInfo": {
"elapsed": 12,
"status": "ok",
"timestamp": 1637567363692,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "ez2T78hkFQRG",
"outputId": "1818b76a-2336-491f-da40-da1d8724bd05"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"tf.GradientTape()
function and Learning rate decay\n",
"\n",
"Calculate the loss and gradients with the tf.GradientTape()
function, and apply the gradients to the variables. \n",
"In addition, perform Learning rate decay in the optimizer.\n",
"\n",
"[Caution] Note that if you call the save_image()
function in the training, encoder.predict()
and decoder.predict()
will work and the execution will be slow.\n",
"\n",
"## (3) tf.GradientTape()
関数と学習率減衰を使った学習\n",
"\n",
"tf.GradientTape()
関数を使って loss と gradients を計算して、gradients を変数に適用する。\n",
"さらに、optimizer において Learning rate decay を行う。\n",
"\n",
"(注意) trainingの途中で save_images()
関数を呼び出すと、 encoder.predict()
と decoder.predict()
が動作して、実行が非常に遅くなるので注意すること。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TYuLK26XIUYl"
},
"outputs": [],
"source": [
"save_path3 = '/content/drive/MyDrive/ColabRun/AE03/'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5_oyibFQbiBj"
},
"outputs": [],
"source": [
"from nw.AutoEncoder import AutoEncoder\n",
"\n",
"AE3 = AutoEncoder(\n",
" input_dim = (28, 28, 1),\n",
" encoder_conv_filters = [32, 64, 64, 64],\n",
" encoder_conv_kernel_size = [3, 3, 3, 3],\n",
" encoder_conv_strides = [1, 2, 2, 1],\n",
" decoder_conv_t_filters = [64, 64, 32, 1],\n",
" decoder_conv_t_kernel_size = [3, 3, 3, 3],\n",
" decoder_conv_t_strides = [1, 2, 2, 1],\n",
" z_dim = 2\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MMkuVG5TCDzr"
},
"outputs": [],
"source": [
"# initial_learning_rate * decay_rate ^ (step // decay_steps)\n",
"\n",
"lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(\n",
" initial_learning_rate = learning_rate,\n",
" decay_steps = 1000,\n",
" decay_rate=0.96\n",
")\n",
"\n",
"optimizer3 = tf.keras.optimizers.Adam(learning_rate=lr_schedule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 117330,
"status": "ok",
"timestamp": 1637567482503,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "YNKUQta0bm9E",
"outputId": "29c08517-aed3-4a87-82e2-b435647313d4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/3 1875 loss: 0.0493 val loss: 0.0491 0:00:38.910057\n",
"2/3 1875 loss: 0.0451 val loss: 0.0451 0:01:17.719193\n",
"3/3 1875 loss: 0.0463 val loss: 0.0439 0:01:56.448946\n"
]
}
],
"source": [
"# At first, train for a few epochs.\n",
"# まず、少ない回数 training してみる\n",
"\n",
"loss3_1, vloss3_1 = AE3.train(\n",
" x_train,\n",
" x_train,\n",
" batch_size=32,\n",
" epochs = 3, \n",
" shuffle=True,\n",
" run_folder=save_path3,\n",
" optimizer = optimizer3,\n",
" save_epoch_interval=50,\n",
" validation_data=(x_test, x_test)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 5,
"status": "ok",
"timestamp": 1637567482504,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "cUDPmvBqepKF",
"outputId": "bb83cce5-bc96-4024-8a26-d6112eb0ff67"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n"
]
}
],
"source": [
"# Load the parameters and the weights saved before.\n",
"# 保存したパラメータと、重みを読み込む。\n",
"\n",
"AE3_work = AutoEncoder.load(save_path3)\n",
"print(AE3_work.epoch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 3167903,
"status": "ok",
"timestamp": 1637570650405,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "UuZRByoze7hn",
"outputId": "ff1a83ed-4193-42f6-d4f4-049b99a99d11"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4/200 1875 loss: 0.0441 val loss: 0.0432 0:00:16.470529\n",
"5/200 1875 loss: 0.0425 val loss: 0.0424 0:00:32.578862\n",
"6/200 1875 loss: 0.0418 val loss: 0.0416 0:00:48.702078\n",
"7/200 1875 loss: 0.0413 val loss: 0.0411 0:01:04.721034\n",
"8/200 1875 loss: 0.0410 val loss: 0.0408 0:01:20.828852\n",
"9/200 1875 loss: 0.0406 val loss: 0.0408 0:01:36.866276\n",
"10/200 1875 loss: 0.0403 val loss: 0.0405 0:01:52.938620\n",
"11/200 1875 loss: 0.0401 val loss: 0.0403 0:02:08.915218\n",
"12/200 1875 loss: 0.0398 val loss: 0.0403 0:02:24.993334\n",
"13/200 1875 loss: 0.0397 val loss: 0.0402 0:02:40.943671\n",
"14/200 1875 loss: 0.0395 val loss: 0.0400 0:02:57.026531\n",
"15/200 1875 loss: 0.0393 val loss: 0.0396 0:03:13.174460\n",
"16/200 1875 loss: 0.0392 val loss: 0.0397 0:03:29.264352\n",
"17/200 1875 loss: 0.0391 val loss: 0.0397 0:03:45.378318\n",
"18/200 1875 loss: 0.0389 val loss: 0.0395 0:04:01.329874\n",
"19/200 1875 loss: 0.0388 val loss: 0.0393 0:04:17.245191\n",
"20/200 1875 loss: 0.0387 val loss: 0.0394 0:04:33.400665\n",
"21/200 1875 loss: 0.0386 val loss: 0.0392 0:04:49.561788\n",
"22/200 1875 loss: 0.0385 val loss: 0.0391 0:05:05.574827\n",
"23/200 1875 loss: 0.0384 val loss: 0.0392 0:05:21.680731\n",
"24/200 1875 loss: 0.0384 val loss: 0.0390 0:05:37.694906\n",
"25/200 1875 loss: 0.0383 val loss: 0.0391 0:05:53.796563\n",
"26/200 1875 loss: 0.0382 val loss: 0.0390 0:06:09.769820\n",
"27/200 1875 loss: 0.0382 val loss: 0.0388 0:06:25.773891\n",
"28/200 1875 loss: 0.0381 val loss: 0.0389 0:06:41.814141\n",
"29/200 1875 loss: 0.0380 val loss: 0.0389 0:06:57.785303\n",
"30/200 1875 loss: 0.0380 val loss: 0.0389 0:07:13.735852\n",
"31/200 1875 loss: 0.0379 val loss: 0.0388 0:07:29.669668\n",
"32/200 1875 loss: 0.0379 val loss: 0.0388 0:07:45.531500\n",
"33/200 1875 loss: 0.0379 val loss: 0.0388 0:08:01.576522\n",
"34/200 1875 loss: 0.0378 val loss: 0.0387 0:08:17.534843\n",
"35/200 1875 loss: 0.0378 val loss: 0.0387 0:08:33.530811\n",
"36/200 1875 loss: 0.0378 val loss: 0.0387 0:08:49.452516\n",
"37/200 1875 loss: 0.0377 val loss: 0.0386 0:09:05.483624\n",
"38/200 1875 loss: 0.0377 val loss: 0.0387 0:09:21.607403\n",
"39/200 1875 loss: 0.0377 val loss: 0.0386 0:09:37.595950\n",
"40/200 1875 loss: 0.0376 val loss: 0.0386 0:09:53.711715\n",
"41/200 1875 loss: 0.0376 val loss: 0.0386 0:10:09.573170\n",
"42/200 1875 loss: 0.0376 val loss: 0.0386 0:10:25.672371\n",
"43/200 1875 loss: 0.0376 val loss: 0.0386 0:10:41.658941\n",
"44/200 1875 loss: 0.0376 val loss: 0.0386 0:10:57.703071\n",
"45/200 1875 loss: 0.0375 val loss: 0.0386 0:11:13.708545\n",
"46/200 1875 loss: 0.0375 val loss: 0.0386 0:11:29.633509\n",
"47/200 1875 loss: 0.0375 val loss: 0.0386 0:11:45.618107\n",
"48/200 1875 loss: 0.0375 val loss: 0.0386 0:12:01.542282\n",
"49/200 1875 loss: 0.0375 val loss: 0.0386 0:12:17.577870\n",
"50/200 1875 loss: 0.0375 val loss: 0.0386 0:12:34.355703\n",
"51/200 1875 loss: 0.0375 val loss: 0.0386 0:12:50.369939\n",
"52/200 1875 loss: 0.0374 val loss: 0.0385 0:13:06.322242\n",
"53/200 1875 loss: 0.0374 val loss: 0.0385 0:13:22.294449\n",
"54/200 1875 loss: 0.0374 val loss: 0.0386 0:13:38.360678\n",
"55/200 1875 loss: 0.0374 val loss: 0.0385 0:13:54.381169\n",
"56/200 1875 loss: 0.0374 val loss: 0.0385 0:14:10.436163\n",
"57/200 1875 loss: 0.0374 val loss: 0.0385 0:14:26.374065\n",
"58/200 1875 loss: 0.0374 val loss: 0.0385 0:14:42.461288\n",
"59/200 1875 loss: 0.0374 val loss: 0.0385 0:14:58.804114\n",
"60/200 1875 loss: 0.0374 val loss: 0.0385 0:15:14.948368\n",
"61/200 1875 loss: 0.0374 val loss: 0.0385 0:15:31.003709\n",
"62/200 1875 loss: 0.0374 val loss: 0.0385 0:15:47.036744\n",
"63/200 1875 loss: 0.0374 val loss: 0.0385 0:16:02.927857\n",
"64/200 1875 loss: 0.0374 val loss: 0.0385 0:16:19.033550\n",
"65/200 1875 loss: 0.0374 val loss: 0.0385 0:16:35.146949\n",
"66/200 1875 loss: 0.0374 val loss: 0.0385 0:16:51.198622\n",
"67/200 1875 loss: 0.0374 val loss: 0.0385 0:17:07.332641\n",
"68/200 1875 loss: 0.0374 val loss: 0.0385 0:17:23.460586\n",
"69/200 1875 loss: 0.0373 val loss: 0.0385 0:17:39.446418\n",
"70/200 1875 loss: 0.0373 val loss: 0.0385 0:17:55.528039\n",
"71/200 1875 loss: 0.0373 val loss: 0.0385 0:18:11.502721\n",
"72/200 1875 loss: 0.0373 val loss: 0.0385 0:18:27.473277\n",
"73/200 1875 loss: 0.0373 val loss: 0.0385 0:18:43.449839\n",
"74/200 1875 loss: 0.0373 val loss: 0.0385 0:18:59.486046\n",
"75/200 1875 loss: 0.0373 val loss: 0.0385 0:19:15.552210\n",
"76/200 1875 loss: 0.0373 val loss: 0.0385 0:19:31.626925\n",
"77/200 1875 loss: 0.0373 val loss: 0.0385 0:19:47.513213\n",
"78/200 1875 loss: 0.0373 val loss: 0.0385 0:20:03.722407\n",
"79/200 1875 loss: 0.0373 val loss: 0.0385 0:20:19.804345\n",
"80/200 1875 loss: 0.0373 val loss: 0.0385 0:20:35.791387\n",
"81/200 1875 loss: 0.0373 val loss: 0.0385 0:20:51.787138\n",
"82/200 1875 loss: 0.0373 val loss: 0.0385 0:21:07.730577\n",
"83/200 1875 loss: 0.0373 val loss: 0.0385 0:21:23.787720\n",
"84/200 1875 loss: 0.0373 val loss: 0.0385 0:21:39.687327\n",
"85/200 1875 loss: 0.0373 val loss: 0.0385 0:21:55.584867\n",
"86/200 1875 loss: 0.0373 val loss: 0.0385 0:22:11.523757\n",
"87/200 1875 loss: 0.0373 val loss: 0.0385 0:22:27.390298\n",
"88/200 1875 loss: 0.0373 val loss: 0.0385 0:22:43.472842\n",
"89/200 1875 loss: 0.0373 val loss: 0.0385 0:22:59.672602\n",
"90/200 1875 loss: 0.0373 val loss: 0.0385 0:23:15.707643\n",
"91/200 1875 loss: 0.0373 val loss: 0.0385 0:23:31.828762\n",
"92/200 1875 loss: 0.0373 val loss: 0.0385 0:23:47.825785\n",
"93/200 1875 loss: 0.0373 val loss: 0.0385 0:24:03.841613\n",
"94/200 1875 loss: 0.0373 val loss: 0.0385 0:24:19.741862\n",
"95/200 1875 loss: 0.0373 val loss: 0.0385 0:24:35.758435\n",
"96/200 1875 loss: 0.0373 val loss: 0.0385 0:24:51.756826\n",
"97/200 1875 loss: 0.0373 val loss: 0.0385 0:25:07.796537\n",
"98/200 1875 loss: 0.0373 val loss: 0.0385 0:25:24.006386\n",
"99/200 1875 loss: 0.0373 val loss: 0.0385 0:25:39.953159\n",
"100/200 1875 loss: 0.0373 val loss: 0.0385 0:25:56.680088\n",
"101/200 1875 loss: 0.0373 val loss: 0.0385 0:26:12.840496\n",
"102/200 1875 loss: 0.0373 val loss: 0.0385 0:26:28.849194\n",
"103/200 1875 loss: 0.0373 val loss: 0.0385 0:26:44.841663\n",
"104/200 1875 loss: 0.0373 val loss: 0.0385 0:27:00.859847\n",
"105/200 1875 loss: 0.0373 val loss: 0.0385 0:27:17.004825\n",
"106/200 1875 loss: 0.0373 val loss: 0.0385 0:27:33.131367\n",
"107/200 1875 loss: 0.0373 val loss: 0.0385 0:27:49.215559\n",
"108/200 1875 loss: 0.0373 val loss: 0.0385 0:28:05.210165\n",
"109/200 1875 loss: 0.0373 val loss: 0.0385 0:28:21.219173\n",
"110/200 1875 loss: 0.0373 val loss: 0.0385 0:28:37.222607\n",
"111/200 1875 loss: 0.0373 val loss: 0.0385 0:28:53.464457\n",
"112/200 1875 loss: 0.0373 val loss: 0.0385 0:29:09.544785\n",
"113/200 1875 loss: 0.0373 val loss: 0.0385 0:29:25.686480\n",
"114/200 1875 loss: 0.0373 val loss: 0.0385 0:29:41.597519\n",
"115/200 1875 loss: 0.0373 val loss: 0.0385 0:29:57.662729\n",
"116/200 1875 loss: 0.0373 val loss: 0.0385 0:30:13.863950\n",
"117/200 1875 loss: 0.0373 val loss: 0.0385 0:30:30.316802\n",
"118/200 1875 loss: 0.0373 val loss: 0.0385 0:30:46.468128\n",
"119/200 1875 loss: 0.0373 val loss: 0.0385 0:31:02.499419\n",
"120/200 1875 loss: 0.0373 val loss: 0.0385 0:31:18.614042\n",
"121/200 1875 loss: 0.0373 val loss: 0.0385 0:31:35.006121\n",
"122/200 1875 loss: 0.0373 val loss: 0.0385 0:31:51.303524\n",
"123/200 1875 loss: 0.0373 val loss: 0.0385 0:32:07.364468\n",
"124/200 1875 loss: 0.0373 val loss: 0.0385 0:32:23.438019\n",
"125/200 1875 loss: 0.0373 val loss: 0.0385 0:32:39.503072\n",
"126/200 1875 loss: 0.0373 val loss: 0.0385 0:32:55.600306\n",
"127/200 1875 loss: 0.0373 val loss: 0.0385 0:33:11.714923\n",
"128/200 1875 loss: 0.0373 val loss: 0.0385 0:33:27.757243\n",
"129/200 1875 loss: 0.0373 val loss: 0.0385 0:33:43.846899\n",
"130/200 1875 loss: 0.0373 val loss: 0.0385 0:33:59.986440\n",
"131/200 1875 loss: 0.0373 val loss: 0.0385 0:34:16.036559\n",
"132/200 1875 loss: 0.0373 val loss: 0.0385 0:34:32.108351\n",
"133/200 1875 loss: 0.0373 val loss: 0.0385 0:34:48.126711\n",
"134/200 1875 loss: 0.0373 val loss: 0.0385 0:35:04.243871\n",
"135/200 1875 loss: 0.0373 val loss: 0.0385 0:35:20.342651\n",
"136/200 1875 loss: 0.0373 val loss: 0.0385 0:35:36.834930\n",
"137/200 1875 loss: 0.0373 val loss: 0.0385 0:35:53.240132\n",
"138/200 1875 loss: 0.0373 val loss: 0.0385 0:36:09.494644\n",
"139/200 1875 loss: 0.0373 val loss: 0.0385 0:36:25.711864\n",
"140/200 1875 loss: 0.0373 val loss: 0.0385 0:36:41.715972\n",
"141/200 1875 loss: 0.0373 val loss: 0.0385 0:36:57.798700\n",
"142/200 1875 loss: 0.0373 val loss: 0.0385 0:37:14.012655\n",
"143/200 1875 loss: 0.0373 val loss: 0.0385 0:37:29.964565\n",
"144/200 1875 loss: 0.0373 val loss: 0.0385 0:37:45.947853\n",
"145/200 1875 loss: 0.0373 val loss: 0.0385 0:38:01.922393\n",
"146/200 1875 loss: 0.0373 val loss: 0.0385 0:38:18.113317\n",
"147/200 1875 loss: 0.0373 val loss: 0.0385 0:38:34.302843\n",
"148/200 1875 loss: 0.0373 val loss: 0.0385 0:38:50.460746\n",
"149/200 1875 loss: 0.0373 val loss: 0.0385 0:39:06.514157\n",
"150/200 1875 loss: 0.0373 val loss: 0.0385 0:39:23.229584\n",
"151/200 1875 loss: 0.0373 val loss: 0.0385 0:39:39.470055\n",
"152/200 1875 loss: 0.0373 val loss: 0.0385 0:39:55.751577\n",
"153/200 1875 loss: 0.0373 val loss: 0.0385 0:40:12.016928\n",
"154/200 1875 loss: 0.0373 val loss: 0.0385 0:40:28.212051\n",
"155/200 1875 loss: 0.0373 val loss: 0.0385 0:40:44.578358\n",
"156/200 1875 loss: 0.0373 val loss: 0.0385 0:41:00.970150\n",
"157/200 1875 loss: 0.0373 val loss: 0.0385 0:41:17.138623\n",
"158/200 1875 loss: 0.0373 val loss: 0.0385 0:41:33.213235\n",
"159/200 1875 loss: 0.0373 val loss: 0.0385 0:41:49.222445\n",
"160/200 1875 loss: 0.0373 val loss: 0.0385 0:42:05.256580\n",
"161/200 1875 loss: 0.0373 val loss: 0.0385 0:42:21.305456\n",
"162/200 1875 loss: 0.0373 val loss: 0.0385 0:42:37.293515\n",
"163/200 1875 loss: 0.0373 val loss: 0.0385 0:42:53.218787\n",
"164/200 1875 loss: 0.0373 val loss: 0.0385 0:43:09.221011\n",
"165/200 1875 loss: 0.0373 val loss: 0.0385 0:43:25.241022\n",
"166/200 1875 loss: 0.0373 val loss: 0.0385 0:43:41.364001\n",
"167/200 1875 loss: 0.0373 val loss: 0.0385 0:43:57.468188\n",
"168/200 1875 loss: 0.0373 val loss: 0.0385 0:44:13.478234\n",
"169/200 1875 loss: 0.0373 val loss: 0.0385 0:44:29.480388\n",
"170/200 1875 loss: 0.0373 val loss: 0.0385 0:44:45.507111\n",
"171/200 1875 loss: 0.0373 val loss: 0.0385 0:45:01.481223\n",
"172/200 1875 loss: 0.0373 val loss: 0.0385 0:45:17.528489\n",
"173/200 1875 loss: 0.0373 val loss: 0.0385 0:45:33.518117\n",
"174/200 1875 loss: 0.0373 val loss: 0.0385 0:45:49.637809\n",
"175/200 1875 loss: 0.0373 val loss: 0.0385 0:46:05.890984\n",
"176/200 1875 loss: 0.0373 val loss: 0.0385 0:46:21.956460\n",
"177/200 1875 loss: 0.0373 val loss: 0.0385 0:46:38.062078\n",
"178/200 1875 loss: 0.0373 val loss: 0.0385 0:46:54.053728\n",
"179/200 1875 loss: 0.0373 val loss: 0.0385 0:47:09.917701\n",
"180/200 1875 loss: 0.0373 val loss: 0.0385 0:47:25.930845\n",
"181/200 1875 loss: 0.0373 val loss: 0.0385 0:47:41.949047\n",
"182/200 1875 loss: 0.0373 val loss: 0.0385 0:47:57.975498\n",
"183/200 1875 loss: 0.0373 val loss: 0.0385 0:48:13.821878\n",
"184/200 1875 loss: 0.0373 val loss: 0.0385 0:48:29.675598\n",
"185/200 1875 loss: 0.0373 val loss: 0.0385 0:48:45.563435\n",
"186/200 1875 loss: 0.0373 val loss: 0.0385 0:49:01.501470\n",
"187/200 1875 loss: 0.0373 val loss: 0.0385 0:49:17.551920\n",
"188/200 1875 loss: 0.0373 val loss: 0.0385 0:49:33.465461\n",
"189/200 1875 loss: 0.0373 val loss: 0.0385 0:49:49.437168\n",
"190/200 1875 loss: 0.0373 val loss: 0.0385 0:50:05.438445\n",
"191/200 1875 loss: 0.0373 val loss: 0.0385 0:50:21.633043\n",
"192/200 1875 loss: 0.0373 val loss: 0.0385 0:50:37.783884\n",
"193/200 1875 loss: 0.0373 val loss: 0.0385 0:50:53.888186\n",
"194/200 1875 loss: 0.0373 val loss: 0.0385 0:51:10.185000\n",
"195/200 1875 loss: 0.0373 val loss: 0.0385 0:51:26.449233\n",
"196/200 1875 loss: 0.0373 val loss: 0.0385 0:51:42.700071\n",
"197/200 1875 loss: 0.0373 val loss: 0.0385 0:51:58.913375\n",
"198/200 1875 loss: 0.0373 val loss: 0.0385 0:52:15.079322\n",
"199/200 1875 loss: 0.0373 val loss: 0.0385 0:52:30.994392\n",
"200/200 1875 loss: 0.0373 val loss: 0.0385 0:52:47.570349\n"
]
}
],
"source": [
"# Additional Training.\n",
"# 追加でtrainingする。\n",
"\n",
"# Compiles the part for loss and gradients fo train_tf() function into a graph of Tensorflow 2, so it is a little over twice as fast as train(). However, it is still nearly twice as slow as fit().\n",
"# train_tf() は loss と gradients を求める部分を tf のgraphにコンパイルしているので、train()よりも2倍強高速になっている。しかし、それでもfit()よりは2倍近く遅い。\n",
"\n",
"loss3_2, vloss3_2 = AE3_work.train_tf(\n",
" x_train,\n",
" x_train,\n",
" batch_size=32,\n",
" epochs = MAX_EPOCHS, \n",
" shuffle=True,\n",
" run_folder= save_path3,\n",
" optimizer = optimizer3,\n",
" save_epoch_interval=50,\n",
" validation_data=(x_test, x_test)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 16,
"status": "ok",
"timestamp": 1637570650405,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "XLDhBSeefgX_",
"outputId": "c7b046cc-9b3c-4098-e586-2295fd4bea23"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"