{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
\n", "

An introduction to Deep Learning with Pytorch

\n", "

\n", " Yifei Huang\n", "

\n", " MAFS6010U\n", "

\n", " HKUST\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prerequires " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Materials available on Piazza\n", " \n", "- Require a Python distribution with scientific packages (NumPy, SciPy, Scikit-Learn, Pandas, Pytorch)\n", "\n", "- One can refer to this Link to install Pytorch\n", "\n", "- One can refer to this Link to learn basic operation in Pytorch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Outline\n", "\n", "* Preprocess data with Pytorch Dataloader\n", "* RNN and CNN\n", "* Generative Adversarial Network\n", "* Reinforcement Learning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load some packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn, optim\n", "from torch.utils.data import DataLoader, Dataset\n", "import torch.nn.functional as F\n", "\n", "import os\n", "import argparse\n", "import pandas as pd\n", "import numpy as np\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1' # indecate which gpu you want to use\n", "\n", "try:\n", " from mpl_finance import candlestick_ohlc\n", "except ImportError:\n", " ! pip install mpl_finance\n", " from mpl_finance import candlestick_ohlc\n", " \n", "import matplotlib.pyplot as plt\n", "import matplotlib.dates as mdates\n", "%matplotlib inline\n", "plt.style.use('ggplot')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loading external data with pandas" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv('./okex_future_BTC_USD_this_week_1H.csv',\n", " parse_dates=['candle_begin_time'])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
candle_begin_timeopenhighlowclosevolume
02018-05-15 00:00:008674.338761.768663.178696.66172152.0
12018-05-15 01:00:008694.668728.178662.178701.85122668.0
22018-05-15 02:00:008699.988741.048642.368673.34181228.0
32018-05-15 03:00:008673.198717.708638.958707.10193166.0
42018-05-15 04:00:008710.928737.368684.208699.16155912.0
\n", "
" ], "text/plain": [ " candle_begin_time open high low close volume\n", "0 2018-05-15 00:00:00 8674.33 8761.76 8663.17 8696.66 172152.0\n", "1 2018-05-15 01:00:00 8694.66 8728.17 8662.17 8701.85 122668.0\n", "2 2018-05-15 02:00:00 8699.98 8741.04 8642.36 8673.34 181228.0\n", "3 2018-05-15 03:00:00 8673.19 8717.70 8638.95 8707.10 193166.0\n", "4 2018-05-15 04:00:00 8710.92 8737.36 8684.20 8699.16 155912.0" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualization" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data['timestamp'] = data.candle_begin_time.apply(mdates.date2num)\n", "fig, ax = plt.subplots(1, 1, figsize=(40, 10))\n", "plt.title('OKEX-FUTURE_BTC-USD_1H', fontsize=30)\n", "candlestick_ohlc(ax,\n", " data[['timestamp', 'open', 'high', 'low', 'close']].values,\n", " width=.6,\n", " colorup='g',\n", " alpha=1)\n", "ax.xaxis_date()\n", "ax.xaxis.set_major_formatter(mdates.DateFormatter('%y-%m-%d'))\n", "ax.set_ylabel('Price', fontdict={'size': 30})\n", "ax.set_xlabel('Date', fontdict={'size': 30})\n", "data = data.drop(['timestamp'], axis=1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "parse = argparse.ArgumentParser(description='Tutorial2MAFS6010U')\n", "\n", "parse.add_argument('--input_length', default=12, type=int)\n", "parse.add_argument('--input_dim', default=5, type=int)\n", "parse.add_argument('--training_ratio', default=0.8, type=float)\n", "parse.add_argument('--batch_size', default=64, type=int)\n", "parse.add_argument('--hidden_dim', default=128, type=int)\n", "parse.add_argument('--rnn_layers', default=1, type=int)\n", "parse.add_argument('--lr', default=1e-4, type=float)\n", "parse.add_argument('--epochs', default=10, type=int)\n", "\n", "parse.add_argument('--cnn_kernels',default=[2,3,4],type=list)\n", "parse.add_argument('--cnn_channels',default=64,type=int)\n", "\n", "config, _ = parse.parse_known_args()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# First Part" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## RNN and CNN\n", "### We want to build a Recurrent neural network and Convolutional neural network to predict the next close price" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RNN:\n", "* Navie RNN\n", "* GRU\n", "* LSTM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://drive.google.com/uc?id=1B9zjuzsSfSjxI_qL7ahaCMmibSSC0EHi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First of all, we need to write a dataloader to load our data based on our requirements" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class FinDataset(Dataset):\n", " def __init__(self, input_length=12, training_ratio=.8, training=True):\n", " data = pd.read_csv('./okex_future_BTC_USD_this_week_1H.csv',\n", " parse_dates=['candle_begin_time'])\n", " data.index = data['candle_begin_time']\n", " data = data.drop(['candle_begin_time'], axis=1)\n", " training_size = int(len(data) * training_ratio)\n", " self.data = data[:training_size] if training else data[training_size:]\n", " self.input_length = input_length\n", "\n", " def __getitem__(self, idx):\n", " # input is the previous state\n", " # label is the next close price\n", " input_seq = self.data.values[idx:idx + self.input_length]\n", " label = self.data.close[idx + self.input_length]\n", " mean = np.mean(input_seq, axis=0, keepdims=True)\n", " std = np.std(input_seq, axis=0, keepdims=True)\n", " \n", " # do scaling for input\n", " input_seq = (input_seq - mean) / (std + 1e-8)\n", "\n", " close_price = self.data.close[idx:idx + self.input_length]\n", " self.mean = close_price.mean()\n", " self.std = close_price.std()\n", " \n", " # do scaling for label\n", " label = (label - self.mean) / (self.std + 1e-8)\n", " return input_seq, label\n", "\n", " def __len__(self):\n", " return len(self.data) - self.input_length" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "training_dataset = FinDataset(input_length=12,\n", " training_ratio=.8,\n", " training=True)\n", "val_dataset = FinDataset(input_length=12, training_ratio=.8, training=False)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training size: 5287, Val size: 1313\n" ] } ], "source": [ "print(f'Training size: {len(training_dataset)}, Val size: {len(val_dataset)}')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[-1.37962864, 0.18147287, -0.59518499, -0.80949459, 0.40760911],\n", " [-0.56938742, -0.83929303, -0.63166397, -0.61047626, -0.66016787],\n", " [-0.35736168, -0.44818684, -1.3543126 , -1.70373495, 0.6034531 ],\n", " [-1.42506273, -1.15746567, -1.47870593, -0.40915714, 0.86105397],\n", " [ 0.07864614, -0.56001829, 0.171968 , -0.71362835, 0.0571787 ],\n", " [-0.45659929, -0.17590156, 0.18728917, -0.01342127, -0.25320251],\n", " [ 0.05074802, 0.91810179, 0.92854209, 0.53723445, 0.2370548 ],\n", " [ 0.845446 , 0.01463736, 1.11312574, 0.35777284, -1.17381598],\n", " [ 0.63979697, -0.85053696, -1.14820636, -0.00498504, 0.04125397],\n", " [ 0.23407854, -0.63963742, 0.40871659, -0.32671217, -1.22931518],\n", " [-0.14971993, 2.51200389, 0.60825663, 1.98366439, 2.37972958],\n", " [ 2.48904403, 1.04482385, 1.79017564, 1.71293811, -1.27083169]]),\n", " 1.4061453421112702)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "training_dataset[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Usually we need to use DataLoader in Pytorch to load our data, because it is faster than our implementation" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def get_loader(config):\n", " training_dataset = FinDataset(input_length=config.input_length,\n", " training_ratio=config.training_ratio,\n", " training=True)\n", " val_dataset = FinDataset(input_length=config.input_length,\n", " training_ratio=config.training_ratio,\n", " training=False)\n", " \n", " # one can tune the batch_size, not too small, not too large\n", " # here we define our dataloader to load training data and validation data\n", " # usually we need to shuffle the training data to increase generalization but not shuffle the val data\n", " # pin_memory should always set to be True to reduce the CPU memory use\n", " training_loader = DataLoader(training_dataset,\n", " config.batch_size,\n", " shuffle=True,\n", " num_workers=4,\n", " pin_memory=True)\n", " val_loader = DataLoader(val_dataset,\n", " config.batch_size,\n", " shuffle=False,\n", " num_workers=4,\n", " pin_memory=True)\n", "\n", " return {'train': training_loader, 'val': val_loader}" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class RNNModel(nn.Module):\n", " def __init__(self, input_dim, hidden_dim, rnn_layers=1):\n", " super(RNNModel, self).__init__()\n", " self.input_dim = input_dim\n", " self.hidden_dim = hidden_dim\n", " \n", " # define our recurrent neural network\n", " self.gru = nn.GRU(input_dim, hidden_dim, rnn_layers, batch_first=True)\n", "\n", " # define our regressor because our problem is a regression problem\n", " self.regressor = nn.Sequential(nn.Linear(hidden_dim, hidden_dim // 2),\n", " nn.ReLU(inplace=True), nn.Dropout(.5),\n", " nn.Linear(hidden_dim // 2, 1))\n", "\n", " # here do the initialization for recurrent neural network\n", " for name, param in self.gru.named_parameters():\n", " if 'weight' in name:\n", " nn.init.orthogonal_(param)\n", " if 'bias' in name:\n", " bias_length = param.size()[0]\n", " start, end = bias_length // 4, bias_length // 2\n", " param.data[start:end].fill_(1.)\n", "\n", " def forward(self, s):\n", " # s: torch.tensor [batch_size, t, m]\n", " # where t is the sequence length, m is the feature length (how many features you have)\n", " output, hidden = self.gru(s)\n", " \n", " # output: torch.tensor [batch_size, t, hidden_dim]\n", " # usually we use the last output i.e. output[:, -1, :] as the later input\n", " output = self.regressor(output[:, -1, :])\n", " return output" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: [0/10]\n", "iters: [0/83], training_loss: 4.6170\n", "iters: [80/83], training_loss: 4.1754\n", "val_loss: 4.9485\n", "Epoch: [1/10]\n", "iters: [0/83], training_loss: 4.2196\n", "iters: [80/83], training_loss: 3.8386\n", "val_loss: 4.5632\n", "Epoch: [2/10]\n", "iters: [0/83], training_loss: 2.0318\n", "iters: [80/83], training_loss: 3.3538\n", "val_loss: 4.1762\n", "Epoch: [3/10]\n", "iters: [0/83], training_loss: 9.0688\n", "iters: [80/83], training_loss: 3.1163\n", "val_loss: 4.0551\n", "Epoch: [4/10]\n", "iters: [0/83], training_loss: 1.7701\n", "iters: [80/83], training_loss: 3.0344\n", "val_loss: 3.9876\n", "Epoch: [5/10]\n", "iters: [0/83], training_loss: 3.3667\n", "iters: [80/83], training_loss: 2.9157\n", "val_loss: 3.9482\n", "Epoch: [6/10]\n", "iters: [0/83], training_loss: 0.7015\n", "iters: [80/83], training_loss: 2.9102\n", "val_loss: 3.9085\n", "Epoch: [7/10]\n", "iters: [0/83], training_loss: 0.9521\n", "iters: [80/83], training_loss: 2.6907\n", "val_loss: 3.8947\n", "Epoch: [8/10]\n", "iters: [0/83], training_loss: 2.2297\n", "iters: [80/83], training_loss: 2.8626\n", "val_loss: 3.8831\n", "Epoch: [9/10]\n", "iters: [0/83], training_loss: 2.0857\n", "iters: [80/83], training_loss: 2.8488\n", "val_loss: 3.8781\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEJCAYAAABv6GdPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd8FHX+x/HXd3Y3m56QBBLS6FZUVBQEC4LoIZ5Y7r56p56eepZTQbGdFcR2nmc99dSzo6d+LfzEgoqCFbuiHoecCkhCKCm0EEjZmd8fs6RRUkgy2d3P8/HII7uzM5vPfgnv+eY7M99RjuMghBAiulheFyCEEKLjSbgLIUQUknAXQogoJOEuhBBRSMJdCCGikIS7EEJEIQl3EfWUUk8opd7xug4hupKEuxBCRCEJdyGEiEIS7iKmKNdlSqnFSqkapdTPSqmLm60zQSn1jVKqSim1Vin1uVJq3/BrAaXUnUqpYqVUtVJqhVLqOW8+jRDb5/e6ACG62J+BG4FJwFxgDHC3UmqD4ziPKqVygBeAa8Pf44F9gbrw9hcBGjgVWAxkAyO79BMI0QoS7iLW/AX4h+M4D4ef/6iU2hW4BngU6A0EAOM4ztLwOgsbbd8H+B/wvuNOzLQM+KIrCheiLWRYRsQMpVQqkA980Oyl94G+SqlE4DvgLeA/SqkZSqlJSqmCRus+DuwF/KSUelApdaJSKq4r6heiLSTcRSxqPhWqqn/BcULAOGA0bo/8ROB/Sqljwq/PB/oBlwE1wD3A/PCOQ4huQ8JdxAzHcdYDxcBhzV46FFjiOE5VeD3HcZzPHce5xXGcQ3F79n9s9D6VjuPMcBxnIjAU2H0b7ymEp2TMXcSaW4E7lFI/Au/h9tDPBy4AUEqNwD3I+jawAhgE7I07Ho9S6nKgBJgPVAG/A0K44/BCdBsS7iLW/BNIAq4GHgCKgL84jvNo+PV1wEG4Yd8DWAk8g3uGDcB6YDJu6Fu4B1tPdBxnUVd9ACFaQ8mdmIQQIvrImLsQQkQhCXchhIhCEu5CCBGFJNyFECIKeXm2jBzJFUKI9lEtreDpqZAlJSXt2i4rK4uysrIOriZySXs0Je3RQNqiqWhoj9zc3Fat16pw11ovBTbgXqxRZ4wZ2ux1hXsZ9tG4F3acYYz5ug31CiGE6EBt6bkfbozZ3i5vHO5FHYOAYbgXigzbydqEEEK0U0cdUJ0APGWMcYwxnwLpWuveHfTeQggh2qi1PXcHeFtr7QAPGWMebvZ6Hu5l3FsUh5etaLyS1voc4BwAYwxZWVntK9rvb/e20UjaoylpjwbtaQvHcaioqKCurq7llSPM6tWriZSr8v1+PxkZGSjV4rHTbW/fyvVGGmNKtNa9gNla6x+MMY3nxN7WT9+qBcM7hS07Bqe9Bzai4aBIR5L2aErao0F72mLTpk0EAgH8/uibesrv90fMTqu2tpbi4mISEhKaLG/tAdVWDcsYY0rC31cDM4ADm61SDDS+oUE+7sx5QogIY9t2VAZ7pPH7/di23f7tW1pBa50EWMaYDeHHRwLTmq02E7hQa/0c7oHUdcaYFQghIk57hwFEx9uZf4vW7J6zgRla6y3r/9sY86bW+jwAY8yDwBu4p0H+hHsq5B+38147bcWGGl5YtJQTd0nCkl9CIYTYJi+n/HXacxHTjP+W88Q3pRyzaw/O3r+X9DKQMebmpD0atKctqqqqSExM7KSKvBVJY+6w7X+L8Jh7i8EXcXPLHLd7BnpILq8tWsOz38t/YCGizbp163jiiSfavN1pp53GunXrdrjObbfdxgcfNL8/+s4ZNGhQh75fR4m4cFdKMfHQfhwxII3nvy/nlYUVXpckhOhA69ev56mnntpqeSgU2uF206dPJy0tbYfrXHnllRx66KE7VV+kiMhD4kop/nxgDlW1No99vZrEgMXYgelelyVE1LGf+xdO0ZIOfU9V0A/r5D9t9/VbbrmFX375hbFjxxIIBEhMTCQ7O5sFCxbw3nvvceaZZ1JSUkJ1dTVnnXUWp556KgDDhg1j1qxZbNy4kVNPPZUDDzyQL7/8kpycHB577DESEhKYOHEio0eP5phjjmHYsGH89re/Zfbs2dTV1fHQQw8xcOBAysvLueCCC1izZg377LMP7733Hm+++SYZGRk7/FyO43DTTTcxd+5ctxM6cSITJkxg1apVnH/++WzYsIFQKMStt97K0KFDufTSS/nuu+9QSnHSSSdxzjnndGg7R2S4A/gsxeQRvdlUa/PA5ytJDFiM7JPqdVlCiJ109dVXs2jRImbPns28efP4wx/+wJw5cygsLATgjjvuoEePHmzatInx48dz9NFHbxW8S5Ys4f777+f222/n3HPP5Y033uDEE0/c6mdlZGTw1ltv8cQTT/Dggw/y97//nTvvvJORI0dy0UUXMXfuXJ555plW1f3GG2+wYMECZs+eTUVFBUcffTTDhw9nxowZHHbYYUyaNIlQKMSmTZtYsGABK1euZM6cOQAtDie1R8SGO0DAZ3HVoXlMnVPEnfNKiPdb7J+X7HVZQkSNHfWwu8qQIUPqgx3gscceY9asWYA7s+ySJUu2CveCggIGDx4MwN57701RURHbMm7cuPp1trzn559/zqOPuvdLP/zww0lPb92owOeff85xxx2Hz+ejZ8+eDB8+nG+//ZYhQ4Zw6aWXUldXx1FHHcXgwYMpLCxk2bJlXHvttYwZM4bDDjusDS3SOhE35t5c0G9x7ah8CtOC/PXD5SxYXeV1SUKIDtT4bJF58+bx4Ycf8uqrr/LOO+8wePBgqqurt9omGAzWP/b5fNsdr9+yXuN12nsG4fa2Gz58OC+99BI5OTlMmjSJF154gfT0dGbPns1BBx3EE088wWWXXdaun7kjER/uAElxPqaMLqBnUoAb5xbzU/lmr0sSQrRTUlISlZWV23xtw4YNpKWlkZCQwE8//cTXX3f8zOIHHnggr776KgDvv/8+a9eubdV2w4cPZ+bMmYRCIcrLy/nss88YMmQIxcXFZGVlccopp3DyySfz/fffU1FRgW3bjB8/nssvv5zvv/++wz9HRA/LNJYe72famAKuevsXbphbxC1jCylIC7a8oRCiW8nIyOCAAw5g9OjRxMfHN5n4bNSoUUyfPp0jjjiC/v37s99++3X4z588eTJ//vOfmTlzJsOHDyc7O5ukpKQWtxs3bhxfffUVY8eORSnFNddcQ69evTDG8OCDD+L3+0lKSuKee+5hxYoVTJ48uX56gauuuqrDP0fEXcQEO74wo2R9DVfN/gWfUtx6ZCHZyXE7U2NEkIt2mpL2aCAXMTXVmouYqqur8fl8+P1+vvzyS6666ipmz57dRRU2tTMXMUVNz32L3NQ4bhhdwDXvLOP6d4u49cg+ZCRE3ccUQnSS5cuXc95552HbNnFxcdx+++1el9QuUZl6fXvEc/3hBVz/7jKmvlvEzWMLSQn6vC5LCBEB+vfvz9tvv91kWUVFBSeddNJW6z7//PMtnv/ulagMd4BdsxK4+rB8ps0t5oa5RUwbU0BiQAJeCNF2GRkZng3NtFdUnC2zPfvkJHHFwbn8XLGZW95fTk2o/XMjCyFEJInqcAcYVpDCpIN68/2qKv72YQl1dmTcYksIIXZG1Ic7wKh+aZx7QDZfLK/knk9WYEfIPRSFEKK9onbMvbmjd+lBVY3N9G9LSQxYnHdAtswFL4SIWjHRc9/ixD0zOGGPDN78cS3T55d6XY4QogPsaD71oqIiRo8e3YXVdB8x03MHd6rgPwzpycYam5f+W0FinI/f7JnpdVlCCNHhYircwQ34cw/IZlOtzfT5pSQFLMbt0sPrsoTolh75chVL1nTsXE39esRz9tDs7b5+8803k5eXxxlnnAG4U/wqpfj0009Zt24ddXV1XHHFFRx11FFt+rmbN2/mmmuuYf78+fh8PqZMmcLIkSNZtGgRkydPpqamBsdxePjhh8nJyeHcc89lxYoV2LbNpEmTmDBhws587C4Xc+EO7lzwk0b0ZlNdiIe+WEVCwGJUvx3fwUUI0TUmTJjAlClT6sP91Vdf5ZlnnuFPf/oTKSkpVFRU8Otf/5ojjzyyTcfNtty679133+Wnn37id7/7HR9++CHTp0/nrLPO4oQTTqCmpoZQKMScOXPIyclh+vTpgHt3qEgTk+EO4LcUlx+cx7T3irnnkxUkBCyG5ad4XZYQ3cqOetidZfDgwZSVlbFy5UrKy8tJS0ujV69eTJ06lc8++wylFCtXrqS0tJRevXq1+n2/+OILzj77bAAGDhxIfn4+ixcvZv/99+fee+9lxYoVjBs3jv79+7Pbbrtx4403cvPNN3PEEUcwbNiwzvq4nSamDqg2F/RbXHNYHgMy4rn9wxK+W7nR65KEEMD48eN5/fXXmTlzJhMmTODll1+mvLycWbNmMXv2bLKysrY5j/uObG+SxOOPP57HH3+c+Ph4TjnlFD766CMGDBjArFmz2G233bj11lu56667OuJjdamYDneAxICP6w8voHdKgJvfL2ZR2SavSxIi5k2YMIFXXnmF119/nfHjx7NhwwaysrIIBAJ8/PHHFBcXt/k9hw0bxksvvQTAzz//zPLlyxkwYAC//PILffr04ayzzmLs2LEsXLiQlStXkpCQwIknnsh5553XKfOtd7aYD3eA1KCPqaML3Dnh5xaxtIMPIAkh2mbXXXdl48aN5OTkkJ2dzQknnMC3337LuHHjmDFjBgMHDmzze55++unYts2YMWM4//zzueuuuwgGg8ycOZPRo0czduxYfv75Z37zm9/www8/cMwxxzB27FjuvfdeJk2a1AmfsnNF3XzuO2NVZQ1/eXsZjuNw65F96J0SGXPBy/zlTUl7NJD53JtqzXzu3cnOzOcuPfdGspPjuGFMAXUOXP/uMsqqar0uSQgh2iVmz5bZnsK0IFMOz+e6d4qY8q57u760eGkmIbqzhQsXMnHixCbLgsEgr732mkcVeU9SaxsGZSZw3ah8ps4tYuqcIm46opCkOJkLXsQGD4dq22333XePuPnWW2Nn/i1kWGY79sxO5MpD8vhlbTU3vVdMdZ3MBS9ig2VZETUuHa3q6uqwrPZHtPTcd2BoXjKXjMjljo9L+OsHy7n6sHwCPplJUkS3+Ph4Nm/eTHV1ddTNnBoMBtt8frwXHMfBsizi4+Pb/R4S7i04pG8qm+ps7v9sJXfNK+HSkbn4rOj6hReiMaUUCQkJXpfRKWLpTCoJ91Y4cmA6VbUhHv+6lITASi4clhN1PRohRHSRcG+l43bPZGONjflPOYkBizP36yUBL4TotiTc2+D3e2exsdZm5g9rSIrzcfJeWV6XJIQQ2yTh3gZKKc7evxebakM8+10ZSQGLX++W4XVZQgixFQn3NrKU4sJhvamqtXnkq9UkBizGDEj3uiwhhGhCznNvB5+luGxkLvvkJHLfZyv5ZNkGr0sSQogmJNzbKeCzuOrQfAZlxvP3j0uYv0LmghdCdB+tHpbRWvuAL4Hlxphjmr12BnA7sDy86D5jzCMdVWR3lRCwuG5UAde8s4xbPyhm2phCds2KzvODhRCRpS1j7pOAhUDqdl5/3hhz4c6XFFlSgj5uGF3AX97+hRvmFnHLEYX07dH+q8qEEKIjtGpYRmudD4wHor433h49EvxMG1NA0GcxZU4RKzbUeF2SECLGtbbnfjdwBbCjO0ifqLU+FPgfcIkxpqj5Clrrc4BzAIwxZGW17zxxv9/f7m07S1YW3HtiOhe8+B1T5y7nQb03PZODXfKzu2N7eEnao4G0RVOx1B4thrvW+hhgtTHmK631qO2s9irwrDGmWmt9HvAkMLr5SsaYh4GHw0+d9s7x0F3nh0gBrhuVx3XvFHHhC99y69hCUrtgLvju2h5ekfZoIG3RVDS0R/hOTC1qzbDMSOBYrfVS4DlgtNb66cYrGGPKjTFbplr7F7B/60uNLoMyE7hmVB6rKmuZOreYqtqQ1yUJIWJQi+FujLnKGJNvjOkLnAzMMcac2ngdrXXvRk+PxT3wGrP2yk7iykPyWLpmMzfLXPBCCA+0+zx3rfU0rfWx4acTtdYLtNbfAhOBMzqiuEh2QH4ykw7qzYLVm7j9o+XU2ZF3dxshRORSHt5SyykpKWnXhpE0bjbrf2t48ItVHNo3lUtG9MbqhJkkI6k9uoK0RwNpi6aioT3CY+4tBonMLdPJxu3Sg401NtO/LSUpYHHuAdkyVbAQotNJuHeBE/fMoLImxIyFFSTF+ThtSE+vSxJCRDkJ9y6glOL0fXuysTbEiwvKSYqzOGGPTK/LEkJEMQn3LqKU4rwDcthYY/PkN6Ukx/k4cqBMFSyE6BwS7l3IZykuGZHL5rpiHvhsJYkBi4P7bG+qHiGEaD+Z8reLBXyKKw/JY/eeCdw1r4Svlld6XZIQIgpJuHsg6Le4dlQ+hWlB/vrhchasrvK6JCFElJFw90hSnI8powvISgxw03vFLK7Y7HVJQogoIuHuofR4d6rgxIDF1DlFFK+vbnkjIYRoBQl3j/VMCjBtTCEAU94tonRjrccVCSGigYR7N5CXGsfU0QVsqrW5/t0i1m6u87okIUSEk3DvJvpnxHPdqHzKqmqZOqeIyhqZKlgI0X4S7t3I7r0SuerQPIrWVXOTTBUshNgJEu7dzH65yUwekcuisk389YPl1IZkqmAhRNtJuHdDI/ukcv6BOXy9YiN3zSshJHPBCyHaSKYf6KaOHJhOVW2Ix78uJTGwkguG5chUwUKIVpNw78aO2z2TymqbFxaUkxTn44x9e0rACyFaRcK9mztlnyw21ob4v4UVJMdZ/HZwltclCSEigIR7N6eU4k9Ds6mqsXn62zKS43yM26WH12UJIbo5CfcIYCnFRQf1ZmOtzUNfrCIhYDGqX5rXZQkhujE5WyZC+C3FFYfksmd2Ivd8soLPizd4XZIQohuTcI8gcT6Law7LY0BGPH/7sITvV230uiQhRDcl4R5hEgM+rj+8gJyUADe9t5wfyzd5XZIQohuScI9AqUEfN4wucL/PKWLBShmiEUI0JeEeoTITA0wbU4DfZ3HO899ywauLefa7UpatkznhhRBytkxE650Sx71H92V+ucNb/y3h+e/Lee77cgrT4hhZmMrIPikUpAW9LlMI4QEJ9wiXGu/nhH2yODQvQMWmOj5ZtoGPl63nue/LePb7MvqkBRnRJ4WDC1PIl6AXImZEXLg7jkNo9QqwAl6X0u1kJPgZv2sPxu/ag/KqWj4tquSjX9bz3HdlPPtdGX3Sg4wsTGFknxTyUyXohYhmkRfub7xA+dszUBdcg9plsNfldFuZiYEmQf9J0QY+/mUD//6ujH9/V0bf9CAj+6QwsjCVvNQ4r8sVQnQw5TieTSfrlJSUtH2jilLUvdMIrV6Bdc7lqCHDOqG0yJKVlUVZWVmr1i2vqmXesg18vGwDC0vd0yj79Qj36AtTyY2CoG9Le0Q7aYumoqE9cnNzAVqcQTDiwh0gI85P6ZRJsOxn1OkXYY0Y08GlRZb2/sKWbQn6XzbwQ1lD0B8cPhjbOyUygz4a/gN3FGmLpqKhPaI63LOysigtXob9wK2w8FvUb8/EOvK4Di4vcnTEL2zpxoYe/aJw0PfvEWRkn1RGFkZW0EfDf+COIm3RVDS0R2vDPeLG3LdQ8YlYF12P/egdOC88hl25HnX8aTLfeTv1TAowYfcMJuye0Sjo1zN9finT55cyICPonl5ZmEJOBAW9ELEqYsMdQAUCWOdcjvPMQzizXoTK9XDq+SjL53VpEa1x0K+urGVe0Xo+/mUDT80v5an5pQzMiK8/6yY7WYJeiO4oosMdcIP81PMhORXnDYOzsRLr7EtRATlVsiP0Sg5w3O6ZHLd7Jqsqa+qHbp6cX8qT80sZlBnPiMIURhZK0AvRnUR8uIN7Qwt1/KnYySk45lHseyuxLrgaFZ/odWlRJTs5juP3yOT4Pdyg/zh8MPbJb0p58ptSeiUF6NcjSP8e8fTrEaRvjyC9kgIyVCaEB6Ii3Lewxk7ATkrBefJe7L9fizVpCipFbmrRGbKT4zhhj0xO2COTlRtq+LR4Az+Wb2bJmmo+L65ky2H6pDiLfulB+oUDv1+PeArS4gj4ZFojITpTVIU7gDViNE5SMvZDf8P+21+wLp6GyuzpdVlRLScljuN2z6x/vrnO5pe11SxZ44b9kjWbefuntVSH3Mj3KShIC9aHvdvLjyc1KMdKhOgorQ53rbUP+BJYbow5ptlrQeApYH+gHDjJGLO0A+tsE7XPgVgX34B9303Yt12JdckNqN4FXpUTc+L9FrtmJbBrVkL9spDtsKKyhqVrqusDf/7KKuYuWV+/Tlaiv1EP3w3+7OQAlgzrCNFmbem5TwIWAqnbeO0sYI0xZqDW+mTgNuCkDqiv3dQue2Jdfgv23VPcHvzEKah+u3hZUkzzWYr81CD5qUEO7tOwfO3mOpauqWZxo17+VyWV2OFxnQS/Rd8eTXv5hWlBgn4Z1hFiR1oV7lrrfGA8cDMweRurTACmhh+/CNyntVbGGM+ukAJQBf2wrrwN+67rse+4FuvPV6P2GOJlSaKZ9Hg/Q3r7GdI7qX5ZdZ3NsnUNPfwla6qZs3g9m+vWAmApyEuNqw/7/j3i6dsjSHp81I0yCtFurf3fcDdwBZCyndfzgCIAY0yd1nodkAk0uRRMa30OcE54PbKystpTM36/v/XbZmURuu1frJ12CXX/mEbaxVOJHzm6XT+3u2pTe0SIvBw4qNFz23EoWbeZH0s38mPZRn4qreSH0o18sLRhWCczKY5BWUnkpq8nKzFAz+Q4eiUH6Znifk8IxN6YfjT+buyMWGqPFsNda30MsNoY85XWetR2VtvWoOhWvXZjzMPAw1teb+9lwO25hNi55Ea470bW3XEd61csxxo1rl0/uzuKhkuqWyMe2KsH7NUjCQa5Pf311SGWhnv3i9ds5pe1Vfx31QbWb67bavukOIushACZiX4yE/1kJQbISPSTlegnM9FdnhSwourUzVj53WitaGiP8PQDLWpNz30kcKzW+mjc/1+pWuunjTGnNlqnGCgAirXWfiANqGhbyZ1LJSVjXTwN+6HbcJ75J/bGDaijfxtV/5FjUWrQx945Seyd0zCsk5WVxfKVq6nYVEdZVS3lVXWUVdVR3ujx4jWbWbs5tNX7xftVfdBnJfrJTAjU7wi27BRSgz75vRHdXovhboy5CrgKINxzv6xZsAPMBE4HPgF+A8zxerx9W1QwiPXnq3GevBfn/552pyv47ZkoSw7ORZug36J3StwOJzyrDTms2eSGfllVHeWbwt/DO4JvV1axZlNd/cHdLQKWqg/6zMRAuOff+HGAtKAPnyU7AOGddh+B0lpPA740xswEHgWma61/wu2xn9xB9XU45ffDHy+GpBScd2a6AX/6RHe5iCkBn6JXcoBeydufqiJkO6zdvCXwG/4S2PJ4Udkm5lXVUddsD+BT7p2xmvwV0Cj8MxP9ZCT4ZQcgOk3ETvm7s+NmjuPgvPGC24PfayjWuVeigpF567loGEfsSF3dHrbjsL46REXV9oeByqpqqQk1/b9mKfdsocxm4/5ZjXYIGQn+nbqaV343moqG9oj6KX93llIKNV670xX8+0Hsu6dgXXQtKjHZ69JEhLGUIj3eT3q8n/4Z8dtcx3EcNtbYDT3/8PGAso3uTqBoXQ3frKhic5291bZp8T6ytjn8Ez4onOCX8/7FVmI23LewRo1zA/7RO7Fvvxpr0lRUeobXZYkoo5QiOegjOeijb4/tr1dVG2oy7t/4L4CVlbUsWF1FZc3WO4CUoC98ALhp+BdWWtRsrCLOrwj6LILh71uey7BQ9Ir5cAewDjgYJykJ+4Fbw/PR3IDq1dvrskQMSgz4KEzzUZi2/SHCzXX2NsN/y+MfyzezrnrLmUArd/jz/BZu2PsUQb/VKPjd53Hb2CHUP9+yTfPn9cvDy3wWfgs5w6iLSbiHqT32xZp8I/a909z5aC6+AVXQz+uyhNhKvN8iLzWOvB3czLwmZFNRVQfxyawuX0N1nUNNyKY65FBdZ1MdssPLmj93v1eHbNZtDlEdqm1YJ7y8+dlDrWEp9yyjOL9FnKUI+BRxPkXA5+4MGj+Pa/w4vG7QZzVaRxHXZL2G5wGfIs4Kv+ZXBKzY3bHE7AHV7XFKlmHfPRU2b8K66DrUoD065ed0pGg4SNSRpD0adEZb1NlbdghbQt92dwCNdgw1dY2eh1+vtd2dR03I3anUhh/XhhyqGz+3nfA6NrXhdXeGgvodQ5zfB46DT4FlKfe7UviUwrLcx5bCfa7cOZG2fr6j17azbqP39ykY0juJfj22fXymJXJAtZ1UbqE7H83d12PfdT3WeVei9j7A67KE6Db8lsIf5yOp5VU7hOM41NUHftPQb/LcdsLLG3YajXcmNSGbQFw8GzdtwnbAth1sB0KOg+04hMLLQg71z0O2Q40DtmO7y+yG17b1vOl7uttva990fsBqd7i3loT7NqjMnlhX/BX7nhuw778ZdcYkrIMO97osIWKSUu5wS8DHTu9QvPqrrumOwCHQBRdOyvlT26FS0rAuuwl2GYzz2F3Y77zidUlCiAhlhXdQQb9FYsBHwNf5xwAk3HdAxSdiTbwe9jsI5/lHsWc8jYfHKIQQotUk3FugAnFY516BOuRInDcMztP/xLG3nnBKCCG6ExlzbwVl+eC0CyA5BWfWSzgb12OddSkqsP05SYQQwksS7q2klEKdcDp2cirOC49jV2107+wUn9DyxkII0cVkWKaNrCOPR/1xEiz6HvvuKTi1NV6XJIQQW5FwbwdrxBisP10GP/+A88Q/5CCrEKLbkWGZdlJDD0atKnGnDM4tQI3XXpckhBD1JNx3gjr6t1BShPN/T+P0LkDtd1DLGwkhRBeQYZmdoJRCnX4h9NsF+9E7cZYt9rokIYQAJNx3mopz78tKUgr2/TfhrFvjdUlCCCHh3hFUegbWBddA5XrsB26RM2iEEJ6TcO8gqs8ArDMnw+JFOE/dJ2fQCCE8JeHegdT+I1ATfo/z6Xs4b77kdTlCiBgmZ8t0MDX+JPcMmhnTcXrno4YM97okIUQMkp7aIzswAAARWklEQVR7B1NKoc6YCH0GYj9yJ07REq9LEkLEIAn3TqDiglgXXA0JSdj33YSzXs6gEUJ0LQn3TqLSM7EuvAYq12E/cCtOba3XJQkhYoiEeydSfQZi/fFidw6a6XIGjRCi68gB1U6mhh6MKinCefVZyC1E/epEr0sSQsQACfcuoH59Mqwownn5KZycfNSQYV6XJISIcjIs0wXcM2gmQeEA9wya4qVelySEiHIS7l1EBYPuFAXxCeEzaNZ6XZIQIopJuHch1SPTDfj1a7H/KWfQCCE6j4R7F1P9Brm36ftpIc7TD8gZNEKITiEHVD1gHXAIdkkRzmvPuWfQHHW81yUJIaKM9Nw9on59Muw/AuelJ3C+/cLrcoQQUUbC3SPKsrD+eAkU9Mf+199xlv/idUlCiCgi4e6hJmfQ/ONGnA3rvC5JCBElJNw9pjKy3EnGtpxBUydn0Aghdl6LB1S11vHAB0AwvP6LxpgpzdY5A7gdWB5edJ8x5pGOLTV6qX67oE6/COeRO3Ce/iecfhFKKa/LEkJEsNacLVMNjDbGVGqtA8BHWutZxphPm633vDHmwo4vMTZYww7DXlGE87pxz6A58jivSxJCRLAWw90Y4wCV4aeB8JecnN0J1LG/x1lRjPPiE+5dnPYa6nVJQogI1aoxd621T2s9H1gNzDbGfLaN1U7UWn+ntX5Ra13QoVXGCGVZWGdeDAV9sR++HWf5Mq9LEkJEKNWWKyS11unADOAiY8x/Gi3PBCqNMdVa6/MAbYwZvY3tzwHOATDG7F9TU9Ouov1+P3V1de3aNhKEylZRcflZqGA8GX97BCs1fYfrR3t7tJW0RwNpi6aioT3i4uIAWjwo16ZwB9BaTwE2GmP+vp3XfUCFMSathbdySkpK2vSzt8jKyqKsrKxd20YKZ/Ei7Nuvhv67YF0yDeUPbHfdWGiPtpD2aCBt0VQ0tEdubi60ItxbHJbRWvcM99jRWicARwA/NFund6OnxwIL21Ks2Jrqv6t7o+3/LcD590MyB40Qok1ac7ZMb+DJcI/cAowx5jWt9TTgS2PMTGCi1vpYoA6oAM7orIJjiTXsMOySZThvvAC5BagjJnhdkhAiQrR5WKYDybBMKzi2jf3gX2H+51gXXYfaa/+t1oml9mgNaY8G0hZNRUN7dNiwjPCWewbNJZDXB/tft+OsKPK6JCFEBJBwjwAqPgHrwmshEOfOQVO53uuShBDdnIR7hFCZPbH+fDWsKcN+8DaZg0YIsUMS7hFEDdgNdfpFsOh7nGcfljNohBDbJXdiijDW8MPdM2hmveTOQTPm116XJITohqTnHoHUcafBkGE4zz+K85+vvS5HCNENSbhHIGVZWGdNhrxC7Idvp65oqdclCSG6GQn3CFV/Bo3fT8WVZ2O//yaObXtdlhCim5Bwj2AqsxfWVbcTGLQHztMPYN95Hc7q9l0YJoSILhLuEU71zCF96j2oP1wIyxZjT52I/dYMnFDI69KEEB6ScI8CSimsQ47EmnYf7LkvzouPY996OU7xEq9LE0J4RMI9iqj0TKw/X4065wqoKMW+aTL2K//GqZULnoSINRLuUUYphXXAwVjT7kcdcAjOa89h33gxzs8/tLyxECJqSLhHKZWcinXWZKyJ18PmTdi3XYn9/KM41Zu9Lk0I0QUk3KOc2mso1g33oQ77Fc47r2BPvQhn4bdelyWE6GQS7jFAJSRinXI+1uW3gOXDvvM67Kfuw6mq9Lo0IUQnkXCPIWqXwVhT7kEddQLOR+9gX38hzvxPvS5LCNEJJNxjjIoLYv3mDKyrb4eUVOz7b8F++Hac9Wu9Lk0I0YEk3GOU6jsI65o7URNOwfnmE+zrL8D+dK5MIyxElJBwj2HK78c65iSs6+6G7FycR+9y7/RUUep1aUKInSThLlC5hVhX/hV10tmw6HvsKRdiv/eGTEQmRASTcBcAKMuHdcSxWFP/Af12wXnmQew7rsFZJRORCRGJJNxFE6pnDtYl09zb+RUtxb5hIvabL8lEZEJEGAl3sRWlFNbBY8MTke2H89KT7kRkRTIRmRCRQsJdbJc7EdlVWOeGJyK7eTL2/z0tE5EJEQHkBtlih5RSMPRgrN32xjGP4rxucL7+BOv0i1ADdvO6PCHEdkjPXbSKSk7FOvMSrIlToDo8Edlz/5KJyITopiTcRZuovfZ3JyIbNQ7n3Vexp1yI89/5XpclhGhGwl20mYpPxPr9eViX3wo+P/Zd12M/cS9ORZnXpQkhwmTMXbSb2mVPrCn34Lz2HM5bM3A+fgcG7Iba7yDUfiNQWdlelyhEzJJwFztFxQVRJ5yOc/BYnC8+wvl6Hs4Lj+O88Dj0GYjaf4Qb9Nm5XpcqREyRcBcdQvXKRY3XMF7jlK50Q/6reTgvP4Xz8lOQ39cN+f1HoHILvS5XiKgn4S46nOqZgzrqBDjqBJyKUpyvP3GD/tVncWb+G3Ly63v0FPRzT7cUQnQoCXfRqVRGT9QRx8IRx+KsrcD55lOcrz7GeeNFnNcN9MwJ9+hHQt+BEvRCdBAJd9FlVHoG6vCj4fCjcTasc4P+63k477yC89bLkNGzfuiG/ruiLDmZS4j2knAXnlApaahDj4JDj8LZuAHn28/doZv3Xsd55xVIz0Dte5Ab9IP2QFk+r0sWIqJIuAvPqaQU1IgxMGIMzqYqnO++cHv0H83Gmfs6pKSh9h3uBv0ue6H88msrREvkf4noVlRCImrYYTDsMJzNm+A/X7kHZD97H+eDtyApBTXkQPdg7O5DUIGA1yUL0S21GO5a63jgAyAYXv9FY8yUZusEgaeA/YFy4CRjzNIOr1bEFBWfAEMPRg09GKemGv77jTt08/UnOB+/CwmJqL0PcA/G7rkvKi7odclCdButOWJVDYw2xuwDDAF+pbUe3myds4A1xpiBwF3AbR1bpoh1Ki6IGjIc66zJWHdMx5p4PWq/ETj/+Rr7gVuwJ5+G/dDfsL/4kFDZKrnRt4h5LfbcjTEOUBl+Ggh/Nf+fMwGYGn78InCf1lqFtxWiQ6lAAPYaitprKE5dHfzvP26P/ptP4MuPKHsYSEx2z6Ev6Od+z+8HuQUovwzjiNjQqjF3rbUP+AoYCNxvjPms2Sp5QBGAMaZOa70OyATKmr3POcA54fXIyspqX9F+f7u3jUYx3x45OXDoETihELU/LcRe+hM1ixdRu+RH6j54C2qq3d6I348/vy/+voPw9x1IoN8g/H0HYaWmef0JOk3M/240E0vt0apwN8aEgCFa63RghtZ6sDHmP41W2daVJ1v12o0xDwMPb3m9rKx9swhmZWXR3m2jkbRHI5k5ZO06mMpwe1h2CFatwCleAkWLqStaSt03n8F7sxq26ZHlTo9Q0B9V0BcK+rsXV0XBefbyu9FUNLRHbm7r5mlq09kyxpi1Wuv3gF8BjcO9GCgAirXWfiANqGjLewvRGZTlg975qN75cMAh9cud9WuheAlO0dLw9yU4C77GsW13hWA85PVxh3Xyw8M7+X1RwXhvPogQbdSas2V6ArXhYE8AjmDrA6YzgdOBT4DfAHNkvF10Zyo1HfbYF7XHvvXLnNoaKCkK9/LDgf/5h/D+m+6foUpBr1xUft9G4/n93QuuZNoE0c20pufeG3gyPO5uAcYY85rWehrwpTFmJvAoMF1r/RNuj/3kTqtYiE6iAnHQZwCqz4D6ZY7jQEUpFC3GKVqKU7QYZ9nP8NXHDeOOySlQ0D8c+uGhnZwCudhKeEp5eMqYU1JS0q4No2HcrCNJezTVFe3hbKqCYjfsKV6Ks2wxlCyD2hp3Bctyx/IzslAZPSH8pTIbPU5I7NQaQX43mouG9giPubf4p6J0LYRoB5WQ6M55M2iP+mVOKASrluMULYGSIqgodac8/vkH+PIjCIWanmWQkAQZWVuH/padQXoGyidz6oj2kXAXooMonw9yC7d5MxLHDsG6tfWBT0UplDc8dhYvgo0b3HXr39CCHhlNA3/L40x3p0BCkoz3i22ScBeiCyjLBz0yoUcmasBu21zH2bwJ1pQ1CX13Z1CGs+R/8NU8CNU17f3HJzTr8WdBZsPOwElP75LPJ7ofCXchugkVnwC9C6B3wbYvHLFtWL+2UeiXQkUZTnn4+dIfoXK9u254m9WWD7KyISfPvY9tTh4qOw+y8yCth/T6o5iEuxARQlkWpGe4X/133fYOoLoa1oTDvryUhI3rqVryk3ssYOG3UFvT0POPT4DscOhn50F2LionH7J7o+I7/2Cv6FwS7kJEERUMQk6+e59aICUri+rw2SGObbvDPquW46xcDqtKcFYudw/4fvEhOE5D8KdlNPT2s8O9/Zw8yOwlp3hGCPlXEiJGKMuCzF5uQDe6eAtwp1QuXQkrl+OsCgf/quU4X8+Dyg0Noe/zQc+chsDPzm0I/tR0GebpRiTchRDuXPh5fdwpF5q95lSur+/ls6pR+C/4BupqG4I/IdG9gndL2GfnonLy3GXxCV38iYSEuxBih1RyKiSnbnWWj2OHoKJsq+B3fl4IX3zQdJgnGO9O31D/ZbmX4ShrJ5c3eh3ci8e297pSrElIIGQ74Pe70z/7AxAIgM/f8Njf+Mtf/10FtrW8+Tb+Jo+9/EtGwl0I0S5qy5k4WdmoPbcxzLN6RcP4/sYN7ik8TnhiNtsGxwEcsJ2tl9d/2Q3bhZc521m+vfXrl4ds7Mo62LwJ6upw6mqhthbqaiFU1/B4O9p1Lb/fD74ABPxNdhrq2N9hNZrIrjNIuAshOpyKC0J+X3cmTa+LaSSzhekHHMdxg76uFmrrILRlBxBetmWHUL9OrbuTqP9qWN5km2bbq6TkTv+sEu5CCBGmlGroYbdydufutPNqLPLvRiCEEGIrEu5CCBGFJNyFECIKSbgLIUQUknAXQogoJOEuhBBRSMJdCCGikIS7EEJEIU9vkO3VDxZCiAjX4rVTXvbcVXu/tNZf7cz20fYl7SHtIW0Rc+3RIhmWEUKIKCThLoQQUShSw/1hrwvoZqQ9mpL2aCBt0VTMtIeXB1SFEEJ0kkjtuQshhNgBCXchhIhCEXezDq31r4B7AB/wiDHmrx6X5AmtdQHwFJAD2MDDxph7vK3Ke1prH/AlsNwYc4zX9XhJa50OPAIMxr2u5ExjzCfeVuUNrfUlwNm47fA98EdjzGZvq+pcEdVzD//HvR8YB+wB/E5rvYe3VXmmDrjUGLM7MBy4IIbborFJwEKvi+gm7gHeNMbsBuxDjLaL1joPmAgMNcYMxu0YnuxtVZ0v0nruBwI/GWMWA2itnwMmAP/1tCoPGGNWACvCjzdorRcCecRgW2yhtc4HxgM3A5M9LsdTWutU4FDgDABjTA1Q42VNHvMDCVrrWiARKPG4nk4XUT133PAqavS8OLwspmmt+wL7Ap95XIrX7gauwB2minX9gVLgca31N1rrR7TWSV4X5QVjzHLg78Ay3A7ROmPM295W1fkiLdy3ddltTJ/LqbVOBl4CLjbGrPe6Hq9orY8BVhtjvvK6lm7CD+wH/NMYsy+wEfiLtyV5Q2vdA/cv/H5ALpCktT7V26o6X6SFezFQ0Oh5PjHw59X2aK0DuMH+jDHmZa/r8dhI4Fit9VLgOWC01vppb0vyVDFQbIzZ8tfci7hhH4uOAJYYY0qNMbXAy8AIj2vqdJE25v4FMEhr3Q9YjntQ5PfeluQNrbUCHgUWGmPu9LoerxljrgKuAtBajwIuM8ZEfe9se4wxK7XWRVrrXY0xi4AxxO7xmGXAcK11IrAJty2+9LakzhdRPXdjTB1wIfAW7pF/Y4xZ4G1VnhkJnIbbQ50f/jra66JEt3IR8IzW+jtgCHCLx/V4IvzXy4vA17inQVrEwDQEMv2AEEJEoYjquQshhGgdCXchhIhCEu5CCBGFJNyFECIKSbgLIUQUknAXQogoJOEuhBBR6P8B0B/mW2/7QF0AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "use_gpu = torch.cuda.is_available()\n", "loaders = get_loader(config)\n", "\n", "# initial our model\n", "model = RNNModel(config.input_dim, config.hidden_dim, config.rnn_layers)\n", "if use_gpu:\n", " model = model.cuda()\n", "\n", "# define our loss\n", "criterion = nn.MSELoss()\n", "# define our optimizer, usually we use adam or sgd\n", "# you should tune the lr usually from 0.1 to 1e-4\n", "optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=5e-7)\n", "\n", "losses = {'train': [], 'val': []}\n", "for epoch in range(config.epochs):\n", " print(f'Epoch: [{epoch}/{config.epochs}]')\n", " training_size, val_size = 0, 0\n", " training_loss, val_loss = 0, 0\n", " \n", " # before training, do not forget to set the model to be training mode\n", " model.train()\n", " \n", " # use loader to load our training data\n", " for i, (inputs, labels) in enumerate(loaders['train']):\n", " \n", " # float our inputs and labels\n", " inputs, labels = inputs.float(), labels.float()\n", " if use_gpu:\n", " inputs = inputs.cuda()\n", " labels = labels.cuda()\n", " outputs = model(inputs)\n", " loss = criterion(outputs.view(-1, ), labels)\n", "\n", " # after getting the loss, we need to do gradient desent by the following three lines\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " training_size += inputs.size()[0]\n", " training_loss = training_loss + loss.item() * inputs.size()[0]\n", "\n", " if i % 80 == 0:\n", " print(\n", " f\"iters: [{i}/{len(loaders['train'])}], training_loss: {training_loss / training_size:.4f}\"\n", " )\n", "\n", " losses['train'].append(training_loss / training_size)\n", " \n", " # before do validation, do not forget to set the model to be eval mode,\n", " # because it will affect batch_normalization and dropout layers\n", " model.eval()\n", " \n", " # and set torch.no_grad() to save GPU memory and be faster\n", " with torch.no_grad():\n", " \n", " # load our validation data\n", " for i, (inputs, labels) in enumerate(loaders['val']):\n", " inputs, labels = inputs.float(), labels.float()\n", " if use_gpu:\n", " inputs = inputs.cuda()\n", " labels = labels.cuda()\n", " outputs = model(inputs)\n", " loss = criterion(outputs.view(-1, ), labels)\n", "\n", " val_size += inputs.size()[0]\n", " val_loss += loss.item() * inputs.size()[0]\n", "\n", " losses['val'].append(val_loss / val_size)\n", " print(f\"val_loss: {val_loss / val_size:.4f}\")\n", "\n", "# plot the training loss and val loss\n", "plt.title('loss')\n", "plt.plot(losses['train'], label='training_loss')\n", "plt.plot(losses['val'], label='val_loss')\n", "plt.legend()\n", "plt.show()\n", "\n", "# test our model on new data (here we simple use the validation data)\n", "close_real = []\n", "close_predict = []\n", "for i in range(10):\n", " val_input, label = val_dataset[i]\n", " close_real.append(label * (val_dataset.std + 1e-8) + val_dataset.mean)\n", " val_input = torch.tensor(val_input).float().unsqueeze(0)\n", " if use_gpu:\n", " val_input = val_input.cuda()\n", " val_output = model(val_input)\n", " close_predict.append(val_output.item() * (val_dataset.std + 1e-8) +\n", " val_dataset.mean)\n", "\n", "plt.title('RNN')\n", "plt.plot(close_real, label='real close')\n", "plt.plot(close_predict, label='predicted close')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Further Improvement\n", "* Adjust learning rate during epochs, eg, step wise or cosine decay\n", "* Use different inilization\n", "* Clip gradient norm of parameters in case of gradient blow up\n", "* Add more layers or use bidirection to enlarge model\n", "* More data\n", "* ...\n", "\n", "For more tricks on RNN, you can refer this [tricks](https://danijar.com/tips-for-training-recurrent-neural-networks/)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### CNN (Time Series):\n", "* Text CNN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://drive.google.com/uc?id=1GdNE4-llPnsDeCXp5cRHPZjZQWziiPLb)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class TextCNN(nn.Module):\n", " def __init__(self, input_dim, kernels=[2, 3, 4], channels=64):\n", " super(TextCNN, self).__init__()\n", " self.input_dim = input_dim\n", " self.dropout = nn.Dropout(.5)\n", " self.kernels = kernels\n", " self.channels = channels\n", "\n", " # here we define our cnn layers\n", " self.convs = self._make_conv_layers(1, self.channels, self.kernels)\n", "\n", " hidden_dim = self.channels * len(self.kernels)\n", " self.regressor = nn.Sequential(nn.Linear(hidden_dim, hidden_dim // 2),\n", " nn.ReLU(inplace=True), nn.Dropout(.5),\n", " nn.Linear(hidden_dim // 2, 1))\n", "\n", " def forward(self, s):\n", " conv_feats = [\n", " conv(s.unsqueeze(1)).squeeze(3) for conv in self.convs\n", " ]\n", " conv_feats_max = torch.cat(\n", " [torch.max(feats, dim=2)[0] for feats in conv_feats], dim=1)\n", "\n", " output = self.regressor(conv_feats_max)\n", " return output\n", "\n", " def _make_conv_layers(self, input_channel, channels, kernels):\n", " convs = []\n", " for kernel in kernels:\n", " convs.append(\n", " nn.Sequential(\n", " nn.Conv2d(input_channel,\n", " channels, (kernel, self.input_dim),\n", " stride=1), nn.BatchNorm2d(channels),\n", " nn.ReLU(inplace=True)))\n", " return nn.ModuleList(convs)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: [0/10]\n", "iters: [0/83], training_loss: 2.8204\n", "iters: [80/83], training_loss: 4.1512\n", "val_loss: 4.9487\n", "Epoch: [1/10]\n", "iters: [0/83], training_loss: 19.0332\n", "iters: [80/83], training_loss: 3.8837\n", "val_loss: 4.7732\n", "Epoch: [2/10]\n", "iters: [0/83], training_loss: 2.3307\n", "iters: [80/83], training_loss: 3.7391\n", "val_loss: 4.6696\n", "Epoch: [3/10]\n", "iters: [0/83], training_loss: 3.1343\n", "iters: [80/83], training_loss: 3.6378\n", "val_loss: 4.6036\n", "Epoch: [4/10]\n", "iters: [0/83], training_loss: 7.6225\n", "iters: [80/83], training_loss: 3.4913\n", "val_loss: 4.5464\n", "Epoch: [5/10]\n", "iters: [0/83], training_loss: 1.6343\n", "iters: [80/83], training_loss: 3.4901\n", "val_loss: 4.4952\n", "Epoch: [6/10]\n", "iters: [0/83], training_loss: 0.7859\n", "iters: [80/83], training_loss: 3.4487\n", "val_loss: 4.4516\n", "Epoch: [7/10]\n", "iters: [0/83], training_loss: 1.6998\n", "iters: [80/83], training_loss: 3.3861\n", "val_loss: 4.4020\n", "Epoch: [8/10]\n", "iters: [0/83], training_loss: 1.9916\n", "iters: [80/83], training_loss: 3.3331\n", "val_loss: 4.3623\n", "Epoch: [9/10]\n", "iters: [0/83], training_loss: 3.0594\n", "iters: [80/83], training_loss: 3.2131\n", "val_loss: 4.3253\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "use_gpu = torch.cuda.is_available()\n", "loaders = get_loader(config)\n", "model = TextCNN(config.input_dim, config.cnn_kernels, config.cnn_channels)\n", "if use_gpu:\n", " model = model.cuda()\n", "\n", "criterion = nn.MSELoss()\n", "optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=5e-7)\n", "\n", "losses = {'train': [], 'val': []}\n", "for epoch in range(config.epochs):\n", " print(f'Epoch: [{epoch}/{config.epochs}]')\n", " training_size, val_size = 0, 0\n", " training_loss, val_loss = 0, 0\n", "\n", " model.train()\n", " for i, (inputs, labels) in enumerate(loaders['train']):\n", " inputs, labels = inputs.float(), labels.float()\n", " if use_gpu:\n", " inputs = inputs.cuda()\n", " labels = labels.cuda()\n", " outputs = model(inputs)\n", " loss = criterion(outputs.view(-1, ), labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " training_size += inputs.size()[0]\n", " training_loss = training_loss + loss.item() * inputs.size()[0]\n", "\n", " if i % 80 == 0:\n", " print(\n", " f\"iters: [{i}/{len(loaders['train'])}], training_loss: {training_loss / training_size:.4f}\"\n", " )\n", " losses['train'].append(training_loss / training_size)\n", "\n", " model.eval()\n", " with torch.no_grad():\n", " for i, (inputs, labels) in enumerate(loaders['val']):\n", " inputs, labels = inputs.float(), labels.float()\n", " if use_gpu:\n", " inputs = inputs.cuda()\n", " labels = labels.cuda()\n", " outputs = model(inputs.float())\n", " loss = criterion(outputs.view(-1, ), labels.float())\n", "\n", " val_size += inputs.size()[0]\n", " val_loss += loss.item() * inputs.size()[0]\n", "\n", " losses['val'].append(val_loss / val_size)\n", " print(f\"val_loss: {losses['val'][-1]:.4f}\")\n", "\n", "plt.title('loss')\n", "plt.plot(losses['train'], label='training_loss')\n", "plt.plot(losses['val'], label='val_loss')\n", "plt.legend()\n", "plt.show()\n", "\n", "close_real = []\n", "close_predict = []\n", "for i in range(10):\n", " val_input, label = val_dataset[i]\n", " close_real.append(label * (val_dataset.std + 1e-8) + val_dataset.mean)\n", " val_input = torch.tensor(val_input).float().unsqueeze(0)\n", " if use_gpu:\n", " val_input = val_input.cuda()\n", " val_output = model(val_input)\n", " close_predict.append(val_output.item() * (val_dataset.std + 1e-8) +\n", " val_dataset.mean)\n", "\n", "plt.title('TextCNN')\n", "plt.plot(close_real, label='real close')\n", "plt.plot(close_predict, label='predicted close')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Further Improvement\n", "* Adjust learning rate during epochs, eg, step wise or cosine decay\n", "* Use different inilization\n", "* More cnn layers\n", "* Use more kernels and larger channels\n", "* Adjust dropout rate\n", "* More data\n", "* ...\n", "\n", "For more tricks on CNN, you can refer this [tricks](https://towardsdatascience.com/a-bunch-of-tips-and-tricks-for-training-deep-neural-networks-3ca24c31ddc8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Second Part" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GAN\n", "### We want to build a GAN which can generate the next price based on previous market\n", "* Generator: RNN used to generate the next price\n", "* Discriminator: MLP or RNN used to classify whether the generated price is real or fake" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://drive.google.com/uc?id=1H7hXQxXCB5K7clYOruIPcGne6PZ2Xu3K)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "class Generator(nn.Module):\n", " def __init__(self, model_type='rnn'):\n", " super(Generator, self).__init__()\n", " \n", " # we can choose different model based what we had writen above\n", " if model_type == 'rnn':\n", " self.g_model = RNNModel(config.input_dim, config.hidden_dim,\n", " config.rnn_layers)\n", " elif model_type == 'cnn':\n", " self.g_model = TextCNN(config.input_dim, config.cnn_kernels,\n", " config.cnn_channels)\n", "\n", " def forward(self, s):\n", " return self.g_model(s)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "class Discriminator(nn.Module):\n", " def __init__(self, model_type='rnn'):\n", " super(Discriminator, self).__init__()\n", " \n", " # the same as generator\n", " if model_type == 'rnn':\n", " self.d_model = RNNModel(1, config.hidden_dim,\n", " config.rnn_layers)\n", " elif model_type == 'cnn':\n", " self.d_model = TextCNN(1, config.cnn_kernels,\n", " config.cnn_channels)\n", "\n", " def forward(self, s):\n", " return self.d_model(s)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: [0/10]\n", "iters: [0/83], training_g_loss: 0.6500 training_d_loss: 1.3865\n", "iters: [80/83], training_g_loss: 0.6651 training_d_loss: 1.3830\n", "val_g_loss: 0.6767 val_d_loss: 1.3791\n", "Epoch: [1/10]\n", "iters: [0/83], training_g_loss: 0.6757 training_d_loss: 1.3787\n", "iters: [80/83], training_g_loss: 0.6921 training_d_loss: 1.3684\n", "val_g_loss: 0.7320 val_d_loss: 1.3150\n", "Epoch: [2/10]\n", "iters: [0/83], training_g_loss: 0.7306 training_d_loss: 1.3015\n", "iters: [80/83], training_g_loss: 0.6980 training_d_loss: 1.3662\n", "val_g_loss: 0.7066 val_d_loss: 1.3852\n", "Epoch: [3/10]\n", "iters: [0/83], training_g_loss: 0.7081 training_d_loss: 1.3878\n", "iters: [80/83], training_g_loss: 0.7105 training_d_loss: 1.3841\n", "val_g_loss: 0.7091 val_d_loss: 1.3826\n", "Epoch: [4/10]\n", "iters: [0/83], training_g_loss: 0.7111 training_d_loss: 1.3781\n", "iters: [80/83], training_g_loss: 0.7034 training_d_loss: 1.3694\n", "val_g_loss: 0.6917 val_d_loss: 1.3434\n", "Epoch: [5/10]\n", "iters: [0/83], training_g_loss: 0.6957 training_d_loss: 1.3519\n", "iters: [80/83], training_g_loss: 0.7027 training_d_loss: 1.3515\n", "val_g_loss: 0.7111 val_d_loss: 1.3590\n", "Epoch: [6/10]\n", "iters: [0/83], training_g_loss: 0.7133 training_d_loss: 1.3684\n", "iters: [80/83], training_g_loss: 0.7003 training_d_loss: 1.3682\n", "val_g_loss: 0.6533 val_d_loss: 1.3407\n", "Epoch: [7/10]\n", "iters: [0/83], training_g_loss: 0.6468 training_d_loss: 1.3560\n", "iters: [80/83], training_g_loss: 0.6478 training_d_loss: 1.4676\n", "val_g_loss: 0.6987 val_d_loss: 1.4350\n", "Epoch: [8/10]\n", "iters: [0/83], training_g_loss: 0.6970 training_d_loss: 1.4223\n", "iters: [80/83], training_g_loss: 0.8546 training_d_loss: 1.2868\n", "val_g_loss: 0.7373 val_d_loss: 1.3527\n", "Epoch: [9/10]\n", "iters: [0/83], training_g_loss: 0.7254 training_d_loss: 1.3441\n", "iters: [80/83], training_g_loss: 0.6804 training_d_loss: 1.3892\n", "val_g_loss: 0.6439 val_d_loss: 1.4037\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "use_gpu = torch.cuda.is_available()\n", "loaders = get_loader(config)\n", "\n", "# initial generator and discriminator\n", "# you can try different model type\n", "generator = Generator(model_type='rnn')\n", "discriminator = Discriminator(model_type='rnn')\n", "if use_gpu:\n", " generator = generator.cuda()\n", " discriminator = discriminator.cuda()\n", "\n", "# the loss is different from before because here we want to do classification \n", "# whether the generated sequence is real or fake\n", "criterion = nn.BCEWithLogitsLoss()\n", "optimizer_g = optim.Adam(generator.parameters(),\n", " lr=config.lr,\n", " weight_decay=5e-7)\n", "optimizer_d = optim.Adam(discriminator.parameters(),\n", " lr=config.lr,\n", " weight_decay=5e-7)\n", "\n", "losses = {'train_g': [], 'train_d': [], 'val_g': [], 'val_d': []}\n", "for epoch in range(config.epochs):\n", " print(f'Epoch: [{epoch}/{config.epochs}]')\n", " training_size, val_size = 0, 0\n", " training_g_loss, val_g_loss = 0, 0\n", " training_d_loss, val_d_loss = 0, 0\n", "\n", " generator.train()\n", " discriminator.train()\n", " for i, (inputs, labels) in enumerate(loaders['train']):\n", " inputs, labels = inputs.float(), labels.float()\n", " if use_gpu:\n", " inputs = inputs.cuda()\n", " labels = labels.cuda()\n", " fakes = generator(inputs)\n", " \n", " # generator tries to make the generated sequence real in discriminator\n", " output_g = discriminator(\n", " torch.cat([inputs[:, :, 3], fakes], dim=1).unsqueeze(2))\n", " # so the label should be 1\n", " loss_g = criterion(output_g, torch.ones_like(output_g))\n", "\n", " optimizer_g.zero_grad()\n", " loss_g.backward()\n", " optimizer_g.step()\n", "\n", " # however, the discriminator tries to make the generated sequence by generator to be fake\n", " # while the true sequence to be real\n", " output_real = discriminator(\n", " torch.cat([inputs[:, :, 3], labels.unsqueeze(1)],\n", " dim=1).unsqueeze(2))\n", " output_fake = discriminator(\n", " torch.cat([inputs[:, :, 3], fakes.detach()], dim=1).unsqueeze(2))\n", " \n", " # real sequence should be real, so the label is 1\n", " loss_d_real = criterion(output_real, torch.ones_like(output_real))\n", " # generated seq should be fake, so the label is 0\n", " loss_d_fake = criterion(output_fake, torch.zeros_like(output_fake))\n", " loss_d = loss_d_real + loss_d_fake\n", "\n", " optimizer_d.zero_grad()\n", " loss_d.backward()\n", " optimizer_d.step()\n", "\n", " training_size += inputs.size()[0]\n", " training_g_loss = training_g_loss + loss_g.item() * inputs.size()[0]\n", " training_d_loss = training_d_loss + loss_d.item() * inputs.size()[0]\n", "\n", " if i % 80 == 0:\n", " print(\n", " f\"iters: [{i}/{len(loaders['train'])}], training_g_loss: {training_g_loss/training_size:.4f} \"\n", " f\"training_d_loss: {training_d_loss/training_size:.4f}\")\n", "\n", " losses['train_g'].append(training_g_loss / training_size)\n", " losses['train_d'].append(training_d_loss / training_size)\n", "\n", " generator.eval()\n", " discriminator.eval()\n", " with torch.no_grad():\n", " for i, (inputs, labels) in enumerate(loaders['val']):\n", " inputs, labels = inputs.float(), labels.float()\n", " if use_gpu:\n", " inputs = inputs.cuda()\n", " labels = labels.cuda()\n", " fakes = generator(inputs)\n", " output_g = discriminator(\n", " torch.cat([inputs[:, :, 3], fakes], dim=1).unsqueeze(2))\n", " loss_g = criterion(output_g, torch.ones_like(output_g))\n", "\n", " output_real = discriminator(\n", " torch.cat(\n", " [inputs[:, :, 3], labels.unsqueeze(1)],\n", " dim=1).unsqueeze(2))\n", " output_fake = discriminator(\n", " torch.cat([inputs[:, :, 3], fakes.detach()],\n", " dim=1).unsqueeze(2))\n", " loss_d_real = criterion(output_real, torch.ones_like(output_real))\n", " loss_d_fake = criterion(output_fake, torch.zeros_like(output_fake))\n", " loss_d = loss_d_real + loss_d_fake\n", "\n", " val_size += inputs.size()[0]\n", " val_g_loss = val_g_loss + loss_g.item() * inputs.size()[0]\n", " val_d_loss = val_d_loss + loss_d.item() * inputs.size()[0]\n", "\n", " losses['val_g'].append(val_g_loss / val_size)\n", " losses['val_d'].append(val_d_loss / val_size)\n", " print(\n", " f\"val_g_loss: {val_g_loss/val_size:.4f} val_d_loss: {val_d_loss/val_size:.4f}\"\n", " )\n", "\n", "plt.title('loss')\n", "plt.plot(losses['train_g'], label='training_g_loss')\n", "plt.plot(losses['train_d'], label='training_d_loss')\n", "plt.plot(losses['val_g'], label='val_g_loss')\n", "plt.plot(losses['val_d'], label='val_d_loss')\n", "plt.legend()\n", "plt.show()\n", "\n", "close_real = []\n", "close_predict = []\n", "for i in range(10):\n", " val_input, label = val_dataset[i]\n", " close_real.append(label * (val_dataset.std + 1e-8) + val_dataset.mean)\n", " val_input = torch.tensor(val_input).float().unsqueeze(0)\n", " if use_gpu:\n", " val_input = val_input.cuda()\n", " val_output = generator(val_input)\n", " close_predict.append(val_output.item() * (val_dataset.std + 1e-8) +\n", " val_dataset.mean)\n", "\n", "plt.title('GAN')\n", "plt.plot(close_real, label='real close')\n", "plt.plot(close_predict, label='predicted close')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Further Improvement\n", "You can refer this [GAN Tricks](https://github.com/soumith/ganhacks)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Thrid Part" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reinforcement Learning\n", "* Deep QLearning\n", "* Policy Gradient\n", "* DDPG\n", "* A3C" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://drive.google.com/uc?id=1lcj8jH1xEoYV_yGGH7sa2FXRQO5XvONU)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "class net(nn.Module):\n", " def __init__(self, input_size, output_size):\n", " super(net, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 10)\n", " self.fc2 = nn.Linear(10, output_size)\n", " self.fc21 = nn.Linear(10, 1)\n", "\n", " def forward(self, s):\n", " if len(s.size()) == 1:\n", " s = s.unsqueeze(0)\n", " s = self.fc1(s)\n", " s = F.relu(s, inplace=True)\n", " logits = self.fc2(s)\n", " values = self.fc21(s)\n", " return logits, values\n", "\n", "class RL_pg:\n", " def __init__(self, n_actions, observation_shape, lr=0.01, reward_decay=0.9):\n", " self.n_actions = n_actions\n", " self.observation_shape = observation_shape\n", " self.lr = lr\n", " self.gama = reward_decay\n", "\n", " self.s = []\n", " self.a = []\n", " self.r = []\n", " self.p = []\n", " self.v = []\n", "\n", " self.eval = net(self.observation_shape, self.n_actions)\n", " \n", " if torch.cuda.is_available():\n", " self.eval=self.eval.cuda()\n", "\n", " def store_transition(self, s, a, r):\n", " # store the states for further learning\n", " self.s.append(s)\n", " self.a.append(a)\n", " self.r.append(r)\n", "\n", " def choose_action(self, s):\n", " # choose action based on the state\n", " s = torch.tensor(s).float()\n", " if torch.cuda.is_available():\n", " s=s.cuda()\n", "\n", " p, v = self.eval.forward(s)\n", " self.p.append(F.softmax(p, -1))\n", " self.v.append(v)\n", " prob_actions = self.p[-1].cpu().data.numpy()\n", " action = np.random.choice(prob_actions.shape[1], p=prob_actions.reshape(-1, ))\n", " return action\n", "\n", " def learn(self):\n", " # before learning, the reward should be normalized\n", " normed_r = self.norm_reward()\n", "\n", " normed_r = torch.FloatTensor(normed_r)\n", "\n", " optimizer = optim.Adam(self.eval.parameters(), lr=self.lr)\n", "\n", " m = Categorical(torch.cat(self.p))\n", " a = torch.tensor(self.a).long()\n", " v = torch.tensor(self.v)\n", " \n", " if torch.cuda.is_available():\n", " a=a.cuda()\n", " v=v.cuda()\n", " normed_r=normed_r.cuda()\n", " \n", " # policy gradient loss\n", " loss = torch.mean(torch.mul(-m.log_prob(a), normed_r - v.data)) + \\\n", " F.smooth_l1_loss(v, normed_r)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # after learning, we should set the transition to empty for future learning\n", " self.s = []\n", " self.a = []\n", " self.r = []\n", " self.p = []\n", " self.v = []\n", " return normed_r.cpu().data.numpy()\n", "\n", " def norm_reward(self):\n", " normed_r = np.zeros_like(self.r)\n", "\n", " add = 0\n", " for i in reversed(range(0, len(self.r))):\n", " add = add * self.gama + self.r[i]\n", " normed_r[i] = add\n", "\n", " normed_r -= np.mean(normed_r)\n", " normed_r = normed_r / np.std(normed_r) if np.std(normed_r) != 0 else normed_r\n", "\n", " return normed_r" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting gym\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/7b/57/e2fc4123ff2b4e3d61ae9b3d08c6878aecf2d5ec69b585ed53bc2400607f/gym-0.12.1.tar.gz (1.5MB)\n", "\u001b[K 100% |████████████████████████████████| 1.5MB 6.7MB/s eta 0:00:01\n", "\u001b[?25hRequirement already satisfied: scipy in /home/yhuangcc/anaconda3/envs/mafs6010u/lib/python3.7/site-packages (from gym) (1.1.0)\n", "Requirement already satisfied: numpy>=1.10.4 in /home/yhuangcc/anaconda3/envs/mafs6010u/lib/python3.7/site-packages (from gym) (1.15.4)\n", "Requirement already satisfied: requests>=2.0 in /home/yhuangcc/anaconda3/envs/mafs6010u/lib/python3.7/site-packages (from gym) (2.21.0)\n", "Requirement already satisfied: six in /home/yhuangcc/anaconda3/envs/mafs6010u/lib/python3.7/site-packages (from gym) (1.12.0)\n", "Collecting pyglet>=1.2.0 (from gym)\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1c/fc/dad5eaaab68f0c21e2f906a94ddb98175662cc5a654eee404d59554ce0fa/pyglet-1.3.2-py2.py3-none-any.whl (1.0MB)\n", "\u001b[K 100% |████████████████████████████████| 1.0MB 8.4MB/s eta 0:00:01\n", "\u001b[?25hRequirement already satisfied: urllib3<1.25,>=1.21.1 in /home/yhuangcc/anaconda3/envs/mafs6010u/lib/python3.7/site-packages (from requests>=2.0->gym) (1.24.1)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/yhuangcc/anaconda3/envs/mafs6010u/lib/python3.7/site-packages (from requests>=2.0->gym) (2018.11.29)\n", "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /home/yhuangcc/anaconda3/envs/mafs6010u/lib/python3.7/site-packages (from requests>=2.0->gym) (3.0.4)\n", "Requirement already satisfied: idna<2.9,>=2.5 in /home/yhuangcc/anaconda3/envs/mafs6010u/lib/python3.7/site-packages (from requests>=2.0->gym) (2.8)\n", "Collecting future (from pyglet>=1.2.0->gym)\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/90/52/e20466b85000a181e1e144fd8305caf2cf475e2f9674e797b222f8105f5f/future-0.17.1.tar.gz (829kB)\n", "\u001b[K 100% |████████████████████████████████| 829kB 8.4MB/s eta 0:00:01\n", "\u001b[?25hBuilding wheels for collected packages: gym, future\n", " Running setup.py bdist_wheel for gym ... \u001b[?25ldone\n", "\u001b[?25h Stored in directory: /home/yhuangcc/.cache/pip/wheels/57/b0/13/4153e1acab826fbe612c95b1336a63a3fa6416902a8d74a1b7\n", " Running setup.py bdist_wheel for future ... \u001b[?25ldone\n", "\u001b[?25h Stored in directory: /home/yhuangcc/.cache/pip/wheels/0c/61/d2/d6b7317325828fbb39ee6ad559dbe4664d0896da4721bf379e\n", "Successfully built gym future\n", "Installing collected packages: future, pyglet, gym\n", "Successfully installed future-0.17.1 gym-0.12.1 pyglet-1.3.2\n", "episode: 0, reward: 21.0, running_reward: 0.21\n", "episode: 1, reward: 13.0, running_reward: 0.34\n", "episode: 2, reward: 18.0, running_reward: 0.51\n", "episode: 3, reward: 13.0, running_reward: 0.64\n", "episode: 4, reward: 9.0, running_reward: 0.72\n", "episode: 5, reward: 11.0, running_reward: 0.83\n", "episode: 6, reward: 13.0, running_reward: 0.95\n", "episode: 7, reward: 11.0, running_reward: 1.05\n", "episode: 8, reward: 14.0, running_reward: 1.18\n", "episode: 9, reward: 15.0, running_reward: 1.32\n", "episode: 10, reward: 23.0, running_reward: 1.53\n", "episode: 11, reward: 25.0, running_reward: 1.77\n", "episode: 12, reward: 10.0, running_reward: 1.85\n", "episode: 13, reward: 16.0, running_reward: 1.99\n", "episode: 14, reward: 14.0, running_reward: 2.11\n", "episode: 15, reward: 33.0, running_reward: 2.42\n", "episode: 16, reward: 28.0, running_reward: 2.68\n", "episode: 17, reward: 14.0, running_reward: 2.79\n", "episode: 18, reward: 14.0, running_reward: 2.90\n", "episode: 19, reward: 14.0, running_reward: 3.01\n", "episode: 20, reward: 13.0, running_reward: 3.11\n", "episode: 21, reward: 15.0, running_reward: 3.23\n", "episode: 22, reward: 11.0, running_reward: 3.31\n", "episode: 23, reward: 11.0, running_reward: 3.39\n", "episode: 24, reward: 19.0, running_reward: 3.54\n", "episode: 25, reward: 10.0, running_reward: 3.61\n", "episode: 26, reward: 12.0, running_reward: 3.69\n", "episode: 27, reward: 33.0, running_reward: 3.98\n", "episode: 28, reward: 18.0, running_reward: 4.12\n", "episode: 29, reward: 19.0, running_reward: 4.27\n", "episode: 30, reward: 8.0, running_reward: 4.31\n", "episode: 31, reward: 25.0, running_reward: 4.52\n", "episode: 32, reward: 16.0, running_reward: 4.63\n", "episode: 33, reward: 17.0, running_reward: 4.75\n", "episode: 34, reward: 10.0, running_reward: 4.81\n", "episode: 35, reward: 13.0, running_reward: 4.89\n", "episode: 36, reward: 12.0, running_reward: 4.96\n", "episode: 37, reward: 15.0, running_reward: 5.06\n", "episode: 38, reward: 16.0, running_reward: 5.17\n", "episode: 39, reward: 18.0, running_reward: 5.30\n", "episode: 40, reward: 18.0, running_reward: 5.43\n", "episode: 41, reward: 10.0, running_reward: 5.47\n", "episode: 42, reward: 13.0, running_reward: 5.55\n", "episode: 43, reward: 11.0, running_reward: 5.60\n", "episode: 44, reward: 14.0, running_reward: 5.69\n", "episode: 45, reward: 13.0, running_reward: 5.76\n", "episode: 46, reward: 21.0, running_reward: 5.91\n", "episode: 47, reward: 10.0, running_reward: 5.95\n", "episode: 48, reward: 22.0, running_reward: 6.11\n", "episode: 49, reward: 12.0, running_reward: 6.17\n", "episode: 50, reward: 27.0, running_reward: 6.38\n", "episode: 51, reward: 16.0, running_reward: 6.48\n", "episode: 52, reward: 14.0, running_reward: 6.55\n", "episode: 53, reward: 14.0, running_reward: 6.63\n", "episode: 54, reward: 20.0, running_reward: 6.76\n", "episode: 55, reward: 14.0, running_reward: 6.83\n", "episode: 56, reward: 16.0, running_reward: 6.92\n", "episode: 57, reward: 10.0, running_reward: 6.95\n", "episode: 58, reward: 14.0, running_reward: 7.02\n", "episode: 59, reward: 12.0, running_reward: 7.07\n", "episode: 60, reward: 15.0, running_reward: 7.15\n", "episode: 61, reward: 14.0, running_reward: 7.22\n", "episode: 62, reward: 61.0, running_reward: 7.76\n", "episode: 63, reward: 9.0, running_reward: 7.77\n", "episode: 64, reward: 14.0, running_reward: 7.83\n", "episode: 65, reward: 26.0, running_reward: 8.02\n", "episode: 66, reward: 34.0, running_reward: 8.28\n", "episode: 67, reward: 25.0, running_reward: 8.44\n", "episode: 68, reward: 17.0, running_reward: 8.53\n", "episode: 69, reward: 18.0, running_reward: 8.62\n", "episode: 70, reward: 14.0, running_reward: 8.68\n", "episode: 71, reward: 19.0, running_reward: 8.78\n", "episode: 72, reward: 16.0, running_reward: 8.85\n", "episode: 73, reward: 13.0, running_reward: 8.89\n", "episode: 74, reward: 13.0, running_reward: 8.93\n", "episode: 75, reward: 19.0, running_reward: 9.04\n", "episode: 76, reward: 15.0, running_reward: 9.10\n", "episode: 77, reward: 14.0, running_reward: 9.14\n", "episode: 78, reward: 13.0, running_reward: 9.18\n", "episode: 79, reward: 14.0, running_reward: 9.23\n", "episode: 80, reward: 17.0, running_reward: 9.31\n", "episode: 81, reward: 25.0, running_reward: 9.47\n", "episode: 82, reward: 10.0, running_reward: 9.47\n", "episode: 83, reward: 30.0, running_reward: 9.68\n", "episode: 84, reward: 20.0, running_reward: 9.78\n", "episode: 85, reward: 16.0, running_reward: 9.84\n", "episode: 86, reward: 36.0, running_reward: 10.10\n", "episode: 87, reward: 16.0, running_reward: 10.16\n", "episode: 88, reward: 26.0, running_reward: 10.32\n", "episode: 89, reward: 17.0, running_reward: 10.39\n", "episode: 90, reward: 14.0, running_reward: 10.42\n", "episode: 91, reward: 29.0, running_reward: 10.61\n", "episode: 92, reward: 27.0, running_reward: 10.77\n", "episode: 93, reward: 11.0, running_reward: 10.78\n", "episode: 94, reward: 12.0, running_reward: 10.79\n", "episode: 95, reward: 42.0, running_reward: 11.10\n", "episode: 96, reward: 12.0, running_reward: 11.11\n", "episode: 97, reward: 23.0, running_reward: 11.23\n", "episode: 98, reward: 16.0, running_reward: 11.28\n", "episode: 99, reward: 13.0, running_reward: 11.29\n", "episode: 100, reward: 19.0, running_reward: 11.37\n", "episode: 101, reward: 29.0, running_reward: 11.55\n", "episode: 102, reward: 84.0, running_reward: 12.27\n", "episode: 103, reward: 23.0, running_reward: 12.38\n", "episode: 104, reward: 25.0, running_reward: 12.50\n", "episode: 105, reward: 9.0, running_reward: 12.47\n", "episode: 106, reward: 28.0, running_reward: 12.62\n", "episode: 107, reward: 16.0, running_reward: 12.66\n", "episode: 108, reward: 17.0, running_reward: 12.70\n", "episode: 109, reward: 22.0, running_reward: 12.79\n", "episode: 110, reward: 36.0, running_reward: 13.03\n", "episode: 111, reward: 20.0, running_reward: 13.10\n", "episode: 112, reward: 19.0, running_reward: 13.16\n", "episode: 113, reward: 23.0, running_reward: 13.25\n", "episode: 114, reward: 22.0, running_reward: 13.34\n", "episode: 115, reward: 29.0, running_reward: 13.50\n", "episode: 116, reward: 15.0, running_reward: 13.51\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "episode: 117, reward: 15.0, running_reward: 13.53\n", "episode: 118, reward: 29.0, running_reward: 13.68\n", "episode: 119, reward: 22.0, running_reward: 13.77\n", "episode: 120, reward: 38.0, running_reward: 14.01\n", "episode: 121, reward: 27.0, running_reward: 14.14\n", "episode: 122, reward: 14.0, running_reward: 14.14\n", "episode: 123, reward: 25.0, running_reward: 14.25\n", "episode: 124, reward: 33.0, running_reward: 14.43\n", "episode: 125, reward: 36.0, running_reward: 14.65\n", "episode: 126, reward: 13.0, running_reward: 14.63\n", "episode: 127, reward: 25.0, running_reward: 14.74\n", "episode: 128, reward: 33.0, running_reward: 14.92\n", "episode: 129, reward: 55.0, running_reward: 15.32\n", "episode: 130, reward: 26.0, running_reward: 15.43\n", "episode: 131, reward: 23.0, running_reward: 15.50\n", "episode: 132, reward: 34.0, running_reward: 15.69\n", "episode: 133, reward: 25.0, running_reward: 15.78\n", "episode: 134, reward: 18.0, running_reward: 15.80\n", "episode: 135, reward: 42.0, running_reward: 16.06\n", "episode: 136, reward: 23.0, running_reward: 16.13\n", "episode: 137, reward: 37.0, running_reward: 16.34\n", "episode: 138, reward: 52.0, running_reward: 16.70\n", "episode: 139, reward: 21.0, running_reward: 16.74\n", "episode: 140, reward: 15.0, running_reward: 16.72\n", "episode: 141, reward: 32.0, running_reward: 16.88\n", "episode: 142, reward: 23.0, running_reward: 16.94\n", "episode: 143, reward: 40.0, running_reward: 17.17\n", "episode: 144, reward: 39.0, running_reward: 17.39\n", "episode: 145, reward: 16.0, running_reward: 17.37\n", "episode: 146, reward: 90.0, running_reward: 18.10\n", "episode: 147, reward: 24.0, running_reward: 18.16\n", "episode: 148, reward: 55.0, running_reward: 18.53\n", "episode: 149, reward: 30.0, running_reward: 18.64\n", "episode: 150, reward: 23.0, running_reward: 18.69\n", "episode: 151, reward: 47.0, running_reward: 18.97\n", "episode: 152, reward: 23.0, running_reward: 19.01\n", "episode: 153, reward: 27.0, running_reward: 19.09\n", "episode: 154, reward: 37.0, running_reward: 19.27\n", "episode: 155, reward: 65.0, running_reward: 19.73\n", "episode: 156, reward: 32.0, running_reward: 19.85\n", "episode: 157, reward: 61.0, running_reward: 20.26\n", "episode: 158, reward: 76.0, running_reward: 20.82\n", "episode: 159, reward: 144.0, running_reward: 22.05\n", "episode: 160, reward: 44.0, running_reward: 22.27\n", "episode: 161, reward: 62.0, running_reward: 22.67\n", "episode: 162, reward: 30.0, running_reward: 22.74\n", "episode: 163, reward: 54.0, running_reward: 23.05\n", "episode: 164, reward: 15.0, running_reward: 22.97\n", "episode: 165, reward: 43.0, running_reward: 23.17\n", "episode: 166, reward: 26.0, running_reward: 23.20\n", "episode: 167, reward: 19.0, running_reward: 23.16\n", "episode: 168, reward: 105.0, running_reward: 23.98\n", "episode: 169, reward: 26.0, running_reward: 24.00\n", "episode: 170, reward: 70.0, running_reward: 24.46\n", "episode: 171, reward: 58.0, running_reward: 24.79\n", "episode: 172, reward: 16.0, running_reward: 24.70\n", "episode: 173, reward: 50.0, running_reward: 24.96\n", "episode: 174, reward: 74.0, running_reward: 25.45\n", "episode: 175, reward: 73.0, running_reward: 25.92\n", "episode: 176, reward: 74.0, running_reward: 26.40\n", "episode: 177, reward: 30.0, running_reward: 26.44\n", "episode: 178, reward: 115.0, running_reward: 27.32\n", "episode: 179, reward: 55.0, running_reward: 27.60\n", "episode: 180, reward: 64.0, running_reward: 27.97\n", "episode: 181, reward: 91.0, running_reward: 28.60\n", "episode: 182, reward: 43.0, running_reward: 28.74\n", "episode: 183, reward: 31.0, running_reward: 28.76\n", "episode: 184, reward: 70.0, running_reward: 29.17\n", "episode: 185, reward: 96.0, running_reward: 29.84\n", "episode: 186, reward: 93.0, running_reward: 30.47\n", "episode: 187, reward: 34.0, running_reward: 30.51\n", "episode: 188, reward: 29.0, running_reward: 30.49\n", "episode: 189, reward: 75.0, running_reward: 30.94\n", "episode: 190, reward: 19.0, running_reward: 30.82\n", "episode: 191, reward: 84.0, running_reward: 31.35\n", "episode: 192, reward: 29.0, running_reward: 31.33\n", "episode: 193, reward: 43.0, running_reward: 31.45\n", "episode: 194, reward: 90.0, running_reward: 32.03\n", "episode: 195, reward: 46.0, running_reward: 32.17\n", "episode: 196, reward: 59.0, running_reward: 32.44\n", "episode: 197, reward: 87.0, running_reward: 32.98\n", "episode: 198, reward: 24.0, running_reward: 32.89\n", "episode: 199, reward: 88.0, running_reward: 33.45\n", "episode: 200, reward: 64.0, running_reward: 33.75\n", "episode: 201, reward: 60.0, running_reward: 34.01\n", "episode: 202, reward: 33.0, running_reward: 34.00\n", "episode: 203, reward: 92.0, running_reward: 34.58\n", "episode: 204, reward: 141.0, running_reward: 35.65\n", "episode: 205, reward: 101.0, running_reward: 36.30\n", "episode: 206, reward: 157.0, running_reward: 37.51\n", "episode: 207, reward: 95.0, running_reward: 38.08\n", "episode: 208, reward: 66.0, running_reward: 38.36\n", "episode: 209, reward: 140.0, running_reward: 39.38\n", "episode: 210, reward: 237.0, running_reward: 41.36\n", "episode: 211, reward: 29.0, running_reward: 41.23\n", "episode: 212, reward: 17.0, running_reward: 40.99\n", "episode: 213, reward: 25.0, running_reward: 40.83\n", "episode: 214, reward: 16.0, running_reward: 40.58\n", "episode: 215, reward: 22.0, running_reward: 40.40\n", "episode: 216, reward: 44.0, running_reward: 40.43\n", "episode: 217, reward: 23.0, running_reward: 40.26\n", "episode: 218, reward: 63.0, running_reward: 40.48\n", "episode: 219, reward: 16.0, running_reward: 40.24\n", "episode: 220, reward: 124.0, running_reward: 41.08\n", "episode: 221, reward: 82.0, running_reward: 41.49\n", "episode: 222, reward: 70.0, running_reward: 41.77\n", "episode: 223, reward: 57.0, running_reward: 41.92\n", "episode: 224, reward: 68.0, running_reward: 42.18\n", "episode: 225, reward: 56.0, running_reward: 42.32\n", "episode: 226, reward: 27.0, running_reward: 42.17\n", "episode: 227, reward: 69.0, running_reward: 42.44\n", "episode: 228, reward: 31.0, running_reward: 42.32\n", "episode: 229, reward: 161.0, running_reward: 43.51\n", "episode: 230, reward: 22.0, running_reward: 43.30\n", "episode: 231, reward: 49.0, running_reward: 43.35\n", "episode: 232, reward: 337.0, running_reward: 46.29\n", "episode: 233, reward: 152.0, running_reward: 47.35\n", "episode: 234, reward: 152.0, running_reward: 48.39\n", "episode: 235, reward: 153.0, running_reward: 49.44\n", "episode: 236, reward: 131.0, running_reward: 50.25\n", "episode: 237, reward: 164.0, running_reward: 51.39\n", "episode: 238, reward: 69.0, running_reward: 51.57\n", "episode: 239, reward: 120.0, running_reward: 52.25\n", "episode: 240, reward: 40.0, running_reward: 52.13\n", "episode: 241, reward: 171.0, running_reward: 53.32\n", "episode: 242, reward: 39.0, running_reward: 53.17\n", "episode: 243, reward: 137.0, running_reward: 54.01\n", "episode: 244, reward: 209.0, running_reward: 55.56\n", "episode: 245, reward: 97.0, running_reward: 55.98\n", "episode: 246, reward: 163.0, running_reward: 57.05\n", "episode: 247, reward: 135.0, running_reward: 57.83\n", "episode: 248, reward: 111.0, running_reward: 58.36\n", "episode: 249, reward: 118.0, running_reward: 58.96\n", "episode: 250, reward: 93.0, running_reward: 59.30\n", "episode: 251, reward: 70.0, running_reward: 59.40\n", "episode: 252, reward: 194.0, running_reward: 60.75\n", "episode: 253, reward: 65.0, running_reward: 60.79\n", "episode: 254, reward: 142.0, running_reward: 61.60\n", "episode: 255, reward: 154.0, running_reward: 62.53\n", "episode: 256, reward: 222.0, running_reward: 64.12\n", "episode: 257, reward: 248.0, running_reward: 65.96\n", "episode: 258, reward: 252.0, running_reward: 67.82\n", "episode: 259, reward: 295.0, running_reward: 70.09\n", "episode: 260, reward: 288.0, running_reward: 72.27\n", "episode: 261, reward: 230.0, running_reward: 73.85\n", "episode: 262, reward: 81.0, running_reward: 73.92\n", "episode: 263, reward: 261.0, running_reward: 75.79\n", "episode: 264, reward: 119.0, running_reward: 76.22\n", "episode: 265, reward: 85.0, running_reward: 76.31\n", "episode: 266, reward: 112.0, running_reward: 76.67\n", "episode: 267, reward: 109.0, running_reward: 76.99\n", "episode: 268, reward: 64.0, running_reward: 76.86\n", "episode: 269, reward: 376.0, running_reward: 79.85\n", "episode: 270, reward: 347.0, running_reward: 82.52\n", "episode: 271, reward: 76.0, running_reward: 82.46\n", "episode: 272, reward: 191.0, running_reward: 83.54\n", "episode: 273, reward: 173.0, running_reward: 84.44\n", "episode: 274, reward: 152.0, running_reward: 85.11\n", "episode: 275, reward: 299.0, running_reward: 87.25\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "episode: 276, reward: 169.0, running_reward: 88.07\n", "episode: 277, reward: 241.0, running_reward: 89.60\n", "episode: 278, reward: 241.0, running_reward: 91.11\n", "episode: 279, reward: 231.0, running_reward: 92.51\n", "episode: 280, reward: 174.0, running_reward: 93.33\n", "episode: 281, reward: 427.0, running_reward: 96.66\n", "episode: 282, reward: 119.0, running_reward: 96.89\n", "episode: 283, reward: 284.0, running_reward: 98.76\n", "episode: 284, reward: 98.0, running_reward: 98.75\n", "episode: 285, reward: 424.0, running_reward: 102.00\n", "episode: 286, reward: 186.0, running_reward: 102.84\n", "episode: 287, reward: 576.0, running_reward: 107.58\n", "episode: 288, reward: 183.0, running_reward: 108.33\n", "episode: 289, reward: 249.0, running_reward: 109.74\n", "episode: 290, reward: 188.0, running_reward: 110.52\n", "episode: 291, reward: 272.0, running_reward: 112.13\n", "episode: 292, reward: 225.0, running_reward: 113.26\n", "episode: 293, reward: 341.0, running_reward: 115.54\n", "episode: 294, reward: 230.0, running_reward: 116.68\n", "episode: 295, reward: 265.0, running_reward: 118.17\n", "episode: 296, reward: 275.0, running_reward: 119.74\n", "episode: 297, reward: 131.0, running_reward: 119.85\n", "episode: 298, reward: 166.0, running_reward: 120.31\n", "episode: 299, reward: 266.0, running_reward: 121.77\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# we use gym for reinforcement learning\n", "try:\n", " import gym\n", "except:\n", " ! pip install gym\n", " import gym\n", "\n", "from torch.distributions import Categorical\n", "\n", "# use the game cartpole\n", "env = gym.make('CartPole-v0')\n", "env = env.unwrapped\n", "\n", "# initial our policy gradient agent\n", "rl_pg = RL_pg(\n", " n_actions = env.action_space.n,\n", " observation_shape = env.observation_space.shape[0],\n", " lr = 0.01,\n", " reward_decay = 0.99\n", ")\n", "\n", "rewards = []\n", "cumulated_reward=0\n", "for i_episode in range(300):\n", "\n", " # reset the env to get the initial observation\n", " observation = env.reset()\n", "\n", " while True:\n", " # env.render()\n", " \n", " # choose action by agent based on the current observation\n", " action = rl_pg.choose_action(observation)\n", "\n", " # give the action to env to get the reward and the next observation\n", " observation_, reward, done, info = env.step(action)\n", "\n", " # store the state\n", " rl_pg.store_transition(observation, action, reward)\n", "\n", " if done:\n", " r_sum = sum(rl_pg.r)\n", " cumulated_reward = cumulated_reward * 0.99 + r_sum * 0.01\n", " rewards.append(cumulated_reward)\n", "\n", " print(f\"episode: {i_episode}, reward: {r_sum}, running_reward: {cumulated_reward:.2f}\")\n", " \n", " # after the game done, the agent should learn\n", " vt = rl_pg.learn()\n", " break\n", "\n", " # set the observation to the new observation\n", " observation = observation_\n", " \n", "plt.title('Cumulated Reward')\n", "plt.plot(rewards)\n", "plt.xlabel('episode')\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }