{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# A simple example of a fully connected deep neural network (FCN)\n", "# The FCN trains on images of handwritten numbers from 0 to 9, then\n", "# takes an image of a handwritten number and estimates which number\n", "# it represents" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Import matplotlib to display digit images\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "# Import PyTorch data used in example program\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torchvision.datasets as dsets\n", "import torchvision.transforms as xforms\n", "from torch.autograd import Variable" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "input_size = 784 # The image size = 28x28 = 784\n", "hidden_size = 500 # Number of nodes in the hidden layer\n", "num_classes = 10 # Number of output classes (0..9)\n", "num_epochs = 5 # Number of times DNN is trained on training dataset\n", "batch_size = 100 # The size of input data for one iteration\n", "learning_rate = 0.001 # The speed of convergence" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Read om training and test datasets, these are included as part of\n", "# PyTorch's dataset repository\n", "\n", "train_dataset = dsets.MNIST( root ='./data', train=True, transform=xforms.ToTensor(), download=True )\n", "test_dataset = dsets.MNIST( root='./data', train=False, transform=xforms.ToTensor() )" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5 0\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAC7CAYAAAB1qmWGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAEX5JREFUeJzt3XmMlWWWx/HfEcW4tAt2RAIirUEc7SAqILFRcWHiKMZdm6hoNMIfkqAxpNWgrZOgRMEe98AoCuoAbWhH1DhqFEEjISJiizAOaBy6pAIqFqtLgDN/cJmUPM+17vLe5X3q+0nIrTp17n3PW3Xq8NZ9N3N3AQDyb69GFwAAyAYDHQASwUAHgEQw0AEgEQx0AEgEAx0AEsFAB4BEMNABIBFVDXQzO8/MPjez1WZ2e1ZFAY1GbyOPrNIzRc2si6T/kTRcUoukDyWNdPcV2ZUH1B+9jbzau4rnDpa02t2/lCQzmy3pIklFm97MuM4AasrdLYOXobfRdErp7Wrecukp6R/tPm8pxIC8o7eRS9Vsocf+twi2UsxstKTRVSwHqDd6G7lUzUBvkXRku897SVq7Z5K7T5M0TeLPUuQGvY1cquYtlw8l9TWz35lZV0l/lDQvm7KAhqK3kUsVb6G7+3YzGyvpDUldJE13988yqwxoEHobeVXxYYsVLYw/S1FjGR3lUjZ6G7VW66NcAABNhIEOAIlgoANAIhjoAJAIBjoAJIKBDgCJYKADQCIY6ACQCAY6ACSCgQ4AiWCgA0AiGOgAkAgGOgAkgoEOAImo5o5FAJCpU045JYiNHTs2mjtq1KggNnPmzGjuo48+GsSWLl1aZnXNjy10AEgEAx0AEsFAB4BEMNABIBFV3VPUzL6StFnSDknb3X1gB/ncd1FSly5dgtjBBx9c1WsW23G0//77B7F+/fpFc2+++eYgNnny5GjuyJEjg9iPP/4YzZ00aVIQu/fee6O51crqnqL0dm0NGDAgGn/nnXeC2EEHHVT18jZu3BjEDjvssKpft55K6e0sjnI5y92/zeB1gGZDbyNXeMsFABJR7UB3SW+a2UdmNjqLgoAmQW8jd6p9y+UP7r7WzA6X9JaZ/be7L2yfUPhl4BcCeUNvI3eq2kJ397WFx/WSXpI0OJIzzd0HdrRTCWgm9DbyqOItdDM7QNJe7r658PE/S/rXzCprAr179w5iXbt2jeaedtppQWzo0KHR3EMOOSSIXXbZZWVWV7mWlpZo/JFHHglil1xySTR38+bNQeyTTz6J5i5YsKCM6hqvM/R2PQ0eHPxfqLlz50ZzY0d7FTsSL9aDP//8czQ3dkTLkCFDormxSwIUe91mU81bLt0lvWRmu1/nP9z9vzKpCmgsehu5VPFAd/cvJZ2YYS1AU6C3kVcctggAiWCgA0Aiqjr1v+yFNenp0eWchlztKfr1tnPnziB2ww03RHO3bNlS8uu2trYGse+//z6a+/nnn5f8utXK6tT/cjVrb9dK7JISknTyyScHseeffz6I9erVK/r8wn6LXyg2o2I7Lx944IFo7uzZs0taliRNmDAhiN1///3R3HoqpbfZQgeARDDQASARDHQASAQDHQASwUAHgERkcT303FuzZk00/t133wWxeh7lsnjx4mi8ra0tiJ111lnR3Ngpy88991x1haHTmzp1ajQeu/FJrcSOqDnwwAOjubHLTwwbNiya279//6rqaiS20AEgEQx0AEgEAx0AEsFAB4BEsFNU0oYNG6Lx8ePHB7ERI0ZEcz/++OMgFru+eDHLli0LYsOHD4/mbt26NYidcMIJ0dxx48aVXAMQc8oppwSxCy64IJpb7HT6PRW7Rv4rr7wSxCZPnhzNXbt2bRCL/R5K8ctSnH322dHcUtehGbGFDgCJYKADQCIY6ACQCAY6ACSiw4FuZtPNbL2ZLW8X62Zmb5nZqsLjobUtE8gevY3UdHiDCzM7Q9IWSTPd/feF2AOSNrj7JDO7XdKh7v6nDheWwE0ADjrooGg8dgfyYqdH33jjjUHsmmuuCWKzZs0qszqUc4MLevuXyrnRS7Hfg5jXX389iBW7RMCZZ54ZxIqdiv/UU08FsW+++abkunbs2BGNb9u2raS6pPhNNmolkxtcuPtCSXse13eRpBmFj2dIurjs6oAGo7eRmkrfQ+/u7q2SVHg8PLuSgIait5FbNT+xyMxGSxpd6+UA9UZvo9lUuoW+zsx6SFLhcX2xRHef5u4D3X1ghcsC6oneRm5VuoU+T9J1kiYVHl/OrKImt2nTppJzN27cWHLuTTfdFMTmzJkTzd25c2fJr4uydYrePvbYY4NY7FIXUvweAN9++200t7W1NYjNmDEjiG3ZsiX6/Ndee62kWC3tt99+Qey2226L5l599dW1LqcspRy2OEvSIkn9zKzFzG7UrmYfbmarJA0vfA7kCr2N1HS4he7uxW5Bck7GtQB1RW8jNZwpCgCJYKADQCIY6ACQCG5wUUP33HNPNB67YUDs1OJzzz03+vw333yzqrrQeey7777ReOymEeeff340N3ZZi1GjRkVzlyxZEsRiR43kTe/evRtdQknYQgeARDDQASARDHQASAQDHQAS0eH10DNdWALXjM7CMcccE8Ri11Vua2uLPn/+/PlBLLYzSpIef/zxIFbPn3m9lXM99Cw1a28PGTIkGn///fdLfo1zzgnPs1qwYEHFNTWLYtdDj/1+LFq0KJp7+umnZ1rTr8nkeugAgHxgoANAIhjoAJAIBjoAJIIzRRvgiy++CGLXX399EHvmmWeiz7/22mtLiknSAQccEMRmzpwZzY1dyxr59tBDD0XjZuH+tWI7OlPYARqz117x7dk832+ALXQASAQDHQASwUAHgEQw0AEgEaXcU3S6ma03s+XtYveY2ddmtqzwL37dTaCJ0dtITSlHuTwr6TFJex4a8Rd3Dy+qjIq89NJLQWzVqlXR3NiRC7HTsyXpvvvuC2JHHXVUNHfixIlB7Ouvv47mJuJZJdTbI0aMCGIDBgyI5sZOb583b17mNTWzYkezxL43y5Ytq3U5mehwC93dF0raUIdagLqit5Gaat5DH2tmfy/82XpoZhUBjUdvI5cqHehPSjpG0gBJrZKmFEs0s9FmtsTM4pcDBJoLvY3cqmigu/s6d9/h7jsl/bukwb+SO83dB7r7wEqLBOqF3kaeVXTqv5n1cPfd54lfImn5r+WjMsuXx7+tV155ZRC78MILo7mxyweMGTMmmtu3b98gNnz48F8rMTl57u3YzZi7du0azV2/fn0QmzNnTuY11Vuxm2IXu2F7zDvvvBPE7rjjjkpLqqsOB7qZzZI0TNJvzaxF0p8lDTOzAZJc0leS4hMCaGL0NlLT4UB395GR8NM1qAWoK3obqeFMUQBIBAMdABLBQAeARHCDixxqa2sLYs8991w096mnngpie+8d/7GfccYZQWzYsGHR3Hfffbd4gWh6P/30UxDL2w1OYke0TJgwIZo7fvz4INbS0hLNnTIlPPVgy5YtZVbXGGyhA0AiGOgAkAgGOgAkgoEOAIlgp2gT69+/fzR++eWXB7FBgwZFc4vtAI1ZsWJFEFu4cGHJz0d+5Ona58Wu6R7b0XnVVVdFc19++eUgdtlll1VXWBNiCx0AEsFAB4BEMNABIBEMdABIBAMdABLBUS4N0K9fvyA2duzYIHbppZdGn3/EEUdUtfwdO3ZE47FTv4vdGR3Nx8xKiknSxRdfHMTGjRuXeU3luvXWW4PYXXfdFc09+OCDg9gLL7wQzR01alR1heUEW+gAkAgGOgAkgoEOAInocKCb2ZFmNt/MVprZZ2Y2rhDvZmZvmdmqwuOhtS8XyA69jdSUslN0u6Tb3H2pmf1G0kdm9pak6yW97e6TzOx2SbdL+lPtSm1usR2VI0fGblkZ3wHap0+frEuSJC1ZsiSITZw4MZqbp9PBM5JUb7t7STEp3q+PPPJINHf69OlB7LvvvovmDhkyJIhde+21QezEE0+MPr9Xr15BbM2aNdHcN954I4g98cQT0dzOosMtdHdvdfelhY83S1opqaekiyTNKKTNkBTuNgeaGL2N1JT1HrqZ9ZF0kqTFkrq7e6u06xdD0uFZFwfUC72NFJR8HLqZHShprqRb3H1TseNbI88bLWl0ZeUBtUdvIxUlbaGb2T7a1fAvuPvfCuF1Ztaj8PUektbHnuvu09x9oLsPzKJgIEv0NlJSylEuJulpSSvd/aF2X5on6brCx9dJCi84DDQxehupsWJ7wf8/wWyopPckfSpp93ngd2rXe41/ldRb0hpJV7j7hg5e69cX1mS6d+8exI4//vho7mOPPRbEjjvuuMxrkqTFixcHsQcffDCaG7uwf8qn87t7ae+XKL3evuKKK4LYrFmzqn7ddevWBbFNmzZFc/v27VvVshYtWhTE5s+fH829++67q1pW3pTS2x2+h+7u70sq9kLnlFsU0CzobaSGM0UBIBEMdABIBAMdABLR4U7RTBfWBDuOunXrFsSmTp0azY3dbfzoo4/OvCZJ+uCDD4LYlClTormxU55/+OGHzGvKo3J2imapGXo7dtr8iy++GM0dNGhQya8bOy6/nLkRu0zA7Nmzo7nNcE32ZlVKb7OFDgCJYKADQCIY6ACQCAY6ACSCgQ4AiUjiKJdTTz01iI0fPz6aO3jw4CDWs2fPzGuSpG3btkXjsRsJ3HfffUFs69atmdeUus58lEtMjx49ovExY8YEsQkTJkRzyznK5eGHHw5iTz75ZBBbvXp19PkojqNcAKATYaADQCIY6ACQCAY6ACQiiZ2ikyZNCmLFdoqWY8WKFUHs1VdfjeZu3749iBU7db+tra26wlAUO0WRKnaKAkAnwkAHgEQw0AEgEaXcJPpIM5tvZivN7DMzG1eI32NmX5vZssK/82tfLpAdehup6fCeopK2S7rN3Zea2W8kfWRmbxW+9hd3n1y78oCaoreRlLKPcjGzlyU9JukPkraU0/QcCYBaq+YoF3obzSzzo1zMrI+kkyQtLoTGmtnfzWy6mR1adoVAk6C3kYKSB7qZHShprqRb3H2TpCclHSNpgKRWSdGDrs1stJktMbMlGdQLZI7eRipKesvFzPaR9KqkN9z9ocjX+0h61d1/38Hr8Gcpaqrct1zobeRFJm+52K5rZz4taWX7hjez9tflvETS8kqKBBqF3kZqOtxCN7Ohkt6T9KmknYXwnZJGatefpC7pK0lj3L21g9diKwY1Vc4WOr2NPCmlt5O4lguwG9dyQaq4lgsAdCIMdABIBAMdABLBQAeARDDQASARDHQASAQDHQASwUAHgEQw0AEgEaXc4CJL30r638LHvy18nhrWq3GOauCyd/d2Hr5PlUp13fKwXiX1dl1P/f/Fgs2WuPvAhiy8hlivzi3l71Oq65bSevGWCwAkgoEOAIlo5ECf1sBl1xLr1bml/H1Kdd2SWa+GvYcOAMgWb7kAQCLqPtDN7Dwz+9zMVpvZ7fVefpYKd4Rfb2bL28W6mdlbZraq8Ji7O8ab2ZFmNt/MVprZZ2Y2rhDP/brVUiq9TV/nb912q+tAN7Mukh6X9C+Sjpc00syOr2cNGXtW0nl7xG6X9La795X0duHzvNku6TZ3/ydJQyTdXPg5pbBuNZFYbz8r+jqX6r2FPljSanf/0t1/ljRb0kV1riEz7r5Q0oY9whdJmlH4eIaki+taVAbcvdXdlxY+3ixppaSeSmDdaiiZ3qav87duu9V7oPeU9I92n7cUYinpvvuGwoXHwxtcT1XMrI+kkyQtVmLrlrHUezupn32qfV3vgR67ySmH2TQpMztQ0lxJt7j7pkbX0+To7ZxIua/rPdBbJB3Z7vNektbWuYZaW2dmPSSp8Li+wfVUxMz20a6mf8Hd/1YIJ7FuNZJ6byfxs0+9r+s90D+U1NfMfmdmXSX9UdK8OtdQa/MkXVf4+DpJLzewloqYmUl6WtJKd3+o3Zdyv241lHpv5/5n3xn6uu4nFpnZ+ZL+TVIXSdPdfWJdC8iQmc2SNEy7rta2TtKfJf2npL9K6i1pjaQr3H3PHUxNzcyGSnpP0qeSdhbCd2rX+425XrdaSqW36ev8rdtunCkKAIngTFEASAQDHQASwUAHgEQw0AEgEQx0AEgEAx0AEsFAB4BEMNABIBH/BynEqYcYobnEAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Let's look at two of the images in the training dataset, and the\n", "# corresponding label (which defines the number the image represents)\n", "\n", "# Grab the first two images and corresponding numbers they represent\n", "# from the training set\n", "\n", "img_0, val_0 = train_dataset[ 0 ]\n", "img_1, val_1 = train_dataset[ 1 ]\n", "\n", "# Print out what numbers the two images are meant to represent\n", "\n", "print( val_0, ' ', val_1 )\n", "\n", "# Create a single row of two images, place first image in the first\n", "# image position\n", "\n", "plt.subplot( 1, 2, 1 )\n", "plt.imshow( img_0.numpy()[ 0 ], cmap='gray' )\n", "\n", "# Place the second image in the second image position\n", "\n", "plt.subplot( 1, 2, 2 )\n", "plt.imshow( img_1.numpy()[ 0 ], cmap='gray' )" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Create training and test dataloaders, these will load and cache data of batch_size during\n", "# looping, training data is shuffled randomly, test data is not\n", "\n", "train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True )\n", "test_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=batch_size, shuffle=False )" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "net = nn.Sequential(\n", " nn.Linear( input_size, hidden_size ), # fully connected layer, 784 (input data) -> 500 (hidden layer)\n", " nn.ReLU(), # Rectified liner unit filter, max( 0, x )\n", " nn.Linear( hidden_size, num_classes ) # output layer, 500 (hidden layer) -> 10 (output classes)\n", ")\n", "\n", "#net.cuda() # Uncomment to enable GPU computation" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Choose a cross-entropy loss function to use during evaluation of DNN performance, and an\n", "# Adam optimizer (Adam, see https://arxiv.org/abs/1412.6980) to update DNN weights\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam( net.parameters(), lr=learning_rate )" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [1/5], Step [100/600], Loss: 0.5752\n", "Epoch [1/5], Step [200/600], Loss: 0.3144\n", "Epoch [1/5], Step [300/600], Loss: 0.1291\n", "Epoch [1/5], Step [400/600], Loss: 0.2598\n", "Epoch [1/5], Step [500/600], Loss: 0.1865\n", "Epoch [1/5], Step [600/600], Loss: 0.1451\n", "Epoch [2/5], Step [100/600], Loss: 0.0957\n", "Epoch [2/5], Step [200/600], Loss: 0.0728\n", "Epoch [2/5], Step [300/600], Loss: 0.0587\n", "Epoch [2/5], Step [400/600], Loss: 0.0266\n", "Epoch [2/5], Step [500/600], Loss: 0.0251\n", "Epoch [2/5], Step [600/600], Loss: 0.0734\n", "Epoch [3/5], Step [100/600], Loss: 0.1127\n", "Epoch [3/5], Step [200/600], Loss: 0.0569\n", "Epoch [3/5], Step [300/600], Loss: 0.0640\n", "Epoch [3/5], Step [400/600], Loss: 0.0495\n", "Epoch [3/5], Step [500/600], Loss: 0.0445\n", "Epoch [3/5], Step [600/600], Loss: 0.0815\n", "Epoch [4/5], Step [100/600], Loss: 0.0929\n", "Epoch [4/5], Step [200/600], Loss: 0.0259\n", "Epoch [4/5], Step [300/600], Loss: 0.0104\n", "Epoch [4/5], Step [400/600], Loss: 0.1127\n", "Epoch [4/5], Step [500/600], Loss: 0.0153\n", "Epoch [4/5], Step [600/600], Loss: 0.0465\n", "Epoch [5/5], Step [100/600], Loss: 0.0414\n", "Epoch [5/5], Step [200/600], Loss: 0.0582\n", "Epoch [5/5], Step [300/600], Loss: 0.0360\n", "Epoch [5/5], Step [400/600], Loss: 0.0339\n", "Epoch [5/5], Step [500/600], Loss: 0.0069\n", "Epoch [5/5], Step [600/600], Loss: 0.0065\n" ] } ], "source": [ "# Train the DNN\n", "\n", "for epoch in range( 0, num_epochs ): # For each of the five epochs\n", " \n", " for i, (images, labels) in enumerate( train_loader ): # For each of the 100-image batches\n", " \n", " images = Variable( images.view( -1, 28 * 28 )) # Convert torch tensor to Variable, from 784-pix image to 28x28 matrix\n", " labels = Variable( labels ) # Grab pre-assigned labels for each of the 100 images\n", " \n", " optimizer.zero_grad() # Zero gradients for current pass\n", " outputs = net( images ) # Forward pass, compute output class for given image\n", " loss = criterion( outputs, labels ) # Compute loss, difference between estimated label and actual label\n", " loss.backward() # Backpropegation, compute improved weights\n", " optimizer.step() # Update weights of hidden nodes\n", " \n", " if ( i + 1 ) % 100 == 0: # Log message\n", " print( 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f' % ( epoch + 1, num_epochs, i + 1, len( train_dataset ) // batch_size, loss.data ) )" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy on 10K test images: 97 %\n" ] } ], "source": [ "# Test performance on test images\n", "\n", "correct = 0\n", "total = 0\n", "\n", "for images, labels in test_loader: # For each image the the test dataset\n", " \n", " images = Variable( images.view( -1, 28 * 28 ) ) # Convert torch tensor to Variable, from 784-pix image to 28x28 matrix\n", " outputs = net( images ) # Get correct lables for each image\n", " \n", " _, predicted = torch.max( outputs.data, 1 ) # Choose best class from DNN output\n", " total += labels.size( 0 ) # Increment total count\n", " correct += ( predicted == labels ).sum() # Increment correct prediction count\n", " \n", "# Print overall accuracy of labeling\n", " \n", "print( 'Accuracy on 10K test images: %d %%' % ( 100 * correct / total ) )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }