diff --git a/ctmodel-ml.yml b/ctmodel-ml.yml new file mode 100644 index 0000000000000000000000000000000000000000..5e2878948d10e5ea09cfcc688aec509def0402f9 --- /dev/null +++ b/ctmodel-ml.yml @@ -0,0 +1,170 @@ +name: ctmodel-ml +channels: + - conda-forge + - anaconda + - defaults +dependencies: + - _tflow_1100_select=0.0.1=gpu + - absl-py=0.4.1=py36_0 + - astor=0.7.1=py36_0 + - bleach=1.5.0=py36_0 + - cudatoolkit=9.0=h13b8566_0 + - cudnn=7.1.2=cuda9.0_0 + - cupti=9.0.176=0 + - gast=0.2.0=py36_0 + - grpcio=1.12.1=py36hdbcaa40_0 + - hdf5=1.10.2=hba1933b_1 + - html5lib=0.9999999=py36_0 + - libprotobuf=3.6.0=hdbcaa40_0 + - markdown=2.6.11=py36_0 + - pyyaml=3.13=py36h14c3975_0 + - termcolor=1.1.0=py36_1 + - werkzeug=0.14.1=py36_0 + - yaml=0.1.7=h96e3832_1 + - nibabel=2.3.0=pyh24bf2e0_1 + - pydicom=1.1.0=py_0 + - appdirs=1.4.3=py36h28b3542_0 + - asn1crypto=0.24.0=py36_0 + - attrs=18.2.0=py36h28b3542_0 + - automat=0.7.0=py36_0 + - backcall=0.1.0=py36_0 + - blas=1.0=mkl + - ca-certificates=2018.03.07=0 + - certifi=2018.10.15=py36_0 + - cffi=1.11.5=py36he75722e_1 + - constantly=15.1.0=py36h28b3542_0 + - cryptography=2.3.1=py36hc365091_0 + - cycler=0.10.0=py36_0 + - cython=0.28.5=py36hf484d3e_0 + - dbus=1.13.2=h714fa37_1 + - decorator=4.3.0=py36_0 + - entrypoints=0.2.3=py36_2 + - expat=2.2.6=he6710b0_0 + - fontconfig=2.13.0=h9420a91_0 + - freetype=2.9.1=h8a8886c_1 + - future=0.16.0=py36_0 + - glib=2.56.2=hd408876_0 + - gmp=6.1.2=h6c8ec71_1 + - gst-plugins-base=1.14.0=hbbd80ab_1 + - gstreamer=1.14.0=hb453b48_1 + - h5py=2.8.0=py36h989c5e5_3 + - hyperlink=18.0.0=py36_0 + - icu=58.2=h9c2bf20_1 + - idna=2.7=py36_0 + - incremental=17.5.0=py36_0 + - intel-openmp=2019.0=118 + - ipykernel=4.9.0=py36_0 + - ipython=6.5.0=py36_0 + - ipython_genutils=0.2.0=py36_0 + - jedi=0.12.1=py36_0 + - jinja2=2.10=py36_0 + - jpeg=9b=h024ee3a_2 + - jsonschema=2.6.0=py36_0 + - jupyter_client=5.2.3=py36_0 + - jupyter_core=4.4.0=py36_0 + - keras-applications=1.0.4=py36_1 + - keras-base=2.2.2=py36_0 + - keras-preprocessing=1.0.2=py36_1 + - kiwisolver=1.0.1=py36hf484d3e_0 + - libedit=3.1.20170329=h6b74fdf_2 + - libffi=3.2.1=hd88cf55_4 + - libgcc-ng=8.2.0=hdf63c60_1 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libpng=1.6.34=hb9fc6fc_0 + - libsodium=1.0.16=h1bed415_0 + - libstdcxx-ng=8.2.0=hdf63c60_1 + - libuuid=1.0.3=h1bed415_2 + - libxcb=1.13=h1bed415_1 + - libxml2=2.9.8=h26e45fe_1 + - markupsafe=1.0=py36h14c3975_1 + - matplotlib=2.2.3=py36hb69df0a_0 + - mistune=0.8.3=py36h14c3975_1 + - mkl=2019.0=118 + - mkl_fft=1.0.4=py36h4414c95_1 + - mkl_random=1.0.1=py36h4414c95_1 + - nbconvert=5.3.1=py36_0 + - nbformat=4.4.0=py36_0 + - ncurses=6.1=hf484d3e_0 + - notebook=5.6.0=py36_0 + - numpy=1.14.5=py36h1b885b7_4 + - numpy-base=1.14.5=py36hdbf6ddf_4 + - openssl=1.0.2p=h14c3975_0 + - pandas=0.23.4=py36h04863e7_0 + - pandoc=2.2.3.2=0 + - pandocfilters=1.4.2=py36_1 + - parso=0.3.1=py36_0 + - patsy=0.5.0=py36_0 + - pcre=8.42=h439df22_0 + - pexpect=4.6.0=py36_0 + - pickleshare=0.7.4=py36_0 + - pip=10.0.1=py36_0 + - prometheus_client=0.3.1=py36h28b3542_0 + - prompt_toolkit=1.0.15=py36_0 + - protobuf=3.6.0=py36hf484d3e_0 + - ptyprocess=0.6.0=py36_0 + - pyasn1=0.4.4=py36h28b3542_0 + - pyasn1-modules=0.2.2=py36_0 + - pycparser=2.18=py36_1 + - pydot=1.2.4=py36_0 + - pygments=2.2.0=py36_0 + - pyopenssl=18.0.0=py36_0 + - pyparsing=2.2.0=py36_1 + - pyqt=5.9.2=py36h05f1152_2 + - python=3.6.6=hc3d631a_0 + - python-dateutil=2.7.3=py36_0 + - pytz=2018.5=py36_0 + - pyzmq=17.1.2=py36h14c3975_0 + - qt=5.9.6=h8703b6f_2 + - readline=7.0=h7b6447c_5 + - scikit-learn=0.20.0=py36h4989274_1 + - scipy=1.1.0=py36hd20e5f9_0 + - seaborn=0.9.0=py36_0 + - send2trash=1.5.0=py36_0 + - service_identity=17.0.0=py36h28b3542_0 + - setuptools=39.1.0=py36_0 + - simplegeneric=0.8.1=py36_2 + - sip=4.19.8=py36hf484d3e_0 + - six=1.11.0=py36_1 + - sqlite=3.24.0=h84994c4_0 + - statsmodels=0.9.0=py36h035aef0_0 + - tensorboard=1.10.0=py36hf484d3e_0 + - tensorflow=1.10.0=gpu_py36h97a2126_0 + - tensorflow-base=1.10.0=gpu_py36h6ecc378_0 + - tensorflow-gpu=1.10.0=hf154084_0 + - terminado=0.8.1=py36_1 + - testpath=0.3.1=py36_0 + - tk=8.6.8=hbc83047_0 + - tornado=5.1=py36h14c3975_0 + - traitlets=4.3.2=py36_0 + - twisted=18.7.0=py36h14c3975_1 + - wcwidth=0.1.7=py36_0 + - webencodings=0.5.1=py36_1 + - wheel=0.31.1=py36_0 + - xz=5.2.4=h14c3975_4 + - zeromq=4.2.5=hf484d3e_1 + - zlib=1.2.11=ha838bed_2 + - zope=1.0=py36_1 + - zope.interface=4.5.0=py36h14c3975_0 + - pip: + - blinker==1.4 + - chardet==3.0.4 + - cloudpickle==0.5.6 + - configparser==3.5.0 + - dask==0.19.1 + - kaggle==1.4.7.1 + - keras==2.2.2 + - keras-vis==0.4.1 + - networkx==2.2rc1 + - packaging==17.1 + - pillow==5.2.0 + - pyhamcrest==1.9.0 + - python-slugify==1.2.6 + - pywavelets==1.0.0 + - requests==2.19.1 + - scikit-image==0.14.0 + - simpleitk==1.1.0 + - toolz==0.9.0 + - tqdm==4.26.0 + - unidecode==1.0.22 + - urllib3==1.22 + - xgboost==0.90 diff --git a/model-py/regression-MKD.ipynb b/model-py/regression-MKD.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..80dcc187d4ade159609c77e97cbb517682cffdeb --- /dev/null +++ b/model-py/regression-MKD.ipynb @@ -0,0 +1,824 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'xgboost'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-17-a78a46995753>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmpl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mxgboost\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mxgb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreprocessing\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPolynomialFeatures\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mensemble\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'xgboost'" + ] + } + ], + "source": [ + "# memory management functions\n", + "import os\n", + "import psutil\n", + "\n", + "# math and models\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "import xgboost as xgb\n", + "from sklearn.preprocessing import PolynomialFeatures\n", + "from sklearn import ensemble\n", + "from sklearn.metrics import r2_score\n", + "from sklearn import linear_model\n", + "from sklearn import gaussian_process as gpm\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn import tree" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Load dataset\n", + "df = pd.read_csv(\"/data/user/mdefende/Projects/CorticalThicknessModel/A/Scripts/Data/BigThick_Restricted_cleaned.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Subject Gender Age Hemi Group Ecc Pol normEcc \\\n", + "0 100206 M 27 LH Train 3.567456 177.411148 0.071349 \n", + "1 100206 M 27 LH Train 3.537558 173.980743 0.070751 \n", + "2 100206 M 27 LH Train 3.522834 177.403961 0.070457 \n", + "3 100206 M 27 LH Train 3.497096 177.711700 0.069942 \n", + "4 100206 M 27 LH Train 1.356428 156.274979 0.027129 \n", + "5 100206 M 27 LH Train 1.323316 157.835205 0.026466 \n", + "6 100206 M 27 LH Train 1.208326 166.347137 0.024167 \n", + "7 100206 M 27 LH Train 1.176058 167.972031 0.023521 \n", + "8 100206 M 27 LH Train 1.407720 148.838181 0.028154 \n", + "9 100206 M 27 LH Train 1.341403 152.340546 0.026828 \n", + "10 100206 M 27 LH Train 1.269300 157.293900 0.025386 \n", + "11 100206 M 27 LH Train 1.198244 159.868393 0.023965 \n", + "12 100206 M 27 LH Train 1.162055 163.119431 0.023241 \n", + "13 100206 M 27 LH Train 1.461745 141.727310 0.029235 \n", + "14 100206 M 27 LH Train 1.432764 143.279037 0.028655 \n", + "15 100206 M 27 LH Train 1.354926 145.016632 0.027099 \n", + "16 100206 M 27 LH Train 1.278286 147.711746 0.025566 \n", + "17 100206 M 27 LH Train 1.213237 153.133057 0.024265 \n", + "18 100206 M 27 LH Train 1.156069 157.608017 0.023121 \n", + "19 100206 M 27 LH Train 1.452105 139.754242 0.029042 \n", + "20 100206 M 27 LH Train 1.389827 139.581390 0.027797 \n", + "21 100206 M 27 LH Train 1.312138 139.130508 0.026243 \n", + "22 100206 M 27 LH Train 1.235464 141.854599 0.024709 \n", + "23 100206 M 27 LH Train 1.288646 134.436020 0.025773 \n", + "24 100206 M 27 LH Train 1.256301 137.291229 0.025126 \n", + "25 100206 M 27 LH Train 1.265432 132.407181 0.025309 \n", + "26 100206 M 27 LH Train 1.244605 134.291290 0.024892 \n", + "27 100206 M 27 LH Train 4.478133 167.913254 0.089563 \n", + "28 100206 M 27 LH Train 4.474024 169.971039 0.089481 \n", + "29 100206 M 27 LH Train 4.458443 174.424789 0.089169 \n", + "... ... ... ... ... ... ... ... ... \n", + "6787055 996782 F 28 RH Train 34.812057 2.016624 0.696242 \n", + "6787056 996782 F 28 RH Train 37.058197 0.325366 0.741165 \n", + "6787057 996782 F 28 RH Train 36.378635 4.269973 0.727574 \n", + "6787058 996782 F 28 RH Train 39.452621 4.759562 0.789053 \n", + "6787059 996782 F 28 RH Train 38.722595 3.068579 0.774453 \n", + "6787060 996782 F 28 RH Train 44.115536 12.932929 0.882312 \n", + "6787061 996782 F 28 RH Train 42.664555 10.658121 0.853292 \n", + "6787062 996782 F 28 RH Train 41.358227 8.524337 0.827166 \n", + "6787063 996782 F 28 RH Train 40.848404 6.113272 0.816969 \n", + "6787064 996782 F 28 RH Train 48.257523 15.541472 0.965152 \n", + "6787065 996782 F 28 RH Train 45.951336 14.551472 0.919028 \n", + "6787066 996782 F 28 RH Train 41.813080 2.612780 0.836263 \n", + "6787067 996782 F 28 RH Train 40.911617 4.453729 0.818233 \n", + "6787068 996782 F 28 RH Train 46.174717 10.218830 0.923495 \n", + "6787069 996782 F 28 RH Train 45.041702 6.822008 0.900835 \n", + "6787070 996782 F 28 RH Train 43.202980 5.293739 0.864061 \n", + "6787071 996782 F 28 RH Train 42.252956 4.160832 0.845060 \n", + "6787072 996782 F 28 RH Train 48.620617 9.229843 0.972414 \n", + "6787073 996782 F 28 RH Train 46.881737 7.386467 0.937636 \n", + "6787074 996782 F 28 RH Train 46.225166 2.961761 0.924504 \n", + "6787075 996782 F 28 RH Train 45.898434 0.491990 0.917970 \n", + "6787076 996782 F 28 RH Train 44.115589 2.866768 0.882313 \n", + "6787077 996782 F 28 RH Train 49.843601 4.330243 0.996873 \n", + "6787078 996782 F 28 RH Train 47.795212 3.122164 0.955905 \n", + "6787079 996782 F 28 RH Train 47.197891 0.911866 0.943959 \n", + "6787080 996782 F 28 RH Train 38.595798 3.042604 0.771917 \n", + "6787081 996782 F 28 RH Train 16.830524 8.680782 0.336611 \n", + "6787082 996782 F 28 RH Train 16.522949 11.613343 0.330459 \n", + "6787083 996782 F 28 RH Train 20.728981 164.827118 0.414580 \n", + "6787084 996782 F 28 RH Train 20.537703 161.797836 0.410755 \n", + "\n", + " normPol Sulc0 ... PialArea5 PialArea10 LGI0 \\\n", + "0 0.485617 -0.969500 ... 0.740433 0.694201 2.111248 \n", + "1 0.466560 -1.018326 ... 0.750461 0.702102 2.128481 \n", + "2 0.485578 -0.995660 ... 0.737149 0.687681 2.125945 \n", + "3 0.487287 -1.027502 ... 0.731389 0.685778 2.124727 \n", + "4 0.368194 -0.753537 ... 0.721735 0.697463 2.133806 \n", + "5 0.376862 -0.753183 ... 0.742576 0.704914 2.135973 \n", + "6 0.424151 -0.791031 ... 0.830400 0.740700 2.152152 \n", + "7 0.433178 -0.818936 ... 0.848792 0.751308 2.154360 \n", + "8 0.326879 -0.780384 ... 0.690919 0.696721 2.129827 \n", + "9 0.346336 -0.789114 ... 0.730944 0.708385 2.134366 \n", + "10 0.373855 -0.779502 ... 0.780105 0.724135 2.146819 \n", + "11 0.388158 -0.816658 ... 0.821948 0.742594 2.150895 \n", + "12 0.406219 -0.842135 ... 0.845164 0.752521 2.153529 \n", + "13 0.287374 -0.792315 ... 0.665475 0.697910 2.119450 \n", + "14 0.295995 -0.779278 ... 0.678812 0.700818 2.128336 \n", + "15 0.305648 -0.814661 ... 0.717915 0.714605 2.132625 \n", + "16 0.320621 -0.822019 ... 0.768804 0.733194 2.145043 \n", + "17 0.350739 -0.816008 ... 0.809952 0.746315 2.149433 \n", + "18 0.375600 -0.837289 ... 0.839957 0.758518 2.154962 \n", + "19 0.276412 -0.802535 ... 0.673126 0.702462 2.120098 \n", + "20 0.275452 -0.798065 ... 0.701170 0.714524 2.130581 \n", + "21 0.272947 -0.827945 ... 0.745896 0.733508 2.136259 \n", + "22 0.288081 -0.807783 ... 0.787246 0.750777 2.148183 \n", + "23 0.246867 -0.808789 ... 0.762774 0.745232 2.135591 \n", + "24 0.262729 -0.802400 ... 0.777009 0.750630 2.145497 \n", + "25 0.235595 -0.769775 ... 0.773159 0.752456 2.135851 \n", + "26 0.246063 -0.761871 ... 0.781852 0.755997 2.145127 \n", + "27 0.432851 -1.085865 ... 0.932740 0.876360 2.153401 \n", + "28 0.444284 -1.017086 ... 0.913269 0.858155 2.154156 \n", + "29 0.469027 -0.917080 ... 0.867826 0.820888 2.152521 \n", + "... ... ... ... ... ... ... \n", + "6787055 -0.488797 -0.660271 ... 1.141246 1.158239 3.289434 \n", + "6787056 -0.498192 -0.331669 ... 1.154787 1.123169 3.285268 \n", + "6787057 -0.476278 -0.484029 ... 1.129800 1.148906 3.287321 \n", + "6787058 -0.473558 -0.127753 ... 1.199641 1.064739 3.281233 \n", + "6787059 -0.482952 -0.216425 ... 1.187146 1.099584 3.279390 \n", + "6787060 -0.428150 0.329903 ... 0.963070 0.903809 3.274556 \n", + "6787061 -0.440788 0.185611 ... 1.071004 0.955178 3.276377 \n", + "6787062 -0.452643 0.052616 ... 1.155293 1.000820 3.278221 \n", + "6787063 -0.466037 -0.039510 ... 1.201619 1.035581 3.278041 \n", + "6787064 -0.413658 0.539896 ... 0.760521 0.798241 3.267780 \n", + "6787065 -0.419158 0.452117 ... 0.856887 0.852150 3.271790 \n", + "6787066 -0.485485 -0.058604 ... 1.223373 1.060463 3.271061 \n", + "6787067 -0.475257 -0.149212 ... 1.204664 1.094330 3.270529 \n", + "6787068 -0.443229 0.360516 ... 0.955320 0.896229 3.267207 \n", + "6787069 -0.462100 0.245724 ... 1.067826 0.946347 3.265846 \n", + "6787070 -0.470590 0.106576 ... 1.160608 0.996125 3.269724 \n", + "6787071 -0.476884 0.013787 ... 1.207522 1.028556 3.272161 \n", + "6787072 -0.448723 0.433305 ... 0.887894 0.862750 3.260500 \n", + "6787073 -0.458964 0.337906 ... 0.995089 0.912124 3.262475 \n", + "6787074 -0.483546 0.227673 ... 1.118196 0.965010 3.259609 \n", + "6787075 -0.497267 0.153000 ... 1.166645 0.994987 3.259659 \n", + "6787076 -0.484074 0.031406 ... 1.217565 1.043899 3.261460 \n", + "6787077 -0.475943 0.400194 ... 0.956013 0.894331 3.251189 \n", + "6787078 -0.482655 0.306412 ... 1.056971 0.938883 3.255619 \n", + "6787079 -0.494934 0.229746 ... 1.137643 0.974020 3.257046 \n", + "6787080 -0.483097 0.658627 ... 0.420545 0.492509 3.210731 \n", + "6787081 -0.451773 -0.431437 ... 0.999800 1.061583 3.218198 \n", + "6787082 -0.435481 -0.535194 ... 1.011416 1.075808 3.224614 \n", + "6787083 0.415706 -1.041680 ... 1.024205 1.104097 3.504762 \n", + "6787084 0.398877 -0.901922 ... 1.004598 1.093961 3.491673 \n", + "\n", + " LGI2 LGI5 LGI10 Thick0 Thick2 Thick5 Thick10 \n", + "0 2.121056 2.126482 2.140920 2.766883 2.791073 2.667062 2.302786 \n", + "1 2.123977 2.126430 2.139703 2.754594 2.762720 2.609997 2.266771 \n", + "2 2.120204 2.123042 2.136645 2.563326 2.709144 2.606383 2.279676 \n", + "3 2.122354 2.121592 2.134473 2.659624 2.671067 2.558893 2.259900 \n", + "4 2.134276 2.135392 2.149016 2.353268 2.326346 2.211919 2.088902 \n", + "5 2.137616 2.138694 2.152027 2.451419 2.417072 2.275516 2.118037 \n", + "6 2.152114 2.153917 2.167282 2.743103 2.707918 2.467123 2.229415 \n", + "7 2.155263 2.158118 2.171818 2.707542 2.705797 2.487149 2.254589 \n", + "8 2.129190 2.130467 2.144590 2.246106 2.236198 2.093843 2.024529 \n", + "9 2.136051 2.136491 2.149751 2.388690 2.379972 2.229577 2.091792 \n", + "10 2.144282 2.144277 2.157050 2.621783 2.559522 2.366980 2.163589 \n", + "11 2.150865 2.151188 2.163861 2.714028 2.666185 2.455846 2.219241 \n", + "12 2.154374 2.156217 2.169374 2.716976 2.686966 2.489331 2.250992 \n", + "13 2.121997 2.126662 2.141163 2.196349 2.171223 1.991173 1.955873 \n", + "14 2.125494 2.128615 2.142797 2.239169 2.217598 2.042992 1.985392 \n", + "15 2.133461 2.134429 2.147610 2.380691 2.340416 2.174114 2.053984 \n", + "16 2.142832 2.142251 2.154533 2.551684 2.460895 2.311058 2.134976 \n", + "17 2.149742 2.149195 2.161355 2.636964 2.582885 2.418631 2.202731 \n", + "18 2.154959 2.155726 2.168393 2.679090 2.629709 2.476009 2.250769 \n", + "19 2.122582 2.127754 2.141892 2.208363 2.188120 2.000082 1.956584 \n", + "20 2.128525 2.131949 2.145266 2.298500 2.292803 2.100548 2.006651 \n", + "21 2.136660 2.138601 2.150841 2.398277 2.338111 2.209595 2.074679 \n", + "22 2.146215 2.145615 2.157313 2.387249 2.382713 2.320101 2.152505 \n", + "23 2.137532 2.141211 2.152861 2.224107 2.249597 2.198439 2.077878 \n", + "24 2.142820 2.143733 2.155350 2.257848 2.285086 2.262644 2.119502 \n", + "25 2.138554 2.143358 2.154782 2.174568 2.192579 2.195138 2.085485 \n", + "26 2.141907 2.144859 2.156304 2.168352 2.216717 2.240871 2.114422 \n", + "27 2.155405 2.166630 2.193880 2.406428 2.444317 2.371739 2.277808 \n", + "28 2.155127 2.166830 2.193293 2.537748 2.476763 2.369525 2.277276 \n", + "29 2.155168 2.167406 2.192119 2.596547 2.486246 2.366198 2.276222 \n", + "... ... ... ... ... ... ... ... \n", + "6787055 3.290782 3.279967 3.266013 1.949855 2.043284 1.987879 2.085439 \n", + "6787056 3.283882 3.274675 3.259445 1.949113 1.768108 1.809772 1.988021 \n", + "6787057 3.284873 3.273376 3.258670 1.996609 1.997058 1.924324 2.049818 \n", + "6787058 3.280645 3.275146 3.259971 1.269301 1.359185 1.584634 1.872575 \n", + "6787059 3.279444 3.271110 3.255348 1.547230 1.551830 1.705185 1.928807 \n", + "6787060 3.274055 3.272218 3.263479 1.147584 1.196215 1.370170 1.750175 \n", + "6787061 3.275824 3.273360 3.262072 1.047854 1.186943 1.385340 1.766917 \n", + "6787062 3.277665 3.274319 3.261054 1.111057 1.208556 1.431513 1.795516 \n", + "6787063 3.277556 3.272818 3.258115 1.216371 1.253253 1.494926 1.826654 \n", + "6787064 3.267160 3.265626 3.261921 1.358387 1.344649 1.387528 1.755428 \n", + "6787065 3.271195 3.269597 3.263417 1.181586 1.252182 1.378257 1.747487 \n", + "6787066 3.269981 3.262042 3.246612 1.228040 1.353159 1.573190 1.853610 \n", + "6787067 3.269943 3.260183 3.244612 1.397627 1.504455 1.694168 1.910865 \n", + "6787068 3.266852 3.263960 3.253390 0.983319 1.096354 1.321554 1.729233 \n", + "6787069 3.265382 3.260795 3.247450 0.967814 1.095351 1.341753 1.743206 \n", + "6787070 3.268543 3.262680 3.248484 1.237979 1.202344 1.418325 1.780673 \n", + "6787071 3.271250 3.264894 3.249985 1.200267 1.254715 1.478789 1.811468 \n", + "6787072 3.260072 3.257145 3.246468 1.077476 1.111898 1.291877 1.723117 \n", + "6787073 3.262141 3.258257 3.245415 0.943625 1.041473 1.299829 1.724172 \n", + "6787074 3.259224 3.252843 3.237209 0.984138 1.082332 1.336643 1.742954 \n", + "6787075 3.258257 3.249845 3.234604 1.194583 1.209887 1.413663 1.775008 \n", + "6787076 3.259477 3.249144 3.234652 1.509456 1.398174 1.551992 1.834241 \n", + "6787077 3.251195 3.246758 3.231455 0.938096 1.013600 1.261594 1.722768 \n", + "6787078 3.255130 3.249057 3.233668 1.010567 1.053410 1.313520 1.736925 \n", + "6787079 3.256105 3.248528 3.232362 1.054192 1.108953 1.352336 1.750564 \n", + "6787080 3.210914 3.211640 3.200303 2.044960 2.000274 1.990393 2.203855 \n", + "6787081 3.228139 3.252401 3.268084 2.604342 2.548727 2.245316 2.072108 \n", + "6787082 3.237871 3.259982 3.272573 2.712809 2.573531 2.248212 2.068789 \n", + "6787083 3.498766 3.497312 3.485821 2.278427 2.137778 1.888041 1.689229 \n", + "6787084 3.493212 3.492619 3.482487 2.073916 2.011232 1.824685 1.675286 \n", + "\n", + "[6787085 rows x 38 columns]\n" + ] + } + ], + "source": [ + "print(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['Subject', 'Gender', 'Age', 'Hemi', 'Group', 'Ecc', 'Pol',\n", + " 'normEcc', 'normPol', 'Sulc0', 'Curv0', 'Curv2', 'Curv5', 'Curv10',\n", + " 'PialCurv0', 'PialCurv2', 'PialCurv5', 'PialCurv10', 'Area0',\n", + " 'Area2', 'Area5', 'Area10', 'MidArea', 'MidArea2', 'MidArea5',\n", + " 'MidArea10', 'PialArea0', 'PialArea2', 'PialArea5', 'PialArea10',\n", + " 'LGI0', 'LGI2', 'LGI5', 'LGI10', 'Thick0', 'Thick2', 'Thick5',\n", + " 'Thick10'], dtype=object)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.columns.values" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Code dummy variables for subject, hemisphere, and gender\n", + "g_d = pd.get_dummies(df['Gender'])\n", + "h_d = pd.get_dummies(df['Hemi'])\n", + "s_d = pd.get_dummies(df['Subject'])\n", + "\n", + "df = pd.concat([df,g_d,h_d,s_d], axis = 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Split into test and train dataframes\n", + "df_train = df[df.Group == \"Train\"]\n", + "df_test = df[df.Group == \"Test\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# df_x are your input features\n", + "x_train = pd.concat([df_train[['normEcc','normPol','Sulc0','Curv5','Area5','LGI5']]]) #, df_train.loc[:,'F':'RH']], axis = 1)\n", + "x_test = pd.concat([df_test[['normEcc','normPol','Sulc0','Curv5','Area5','LGI5']]]) #, df_test.loc[:,'F':'RH']], axis = 1)\n", + "\n", + "# df_y is your output feature (the one you want to predict)\n", + "y_train = df_train[['Thick5']]\n", + "y_test = df_test[['Thick5']]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n" + ] + } + ], + "source": [ + "%reset_selective df" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " normEcc normPol Sulc0 Curv5 Area5 LGI5\n", + "0 0.071349 0.485617 -0.969500 -0.175257 0.365139 2.126482\n", + "1 0.070751 0.466560 -1.018326 -0.194592 0.355679 2.126430\n", + "2 0.070457 0.485578 -0.995660 -0.206786 0.339150 2.123042\n", + "3 0.069942 0.487287 -1.027502 -0.226170 0.325459 2.121592\n", + "4 0.027129 0.368194 -0.753537 -0.167363 0.467997 2.135392\n", + "5 0.026466 0.376862 -0.753183 -0.164006 0.468342 2.138694\n", + "6 0.024167 0.424151 -0.791031 -0.162873 0.469467 2.153917\n", + "7 0.023521 0.433178 -0.818936 -0.164371 0.468871 2.158118\n", + "8 0.028154 0.326879 -0.780384 -0.201563 0.442048 2.130467\n", + "9 0.026828 0.346336 -0.789114 -0.183105 0.449047 2.136491\n", + "10 0.025386 0.373855 -0.779502 -0.169546 0.457724 2.144277\n", + "11 0.023965 0.388158 -0.816658 -0.168231 0.456454 2.151188\n", + "12 0.023241 0.406219 -0.842135 -0.166211 0.460725 2.156217\n", + "13 0.029235 0.287374 -0.792315 -0.241704 0.411581 2.126662\n", + "14 0.028655 0.295995 -0.779278 -0.229542 0.417538 2.128615\n", + "15 0.027099 0.305648 -0.814661 -0.211252 0.422046 2.134429\n", + "16 0.025566 0.320621 -0.822019 -0.193055 0.432451 2.142251\n", + "17 0.024265 0.350739 -0.816008 -0.177756 0.444004 2.149195\n", + "18 0.023121 0.375600 -0.837289 -0.170196 0.451620 2.155726\n", + "19 0.029042 0.276412 -0.802535 -0.244678 0.407580 2.127754\n", + "20 0.027797 0.275452 -0.798065 -0.233359 0.405415 2.131949\n", + "21 0.026243 0.272947 -0.827945 -0.218111 0.412554 2.138601\n", + "22 0.024709 0.288081 -0.807783 -0.199763 0.425284 2.145615\n", + "23 0.025773 0.246867 -0.808789 -0.219269 0.413580 2.141211\n", + "24 0.025126 0.262729 -0.802400 -0.210798 0.417964 2.143733\n", + "25 0.025309 0.235595 -0.769775 -0.215277 0.418453 2.143358\n", + "26 0.024892 0.246063 -0.761871 -0.211038 0.419737 2.144859\n", + "27 0.089563 0.432851 -1.085865 -0.227432 0.521610 2.166630\n", + "28 0.089481 0.444284 -1.017086 -0.214535 0.533691 2.166830\n", + "29 0.089169 0.469027 -0.917080 -0.183402 0.558986 2.167406\n", + "... ... ... ... ... ... ...\n", + "6787055 0.696242 -0.488797 -0.660271 -0.252420 0.719446 3.279967\n", + "6787056 0.741165 -0.498192 -0.331669 -0.221859 0.742614 3.274675\n", + "6787057 0.727574 -0.476278 -0.484029 -0.255619 0.715220 3.273376\n", + "6787058 0.789053 -0.473558 -0.127753 -0.135195 0.808798 3.275146\n", + "6787059 0.774453 -0.482952 -0.216425 -0.193565 0.769054 3.271110\n", + "6787060 0.882312 -0.428150 0.329903 0.037845 0.842381 3.272218\n", + "6787061 0.853292 -0.440788 0.185611 0.000564 0.853427 3.273360\n", + "6787062 0.827166 -0.452643 0.052616 -0.045616 0.850372 3.274319\n", + "6787063 0.816969 -0.466037 -0.039510 -0.092075 0.834119 3.272818\n", + "6787064 0.965152 -0.413658 0.539896 0.071233 0.782826 3.265626\n", + "6787065 0.919028 -0.419158 0.452117 0.059935 0.817088 3.269597\n", + "6787066 0.836263 -0.485485 -0.058604 -0.140059 0.804893 3.262042\n", + "6787067 0.818233 -0.475257 -0.149212 -0.197167 0.768328 3.260183\n", + "6787068 0.923495 -0.443229 0.360516 0.039484 0.821977 3.263960\n", + "6787069 0.900835 -0.462100 0.245724 0.001199 0.829692 3.260795\n", + "6787070 0.864061 -0.470590 0.106576 -0.049715 0.834898 3.262680\n", + "6787071 0.845060 -0.476884 0.013787 -0.088670 0.830544 3.264894\n", + "6787072 0.972414 -0.448723 0.433305 0.049380 0.791541 3.257145\n", + "6787073 0.937636 -0.458964 0.337906 0.026758 0.814743 3.258257\n", + "6787074 0.924504 -0.483546 0.227673 -0.018039 0.819589 3.252843\n", + "6787075 0.917970 -0.497267 0.153000 -0.057800 0.814234 3.249845\n", + "6787076 0.882313 -0.484074 0.031406 -0.127112 0.798611 3.249144\n", + "6787077 0.996873 -0.475943 0.400194 0.028088 0.785327 3.246758\n", + "6787078 0.955905 -0.482655 0.306412 -0.001943 0.805595 3.249057\n", + "6787079 0.943959 -0.494934 0.229746 -0.031343 0.813216 3.248528\n", + "6787080 0.771917 -0.483097 0.658627 0.117806 0.740952 3.211640\n", + "6787081 0.336611 -0.451773 -0.431437 -0.245727 0.728921 3.252401\n", + "6787082 0.330459 -0.435481 -0.535194 -0.267580 0.693892 3.259982\n", + "6787083 0.414580 0.415706 -1.041680 -0.293193 0.648885 3.497312\n", + "6787084 0.410755 0.398877 -0.901922 -0.296768 0.653152 3.492619\n", + "\n", + "[5124388 rows x 6 columns]\n" + ] + } + ], + "source": [ + "print(x_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>normEcc</th>\n", + " <th>normPol</th>\n", + " <th>Sulc0</th>\n", + " <th>Curv5</th>\n", + " <th>Area5</th>\n", + " <th>LGI5</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>count</th>\n", + " <td>5.124388e+06</td>\n", + " <td>5.124388e+06</td>\n", + " <td>5.124388e+06</td>\n", + " <td>5.124388e+06</td>\n", + " <td>5.124388e+06</td>\n", + " <td>5.124388e+06</td>\n", + " </tr>\n", + " <tr>\n", + " <th>mean</th>\n", + " <td>2.386589e-01</td>\n", + " <td>4.113428e-02</td>\n", + " <td>-1.747371e-01</td>\n", + " <td>-5.394791e-02</td>\n", + " <td>6.745831e-01</td>\n", + " <td>2.809199e+00</td>\n", + " </tr>\n", + " <tr>\n", + " <th>std</th>\n", + " <td>2.288182e-01</td>\n", + " <td>2.760123e-01</td>\n", + " <td>5.721855e-01</td>\n", + " <td>1.514886e-01</td>\n", + " <td>9.922468e-02</td>\n", + " <td>4.028898e-01</td>\n", + " </tr>\n", + " <tr>\n", + " <th>min</th>\n", + " <td>2.424506e-05</td>\n", + " <td>-4.999273e-01</td>\n", + " <td>-1.971219e+00</td>\n", + " <td>-2.368729e+00</td>\n", + " <td>2.437784e-01</td>\n", + " <td>1.959253e+00</td>\n", + " </tr>\n", + " <tr>\n", + " <th>25%</th>\n", + " <td>7.679665e-02</td>\n", + " <td>-1.714563e-01</td>\n", + " <td>-6.353022e-01</td>\n", + " <td>-1.814308e-01</td>\n", + " <td>6.198345e-01</td>\n", + " <td>2.474502e+00</td>\n", + " </tr>\n", + " <tr>\n", + " <th>50%</th>\n", + " <td>1.539186e-01</td>\n", + " <td>6.200975e-02</td>\n", + " <td>-2.448606e-01</td>\n", + " <td>-6.423837e-02</td>\n", + " <td>6.823665e-01</td>\n", + " <td>2.741358e+00</td>\n", + " </tr>\n", + " <tr>\n", + " <th>75%</th>\n", + " <td>3.278529e-01</td>\n", + " <td>2.621307e-01</td>\n", + " <td>2.495521e-01</td>\n", + " <td>7.672764e-02</td>\n", + " <td>7.347262e-01</td>\n", + " <td>3.102216e+00</td>\n", + " </tr>\n", + " <tr>\n", + " <th>max</th>\n", + " <td>1.000000e+00</td>\n", + " <td>4.999642e-01</td>\n", + " <td>2.009938e+00</td>\n", + " <td>1.398352e+00</td>\n", + " <td>2.217891e+00</td>\n", + " <td>4.249628e+00</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " normEcc normPol Sulc0 Curv5 Area5 \\\n", + "count 5.124388e+06 5.124388e+06 5.124388e+06 5.124388e+06 5.124388e+06 \n", + "mean 2.386589e-01 4.113428e-02 -1.747371e-01 -5.394791e-02 6.745831e-01 \n", + "std 2.288182e-01 2.760123e-01 5.721855e-01 1.514886e-01 9.922468e-02 \n", + "min 2.424506e-05 -4.999273e-01 -1.971219e+00 -2.368729e+00 2.437784e-01 \n", + "25% 7.679665e-02 -1.714563e-01 -6.353022e-01 -1.814308e-01 6.198345e-01 \n", + "50% 1.539186e-01 6.200975e-02 -2.448606e-01 -6.423837e-02 6.823665e-01 \n", + "75% 3.278529e-01 2.621307e-01 2.495521e-01 7.672764e-02 7.347262e-01 \n", + "max 1.000000e+00 4.999642e-01 2.009938e+00 1.398352e+00 2.217891e+00 \n", + "\n", + " LGI5 \n", + "count 5.124388e+06 \n", + "mean 2.809199e+00 \n", + "std 4.028898e-01 \n", + "min 1.959253e+00 \n", + "25% 2.474502e+00 \n", + "50% 2.741358e+00 \n", + "75% 3.102216e+00 \n", + "max 4.249628e+00 " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# prints out bunch of statistics\n", + "x_train.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Create interaction terms for the predictors\n", + "interaction = PolynomialFeatures(degree=1, include_bias=True, interaction_only=False)\n", + "X_train = interaction.fit_transform(x_train)\n", + "X_test = interaction.fit_transform(x_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(5124388, 7)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_train.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# Linear models\n", + "#reg=linear_model.LinearRegression(fit_intercept=True)\n", + "#reg=linear_model.Ridge(alpha=0.9)\n", + "#reg=linear_model.Lasso(alpha=0.1)\n", + "#reg=linear_model.Lars()\n", + "\n", + "# non linear model. From simulations below, looks like ideal max_depth ~ 16 and min_samples_leaf ~ 60\n", + "#reg=tree.DecisionTreeRegressor(min_samples_leaf=60, max_depth=16)\n", + "#reg=gpm.GaussianProcessRegressor()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "__init__() got an unexpected keyword argument 'tree_method'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-16-afab54435052>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m params = {'n_estimators': 500, 'max_depth': 20, 'min_samples_split': 20,\n\u001b[1;32m 3\u001b[0m 'learning_rate': 0.01, 'loss': 'ls', 'tree_method': 'gpu_hist', 'verbose': 1}\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mreg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mensemble\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGradientBoostingRegressor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'tree_method'" + ] + } + ], + "source": [ + "# Testing out Gradient Boost Decision Tree\n", + "params = {'n_estimators': 500, 'max_depth': 20, 'min_samples_split': 20,\n", + " 'learning_rate': 0.01, 'loss': 'ls', 'tree_method': 'gpu_hist', 'verbose': 1}\n", + "reg = ensemble.GradientBoostingRegressor(**params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# this is where your model is being trained\n", + "\n", + "reg.fit(x_train,y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# for linear models you can enable the following comment and check out all the coefficients of yout linear model\n", + "#reg.coef_" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# the 20% of the test data that are used for testing purpose\n", + "#a=reg.predict(x_test)\n", + "b=reg.predict(x_train)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ensuring that the output has the correct dimensions\n", + "#a.shape = (a.size, 1)\n", + "b.shape = (b.size, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# calculate the mean square error (lower the error better yourmodel is)\n", + "#TestRMSE = np.sqrt(np.mean((a-y_test.values)**2))\n", + "#display('Test RMSE = {}'.format(TestRMSE))\n", + "\n", + "TrainRMSE = np.sqrt(np.mean((b-y_train.values)**2))\n", + "display('Train RMSE = {}'.format(TrainRMSE))\n", + "\n", + "#TestMAE = np.mean(np.abs(a-y_test.values))\n", + "#display('Test MAE = {}'.format(TestMAE))\n", + "\n", + "TrainMAE = np.mean(np.abs(b-y_train.values))\n", + "display('Train MAE = {}'.format(TrainMAE))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# R squared metric\n", + "display(\"Test R squared = {}\".format(r2_score(y_test,a)))\n", + "display(\"Train R squared = {}\".format(r2_score(y_train,b)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# max_depth paramater tuning for the decision tree\n", + "max_depths = np.linspace(1, 32, 32, endpoint=True)\n", + "train_results = []\n", + "test_results = []\n", + "for max_depth in max_depths:\n", + " # Create the tree object\n", + " dt = tree.DecisionTreeRegressor(max_depth=max_depth)\n", + " \n", + " # Train the tree\n", + " dt.fit(x_train, y_train)\n", + " \n", + " # Predict the values for the training set\n", + " train_pred = dt.predict(x_train) \n", + " train_pred.shape = (train_pred.size,1)\n", + " \n", + " # Calculate RMSE and add to the train_results array\n", + " TrainRMSE = np.sqrt(np.mean((train_pred-y_train.values)**2))\n", + " train_results.append(TrainRMSE) \n", + " \n", + " # Predict the values for the test set\n", + " test_pred = dt.predict(x_test) \n", + " test_pred.shape = (test_pred.size,1)\n", + " \n", + " # Calculate RMSE and add to the train_results array\n", + " TestRMSE = np.sqrt(np.mean((test_pred-y_test.values)**2))\n", + " test_results.append(TestRMSE)\n", + "\n", + "from matplotlib.legend_handler import HandlerLine2D\n", + "line1, = plt.plot(max_depths, train_results, 'b', label = \"Train RMSE\")\n", + "line2, = plt.plot(max_depths, test_results, 'r', label = \"Test RMSE\")\n", + "\n", + "plt.legend(handler_map={line1: HandlerLine2D(numpoints=2)})\n", + "plt.ylabel('RMSE')\n", + "plt.xlabel('Tree depth')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# min_samples_leaf paramater tuning for the decision tree\n", + "min_samples = np.linspace(5, 100, 20, endpoint=True)\n", + "train_results = []\n", + "test_results = []\n", + "for n in min_samples:\n", + " # Create the tree object\n", + " dt = tree.DecisionTreeRegressor(min_samples_leaf=n/x_train.size)\n", + " \n", + " # Train the tree\n", + " dt.fit(x_train, y_train)\n", + " \n", + " # Predict the values for the training set\n", + " train_pred = dt.predict(x_train) \n", + " train_pred.shape = (train_pred.size,1)\n", + " \n", + " # Calculate RMSE and add to the train_results array\n", + " TrainRMSE = np.sqrt(np.mean((train_pred-y_train.values)**2))\n", + " train_results.append(TrainRMSE) \n", + " \n", + " # Predict the values for the test set\n", + " test_pred = dt.predict(x_test) \n", + " test_pred.shape = (test_pred.size,1)\n", + " \n", + " # Calculate RMSE and add to the train_results array\n", + " TestRMSE = np.sqrt(np.mean((test_pred-y_test.values)**2))\n", + " test_results.append(TestRMSE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib.legend_handler import HandlerLine2D\n", + "line1, = plt.plot(min_samples, train_results, 'b', label = \"Train RMSE\")\n", + "line2, = plt.plot(min_samples, test_results, 'r', label = \"Test RMSE\")\n", + "\n", + "plt.legend(handler_map={line1: HandlerLine2D(numpoints=2)})\n", + "plt.ylabel('RMSE')\n", + "plt.xlabel('Min Samples per Leaf')\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/model-py/regression.ipynb b/model-py/regression.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f52aaa99e6b6ac45ecbf26d3064a7e68dcf05e72 --- /dev/null +++ b/model-py/regression.ipynb @@ -0,0 +1,535 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from sklearn import linear_model\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn import tree" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.datasets import load_boston\n", + "boston =load_boston()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(506, 13)\n" + ] + } + ], + "source": [ + "print (boston.data.shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[6.3200e-03 1.8000e+01 2.3100e+00 ... 1.5300e+01 3.9690e+02 4.9800e+00]\n", + " [2.7310e-02 0.0000e+00 7.0700e+00 ... 1.7800e+01 3.9690e+02 9.1400e+00]\n", + " [2.7290e-02 0.0000e+00 7.0700e+00 ... 1.7800e+01 3.9283e+02 4.0300e+00]\n", + " ...\n", + " [6.0760e-02 0.0000e+00 1.1930e+01 ... 2.1000e+01 3.9690e+02 5.6400e+00]\n", + " [1.0959e-01 0.0000e+00 1.1930e+01 ... 2.1000e+01 3.9345e+02 6.4800e+00]\n", + " [4.7410e-02 0.0000e+00 1.1930e+01 ... 2.1000e+01 3.9690e+02 7.8800e+00]]\n" + ] + } + ], + "source": [ + "print (boston.data)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# df_x are your input features\n", + "df_x=pd.DataFrame(boston.data,columns=boston.feature_names)\n", + "\n", + "# df_y is your output feature (the one you want to predict)\n", + "df_y=pd.DataFrame(boston.target)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>CRIM</th>\n", + " <th>ZN</th>\n", + " <th>INDUS</th>\n", + " <th>CHAS</th>\n", + " <th>NOX</th>\n", + " <th>RM</th>\n", + " <th>AGE</th>\n", + " <th>DIS</th>\n", + " <th>RAD</th>\n", + " <th>TAX</th>\n", + " <th>PTRATIO</th>\n", + " <th>B</th>\n", + " <th>LSTAT</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>count</th>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " <td>506.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>mean</th>\n", + " <td>3.613524</td>\n", + " <td>11.363636</td>\n", + " <td>11.136779</td>\n", + " <td>0.069170</td>\n", + " <td>0.554695</td>\n", + " <td>6.284634</td>\n", + " <td>68.574901</td>\n", + " <td>3.795043</td>\n", + " <td>9.549407</td>\n", + " <td>408.237154</td>\n", + " <td>18.455534</td>\n", + " <td>356.674032</td>\n", + " <td>12.653063</td>\n", + " </tr>\n", + " <tr>\n", + " <th>std</th>\n", + " <td>8.601545</td>\n", + " <td>23.322453</td>\n", + " <td>6.860353</td>\n", + " <td>0.253994</td>\n", + " <td>0.115878</td>\n", + " <td>0.702617</td>\n", + " <td>28.148861</td>\n", + " <td>2.105710</td>\n", + " <td>8.707259</td>\n", + " <td>168.537116</td>\n", + " <td>2.164946</td>\n", + " <td>91.294864</td>\n", + " <td>7.141062</td>\n", + " </tr>\n", + " <tr>\n", + " <th>min</th>\n", + " <td>0.006320</td>\n", + " <td>0.000000</td>\n", + " <td>0.460000</td>\n", + " <td>0.000000</td>\n", + " <td>0.385000</td>\n", + " <td>3.561000</td>\n", + " <td>2.900000</td>\n", + " <td>1.129600</td>\n", + " <td>1.000000</td>\n", + " <td>187.000000</td>\n", + " <td>12.600000</td>\n", + " <td>0.320000</td>\n", + " <td>1.730000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>25%</th>\n", + " <td>0.082045</td>\n", + " <td>0.000000</td>\n", + " <td>5.190000</td>\n", + " <td>0.000000</td>\n", + " <td>0.449000</td>\n", + " <td>5.885500</td>\n", + " <td>45.025000</td>\n", + " <td>2.100175</td>\n", + " <td>4.000000</td>\n", + " <td>279.000000</td>\n", + " <td>17.400000</td>\n", + " <td>375.377500</td>\n", + " <td>6.950000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>50%</th>\n", + " <td>0.256510</td>\n", + " <td>0.000000</td>\n", + " <td>9.690000</td>\n", + " <td>0.000000</td>\n", + " <td>0.538000</td>\n", + " <td>6.208500</td>\n", + " <td>77.500000</td>\n", + " <td>3.207450</td>\n", + " <td>5.000000</td>\n", + " <td>330.000000</td>\n", + " <td>19.050000</td>\n", + " <td>391.440000</td>\n", + " <td>11.360000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>75%</th>\n", + " <td>3.677083</td>\n", + " <td>12.500000</td>\n", + " <td>18.100000</td>\n", + " <td>0.000000</td>\n", + " <td>0.624000</td>\n", + " <td>6.623500</td>\n", + " <td>94.075000</td>\n", + " <td>5.188425</td>\n", + " <td>24.000000</td>\n", + " <td>666.000000</td>\n", + " <td>20.200000</td>\n", + " <td>396.225000</td>\n", + " <td>16.955000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>max</th>\n", + " <td>88.976200</td>\n", + " <td>100.000000</td>\n", + " <td>27.740000</td>\n", + " <td>1.000000</td>\n", + " <td>0.871000</td>\n", + " <td>8.780000</td>\n", + " <td>100.000000</td>\n", + " <td>12.126500</td>\n", + " <td>24.000000</td>\n", + " <td>711.000000</td>\n", + " <td>22.000000</td>\n", + " <td>396.900000</td>\n", + " <td>37.970000</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " CRIM ZN INDUS CHAS NOX RM \\\n", + "count 506.000000 506.000000 506.000000 506.000000 506.000000 506.000000 \n", + "mean 3.613524 11.363636 11.136779 0.069170 0.554695 6.284634 \n", + "std 8.601545 23.322453 6.860353 0.253994 0.115878 0.702617 \n", + "min 0.006320 0.000000 0.460000 0.000000 0.385000 3.561000 \n", + "25% 0.082045 0.000000 5.190000 0.000000 0.449000 5.885500 \n", + "50% 0.256510 0.000000 9.690000 0.000000 0.538000 6.208500 \n", + "75% 3.677083 12.500000 18.100000 0.000000 0.624000 6.623500 \n", + "max 88.976200 100.000000 27.740000 1.000000 0.871000 8.780000 \n", + "\n", + " AGE DIS RAD TAX PTRATIO B \\\n", + "count 506.000000 506.000000 506.000000 506.000000 506.000000 506.000000 \n", + "mean 68.574901 3.795043 9.549407 408.237154 18.455534 356.674032 \n", + "std 28.148861 2.105710 8.707259 168.537116 2.164946 91.294864 \n", + "min 2.900000 1.129600 1.000000 187.000000 12.600000 0.320000 \n", + "25% 45.025000 2.100175 4.000000 279.000000 17.400000 375.377500 \n", + "50% 77.500000 3.207450 5.000000 330.000000 19.050000 391.440000 \n", + "75% 94.075000 5.188425 24.000000 666.000000 20.200000 396.225000 \n", + "max 100.000000 12.126500 24.000000 711.000000 22.000000 396.900000 \n", + "\n", + " LSTAT \n", + "count 506.000000 \n", + "mean 12.653063 \n", + "std 7.141062 \n", + "min 1.730000 \n", + "25% 6.950000 \n", + "50% 11.360000 \n", + "75% 16.955000 \n", + "max 37.970000 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# prints out bunch os statistics\n", + "df_x.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [], + "source": [ + "# Linear models\n", + "#reg=linear_model.LinearRegression()\n", + "#reg=linear_model.Ridge(alpha=0.9)\n", + "reg=linear_model.Lasso(alpha=0.1)\n", + "#reg=linear_model.Lars()\n", + "\n", + "# non linear model\n", + "#reg=tree.DecisionTreeRegressor(max_depth=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ + "# 0.2 refers to 20% of all input data will be used to test the model and the remaining 80% to train the model\n", + "# x_train y_train 80% of data that are used for training\n", + "# x_test y_test 20% of data that are used for testing your model\n", + "\n", + "x_train,x_test,y_train,y_test = train_test_split(df_x,df_y,test_size=0.2, random_state=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Lasso(alpha=0.1, copy_X=True, fit_intercept=True, max_iter=1000,\n", + " normalize=False, positive=False, precompute=False, random_state=None,\n", + " selection='cyclic', tol=0.0001, warm_start=False)" + ] + }, + "execution_count": 67, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# this is where your model is being trained\n", + "\n", + "reg.fit(x_train,y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [], + "source": [ + "# for linear models you can enable the following comment and check out all the coefficients of yout linear model\n", + "#reg.coef_" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [], + "source": [ + "# the 20% of the test data that are used for testing purpose\n", + "a=reg.predict(x_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[11.35769564],\n", + " [26.63065774],\n", + " [17.07212795],\n", + " [14.88066872],\n", + " [36.41257162],\n", + " [24.9585628 ],\n", + " [31.94678858],\n", + " [18.71968836],\n", + " [18.03333259],\n", + " [24.31205853],\n", + " [29.37517697],\n", + " [28.21096667],\n", + " [19.31608322],\n", + " [29.77081668],\n", + " [22.02956911],\n", + " [15.80101535],\n", + " [21.40010518],\n", + " [11.55888929],\n", + " [10.03696639],\n", + " [14.21676695],\n", + " [ 5.93013576],\n", + " [20.67875375],\n", + " [20.28901268],\n", + " [22.045776 ],\n", + " [16.91359062],\n", + " [20.01181348],\n", + " [14.60702376],\n", + " [14.47106462],\n", + " [19.94873173],\n", + " [16.80168678],\n", + " [14.47686108],\n", + " [23.94606474],\n", + " [35.12159987],\n", + " [22.18768349],\n", + " [17.36984591],\n", + " [19.82812892],\n", + " [30.64690326],\n", + " [35.83418403],\n", + " [24.01776652],\n", + " [24.25497155],\n", + " [36.65259504],\n", + " [31.76859732],\n", + " [19.93445419],\n", + " [31.94878121],\n", + " [30.55626307],\n", + " [24.85315173],\n", + " [40.25718892],\n", + " [17.35967841],\n", + " [20.58594129],\n", + " [23.65915748],\n", + " [33.33055041],\n", + " [25.46122166],\n", + " [18.25223929],\n", + " [27.45084254],\n", + " [13.61083007],\n", + " [22.98211928],\n", + " [24.36098849],\n", + " [33.24773708],\n", + " [17.77844029],\n", + " [34.20142858],\n", + " [16.18855141],\n", + " [20.46046193],\n", + " [31.34454514],\n", + " [14.83719596],\n", + " [39.59888611],\n", + " [28.30333095],\n", + " [29.56328051],\n", + " [ 9.50186015],\n", + " [18.44151744],\n", + " [21.6004166 ],\n", + " [23.21248274],\n", + " [22.98510649],\n", + " [23.36802977],\n", + " [27.76775011],\n", + " [16.2541867 ],\n", + " [23.87249049],\n", + " [16.730009 ],\n", + " [25.39057665],\n", + " [14.1120706 ],\n", + " [19.35830505],\n", + " [22.16813444],\n", + " [19.3767848 ],\n", + " [28.29171375],\n", + " [20.01027133],\n", + " [30.10784765],\n", + " [23.22987591],\n", + " [30.21376239],\n", + " [19.89171106],\n", + " [21.10474733],\n", + " [37.4997143 ],\n", + " [31.49817385],\n", + " [41.24683197],\n", + " [18.88741907],\n", + " [37.34892447],\n", + " [20.22051301],\n", + " [23.61804532],\n", + " [23.95396337],\n", + " [22.14526984],\n", + " [12.45469347],\n", + " [21.69669994],\n", + " [ 9.7580847 ],\n", + " [25.09790195]])" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ensuring that the output has the correct dimensions\n", + "a.shape = (a.size, 1)\n", + "a" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 26.452889\n", + "dtype: float64" + ] + }, + "execution_count": 71, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# calculate the mean square error (lower the error better yourmodel is)\n", + "np.mean((a-y_test)**2)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}