From e70def727a99e73a68044e30ae68b7a87406fd40 Mon Sep 17 00:00:00 2001
From: Matthew K Defenderfer <mdefende@uab.edu>
Date: Thu, 22 Aug 2019 14:09:35 -0500
Subject: [PATCH] Set up basic file structure and added current ipynbs

---
 ctmodel-ml.yml                | 170 +++++++
 model-py/regression-MKD.ipynb | 824 ++++++++++++++++++++++++++++++++++
 model-py/regression.ipynb     | 535 ++++++++++++++++++++++
 3 files changed, 1529 insertions(+)
 create mode 100644 ctmodel-ml.yml
 create mode 100644 model-py/regression-MKD.ipynb
 create mode 100644 model-py/regression.ipynb

diff --git a/ctmodel-ml.yml b/ctmodel-ml.yml
new file mode 100644
index 0000000..5e28789
--- /dev/null
+++ b/ctmodel-ml.yml
@@ -0,0 +1,170 @@
+name: ctmodel-ml
+channels:
+  - conda-forge
+  - anaconda
+  - defaults
+dependencies:
+  - _tflow_1100_select=0.0.1=gpu
+  - absl-py=0.4.1=py36_0
+  - astor=0.7.1=py36_0
+  - bleach=1.5.0=py36_0
+  - cudatoolkit=9.0=h13b8566_0
+  - cudnn=7.1.2=cuda9.0_0
+  - cupti=9.0.176=0
+  - gast=0.2.0=py36_0
+  - grpcio=1.12.1=py36hdbcaa40_0
+  - hdf5=1.10.2=hba1933b_1
+  - html5lib=0.9999999=py36_0
+  - libprotobuf=3.6.0=hdbcaa40_0
+  - markdown=2.6.11=py36_0
+  - pyyaml=3.13=py36h14c3975_0
+  - termcolor=1.1.0=py36_1
+  - werkzeug=0.14.1=py36_0
+  - yaml=0.1.7=h96e3832_1
+  - nibabel=2.3.0=pyh24bf2e0_1
+  - pydicom=1.1.0=py_0
+  - appdirs=1.4.3=py36h28b3542_0
+  - asn1crypto=0.24.0=py36_0
+  - attrs=18.2.0=py36h28b3542_0
+  - automat=0.7.0=py36_0
+  - backcall=0.1.0=py36_0
+  - blas=1.0=mkl
+  - ca-certificates=2018.03.07=0
+  - certifi=2018.10.15=py36_0
+  - cffi=1.11.5=py36he75722e_1
+  - constantly=15.1.0=py36h28b3542_0
+  - cryptography=2.3.1=py36hc365091_0
+  - cycler=0.10.0=py36_0
+  - cython=0.28.5=py36hf484d3e_0
+  - dbus=1.13.2=h714fa37_1
+  - decorator=4.3.0=py36_0
+  - entrypoints=0.2.3=py36_2
+  - expat=2.2.6=he6710b0_0
+  - fontconfig=2.13.0=h9420a91_0
+  - freetype=2.9.1=h8a8886c_1
+  - future=0.16.0=py36_0
+  - glib=2.56.2=hd408876_0
+  - gmp=6.1.2=h6c8ec71_1
+  - gst-plugins-base=1.14.0=hbbd80ab_1
+  - gstreamer=1.14.0=hb453b48_1
+  - h5py=2.8.0=py36h989c5e5_3
+  - hyperlink=18.0.0=py36_0
+  - icu=58.2=h9c2bf20_1
+  - idna=2.7=py36_0
+  - incremental=17.5.0=py36_0
+  - intel-openmp=2019.0=118
+  - ipykernel=4.9.0=py36_0
+  - ipython=6.5.0=py36_0
+  - ipython_genutils=0.2.0=py36_0
+  - jedi=0.12.1=py36_0
+  - jinja2=2.10=py36_0
+  - jpeg=9b=h024ee3a_2
+  - jsonschema=2.6.0=py36_0
+  - jupyter_client=5.2.3=py36_0
+  - jupyter_core=4.4.0=py36_0
+  - keras-applications=1.0.4=py36_1
+  - keras-base=2.2.2=py36_0
+  - keras-preprocessing=1.0.2=py36_1
+  - kiwisolver=1.0.1=py36hf484d3e_0
+  - libedit=3.1.20170329=h6b74fdf_2
+  - libffi=3.2.1=hd88cf55_4
+  - libgcc-ng=8.2.0=hdf63c60_1
+  - libgfortran-ng=7.3.0=hdf63c60_0
+  - libpng=1.6.34=hb9fc6fc_0
+  - libsodium=1.0.16=h1bed415_0
+  - libstdcxx-ng=8.2.0=hdf63c60_1
+  - libuuid=1.0.3=h1bed415_2
+  - libxcb=1.13=h1bed415_1
+  - libxml2=2.9.8=h26e45fe_1
+  - markupsafe=1.0=py36h14c3975_1
+  - matplotlib=2.2.3=py36hb69df0a_0
+  - mistune=0.8.3=py36h14c3975_1
+  - mkl=2019.0=118
+  - mkl_fft=1.0.4=py36h4414c95_1
+  - mkl_random=1.0.1=py36h4414c95_1
+  - nbconvert=5.3.1=py36_0
+  - nbformat=4.4.0=py36_0
+  - ncurses=6.1=hf484d3e_0
+  - notebook=5.6.0=py36_0
+  - numpy=1.14.5=py36h1b885b7_4
+  - numpy-base=1.14.5=py36hdbf6ddf_4
+  - openssl=1.0.2p=h14c3975_0
+  - pandas=0.23.4=py36h04863e7_0
+  - pandoc=2.2.3.2=0
+  - pandocfilters=1.4.2=py36_1
+  - parso=0.3.1=py36_0
+  - patsy=0.5.0=py36_0
+  - pcre=8.42=h439df22_0
+  - pexpect=4.6.0=py36_0
+  - pickleshare=0.7.4=py36_0
+  - pip=10.0.1=py36_0
+  - prometheus_client=0.3.1=py36h28b3542_0
+  - prompt_toolkit=1.0.15=py36_0
+  - protobuf=3.6.0=py36hf484d3e_0
+  - ptyprocess=0.6.0=py36_0
+  - pyasn1=0.4.4=py36h28b3542_0
+  - pyasn1-modules=0.2.2=py36_0
+  - pycparser=2.18=py36_1
+  - pydot=1.2.4=py36_0
+  - pygments=2.2.0=py36_0
+  - pyopenssl=18.0.0=py36_0
+  - pyparsing=2.2.0=py36_1
+  - pyqt=5.9.2=py36h05f1152_2
+  - python=3.6.6=hc3d631a_0
+  - python-dateutil=2.7.3=py36_0
+  - pytz=2018.5=py36_0
+  - pyzmq=17.1.2=py36h14c3975_0
+  - qt=5.9.6=h8703b6f_2
+  - readline=7.0=h7b6447c_5
+  - scikit-learn=0.20.0=py36h4989274_1
+  - scipy=1.1.0=py36hd20e5f9_0
+  - seaborn=0.9.0=py36_0
+  - send2trash=1.5.0=py36_0
+  - service_identity=17.0.0=py36h28b3542_0
+  - setuptools=39.1.0=py36_0
+  - simplegeneric=0.8.1=py36_2
+  - sip=4.19.8=py36hf484d3e_0
+  - six=1.11.0=py36_1
+  - sqlite=3.24.0=h84994c4_0
+  - statsmodels=0.9.0=py36h035aef0_0
+  - tensorboard=1.10.0=py36hf484d3e_0
+  - tensorflow=1.10.0=gpu_py36h97a2126_0
+  - tensorflow-base=1.10.0=gpu_py36h6ecc378_0
+  - tensorflow-gpu=1.10.0=hf154084_0
+  - terminado=0.8.1=py36_1
+  - testpath=0.3.1=py36_0
+  - tk=8.6.8=hbc83047_0
+  - tornado=5.1=py36h14c3975_0
+  - traitlets=4.3.2=py36_0
+  - twisted=18.7.0=py36h14c3975_1
+  - wcwidth=0.1.7=py36_0
+  - webencodings=0.5.1=py36_1
+  - wheel=0.31.1=py36_0
+  - xz=5.2.4=h14c3975_4
+  - zeromq=4.2.5=hf484d3e_1
+  - zlib=1.2.11=ha838bed_2
+  - zope=1.0=py36_1
+  - zope.interface=4.5.0=py36h14c3975_0
+  - pip:
+    - blinker==1.4
+    - chardet==3.0.4
+    - cloudpickle==0.5.6
+    - configparser==3.5.0
+    - dask==0.19.1
+    - kaggle==1.4.7.1
+    - keras==2.2.2
+    - keras-vis==0.4.1
+    - networkx==2.2rc1
+    - packaging==17.1
+    - pillow==5.2.0
+    - pyhamcrest==1.9.0
+    - python-slugify==1.2.6
+    - pywavelets==1.0.0
+    - requests==2.19.1
+    - scikit-image==0.14.0
+    - simpleitk==1.1.0
+    - toolz==0.9.0
+    - tqdm==4.26.0
+    - unidecode==1.0.22
+    - urllib3==1.22
+    - xgboost==0.90
diff --git a/model-py/regression-MKD.ipynb b/model-py/regression-MKD.ipynb
new file mode 100644
index 0000000..80dcc18
--- /dev/null
+++ b/model-py/regression-MKD.ipynb
@@ -0,0 +1,824 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "ModuleNotFoundError",
+     "evalue": "No module named 'xgboost'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-17-a78a46995753>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmpl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mxgboost\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mxgb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreprocessing\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPolynomialFeatures\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mensemble\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'xgboost'"
+     ]
+    }
+   ],
+   "source": [
+    "# memory management functions\n",
+    "import os\n",
+    "import psutil\n",
+    "\n",
+    "# math and models\n",
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "import matplotlib as mpl\n",
+    "import matplotlib.pyplot as plt\n",
+    "import xgboost as xgb\n",
+    "from sklearn.preprocessing import PolynomialFeatures\n",
+    "from sklearn import ensemble\n",
+    "from sklearn.metrics import r2_score\n",
+    "from sklearn import linear_model\n",
+    "from sklearn import gaussian_process as gpm\n",
+    "from sklearn.model_selection import train_test_split\n",
+    "from sklearn import tree"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Load dataset\n",
+    "df = pd.read_csv(\"/data/user/mdefende/Projects/CorticalThicknessModel/A/Scripts/Data/BigThick_Restricted_cleaned.csv\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "         Subject Gender  Age Hemi  Group        Ecc         Pol   normEcc  \\\n",
+      "0         100206      M   27   LH  Train   3.567456  177.411148  0.071349   \n",
+      "1         100206      M   27   LH  Train   3.537558  173.980743  0.070751   \n",
+      "2         100206      M   27   LH  Train   3.522834  177.403961  0.070457   \n",
+      "3         100206      M   27   LH  Train   3.497096  177.711700  0.069942   \n",
+      "4         100206      M   27   LH  Train   1.356428  156.274979  0.027129   \n",
+      "5         100206      M   27   LH  Train   1.323316  157.835205  0.026466   \n",
+      "6         100206      M   27   LH  Train   1.208326  166.347137  0.024167   \n",
+      "7         100206      M   27   LH  Train   1.176058  167.972031  0.023521   \n",
+      "8         100206      M   27   LH  Train   1.407720  148.838181  0.028154   \n",
+      "9         100206      M   27   LH  Train   1.341403  152.340546  0.026828   \n",
+      "10        100206      M   27   LH  Train   1.269300  157.293900  0.025386   \n",
+      "11        100206      M   27   LH  Train   1.198244  159.868393  0.023965   \n",
+      "12        100206      M   27   LH  Train   1.162055  163.119431  0.023241   \n",
+      "13        100206      M   27   LH  Train   1.461745  141.727310  0.029235   \n",
+      "14        100206      M   27   LH  Train   1.432764  143.279037  0.028655   \n",
+      "15        100206      M   27   LH  Train   1.354926  145.016632  0.027099   \n",
+      "16        100206      M   27   LH  Train   1.278286  147.711746  0.025566   \n",
+      "17        100206      M   27   LH  Train   1.213237  153.133057  0.024265   \n",
+      "18        100206      M   27   LH  Train   1.156069  157.608017  0.023121   \n",
+      "19        100206      M   27   LH  Train   1.452105  139.754242  0.029042   \n",
+      "20        100206      M   27   LH  Train   1.389827  139.581390  0.027797   \n",
+      "21        100206      M   27   LH  Train   1.312138  139.130508  0.026243   \n",
+      "22        100206      M   27   LH  Train   1.235464  141.854599  0.024709   \n",
+      "23        100206      M   27   LH  Train   1.288646  134.436020  0.025773   \n",
+      "24        100206      M   27   LH  Train   1.256301  137.291229  0.025126   \n",
+      "25        100206      M   27   LH  Train   1.265432  132.407181  0.025309   \n",
+      "26        100206      M   27   LH  Train   1.244605  134.291290  0.024892   \n",
+      "27        100206      M   27   LH  Train   4.478133  167.913254  0.089563   \n",
+      "28        100206      M   27   LH  Train   4.474024  169.971039  0.089481   \n",
+      "29        100206      M   27   LH  Train   4.458443  174.424789  0.089169   \n",
+      "...          ...    ...  ...  ...    ...        ...         ...       ...   \n",
+      "6787055   996782      F   28   RH  Train  34.812057    2.016624  0.696242   \n",
+      "6787056   996782      F   28   RH  Train  37.058197    0.325366  0.741165   \n",
+      "6787057   996782      F   28   RH  Train  36.378635    4.269973  0.727574   \n",
+      "6787058   996782      F   28   RH  Train  39.452621    4.759562  0.789053   \n",
+      "6787059   996782      F   28   RH  Train  38.722595    3.068579  0.774453   \n",
+      "6787060   996782      F   28   RH  Train  44.115536   12.932929  0.882312   \n",
+      "6787061   996782      F   28   RH  Train  42.664555   10.658121  0.853292   \n",
+      "6787062   996782      F   28   RH  Train  41.358227    8.524337  0.827166   \n",
+      "6787063   996782      F   28   RH  Train  40.848404    6.113272  0.816969   \n",
+      "6787064   996782      F   28   RH  Train  48.257523   15.541472  0.965152   \n",
+      "6787065   996782      F   28   RH  Train  45.951336   14.551472  0.919028   \n",
+      "6787066   996782      F   28   RH  Train  41.813080    2.612780  0.836263   \n",
+      "6787067   996782      F   28   RH  Train  40.911617    4.453729  0.818233   \n",
+      "6787068   996782      F   28   RH  Train  46.174717   10.218830  0.923495   \n",
+      "6787069   996782      F   28   RH  Train  45.041702    6.822008  0.900835   \n",
+      "6787070   996782      F   28   RH  Train  43.202980    5.293739  0.864061   \n",
+      "6787071   996782      F   28   RH  Train  42.252956    4.160832  0.845060   \n",
+      "6787072   996782      F   28   RH  Train  48.620617    9.229843  0.972414   \n",
+      "6787073   996782      F   28   RH  Train  46.881737    7.386467  0.937636   \n",
+      "6787074   996782      F   28   RH  Train  46.225166    2.961761  0.924504   \n",
+      "6787075   996782      F   28   RH  Train  45.898434    0.491990  0.917970   \n",
+      "6787076   996782      F   28   RH  Train  44.115589    2.866768  0.882313   \n",
+      "6787077   996782      F   28   RH  Train  49.843601    4.330243  0.996873   \n",
+      "6787078   996782      F   28   RH  Train  47.795212    3.122164  0.955905   \n",
+      "6787079   996782      F   28   RH  Train  47.197891    0.911866  0.943959   \n",
+      "6787080   996782      F   28   RH  Train  38.595798    3.042604  0.771917   \n",
+      "6787081   996782      F   28   RH  Train  16.830524    8.680782  0.336611   \n",
+      "6787082   996782      F   28   RH  Train  16.522949   11.613343  0.330459   \n",
+      "6787083   996782      F   28   RH  Train  20.728981  164.827118  0.414580   \n",
+      "6787084   996782      F   28   RH  Train  20.537703  161.797836  0.410755   \n",
+      "\n",
+      "          normPol     Sulc0    ...     PialArea5  PialArea10      LGI0  \\\n",
+      "0        0.485617 -0.969500    ...      0.740433    0.694201  2.111248   \n",
+      "1        0.466560 -1.018326    ...      0.750461    0.702102  2.128481   \n",
+      "2        0.485578 -0.995660    ...      0.737149    0.687681  2.125945   \n",
+      "3        0.487287 -1.027502    ...      0.731389    0.685778  2.124727   \n",
+      "4        0.368194 -0.753537    ...      0.721735    0.697463  2.133806   \n",
+      "5        0.376862 -0.753183    ...      0.742576    0.704914  2.135973   \n",
+      "6        0.424151 -0.791031    ...      0.830400    0.740700  2.152152   \n",
+      "7        0.433178 -0.818936    ...      0.848792    0.751308  2.154360   \n",
+      "8        0.326879 -0.780384    ...      0.690919    0.696721  2.129827   \n",
+      "9        0.346336 -0.789114    ...      0.730944    0.708385  2.134366   \n",
+      "10       0.373855 -0.779502    ...      0.780105    0.724135  2.146819   \n",
+      "11       0.388158 -0.816658    ...      0.821948    0.742594  2.150895   \n",
+      "12       0.406219 -0.842135    ...      0.845164    0.752521  2.153529   \n",
+      "13       0.287374 -0.792315    ...      0.665475    0.697910  2.119450   \n",
+      "14       0.295995 -0.779278    ...      0.678812    0.700818  2.128336   \n",
+      "15       0.305648 -0.814661    ...      0.717915    0.714605  2.132625   \n",
+      "16       0.320621 -0.822019    ...      0.768804    0.733194  2.145043   \n",
+      "17       0.350739 -0.816008    ...      0.809952    0.746315  2.149433   \n",
+      "18       0.375600 -0.837289    ...      0.839957    0.758518  2.154962   \n",
+      "19       0.276412 -0.802535    ...      0.673126    0.702462  2.120098   \n",
+      "20       0.275452 -0.798065    ...      0.701170    0.714524  2.130581   \n",
+      "21       0.272947 -0.827945    ...      0.745896    0.733508  2.136259   \n",
+      "22       0.288081 -0.807783    ...      0.787246    0.750777  2.148183   \n",
+      "23       0.246867 -0.808789    ...      0.762774    0.745232  2.135591   \n",
+      "24       0.262729 -0.802400    ...      0.777009    0.750630  2.145497   \n",
+      "25       0.235595 -0.769775    ...      0.773159    0.752456  2.135851   \n",
+      "26       0.246063 -0.761871    ...      0.781852    0.755997  2.145127   \n",
+      "27       0.432851 -1.085865    ...      0.932740    0.876360  2.153401   \n",
+      "28       0.444284 -1.017086    ...      0.913269    0.858155  2.154156   \n",
+      "29       0.469027 -0.917080    ...      0.867826    0.820888  2.152521   \n",
+      "...           ...       ...    ...           ...         ...       ...   \n",
+      "6787055 -0.488797 -0.660271    ...      1.141246    1.158239  3.289434   \n",
+      "6787056 -0.498192 -0.331669    ...      1.154787    1.123169  3.285268   \n",
+      "6787057 -0.476278 -0.484029    ...      1.129800    1.148906  3.287321   \n",
+      "6787058 -0.473558 -0.127753    ...      1.199641    1.064739  3.281233   \n",
+      "6787059 -0.482952 -0.216425    ...      1.187146    1.099584  3.279390   \n",
+      "6787060 -0.428150  0.329903    ...      0.963070    0.903809  3.274556   \n",
+      "6787061 -0.440788  0.185611    ...      1.071004    0.955178  3.276377   \n",
+      "6787062 -0.452643  0.052616    ...      1.155293    1.000820  3.278221   \n",
+      "6787063 -0.466037 -0.039510    ...      1.201619    1.035581  3.278041   \n",
+      "6787064 -0.413658  0.539896    ...      0.760521    0.798241  3.267780   \n",
+      "6787065 -0.419158  0.452117    ...      0.856887    0.852150  3.271790   \n",
+      "6787066 -0.485485 -0.058604    ...      1.223373    1.060463  3.271061   \n",
+      "6787067 -0.475257 -0.149212    ...      1.204664    1.094330  3.270529   \n",
+      "6787068 -0.443229  0.360516    ...      0.955320    0.896229  3.267207   \n",
+      "6787069 -0.462100  0.245724    ...      1.067826    0.946347  3.265846   \n",
+      "6787070 -0.470590  0.106576    ...      1.160608    0.996125  3.269724   \n",
+      "6787071 -0.476884  0.013787    ...      1.207522    1.028556  3.272161   \n",
+      "6787072 -0.448723  0.433305    ...      0.887894    0.862750  3.260500   \n",
+      "6787073 -0.458964  0.337906    ...      0.995089    0.912124  3.262475   \n",
+      "6787074 -0.483546  0.227673    ...      1.118196    0.965010  3.259609   \n",
+      "6787075 -0.497267  0.153000    ...      1.166645    0.994987  3.259659   \n",
+      "6787076 -0.484074  0.031406    ...      1.217565    1.043899  3.261460   \n",
+      "6787077 -0.475943  0.400194    ...      0.956013    0.894331  3.251189   \n",
+      "6787078 -0.482655  0.306412    ...      1.056971    0.938883  3.255619   \n",
+      "6787079 -0.494934  0.229746    ...      1.137643    0.974020  3.257046   \n",
+      "6787080 -0.483097  0.658627    ...      0.420545    0.492509  3.210731   \n",
+      "6787081 -0.451773 -0.431437    ...      0.999800    1.061583  3.218198   \n",
+      "6787082 -0.435481 -0.535194    ...      1.011416    1.075808  3.224614   \n",
+      "6787083  0.415706 -1.041680    ...      1.024205    1.104097  3.504762   \n",
+      "6787084  0.398877 -0.901922    ...      1.004598    1.093961  3.491673   \n",
+      "\n",
+      "             LGI2      LGI5     LGI10    Thick0    Thick2    Thick5   Thick10  \n",
+      "0        2.121056  2.126482  2.140920  2.766883  2.791073  2.667062  2.302786  \n",
+      "1        2.123977  2.126430  2.139703  2.754594  2.762720  2.609997  2.266771  \n",
+      "2        2.120204  2.123042  2.136645  2.563326  2.709144  2.606383  2.279676  \n",
+      "3        2.122354  2.121592  2.134473  2.659624  2.671067  2.558893  2.259900  \n",
+      "4        2.134276  2.135392  2.149016  2.353268  2.326346  2.211919  2.088902  \n",
+      "5        2.137616  2.138694  2.152027  2.451419  2.417072  2.275516  2.118037  \n",
+      "6        2.152114  2.153917  2.167282  2.743103  2.707918  2.467123  2.229415  \n",
+      "7        2.155263  2.158118  2.171818  2.707542  2.705797  2.487149  2.254589  \n",
+      "8        2.129190  2.130467  2.144590  2.246106  2.236198  2.093843  2.024529  \n",
+      "9        2.136051  2.136491  2.149751  2.388690  2.379972  2.229577  2.091792  \n",
+      "10       2.144282  2.144277  2.157050  2.621783  2.559522  2.366980  2.163589  \n",
+      "11       2.150865  2.151188  2.163861  2.714028  2.666185  2.455846  2.219241  \n",
+      "12       2.154374  2.156217  2.169374  2.716976  2.686966  2.489331  2.250992  \n",
+      "13       2.121997  2.126662  2.141163  2.196349  2.171223  1.991173  1.955873  \n",
+      "14       2.125494  2.128615  2.142797  2.239169  2.217598  2.042992  1.985392  \n",
+      "15       2.133461  2.134429  2.147610  2.380691  2.340416  2.174114  2.053984  \n",
+      "16       2.142832  2.142251  2.154533  2.551684  2.460895  2.311058  2.134976  \n",
+      "17       2.149742  2.149195  2.161355  2.636964  2.582885  2.418631  2.202731  \n",
+      "18       2.154959  2.155726  2.168393  2.679090  2.629709  2.476009  2.250769  \n",
+      "19       2.122582  2.127754  2.141892  2.208363  2.188120  2.000082  1.956584  \n",
+      "20       2.128525  2.131949  2.145266  2.298500  2.292803  2.100548  2.006651  \n",
+      "21       2.136660  2.138601  2.150841  2.398277  2.338111  2.209595  2.074679  \n",
+      "22       2.146215  2.145615  2.157313  2.387249  2.382713  2.320101  2.152505  \n",
+      "23       2.137532  2.141211  2.152861  2.224107  2.249597  2.198439  2.077878  \n",
+      "24       2.142820  2.143733  2.155350  2.257848  2.285086  2.262644  2.119502  \n",
+      "25       2.138554  2.143358  2.154782  2.174568  2.192579  2.195138  2.085485  \n",
+      "26       2.141907  2.144859  2.156304  2.168352  2.216717  2.240871  2.114422  \n",
+      "27       2.155405  2.166630  2.193880  2.406428  2.444317  2.371739  2.277808  \n",
+      "28       2.155127  2.166830  2.193293  2.537748  2.476763  2.369525  2.277276  \n",
+      "29       2.155168  2.167406  2.192119  2.596547  2.486246  2.366198  2.276222  \n",
+      "...           ...       ...       ...       ...       ...       ...       ...  \n",
+      "6787055  3.290782  3.279967  3.266013  1.949855  2.043284  1.987879  2.085439  \n",
+      "6787056  3.283882  3.274675  3.259445  1.949113  1.768108  1.809772  1.988021  \n",
+      "6787057  3.284873  3.273376  3.258670  1.996609  1.997058  1.924324  2.049818  \n",
+      "6787058  3.280645  3.275146  3.259971  1.269301  1.359185  1.584634  1.872575  \n",
+      "6787059  3.279444  3.271110  3.255348  1.547230  1.551830  1.705185  1.928807  \n",
+      "6787060  3.274055  3.272218  3.263479  1.147584  1.196215  1.370170  1.750175  \n",
+      "6787061  3.275824  3.273360  3.262072  1.047854  1.186943  1.385340  1.766917  \n",
+      "6787062  3.277665  3.274319  3.261054  1.111057  1.208556  1.431513  1.795516  \n",
+      "6787063  3.277556  3.272818  3.258115  1.216371  1.253253  1.494926  1.826654  \n",
+      "6787064  3.267160  3.265626  3.261921  1.358387  1.344649  1.387528  1.755428  \n",
+      "6787065  3.271195  3.269597  3.263417  1.181586  1.252182  1.378257  1.747487  \n",
+      "6787066  3.269981  3.262042  3.246612  1.228040  1.353159  1.573190  1.853610  \n",
+      "6787067  3.269943  3.260183  3.244612  1.397627  1.504455  1.694168  1.910865  \n",
+      "6787068  3.266852  3.263960  3.253390  0.983319  1.096354  1.321554  1.729233  \n",
+      "6787069  3.265382  3.260795  3.247450  0.967814  1.095351  1.341753  1.743206  \n",
+      "6787070  3.268543  3.262680  3.248484  1.237979  1.202344  1.418325  1.780673  \n",
+      "6787071  3.271250  3.264894  3.249985  1.200267  1.254715  1.478789  1.811468  \n",
+      "6787072  3.260072  3.257145  3.246468  1.077476  1.111898  1.291877  1.723117  \n",
+      "6787073  3.262141  3.258257  3.245415  0.943625  1.041473  1.299829  1.724172  \n",
+      "6787074  3.259224  3.252843  3.237209  0.984138  1.082332  1.336643  1.742954  \n",
+      "6787075  3.258257  3.249845  3.234604  1.194583  1.209887  1.413663  1.775008  \n",
+      "6787076  3.259477  3.249144  3.234652  1.509456  1.398174  1.551992  1.834241  \n",
+      "6787077  3.251195  3.246758  3.231455  0.938096  1.013600  1.261594  1.722768  \n",
+      "6787078  3.255130  3.249057  3.233668  1.010567  1.053410  1.313520  1.736925  \n",
+      "6787079  3.256105  3.248528  3.232362  1.054192  1.108953  1.352336  1.750564  \n",
+      "6787080  3.210914  3.211640  3.200303  2.044960  2.000274  1.990393  2.203855  \n",
+      "6787081  3.228139  3.252401  3.268084  2.604342  2.548727  2.245316  2.072108  \n",
+      "6787082  3.237871  3.259982  3.272573  2.712809  2.573531  2.248212  2.068789  \n",
+      "6787083  3.498766  3.497312  3.485821  2.278427  2.137778  1.888041  1.689229  \n",
+      "6787084  3.493212  3.492619  3.482487  2.073916  2.011232  1.824685  1.675286  \n",
+      "\n",
+      "[6787085 rows x 38 columns]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(df)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "array(['Subject', 'Gender', 'Age', 'Hemi', 'Group', 'Ecc', 'Pol',\n",
+       "       'normEcc', 'normPol', 'Sulc0', 'Curv0', 'Curv2', 'Curv5', 'Curv10',\n",
+       "       'PialCurv0', 'PialCurv2', 'PialCurv5', 'PialCurv10', 'Area0',\n",
+       "       'Area2', 'Area5', 'Area10', 'MidArea', 'MidArea2', 'MidArea5',\n",
+       "       'MidArea10', 'PialArea0', 'PialArea2', 'PialArea5', 'PialArea10',\n",
+       "       'LGI0', 'LGI2', 'LGI5', 'LGI10', 'Thick0', 'Thick2', 'Thick5',\n",
+       "       'Thick10'], dtype=object)"
+      ]
+     },
+     "execution_count": 4,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "df.columns.values"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Code dummy variables for subject, hemisphere, and gender\n",
+    "g_d = pd.get_dummies(df['Gender'])\n",
+    "h_d = pd.get_dummies(df['Hemi'])\n",
+    "s_d = pd.get_dummies(df['Subject'])\n",
+    "\n",
+    "df = pd.concat([df,g_d,h_d,s_d], axis = 1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Split into test and train dataframes\n",
+    "df_train = df[df.Group == \"Train\"]\n",
+    "df_test = df[df.Group == \"Test\"]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# df_x are your input features\n",
+    "x_train = pd.concat([df_train[['normEcc','normPol','Sulc0','Curv5','Area5','LGI5']]]) #, df_train.loc[:,'F':'RH']], axis = 1)\n",
+    "x_test = pd.concat([df_test[['normEcc','normPol','Sulc0','Curv5','Area5','LGI5']]]) #, df_test.loc[:,'F':'RH']], axis = 1)\n",
+    "\n",
+    "# df_y is your output feature (the one you want to predict)\n",
+    "y_train = df_train[['Thick5']]\n",
+    "y_test = df_test[['Thick5']]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Once deleted, variables cannot be recovered. Proceed (y/[n])?  y\n"
+     ]
+    }
+   ],
+   "source": [
+    "%reset_selective df"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "          normEcc   normPol     Sulc0     Curv5     Area5      LGI5\n",
+      "0        0.071349  0.485617 -0.969500 -0.175257  0.365139  2.126482\n",
+      "1        0.070751  0.466560 -1.018326 -0.194592  0.355679  2.126430\n",
+      "2        0.070457  0.485578 -0.995660 -0.206786  0.339150  2.123042\n",
+      "3        0.069942  0.487287 -1.027502 -0.226170  0.325459  2.121592\n",
+      "4        0.027129  0.368194 -0.753537 -0.167363  0.467997  2.135392\n",
+      "5        0.026466  0.376862 -0.753183 -0.164006  0.468342  2.138694\n",
+      "6        0.024167  0.424151 -0.791031 -0.162873  0.469467  2.153917\n",
+      "7        0.023521  0.433178 -0.818936 -0.164371  0.468871  2.158118\n",
+      "8        0.028154  0.326879 -0.780384 -0.201563  0.442048  2.130467\n",
+      "9        0.026828  0.346336 -0.789114 -0.183105  0.449047  2.136491\n",
+      "10       0.025386  0.373855 -0.779502 -0.169546  0.457724  2.144277\n",
+      "11       0.023965  0.388158 -0.816658 -0.168231  0.456454  2.151188\n",
+      "12       0.023241  0.406219 -0.842135 -0.166211  0.460725  2.156217\n",
+      "13       0.029235  0.287374 -0.792315 -0.241704  0.411581  2.126662\n",
+      "14       0.028655  0.295995 -0.779278 -0.229542  0.417538  2.128615\n",
+      "15       0.027099  0.305648 -0.814661 -0.211252  0.422046  2.134429\n",
+      "16       0.025566  0.320621 -0.822019 -0.193055  0.432451  2.142251\n",
+      "17       0.024265  0.350739 -0.816008 -0.177756  0.444004  2.149195\n",
+      "18       0.023121  0.375600 -0.837289 -0.170196  0.451620  2.155726\n",
+      "19       0.029042  0.276412 -0.802535 -0.244678  0.407580  2.127754\n",
+      "20       0.027797  0.275452 -0.798065 -0.233359  0.405415  2.131949\n",
+      "21       0.026243  0.272947 -0.827945 -0.218111  0.412554  2.138601\n",
+      "22       0.024709  0.288081 -0.807783 -0.199763  0.425284  2.145615\n",
+      "23       0.025773  0.246867 -0.808789 -0.219269  0.413580  2.141211\n",
+      "24       0.025126  0.262729 -0.802400 -0.210798  0.417964  2.143733\n",
+      "25       0.025309  0.235595 -0.769775 -0.215277  0.418453  2.143358\n",
+      "26       0.024892  0.246063 -0.761871 -0.211038  0.419737  2.144859\n",
+      "27       0.089563  0.432851 -1.085865 -0.227432  0.521610  2.166630\n",
+      "28       0.089481  0.444284 -1.017086 -0.214535  0.533691  2.166830\n",
+      "29       0.089169  0.469027 -0.917080 -0.183402  0.558986  2.167406\n",
+      "...           ...       ...       ...       ...       ...       ...\n",
+      "6787055  0.696242 -0.488797 -0.660271 -0.252420  0.719446  3.279967\n",
+      "6787056  0.741165 -0.498192 -0.331669 -0.221859  0.742614  3.274675\n",
+      "6787057  0.727574 -0.476278 -0.484029 -0.255619  0.715220  3.273376\n",
+      "6787058  0.789053 -0.473558 -0.127753 -0.135195  0.808798  3.275146\n",
+      "6787059  0.774453 -0.482952 -0.216425 -0.193565  0.769054  3.271110\n",
+      "6787060  0.882312 -0.428150  0.329903  0.037845  0.842381  3.272218\n",
+      "6787061  0.853292 -0.440788  0.185611  0.000564  0.853427  3.273360\n",
+      "6787062  0.827166 -0.452643  0.052616 -0.045616  0.850372  3.274319\n",
+      "6787063  0.816969 -0.466037 -0.039510 -0.092075  0.834119  3.272818\n",
+      "6787064  0.965152 -0.413658  0.539896  0.071233  0.782826  3.265626\n",
+      "6787065  0.919028 -0.419158  0.452117  0.059935  0.817088  3.269597\n",
+      "6787066  0.836263 -0.485485 -0.058604 -0.140059  0.804893  3.262042\n",
+      "6787067  0.818233 -0.475257 -0.149212 -0.197167  0.768328  3.260183\n",
+      "6787068  0.923495 -0.443229  0.360516  0.039484  0.821977  3.263960\n",
+      "6787069  0.900835 -0.462100  0.245724  0.001199  0.829692  3.260795\n",
+      "6787070  0.864061 -0.470590  0.106576 -0.049715  0.834898  3.262680\n",
+      "6787071  0.845060 -0.476884  0.013787 -0.088670  0.830544  3.264894\n",
+      "6787072  0.972414 -0.448723  0.433305  0.049380  0.791541  3.257145\n",
+      "6787073  0.937636 -0.458964  0.337906  0.026758  0.814743  3.258257\n",
+      "6787074  0.924504 -0.483546  0.227673 -0.018039  0.819589  3.252843\n",
+      "6787075  0.917970 -0.497267  0.153000 -0.057800  0.814234  3.249845\n",
+      "6787076  0.882313 -0.484074  0.031406 -0.127112  0.798611  3.249144\n",
+      "6787077  0.996873 -0.475943  0.400194  0.028088  0.785327  3.246758\n",
+      "6787078  0.955905 -0.482655  0.306412 -0.001943  0.805595  3.249057\n",
+      "6787079  0.943959 -0.494934  0.229746 -0.031343  0.813216  3.248528\n",
+      "6787080  0.771917 -0.483097  0.658627  0.117806  0.740952  3.211640\n",
+      "6787081  0.336611 -0.451773 -0.431437 -0.245727  0.728921  3.252401\n",
+      "6787082  0.330459 -0.435481 -0.535194 -0.267580  0.693892  3.259982\n",
+      "6787083  0.414580  0.415706 -1.041680 -0.293193  0.648885  3.497312\n",
+      "6787084  0.410755  0.398877 -0.901922 -0.296768  0.653152  3.492619\n",
+      "\n",
+      "[5124388 rows x 6 columns]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(x_train)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>normEcc</th>\n",
+       "      <th>normPol</th>\n",
+       "      <th>Sulc0</th>\n",
+       "      <th>Curv5</th>\n",
+       "      <th>Area5</th>\n",
+       "      <th>LGI5</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>count</th>\n",
+       "      <td>5.124388e+06</td>\n",
+       "      <td>5.124388e+06</td>\n",
+       "      <td>5.124388e+06</td>\n",
+       "      <td>5.124388e+06</td>\n",
+       "      <td>5.124388e+06</td>\n",
+       "      <td>5.124388e+06</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>mean</th>\n",
+       "      <td>2.386589e-01</td>\n",
+       "      <td>4.113428e-02</td>\n",
+       "      <td>-1.747371e-01</td>\n",
+       "      <td>-5.394791e-02</td>\n",
+       "      <td>6.745831e-01</td>\n",
+       "      <td>2.809199e+00</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>std</th>\n",
+       "      <td>2.288182e-01</td>\n",
+       "      <td>2.760123e-01</td>\n",
+       "      <td>5.721855e-01</td>\n",
+       "      <td>1.514886e-01</td>\n",
+       "      <td>9.922468e-02</td>\n",
+       "      <td>4.028898e-01</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>min</th>\n",
+       "      <td>2.424506e-05</td>\n",
+       "      <td>-4.999273e-01</td>\n",
+       "      <td>-1.971219e+00</td>\n",
+       "      <td>-2.368729e+00</td>\n",
+       "      <td>2.437784e-01</td>\n",
+       "      <td>1.959253e+00</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>25%</th>\n",
+       "      <td>7.679665e-02</td>\n",
+       "      <td>-1.714563e-01</td>\n",
+       "      <td>-6.353022e-01</td>\n",
+       "      <td>-1.814308e-01</td>\n",
+       "      <td>6.198345e-01</td>\n",
+       "      <td>2.474502e+00</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>50%</th>\n",
+       "      <td>1.539186e-01</td>\n",
+       "      <td>6.200975e-02</td>\n",
+       "      <td>-2.448606e-01</td>\n",
+       "      <td>-6.423837e-02</td>\n",
+       "      <td>6.823665e-01</td>\n",
+       "      <td>2.741358e+00</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>75%</th>\n",
+       "      <td>3.278529e-01</td>\n",
+       "      <td>2.621307e-01</td>\n",
+       "      <td>2.495521e-01</td>\n",
+       "      <td>7.672764e-02</td>\n",
+       "      <td>7.347262e-01</td>\n",
+       "      <td>3.102216e+00</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>max</th>\n",
+       "      <td>1.000000e+00</td>\n",
+       "      <td>4.999642e-01</td>\n",
+       "      <td>2.009938e+00</td>\n",
+       "      <td>1.398352e+00</td>\n",
+       "      <td>2.217891e+00</td>\n",
+       "      <td>4.249628e+00</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "            normEcc       normPol         Sulc0         Curv5         Area5  \\\n",
+       "count  5.124388e+06  5.124388e+06  5.124388e+06  5.124388e+06  5.124388e+06   \n",
+       "mean   2.386589e-01  4.113428e-02 -1.747371e-01 -5.394791e-02  6.745831e-01   \n",
+       "std    2.288182e-01  2.760123e-01  5.721855e-01  1.514886e-01  9.922468e-02   \n",
+       "min    2.424506e-05 -4.999273e-01 -1.971219e+00 -2.368729e+00  2.437784e-01   \n",
+       "25%    7.679665e-02 -1.714563e-01 -6.353022e-01 -1.814308e-01  6.198345e-01   \n",
+       "50%    1.539186e-01  6.200975e-02 -2.448606e-01 -6.423837e-02  6.823665e-01   \n",
+       "75%    3.278529e-01  2.621307e-01  2.495521e-01  7.672764e-02  7.347262e-01   \n",
+       "max    1.000000e+00  4.999642e-01  2.009938e+00  1.398352e+00  2.217891e+00   \n",
+       "\n",
+       "               LGI5  \n",
+       "count  5.124388e+06  \n",
+       "mean   2.809199e+00  \n",
+       "std    4.028898e-01  \n",
+       "min    1.959253e+00  \n",
+       "25%    2.474502e+00  \n",
+       "50%    2.741358e+00  \n",
+       "75%    3.102216e+00  \n",
+       "max    4.249628e+00  "
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# prints out bunch of statistics\n",
+    "x_train.describe()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Create interaction terms for the predictors\n",
+    "interaction = PolynomialFeatures(degree=1, include_bias=True, interaction_only=False)\n",
+    "X_train = interaction.fit_transform(x_train)\n",
+    "X_test = interaction.fit_transform(x_test)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(5124388, 7)"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "X_train.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Linear models\n",
+    "#reg=linear_model.LinearRegression(fit_intercept=True)\n",
+    "#reg=linear_model.Ridge(alpha=0.9)\n",
+    "#reg=linear_model.Lasso(alpha=0.1)\n",
+    "#reg=linear_model.Lars()\n",
+    "\n",
+    "# non linear model. From simulations below, looks like ideal max_depth ~ 16 and min_samples_leaf ~ 60\n",
+    "#reg=tree.DecisionTreeRegressor(min_samples_leaf=60, max_depth=16)\n",
+    "#reg=gpm.GaussianProcessRegressor()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "TypeError",
+     "evalue": "__init__() got an unexpected keyword argument 'tree_method'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-16-afab54435052>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m params = {'n_estimators': 500, 'max_depth': 20, 'min_samples_split': 20,\n\u001b[1;32m      3\u001b[0m           'learning_rate': 0.01, 'loss': 'ls', 'tree_method': 'gpu_hist', 'verbose': 1}\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mreg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mensemble\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGradientBoostingRegressor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'tree_method'"
+     ]
+    }
+   ],
+   "source": [
+    "# Testing out Gradient Boost Decision Tree\n",
+    "params = {'n_estimators': 500, 'max_depth': 20, 'min_samples_split': 20,\n",
+    "          'learning_rate': 0.01, 'loss': 'ls', 'tree_method': 'gpu_hist', 'verbose': 1}\n",
+    "reg = ensemble.GradientBoostingRegressor(**params)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# this is where your model is being trained\n",
+    "\n",
+    "reg.fit(x_train,y_train)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# for linear models you can enable the following comment and check out all the coefficients of yout linear model\n",
+    "#reg.coef_"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# the 20% of the test data that are used for testing purpose\n",
+    "#a=reg.predict(x_test)\n",
+    "b=reg.predict(x_train)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ensuring that the output has the correct dimensions\n",
+    "#a.shape = (a.size, 1)\n",
+    "b.shape = (b.size, 1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# calculate the mean square error (lower the error better yourmodel is)\n",
+    "#TestRMSE = np.sqrt(np.mean((a-y_test.values)**2))\n",
+    "#display('Test RMSE = {}'.format(TestRMSE))\n",
+    "\n",
+    "TrainRMSE = np.sqrt(np.mean((b-y_train.values)**2))\n",
+    "display('Train RMSE = {}'.format(TrainRMSE))\n",
+    "\n",
+    "#TestMAE = np.mean(np.abs(a-y_test.values))\n",
+    "#display('Test MAE = {}'.format(TestMAE))\n",
+    "\n",
+    "TrainMAE = np.mean(np.abs(b-y_train.values))\n",
+    "display('Train MAE = {}'.format(TrainMAE))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# R squared metric\n",
+    "display(\"Test R squared = {}\".format(r2_score(y_test,a)))\n",
+    "display(\"Train R squared = {}\".format(r2_score(y_train,b)))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# max_depth paramater tuning for the decision tree\n",
+    "max_depths = np.linspace(1, 32, 32, endpoint=True)\n",
+    "train_results = []\n",
+    "test_results = []\n",
+    "for max_depth in max_depths:\n",
+    "    # Create the tree object\n",
+    "    dt = tree.DecisionTreeRegressor(max_depth=max_depth)\n",
+    "    \n",
+    "    # Train the tree\n",
+    "    dt.fit(x_train, y_train)\n",
+    "    \n",
+    "    # Predict the values for the training set\n",
+    "    train_pred = dt.predict(x_train)   \n",
+    "    train_pred.shape = (train_pred.size,1)\n",
+    "    \n",
+    "    # Calculate RMSE and add to the train_results array\n",
+    "    TrainRMSE = np.sqrt(np.mean((train_pred-y_train.values)**2))\n",
+    "    train_results.append(TrainRMSE)   \n",
+    "    \n",
+    "    # Predict the values for the test set\n",
+    "    test_pred = dt.predict(x_test)  \n",
+    "    test_pred.shape = (test_pred.size,1)\n",
+    "    \n",
+    "    # Calculate RMSE and add to the train_results array\n",
+    "    TestRMSE = np.sqrt(np.mean((test_pred-y_test.values)**2))\n",
+    "    test_results.append(TestRMSE)\n",
+    "\n",
+    "from matplotlib.legend_handler import HandlerLine2D\n",
+    "line1, = plt.plot(max_depths, train_results, 'b', label = \"Train RMSE\")\n",
+    "line2, = plt.plot(max_depths, test_results, 'r', label = \"Test RMSE\")\n",
+    "\n",
+    "plt.legend(handler_map={line1: HandlerLine2D(numpoints=2)})\n",
+    "plt.ylabel('RMSE')\n",
+    "plt.xlabel('Tree depth')\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# min_samples_leaf paramater tuning for the decision tree\n",
+    "min_samples = np.linspace(5, 100, 20, endpoint=True)\n",
+    "train_results = []\n",
+    "test_results = []\n",
+    "for n in min_samples:\n",
+    "    # Create the tree object\n",
+    "    dt = tree.DecisionTreeRegressor(min_samples_leaf=n/x_train.size)\n",
+    "    \n",
+    "    # Train the tree\n",
+    "    dt.fit(x_train, y_train)\n",
+    "    \n",
+    "    # Predict the values for the training set\n",
+    "    train_pred = dt.predict(x_train)   \n",
+    "    train_pred.shape = (train_pred.size,1)\n",
+    "    \n",
+    "    # Calculate RMSE and add to the train_results array\n",
+    "    TrainRMSE = np.sqrt(np.mean((train_pred-y_train.values)**2))\n",
+    "    train_results.append(TrainRMSE)   \n",
+    "    \n",
+    "    # Predict the values for the test set\n",
+    "    test_pred = dt.predict(x_test)  \n",
+    "    test_pred.shape = (test_pred.size,1)\n",
+    "    \n",
+    "    # Calculate RMSE and add to the train_results array\n",
+    "    TestRMSE = np.sqrt(np.mean((test_pred-y_test.values)**2))\n",
+    "    test_results.append(TestRMSE)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from matplotlib.legend_handler import HandlerLine2D\n",
+    "line1, = plt.plot(min_samples, train_results, 'b', label = \"Train RMSE\")\n",
+    "line2, = plt.plot(min_samples, test_results, 'r', label = \"Test RMSE\")\n",
+    "\n",
+    "plt.legend(handler_map={line1: HandlerLine2D(numpoints=2)})\n",
+    "plt.ylabel('RMSE')\n",
+    "plt.xlabel('Min Samples per Leaf')\n",
+    "plt.show()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/model-py/regression.ipynb b/model-py/regression.ipynb
new file mode 100644
index 0000000..f52aaa9
--- /dev/null
+++ b/model-py/regression.ipynb
@@ -0,0 +1,535 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "from sklearn import linear_model\n",
+    "from sklearn.model_selection import train_test_split\n",
+    "from sklearn import tree"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from sklearn.datasets import load_boston\n",
+    "boston =load_boston()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(506, 13)\n"
+     ]
+    }
+   ],
+   "source": [
+    "print (boston.data.shape)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[[6.3200e-03 1.8000e+01 2.3100e+00 ... 1.5300e+01 3.9690e+02 4.9800e+00]\n",
+      " [2.7310e-02 0.0000e+00 7.0700e+00 ... 1.7800e+01 3.9690e+02 9.1400e+00]\n",
+      " [2.7290e-02 0.0000e+00 7.0700e+00 ... 1.7800e+01 3.9283e+02 4.0300e+00]\n",
+      " ...\n",
+      " [6.0760e-02 0.0000e+00 1.1930e+01 ... 2.1000e+01 3.9690e+02 5.6400e+00]\n",
+      " [1.0959e-01 0.0000e+00 1.1930e+01 ... 2.1000e+01 3.9345e+02 6.4800e+00]\n",
+      " [4.7410e-02 0.0000e+00 1.1930e+01 ... 2.1000e+01 3.9690e+02 7.8800e+00]]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print (boston.data)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# df_x are your input features\n",
+    "df_x=pd.DataFrame(boston.data,columns=boston.feature_names)\n",
+    "\n",
+    "# df_y is your output feature (the one you want to predict)\n",
+    "df_y=pd.DataFrame(boston.target)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>CRIM</th>\n",
+       "      <th>ZN</th>\n",
+       "      <th>INDUS</th>\n",
+       "      <th>CHAS</th>\n",
+       "      <th>NOX</th>\n",
+       "      <th>RM</th>\n",
+       "      <th>AGE</th>\n",
+       "      <th>DIS</th>\n",
+       "      <th>RAD</th>\n",
+       "      <th>TAX</th>\n",
+       "      <th>PTRATIO</th>\n",
+       "      <th>B</th>\n",
+       "      <th>LSTAT</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>count</th>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "      <td>506.000000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>mean</th>\n",
+       "      <td>3.613524</td>\n",
+       "      <td>11.363636</td>\n",
+       "      <td>11.136779</td>\n",
+       "      <td>0.069170</td>\n",
+       "      <td>0.554695</td>\n",
+       "      <td>6.284634</td>\n",
+       "      <td>68.574901</td>\n",
+       "      <td>3.795043</td>\n",
+       "      <td>9.549407</td>\n",
+       "      <td>408.237154</td>\n",
+       "      <td>18.455534</td>\n",
+       "      <td>356.674032</td>\n",
+       "      <td>12.653063</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>std</th>\n",
+       "      <td>8.601545</td>\n",
+       "      <td>23.322453</td>\n",
+       "      <td>6.860353</td>\n",
+       "      <td>0.253994</td>\n",
+       "      <td>0.115878</td>\n",
+       "      <td>0.702617</td>\n",
+       "      <td>28.148861</td>\n",
+       "      <td>2.105710</td>\n",
+       "      <td>8.707259</td>\n",
+       "      <td>168.537116</td>\n",
+       "      <td>2.164946</td>\n",
+       "      <td>91.294864</td>\n",
+       "      <td>7.141062</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>min</th>\n",
+       "      <td>0.006320</td>\n",
+       "      <td>0.000000</td>\n",
+       "      <td>0.460000</td>\n",
+       "      <td>0.000000</td>\n",
+       "      <td>0.385000</td>\n",
+       "      <td>3.561000</td>\n",
+       "      <td>2.900000</td>\n",
+       "      <td>1.129600</td>\n",
+       "      <td>1.000000</td>\n",
+       "      <td>187.000000</td>\n",
+       "      <td>12.600000</td>\n",
+       "      <td>0.320000</td>\n",
+       "      <td>1.730000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>25%</th>\n",
+       "      <td>0.082045</td>\n",
+       "      <td>0.000000</td>\n",
+       "      <td>5.190000</td>\n",
+       "      <td>0.000000</td>\n",
+       "      <td>0.449000</td>\n",
+       "      <td>5.885500</td>\n",
+       "      <td>45.025000</td>\n",
+       "      <td>2.100175</td>\n",
+       "      <td>4.000000</td>\n",
+       "      <td>279.000000</td>\n",
+       "      <td>17.400000</td>\n",
+       "      <td>375.377500</td>\n",
+       "      <td>6.950000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>50%</th>\n",
+       "      <td>0.256510</td>\n",
+       "      <td>0.000000</td>\n",
+       "      <td>9.690000</td>\n",
+       "      <td>0.000000</td>\n",
+       "      <td>0.538000</td>\n",
+       "      <td>6.208500</td>\n",
+       "      <td>77.500000</td>\n",
+       "      <td>3.207450</td>\n",
+       "      <td>5.000000</td>\n",
+       "      <td>330.000000</td>\n",
+       "      <td>19.050000</td>\n",
+       "      <td>391.440000</td>\n",
+       "      <td>11.360000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>75%</th>\n",
+       "      <td>3.677083</td>\n",
+       "      <td>12.500000</td>\n",
+       "      <td>18.100000</td>\n",
+       "      <td>0.000000</td>\n",
+       "      <td>0.624000</td>\n",
+       "      <td>6.623500</td>\n",
+       "      <td>94.075000</td>\n",
+       "      <td>5.188425</td>\n",
+       "      <td>24.000000</td>\n",
+       "      <td>666.000000</td>\n",
+       "      <td>20.200000</td>\n",
+       "      <td>396.225000</td>\n",
+       "      <td>16.955000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>max</th>\n",
+       "      <td>88.976200</td>\n",
+       "      <td>100.000000</td>\n",
+       "      <td>27.740000</td>\n",
+       "      <td>1.000000</td>\n",
+       "      <td>0.871000</td>\n",
+       "      <td>8.780000</td>\n",
+       "      <td>100.000000</td>\n",
+       "      <td>12.126500</td>\n",
+       "      <td>24.000000</td>\n",
+       "      <td>711.000000</td>\n",
+       "      <td>22.000000</td>\n",
+       "      <td>396.900000</td>\n",
+       "      <td>37.970000</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "             CRIM          ZN       INDUS        CHAS         NOX          RM  \\\n",
+       "count  506.000000  506.000000  506.000000  506.000000  506.000000  506.000000   \n",
+       "mean     3.613524   11.363636   11.136779    0.069170    0.554695    6.284634   \n",
+       "std      8.601545   23.322453    6.860353    0.253994    0.115878    0.702617   \n",
+       "min      0.006320    0.000000    0.460000    0.000000    0.385000    3.561000   \n",
+       "25%      0.082045    0.000000    5.190000    0.000000    0.449000    5.885500   \n",
+       "50%      0.256510    0.000000    9.690000    0.000000    0.538000    6.208500   \n",
+       "75%      3.677083   12.500000   18.100000    0.000000    0.624000    6.623500   \n",
+       "max     88.976200  100.000000   27.740000    1.000000    0.871000    8.780000   \n",
+       "\n",
+       "              AGE         DIS         RAD         TAX     PTRATIO           B  \\\n",
+       "count  506.000000  506.000000  506.000000  506.000000  506.000000  506.000000   \n",
+       "mean    68.574901    3.795043    9.549407  408.237154   18.455534  356.674032   \n",
+       "std     28.148861    2.105710    8.707259  168.537116    2.164946   91.294864   \n",
+       "min      2.900000    1.129600    1.000000  187.000000   12.600000    0.320000   \n",
+       "25%     45.025000    2.100175    4.000000  279.000000   17.400000  375.377500   \n",
+       "50%     77.500000    3.207450    5.000000  330.000000   19.050000  391.440000   \n",
+       "75%     94.075000    5.188425   24.000000  666.000000   20.200000  396.225000   \n",
+       "max    100.000000   12.126500   24.000000  711.000000   22.000000  396.900000   \n",
+       "\n",
+       "            LSTAT  \n",
+       "count  506.000000  \n",
+       "mean    12.653063  \n",
+       "std      7.141062  \n",
+       "min      1.730000  \n",
+       "25%      6.950000  \n",
+       "50%     11.360000  \n",
+       "75%     16.955000  \n",
+       "max     37.970000  "
+      ]
+     },
+     "execution_count": 7,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# prints out bunch os statistics\n",
+    "df_x.describe()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 65,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Linear models\n",
+    "#reg=linear_model.LinearRegression()\n",
+    "#reg=linear_model.Ridge(alpha=0.9)\n",
+    "reg=linear_model.Lasso(alpha=0.1)\n",
+    "#reg=linear_model.Lars()\n",
+    "\n",
+    "# non linear model\n",
+    "#reg=tree.DecisionTreeRegressor(max_depth=10)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 66,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# 0.2 refers to 20% of all input data will be used to test the model and the remaining 80% to train the model\n",
+    "# x_train y_train 80% of data that are used for training\n",
+    "# x_test y_test 20% of data that are used for testing your model\n",
+    "\n",
+    "x_train,x_test,y_train,y_test = train_test_split(df_x,df_y,test_size=0.2, random_state=4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 67,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Lasso(alpha=0.1, copy_X=True, fit_intercept=True, max_iter=1000,\n",
+       "   normalize=False, positive=False, precompute=False, random_state=None,\n",
+       "   selection='cyclic', tol=0.0001, warm_start=False)"
+      ]
+     },
+     "execution_count": 67,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# this is where your model is being trained\n",
+    "\n",
+    "reg.fit(x_train,y_train)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 68,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# for linear models you can enable the following comment and check out all the coefficients of yout linear model\n",
+    "#reg.coef_"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 69,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# the 20% of the test data that are used for testing purpose\n",
+    "a=reg.predict(x_test)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 70,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "array([[11.35769564],\n",
+       "       [26.63065774],\n",
+       "       [17.07212795],\n",
+       "       [14.88066872],\n",
+       "       [36.41257162],\n",
+       "       [24.9585628 ],\n",
+       "       [31.94678858],\n",
+       "       [18.71968836],\n",
+       "       [18.03333259],\n",
+       "       [24.31205853],\n",
+       "       [29.37517697],\n",
+       "       [28.21096667],\n",
+       "       [19.31608322],\n",
+       "       [29.77081668],\n",
+       "       [22.02956911],\n",
+       "       [15.80101535],\n",
+       "       [21.40010518],\n",
+       "       [11.55888929],\n",
+       "       [10.03696639],\n",
+       "       [14.21676695],\n",
+       "       [ 5.93013576],\n",
+       "       [20.67875375],\n",
+       "       [20.28901268],\n",
+       "       [22.045776  ],\n",
+       "       [16.91359062],\n",
+       "       [20.01181348],\n",
+       "       [14.60702376],\n",
+       "       [14.47106462],\n",
+       "       [19.94873173],\n",
+       "       [16.80168678],\n",
+       "       [14.47686108],\n",
+       "       [23.94606474],\n",
+       "       [35.12159987],\n",
+       "       [22.18768349],\n",
+       "       [17.36984591],\n",
+       "       [19.82812892],\n",
+       "       [30.64690326],\n",
+       "       [35.83418403],\n",
+       "       [24.01776652],\n",
+       "       [24.25497155],\n",
+       "       [36.65259504],\n",
+       "       [31.76859732],\n",
+       "       [19.93445419],\n",
+       "       [31.94878121],\n",
+       "       [30.55626307],\n",
+       "       [24.85315173],\n",
+       "       [40.25718892],\n",
+       "       [17.35967841],\n",
+       "       [20.58594129],\n",
+       "       [23.65915748],\n",
+       "       [33.33055041],\n",
+       "       [25.46122166],\n",
+       "       [18.25223929],\n",
+       "       [27.45084254],\n",
+       "       [13.61083007],\n",
+       "       [22.98211928],\n",
+       "       [24.36098849],\n",
+       "       [33.24773708],\n",
+       "       [17.77844029],\n",
+       "       [34.20142858],\n",
+       "       [16.18855141],\n",
+       "       [20.46046193],\n",
+       "       [31.34454514],\n",
+       "       [14.83719596],\n",
+       "       [39.59888611],\n",
+       "       [28.30333095],\n",
+       "       [29.56328051],\n",
+       "       [ 9.50186015],\n",
+       "       [18.44151744],\n",
+       "       [21.6004166 ],\n",
+       "       [23.21248274],\n",
+       "       [22.98510649],\n",
+       "       [23.36802977],\n",
+       "       [27.76775011],\n",
+       "       [16.2541867 ],\n",
+       "       [23.87249049],\n",
+       "       [16.730009  ],\n",
+       "       [25.39057665],\n",
+       "       [14.1120706 ],\n",
+       "       [19.35830505],\n",
+       "       [22.16813444],\n",
+       "       [19.3767848 ],\n",
+       "       [28.29171375],\n",
+       "       [20.01027133],\n",
+       "       [30.10784765],\n",
+       "       [23.22987591],\n",
+       "       [30.21376239],\n",
+       "       [19.89171106],\n",
+       "       [21.10474733],\n",
+       "       [37.4997143 ],\n",
+       "       [31.49817385],\n",
+       "       [41.24683197],\n",
+       "       [18.88741907],\n",
+       "       [37.34892447],\n",
+       "       [20.22051301],\n",
+       "       [23.61804532],\n",
+       "       [23.95396337],\n",
+       "       [22.14526984],\n",
+       "       [12.45469347],\n",
+       "       [21.69669994],\n",
+       "       [ 9.7580847 ],\n",
+       "       [25.09790195]])"
+      ]
+     },
+     "execution_count": 70,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# ensuring that the output has the correct dimensions\n",
+    "a.shape = (a.size, 1)\n",
+    "a"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 71,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "0    26.452889\n",
+       "dtype: float64"
+      ]
+     },
+     "execution_count": 71,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# calculate the mean square error (lower the error better yourmodel is)\n",
+    "np.mean((a-y_test)**2)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
-- 
GitLab