06. Training models in sklearn

Training Models in scikit learn

In this section, we'll still be working with the dataset of the previous sections.

In the last section, we learned one of the most important classification algorithms in Machine Learning, including the following:

  • Logistic Regression
  • Neural Networks
  • Decision Trees
  • Support Vector Machines

Now, we'll have the chance to use them in real data! In sklearn, this is very easy, all we do is define our classifier, and then use the following line to fit the classifier to the data (which we call X, y):

classifier.fit(X,y)

Here are the main classifiers we define, together with the package we must import:

Logistic Regression

from sklearn.linear_model import LogisticRegression
classifier = LogisticRegression()

Neural Networks

(note: This is only available on versions 0.18 or higher of scikit-learn)

from sklearn.neural_network import MLPClassifier
classifier = MLPClassifier()

Decision Trees

from sklearn.tree import DecisionTreeClassifier
classifier = DecisionTreeClassifier()

Support Vector Machines

from sklearn.svm import SVC
classifier = SVC()

Example: Logistic Regression

Let's do an end-to-end example on how to read data and train our classifier. Let's say we carry our X and y from the previous section. Then, the following commands will train the Logistic Regression classifier:

from sklearn.linear_model import LogisticRegression
classifier = LogisticRegression()
classifier.fit(X,y)

This gives us the following boundary:

Quiz: Train your own model

Now, it's your turn to shine! In the quiz below, we'll work with the following dataset:

Your goal is to use one of the classifiers above, between Logistic Regression, Decision Trees, or Support Vector Machines (sorry, Neural Networks are still not available in this version of sklearn, but we will be upgrading soon!), to see which one will fit the data better. Click on Test Run to see the graphical output of your classifier, and in the quiz underneath this, enter the classifier that you think fit the data better!

Start Quiz:

import pandas
import numpy

# Read the data
data = pandas.read_csv('data.csv')

# Split the data into X and y
X = numpy.array(data[['x1', 'x2']])
y = numpy.array(data['y'])

# import statements for the classification algorithms
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC

# TODO: Pick an algorithm from the list:
# - Logistic Regression
# - Decision Trees
# - Support Vector Machines
# Define a classifier (bonus: Specify some parameters!)
# and use it to fit the data, make sure you name the variable as "classifier"
# Click on `Test Run` to see how your algorithm fit the data!
x1,x2,y
0.336493583877,-0.985950993354,0.0
-0.0110425297266,-0.10552856162,1.0
0.238159509297,-0.61741666482,1.0
-0.366782883496,-0.713818716912,1.0
1.22192307438,-1.03939898614,0.0
-1.30456799971,0.59261847015,0.0
-0.407809098981,-0.509110509763,1.0
0.893188941965,1.18285985648,0.0
-0.00546337259365,-0.589551228864,1.0
0.406423768278,0.611062234636,1.0
-0.145506766722,0.0365463997206,1.0
-0.0404887876421,-0.0566500319512,1.0
1.60355997627,0.0908139379574,0.0
-0.604838450284,-0.111340204903,1.0
-0.534401237223,-1.04875779188,0.0
0.977706756346,-1.35281793296,0.0
-0.422036924523,-0.274418973593,1.0
1.69051344717,-0.929766839195,0.0
0.655534595433,-0.244533046405,1.0
0.384609916121,-0.334328465856,1.0
-0.109341027267,0.273694976361,1.0
-1.28710021847,-0.406756443289,0.0
0.435217566287,-0.192221316649,1.0
0.0555208008113,1.024011876,0.0
1.5088217057,-0.799489053235,0.0
0.75932306599,0.775189603256,0.0
0.967078497167,-0.707726241999,0.0
-0.0231301769156,1.34060202328,0.0
-0.274591142835,-0.549682228079,1.0
-1.2080749077,-1.41385342554,0.0
0.381259079564,-0.852947496234,1.0
0.404870623291,-0.38564643089,1.0
0.0173135930664,0.787433467901,1.0
-0.650474497449,0.377281547969,1.0
-0.175095703948,0.557524657143,1.0
0.090747012995,0.146764389396,1.0
-0.23406335446,-1.14282728744,0.0
-0.023240502157,0.0329251073349,1.0
-0.98177853269,-0.614024199162,0.0
0.863118366276,0.626452589641,0.0
-0.494201528321,-1.2458627184,0.0
0.560657440533,0.960463847964,0.0
0.517532460272,-1.015620433,0.0
-1.07674778462,1.64110648889,0.0
-0.40295146753,1.74395283754,0.0
1.26250128528,-0.0880456579187,0.0
-1.13554604657,0.691274079866,0.0
-1.88154070755,0.579520022541,0.0
1.61949373896,-1.16815366758,0.0
-0.167382068846,0.318140979545,1.0
-0.731515970032,-0.626052631824,1.0
0.14962052078,1.24000574432,0.0
1.16720084422,0.521580749715,0.0
-0.436063303539,0.043680311306,1.0
-0.827638902506,0.275166403707,1.0
1.36953107467,0.971233523422,0.0
0.690612759144,-1.27804624607,0.0
1.26986688391,0.575808793135,0.0
0.208866020688,-0.146742455013,1.0
-0.437203222578,0.52116507147,1.0
-0.378363762158,-0.0769780148552,1.0
-0.423820115256,-0.836137209863,1.0
-0.560756181289,-0.41037775405,1.0
0.336052960763,-0.224802048045,1.0
-1.33543072512,-0.990358481473,0.0
-0.0289733996866,0.441010128386,1.0
-1.3193906415,-0.37764709941,0.0
-0.808411080806,1.2283790386,0.0
1.35995943884,1.12161870845,0.0
-0.872069364163,-0.252522725967,1.0
-1.88887517471,0.144098536459,0.0
1.60845822722,-0.774759253864,0.0
-0.358639909549,0.784305199745,1.0
0.520332593218,-0.62185400704,1.0
0.306204273961,0.25448089669,1.0
-1.51072939376,0.00594704976351,0.0
0.956067338203,-0.533023015577,0.0
0.288866739458,-0.725155662248,1.0
0.403468553933,-1.75945770781,0.0
0.0859415686163,-0.958846823471,1.0
0.381957047469,0.0124143718471,1.0
0.336004016976,-0.259620737798,1.0
1.02869639688,-0.785051442286,0.0
-0.181058441906,0.00266871780379,1.0
0.279139768315,0.148068778283,1.0
-0.700587484192,0.118422440942,1.0
-0.474343136475,-0.162548759675,1.0
-1.29581526521,0.755926314388,0.0
0.140673267698,-1.60264376179,0.0
0.328196143279,0.444738575921,1.0
-0.940761503292,-1.00437673463,0.0
0.4177654822,1.11423358886,0.0
-0.802874871784,-1.27790346857,0.0
-0.596842011584,0.593623894204,1.0
-0.112331263254,0.174318514314,1.0
-1.45753325136,-1.30679050369,0.0
1.63561447039,0.27394296313,0.0
0.113120402388,0.0204651461722,1.0
0.753405102224,0.1938301221,1.0
-0.538129041247,-0.000723035827331,1.0
# import statements for the classification algorithms
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC

# Logistic Regression Classifier
classifier = LogisticRegression()
classifier.fit(X,y)

# Decision Tree Classifier
classifier = DecisionTreeClassifier()
classifier.fit(X,y)

# Support Vector Machine Classifier
classifier = SVC()
classifier.fit(X,y)


Training Models Quiz

Which of the previous algorithms managed to fit the data well?

SOLUTION:
  • Decision Tree
  • Support Vector Machine