{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "xRlGzpOI8eO0" }, "source": [ "\n", "[![AnalyticsDojo](https://github.com/rpi-techfundamentals/spring2019-materials/blob/master/fig/final-logo.png?raw=1)](http://rpi.analyticsdojo.com)\n", "

Titanic PCA

\n", "

introml.analyticsdojo.com

\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "-XovA71E3XFM" }, "source": [ "# Titanic PCA" ] }, { "cell_type": "markdown", "metadata": { "id": "7pW1UhJT8ePk" }, "source": [ "As an example of how to work with both categorical and numerical data, we will perform survival predicition for the passengers of the HMS Titanic.\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bvj3Wids8ePm", "outputId": "3c075657-ff1a-424c-9757-3eb6ef1c2b18" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',\n", " 'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'],\n", " dtype='object') Index(['PassengerId', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp', 'Parch',\n", " 'Ticket', 'Fare', 'Cabin', 'Embarked'],\n", " dtype='object')\n" ] } ], "source": [ "import os\n", "import pandas as pd\n", "train = pd.read_csv('https://raw.githubusercontent.com/rpi-techfundamentals/spring2019-materials/master/input/train.csv')\n", "test = pd.read_csv('https://raw.githubusercontent.com/rpi-techfundamentals/spring2019-materials/master/input/test.csv')\n", "\n", "print(train.columns, test.columns)" ] }, { "cell_type": "markdown", "metadata": { "id": "0xqjk2-P8ePp" }, "source": [ "Here is a broad description of the keys and what they mean:\n", "\n", "```\n", "pclass Passenger Class\n", " (1 = 1st; 2 = 2nd; 3 = 3rd)\n", "survival Survival\n", " (0 = No; 1 = Yes)\n", "name Name\n", "sex Sex\n", "age Age\n", "sibsp Number of Siblings/Spouses Aboard\n", "parch Number of Parents/Children Aboard\n", "ticket Ticket Number\n", "fare Passenger Fare\n", "cabin Cabin\n", "embarked Port of Embarkation\n", " (C = Cherbourg; Q = Queenstown; S = Southampton)\n", "boat Lifeboat\n", "body Body Identification Number\n", "home.dest Home/Destination\n", "```\n", "\n", "In general, it looks like `name`, `sex`, `cabin`, `embarked`, `boat`, `body`, and `homedest` may be candidates for categorical features, while the rest appear to be numerical features. We can also look at the first couple of rows in the dataset to get a better understanding:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 548 }, "id": "bqmMR9G78ePr", "outputId": "b8b2ab48-ac65-48a9-f7ff-4575b1d63f8d" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS
\n", "
" ], "text/plain": [ " PassengerId Survived Pclass \\\n", "0 1 0 3 \n", "1 2 1 1 \n", "2 3 1 3 \n", "3 4 1 1 \n", "4 5 0 3 \n", "\n", " Name Sex Age SibSp \\\n", "0 Braund, Mr. Owen Harris male 22.0 1 \n", "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", "2 Heikkinen, Miss. Laina female 26.0 0 \n", "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", "4 Allen, Mr. William Henry male 35.0 0 \n", "\n", " Parch Ticket Fare Cabin Embarked \n", "0 0 A/5 21171 7.2500 NaN S \n", "1 0 PC 17599 71.2833 C85 C \n", "2 0 STON/O2. 3101282 7.9250 NaN S \n", "3 0 113803 53.1000 C123 S \n", "4 0 373450 8.0500 NaN S " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "54WY6zD78ePv" }, "source": [ "### Preprocessing function\n", "\n", "We want to create a preprocessing function that can address transformation of our train and test set. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FKX26KU34Ti6", "outputId": "a9bd50ee-68c8-4ffb-f077-77c700aecbed" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total missing values before processing: 179\n", "Total missing values after processing: 0\n", "Total missing values before processing: 87\n", "Total missing values after processing: 0\n" ] } ], "source": [ "from sklearn.impute import SimpleImputer\n", "import numpy as np\n", "\n", "cat_features = ['Pclass', 'Sex', 'Embarked']\n", "num_features = [ 'Age', 'SibSp', 'Parch', 'Fare' ]\n", "\n", "\n", "def preprocess(df, num_features, cat_features, dv):\n", " features = cat_features + num_features\n", " if dv in df.columns:\n", " y = df[dv]\n", " else:\n", " y=None \n", " #Address missing variables\n", " print(\"Total missing values before processing:\", df[features].isna().sum().sum() )\n", " \n", " imp_mode = SimpleImputer(missing_values=np.nan, strategy='most_frequent')\n", " df[cat_features]=imp_mode.fit_transform(df[cat_features] )\n", " imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')\n", " df[num_features]=imp_mean.fit_transform(df[num_features])\n", " print(\"Total missing values after processing:\", df[features].isna().sum().sum() )\n", " \n", " X = pd.get_dummies(df[features], columns=cat_features, drop_first=True)\n", " return y,X\n", "\n", "y, X = preprocess(train, num_features, cat_features, 'Survived')\n", "test_y, test_X = preprocess(test, num_features, cat_features, 'Survived')" ] }, { "cell_type": "markdown", "metadata": { "id": "yIEMMxHGwEXG" }, "source": [ "# PCA Analysis\n", "\n", "See [Documentation](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html). \n", "\n", "You can incorporate PCA based on number of components or the variance explained. ![image.png]()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "pRaYU2YCvyNw" }, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "pca = PCA(n_components=5)\n", "pca.fit(X)\n", "X2=pca.transform(X)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Q_NMT2q9v5e2", "outputId": "ddfa29b0-f907-4fc8-ca84-a30731610608" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2.47107661e+03 1.67651481e+02 1.25165106e+00 4.73653673e-01\n", " 3.18808533e-01]\n" ] } ], "source": [ "#This indicates the amount of variance explained by each of the principal components.\n", "print(pca.explained_variance_)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "6P5EwsE5wROv" }, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "pca2 = PCA(n_components=0.97)\n", "pca2.fit(X)\n", "X3=pca2.transform(X)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "o7jTEHLBzIzw", "outputId": "b29c6774-34c8-4a44-f03f-21824a1d319a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2471.07660618 167.65148116]\n" ] } ], "source": [ "print(pca2.explained_variance_)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YUBZDEqxv8s1", "outputId": "7d115fee-2c83-4b9e-cac0-fd3e8f31e6d5" }, "outputs": [ { "data": { "text/plain": [ "array([[1.00000000e+00, 1.90521771e-16],\n", " [1.90521771e-16, 1.00000000e+00]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cov_data = np.corrcoef(X3.T)\n", "cov_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Elbow Plot and Kaisers Rule Cutoff\n", "\n", "[Here](https://docs.displayr.com/wiki/Kaiser_Rule) is a link to documentation of Kaisers Rule. \n" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data passed Bartlett’s test for sphericity.\n", "Performing PCA using rotation: quartimax factors: 4 and standardization: False\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from factor_analyzer import FactorAnalyzer\n", "from factor_analyzer.factor_analyzer import calculate_bartlett_sphericity\n", "from sklearn.decomposition import PCA\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "def scree_plot(eigvals):\n", " fig = plt.figure(figsize=(8,5))\n", " sing_vals = np.arange(len(eigvals)) + 1\n", " plt.plot(sing_vals, eigvals, 'ro-', linewidth=2)\n", " #####horizontal line\n", " horiz_line_data = np.array([1 for i in range(len(sing_vals))])\n", " plt.plot(sing_vals, horiz_line_data, 'r--')\n", " plt.title('Scree Plot for PCA')\n", " plt.xlabel('Principal Component')\n", " plt.ylabel('Eigenvalue')\n", " #I don't like the default legend so I typically make mine like below, e.g.\n", " #with smaller fonts and a bit transparent so I do not cover up data, and make\n", " #it moveable by the viewer in case upper-right is a bad place for it\n", " leg = plt.legend(['Eigenvalues from PCA', 'Kaisers Rule Cutoff'], loc='best', borderpad=0.3,\n", " shadow=False, prop=matplotlib.font_manager.FontProperties(size='small'),\n", " markerscale=0.4)\n", " leg.get_frame().set_alpha(0.4)\n", "\n", " #plt.savefig(os.path.join(save_dir / (name +'.jpg')))\n", " return plt\n", "\n", "def pca_workflow(X, factors=-1, standardize=False, rotation='quartimax'):\n", " \"\"\"\n", " This will perform factor analysis, calculating the number of factors.\n", " Printing scree plots, etc.\n", " \"\"\"\n", "\n", " chi_square_value,p_value=calculate_bartlett_sphericity(X)\n", "\n", " if round(p_value,2)<=0.05:\n", " print(\"Data passed Bartlett’s test for sphericity.\")\n", " else:\n", " print(\"Data failed Bartlett’s test for sphericity, use PCA with caution.\")\n", " \n", " #This is used to calculate\n", " if factors ==-1:\n", " fa = FactorAnalyzer(n_factors=X.shape[1], rotation=None, method='ml')\n", " fa.fit_transform(X)\n", " # Check Eigenvalues\n", " ev, v = fa.get_eigenvalues()\n", " #set the number of factors as where Eigenvalue > 1.0\n", " factors = np.sum(ev>1.0)\n", " print (\"Performing PCA using rotation:\", rotation, \" factors: \", factors, \"and standardization: \", standardize)\n", " loading_cols=['F'+str(x+1) for x in range(factors)]\n", " plot=scree_plot(ev)\n", "\n", " if standardize:\n", " X = StandardScaler().fit_transform(X)\n", "\n", " fa = FactorAnalyzer(n_factors=factors, method='principal', rotation=rotation)\n", " fa.fit(X)\n", "\n", " #Change it back to a dataframe.\n", " results=pd.DataFrame(fa.transform(X),columns=loading_cols)\n", " \n", " return results\n", "\n", "X4= pca_workflow(X)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeSibSpParchFarePclass_2Pclass_3Sex_maleEmbarked_QEmbarked_SF1F2F3F4
Age1.000000-0.232625-0.1791910.0915660.006589-0.2810040.084153-0.013855-0.019336-5.866885e-013.811172e-02-3.979659e-024.697929e-01
SibSp-0.2326251.0000000.4148380.159651-0.0559320.092548-0.114631-0.0263540.0687347.266238e-011.159037e-011.523544e-019.665757e-02
Parch-0.1791910.4148381.0000000.216225-0.0007340.015790-0.245489-0.0812280.0608147.530177e-019.012244e-021.703344e-022.089974e-01
Fare0.0915660.1596510.2162251.000000-0.118557-0.413333-0.182333-0.117216-0.1621841.894779e-01-3.740721e-021.871599e-028.756367e-01
Pclass_20.006589-0.055932-0.000734-0.1185571.000000-0.565210-0.064746-0.1273010.1899804.552872e-041.211537e-01-9.175741e-01-1.588046e-01
Pclass_3-0.2810040.0925480.015790-0.413333-0.5652101.0000000.1371430.237449-0.0151041.356613e-01-8.931807e-027.454824e-01-5.398062e-01
Sex_male0.084153-0.114631-0.245489-0.182333-0.0647460.1371431.000000-0.0741150.119224-4.686491e-013.593321e-013.241968e-01-1.889437e-01
Embarked_Q-0.013855-0.026354-0.081228-0.117216-0.1273010.237449-0.0741151.000000-0.499421-3.289038e-02-8.285253e-011.156588e-01-2.081933e-01
Embarked_S-0.0193360.0687340.060814-0.1621840.189980-0.0151040.119224-0.4994211.0000006.820785e-028.340531e-01-1.170429e-01-2.047724e-01
F1-0.5866880.7266240.7530180.1894780.0004550.135661-0.468649-0.0328900.0682081.000000e+00-8.423241e-16-1.197446e-164.285137e-16
F20.0381120.1159040.090122-0.0374070.121154-0.0893180.359332-0.8285250.834053-8.423241e-161.000000e+00-1.106983e-154.455845e-16
F3-0.0397970.1523540.0170330.018716-0.9175740.7454820.3241970.115659-0.117043-1.197446e-16-1.106983e-151.000000e+00-8.211414e-16
F40.4697930.0966580.2089970.875637-0.158805-0.539806-0.188944-0.208193-0.2047724.285137e-164.455845e-16-8.211414e-161.000000e+00
\n", "
" ], "text/plain": [ " Age SibSp Parch Fare Pclass_2 Pclass_3 \\\n", "Age 1.000000 -0.232625 -0.179191 0.091566 0.006589 -0.281004 \n", "SibSp -0.232625 1.000000 0.414838 0.159651 -0.055932 0.092548 \n", "Parch -0.179191 0.414838 1.000000 0.216225 -0.000734 0.015790 \n", "Fare 0.091566 0.159651 0.216225 1.000000 -0.118557 -0.413333 \n", "Pclass_2 0.006589 -0.055932 -0.000734 -0.118557 1.000000 -0.565210 \n", "Pclass_3 -0.281004 0.092548 0.015790 -0.413333 -0.565210 1.000000 \n", "Sex_male 0.084153 -0.114631 -0.245489 -0.182333 -0.064746 0.137143 \n", "Embarked_Q -0.013855 -0.026354 -0.081228 -0.117216 -0.127301 0.237449 \n", "Embarked_S -0.019336 0.068734 0.060814 -0.162184 0.189980 -0.015104 \n", "F1 -0.586688 0.726624 0.753018 0.189478 0.000455 0.135661 \n", "F2 0.038112 0.115904 0.090122 -0.037407 0.121154 -0.089318 \n", "F3 -0.039797 0.152354 0.017033 0.018716 -0.917574 0.745482 \n", "F4 0.469793 0.096658 0.208997 0.875637 -0.158805 -0.539806 \n", "\n", " Sex_male Embarked_Q Embarked_S F1 F2 \\\n", "Age 0.084153 -0.013855 -0.019336 -5.866885e-01 3.811172e-02 \n", "SibSp -0.114631 -0.026354 0.068734 7.266238e-01 1.159037e-01 \n", "Parch -0.245489 -0.081228 0.060814 7.530177e-01 9.012244e-02 \n", "Fare -0.182333 -0.117216 -0.162184 1.894779e-01 -3.740721e-02 \n", "Pclass_2 -0.064746 -0.127301 0.189980 4.552872e-04 1.211537e-01 \n", "Pclass_3 0.137143 0.237449 -0.015104 1.356613e-01 -8.931807e-02 \n", "Sex_male 1.000000 -0.074115 0.119224 -4.686491e-01 3.593321e-01 \n", "Embarked_Q -0.074115 1.000000 -0.499421 -3.289038e-02 -8.285253e-01 \n", "Embarked_S 0.119224 -0.499421 1.000000 6.820785e-02 8.340531e-01 \n", "F1 -0.468649 -0.032890 0.068208 1.000000e+00 -8.423241e-16 \n", "F2 0.359332 -0.828525 0.834053 -8.423241e-16 1.000000e+00 \n", "F3 0.324197 0.115659 -0.117043 -1.197446e-16 -1.106983e-15 \n", "F4 -0.188944 -0.208193 -0.204772 4.285137e-16 4.455845e-16 \n", "\n", " F3 F4 \n", "Age -3.979659e-02 4.697929e-01 \n", "SibSp 1.523544e-01 9.665757e-02 \n", "Parch 1.703344e-02 2.089974e-01 \n", "Fare 1.871599e-02 8.756367e-01 \n", "Pclass_2 -9.175741e-01 -1.588046e-01 \n", "Pclass_3 7.454824e-01 -5.398062e-01 \n", "Sex_male 3.241968e-01 -1.889437e-01 \n", "Embarked_Q 1.156588e-01 -2.081933e-01 \n", "Embarked_S -1.170429e-01 -2.047724e-01 \n", "F1 -1.197446e-16 4.285137e-16 \n", "F2 -1.106983e-15 4.455845e-16 \n", "F3 1.000000e+00 -8.211414e-16 \n", "F4 -8.211414e-16 1.000000e+00 " ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X4ALL = pd.concat([X,X4], axis =1) \n", "\n", "import seaborn as sb\n", "corr = X4ALL.corr()\n", "sb.heatmap(corr, cmap=\"Reds\")\n", "corr" ] }, { "cell_type": "markdown", "metadata": { "id": "bV5s-bSMJPne" }, "source": [ "### Train Test Split\n", "\n", "Now we are ready to model. We are going to separate our Kaggle given data into a \"Train\" and a \"Validation\" set. \n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "icKFkwZQpvCs" }, "outputs": [], "source": [ "#Import Module\n", "from sklearn.model_selection import train_test_split\n", "train_X, val_X, train_y, val_y = train_test_split(X, y, train_size=0.7, test_size=0.3, random_state=122,stratify=y)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "jGoUxc7brPIg" }, "outputs": [], "source": [ "from sklearn.neural_network import MLPClassifier\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.svm import SVC\n", "from sklearn.gaussian_process import GaussianProcessClassifier\n", "from sklearn.gaussian_process.kernels import RBF\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier\n", "from sklearn.naive_bayes import GaussianNB\n", "from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis\n", "from sklearn import metrics" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6kHwslmYrcRw", "outputId": "77bc3c4f-c5ec-4240-9faf-1905c8d0d129" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Metrics score train: 0.7929373996789727\n", "Metrics score validation: 0.8134328358208955\n" ] } ], "source": [ "from sklearn import tree\n", "classifier = tree.DecisionTreeClassifier(max_depth=3)\n", "#This fits the model object to the data.\n", "classifier.fit(train_X[['Age','Sex_male']], train_y)\n", "#This creates the prediction. \n", "train_y_pred = classifier.predict(train_X[['Age','Sex_male']])\n", "val_y_pred = classifier.predict(val_X[['Age','Sex_male']])\n", "test['Survived'] = classifier.predict(test_X[['Age','Sex_male']])\n", "print(\"Metrics score train: \", metrics.accuracy_score(train_y, train_y_pred) )\n", "print(\"Metrics score validation: \", metrics.accuracy_score(val_y, val_y_pred) )" ] } ], "metadata": { "colab": { "collapsed_sections": [], "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 1 }