文章目录
机器学习 | Kaggle鸢尾花数据集Iris训练一、准备工作:引入机器学习库二、数据可视化分析三、逻辑回归分类结果3.1 选取花萼特征分类3.2 选取花瓣特征分类
四、数据集切割(训练集+测试集)五、模型训练并用验证集验证5.1 逻辑回归5.2 决策树5.3 K-邻近5.4 支持向量机
六、神经网络七、将花瓣和花萼特征分离+标签 训练7.1 选取花萼特征训练7.1 选取花瓣特征训练
机器学习 | Kaggle鸢尾花数据集Iris训练
Wenxuan Zeng 2020.10.3
一、准备工作:引入机器学习库
from sklearn
.linear_model
import LogisticRegression
from sklearn
.tree
import DecisionTreeClassifier
from sklearn
.neighbors
import KNeighborsClassifier
from sklearn
.preprocessing
import LabelBinarizer
from sklearn
import svm
from sklearn
import model_selection
from sklearn
import metrics
import matplotlib
.pyplot
as plt
import seaborn
as sns
import pandas
as pd
import numpy
as np
import warnings
warnings
.filterwarnings
("ignore")
二、数据可视化分析
该题数据量很小(150组)特征值不多(4类)标签简单(3种),且题目要求在于分类,所以在做数据可视化分析的时候,不需要像考虑Titanic/Crime_prediction那样深度挖掘不同特征之间微妙的联系,我们直接分情况将散点图绘制在同一坐标系下即可观察。
from sklearn
.datasets
import load_iris
iris
= load_iris
()
print (iris
.data
)
[[5.1 3.5 1.4 0.2]
[4.9 3. 1.4 0.2]
[4.7 3.2 1.3 0.2]
[4.6 3.1 1.5 0.2]
[5. 3.6 1.4 0.2]
[5.4 3.9 1.7 0.4]
[4.6 3.4 1.4 0.3]
[5. 3.4 1.5 0.2]
[4.4 2.9 1.4 0.2]
[4.9 3.1 1.5 0.1]
[5.4 3.7 1.5 0.2]
[4.8 3.4 1.6 0.2]
[4.8 3. 1.4 0.1]
[4.3 3. 1.1 0.1]
[5.8 4. 1.2 0.2]
[5.7 4.4 1.5 0.4]
[5.4 3.9 1.3 0.4]
[5.1 3.5 1.4 0.3]
[5.7 3.8 1.7 0.3]
[5.1 3.8 1.5 0.3]
[5.4 3.4 1.7 0.2]
[5.1 3.7 1.5 0.4]
[4.6 3.6 1. 0.2]
[5.1 3.3 1.7 0.5]
[4.8 3.4 1.9 0.2]
[5. 3. 1.6 0.2]
[5. 3.4 1.6 0.4]
[5.2 3.5 1.5 0.2]
[5.2 3.4 1.4 0.2]
[4.7 3.2 1.6 0.2]
[4.8 3.1 1.6 0.2]
[5.4 3.4 1.5 0.4]
[5.2 4.1 1.5 0.1]
[5.5 4.2 1.4 0.2]
[4.9 3.1 1.5 0.2]
[5. 3.2 1.2 0.2]
[5.5 3.5 1.3 0.2]
[4.9 3.6 1.4 0.1]
[4.4 3. 1.3 0.2]
[5.1 3.4 1.5 0.2]
[5. 3.5 1.3 0.3]
[4.5 2.3 1.3 0.3]
[4.4 3.2 1.3 0.2]
[5. 3.5 1.6 0.6]
[5.1 3.8 1.9 0.4]
[4.8 3. 1.4 0.3]
[5.1 3.8 1.6 0.2]
[4.6 3.2 1.4 0.2]
[5.3 3.7 1.5 0.2]
[5. 3.3 1.4 0.2]
[7. 3.2 4.7 1.4]
[6.4 3.2 4.5 1.5]
[6.9 3.1 4.9 1.5]
[5.5 2.3 4. 1.3]
[6.5 2.8 4.6 1.5]
[5.7 2.8 4.5 1.3]
[6.3 3.3 4.7 1.6]
[4.9 2.4 3.3 1. ]
[6.6 2.9 4.6 1.3]
[5.2 2.7 3.9 1.4]
[5. 2. 3.5 1. ]
[5.9 3. 4.2 1.5]
[6. 2.2 4. 1. ]
[6.1 2.9 4.7 1.4]
[5.6 2.9 3.6 1.3]
[6.7 3.1 4.4 1.4]
[5.6 3. 4.5 1.5]
[5.8 2.7 4.1 1. ]
[6.2 2.2 4.5 1.5]
[5.6 2.5 3.9 1.1]
[5.9 3.2 4.8 1.8]
[6.1 2.8 4. 1.3]
[6.3 2.5 4.9 1.5]
[6.1 2.8 4.7 1.2]
[6.4 2.9 4.3 1.3]
[6.6 3. 4.4 1.4]
[6.8 2.8 4.8 1.4]
[6.7 3. 5. 1.7]
[6. 2.9 4.5 1.5]
[5.7 2.6 3.5 1. ]
[5.5 2.4 3.8 1.1]
[5.5 2.4 3.7 1. ]
[5.8 2.7 3.9 1.2]
[6. 2.7 5.1 1.6]
[5.4 3. 4.5 1.5]
[6. 3.4 4.5 1.6]
[6.7 3.1 4.7 1.5]
[6.3 2.3 4.4 1.3]
[5.6 3. 4.1 1.3]
[5.5 2.5 4. 1.3]
[5.5 2.6 4.4 1.2]
[6.1 3. 4.6 1.4]
[5.8 2.6 4. 1.2]
[5. 2.3 3.3 1. ]
[5.6 2.7 4.2 1.3]
[5.7 3. 4.2 1.2]
[5.7 2.9 4.2 1.3]
[6.2 2.9 4.3 1.3]
[5.1 2.5 3. 1.1]
[5.7 2.8 4.1 1.3]
[6.3 3.3 6. 2.5]
[5.8 2.7 5.1 1.9]
[7.1 3. 5.9 2.1]
[6.3 2.9 5.6 1.8]
[6.5 3. 5.8 2.2]
[7.6 3. 6.6 2.1]
[4.9 2.5 4.5 1.7]
[7.3 2.9 6.3 1.8]
[6.7 2.5 5.8 1.8]
[7.2 3.6 6.1 2.5]
[6.5 3.2 5.1 2. ]
[6.4 2.7 5.3 1.9]
[6.8 3. 5.5 2.1]
[5.7 2.5 5. 2. ]
[5.8 2.8 5.1 2.4]
[6.4 3.2 5.3 2.3]
[6.5 3. 5.5 1.8]
[7.7 3.8 6.7 2.2]
[7.7 2.6 6.9 2.3]
[6. 2.2 5. 1.5]
[6.9 3.2 5.7 2.3]
[5.6 2.8 4.9 2. ]
[7.7 2.8 6.7 2. ]
[6.3 2.7 4.9 1.8]
[6.7 3.3 5.7 2.1]
[7.2 3.2 6. 1.8]
[6.2 2.8 4.8 1.8]
[6.1 3. 4.9 1.8]
[6.4 2.8 5.6 2.1]
[7.2 3. 5.8 1.6]
[7.4 2.8 6.1 1.9]
[7.9 3.8 6.4 2. ]
[6.4 2.8 5.6 2.2]
[6.3 2.8 5.1 1.5]
[6.1 2.6 5.6 1.4]
[7.7 3. 6.1 2.3]
[6.3 3.4 5.6 2.4]
[6.4 3.1 5.5 1.8]
[6. 3. 4.8 1.8]
[6.9 3.1 5.4 2.1]
[6.7 3.1 5.6 2.4]
[6.9 3.1 5.1 2.3]
[5.8 2.7 5.1 1.9]
[6.8 3.2 5.9 2.3]
[6.7 3.3 5.7 2.5]
[6.7 3. 5.2 2.3]
[6.3 2.5 5. 1.9]
[6.5 3. 5.2 2. ]
[6.2 3.4 5.4 2.3]
[5.9 3. 5.1 1.8]]
print (iris
.target
)
print (len(iris
.target
))
print (iris
.data
.shape
)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
150
(150, 4)
iris中包含iris.data和iris.target,其中iris.data是150x4的矩阵,存放四特征(花瓣和花萼长宽值),iris.target是150x1的矩阵,存放150朵花的标签(类别:0,1,2)。这三种类别的花分别50朵,分布在数据集的前50,中50,后50.
DD
= iris
.data
X
= [x
[0] for x
in DD
]
print (X
)
Y
= [x
[1] for x
in DD
]
print (Y
)
plt
.scatter
(X
[:50], Y
[:50], color
='red', marker
='o', label
='setosa')
plt
.scatter
(X
[50:100], Y
[50:100], color
='blue', marker
='x', label
='versicolor')
plt
.scatter
(X
[100:], Y
[100:],color
='green', marker
='+', label
='Virginica')
plt
.legend
(loc
=2)
plt
.show
()
[5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.8, 4.8, 4.3, 5.8, 5.7, 5.4, 5.1, 5.7, 5.1, 5.4, 5.1, 4.6, 5.1, 4.8, 5.0, 5.0, 5.2, 5.2, 4.7, 4.8, 5.4, 5.2, 5.5, 4.9, 5.0, 5.5, 4.9, 4.4, 5.1, 5.0, 4.5, 4.4, 5.0, 5.1, 4.8, 5.1, 4.6, 5.3, 5.0, 7.0, 6.4, 6.9, 5.5, 6.5, 5.7, 6.3, 4.9, 6.6, 5.2, 5.0, 5.9, 6.0, 6.1, 5.6, 6.7, 5.6, 5.8, 6.2, 5.6, 5.9, 6.1, 6.3, 6.1, 6.4, 6.6, 6.8, 6.7, 6.0, 5.7, 5.5, 5.5, 5.8, 6.0, 5.4, 6.0, 6.7, 6.3, 5.6, 5.5, 5.5, 6.1, 5.8, 5.0, 5.6, 5.7, 5.7, 6.2, 5.1, 5.7, 6.3, 5.8, 7.1, 6.3, 6.5, 7.6, 4.9, 7.3, 6.7, 7.2, 6.5, 6.4, 6.8, 5.7, 5.8, 6.4, 6.5, 7.7, 7.7, 6.0, 6.9, 5.6, 7.7, 6.3, 6.7, 7.2, 6.2, 6.1, 6.4, 7.2, 7.4, 7.9, 6.4, 6.3, 6.1, 7.7, 6.3, 6.4, 6.0, 6.9, 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9]
[3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.4, 3.0, 3.0, 4.0, 4.4, 3.9, 3.5, 3.8, 3.8, 3.4, 3.7, 3.6, 3.3, 3.4, 3.0, 3.4, 3.5, 3.4, 3.2, 3.1, 3.4, 4.1, 4.2, 3.1, 3.2, 3.5, 3.6, 3.0, 3.4, 3.5, 2.3, 3.2, 3.5, 3.8, 3.0, 3.8, 3.2, 3.7, 3.3, 3.2, 3.2, 3.1, 2.3, 2.8, 2.8, 3.3, 2.4, 2.9, 2.7, 2.0, 3.0, 2.2, 2.9, 2.9, 3.1, 3.0, 2.7, 2.2, 2.5, 3.2, 2.8, 2.5, 2.8, 2.9, 3.0, 2.8, 3.0, 2.9, 2.6, 2.4, 2.4, 2.7, 2.7, 3.0, 3.4, 3.1, 2.3, 3.0, 2.5, 2.6, 3.0, 2.6, 2.3, 2.7, 3.0, 2.9, 2.9, 2.5, 2.8, 3.3, 2.7, 3.0, 2.9, 3.0, 3.0, 2.5, 2.9, 2.5, 3.6, 3.2, 2.7, 3.0, 2.5, 2.8, 3.2, 3.0, 3.8, 2.6, 2.2, 3.2, 2.8, 2.8, 2.7, 3.3, 3.2, 2.8, 3.0, 2.8, 3.0, 2.8, 3.8, 2.8, 2.8, 2.6, 3.0, 3.4, 3.1, 3.0, 3.1, 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0]
该组出现的问题即是:选取花萼宽度和长度特征并不能有效将蓝色和绿色离散点分开,所以相关度较低。
DD
= iris
.data
X
= [x
[2] for x
in DD
]
print (X
)
Y
= [x
[3] for x
in DD
]
print (Y
)
plt
.scatter
(X
[:50], Y
[:50], color
='red', marker
='o', label
='setosa')
plt
.scatter
(X
[50:100], Y
[50:100], color
='blue', marker
='x', label
='versicolor')
plt
.scatter
(X
[100:], Y
[100:],color
='green', marker
='+', label
='Virginica')
plt
.legend
(loc
=2)
plt
.show
()
[1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.6, 1.4, 1.1, 1.2, 1.5, 1.3, 1.4, 1.7, 1.5, 1.7, 1.5, 1.0, 1.7, 1.9, 1.6, 1.6, 1.5, 1.4, 1.6, 1.6, 1.5, 1.5, 1.4, 1.5, 1.2, 1.3, 1.4, 1.3, 1.5, 1.3, 1.3, 1.3, 1.6, 1.9, 1.4, 1.6, 1.4, 1.5, 1.4, 4.7, 4.5, 4.9, 4.0, 4.6, 4.5, 4.7, 3.3, 4.6, 3.9, 3.5, 4.2, 4.0, 4.7, 3.6, 4.4, 4.5, 4.1, 4.5, 3.9, 4.8, 4.0, 4.9, 4.7, 4.3, 4.4, 4.8, 5.0, 4.5, 3.5, 3.8, 3.7, 3.9, 5.1, 4.5, 4.5, 4.7, 4.4, 4.1, 4.0, 4.4, 4.6, 4.0, 3.3, 4.2, 4.2, 4.2, 4.3, 3.0, 4.1, 6.0, 5.1, 5.9, 5.6, 5.8, 6.6, 4.5, 6.3, 5.8, 6.1, 5.1, 5.3, 5.5, 5.0, 5.1, 5.3, 5.5, 6.7, 6.9, 5.0, 5.7, 4.9, 6.7, 4.9, 5.7, 6.0, 4.8, 4.9, 5.6, 5.8, 6.1, 6.4, 5.6, 5.1, 5.6, 6.1, 5.6, 5.5, 4.8, 5.4, 5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1]
[0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.4, 0.4, 0.3, 0.3, 0.3, 0.2, 0.4, 0.2, 0.5, 0.2, 0.2, 0.4, 0.2, 0.2, 0.2, 0.2, 0.4, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.2, 0.2, 0.3, 0.3, 0.2, 0.6, 0.4, 0.3, 0.2, 0.2, 0.2, 0.2, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6, 1.0, 1.3, 1.4, 1.0, 1.5, 1.0, 1.4, 1.3, 1.4, 1.5, 1.0, 1.5, 1.1, 1.8, 1.3, 1.5, 1.2, 1.3, 1.4, 1.4, 1.7, 1.5, 1.0, 1.1, 1.0, 1.2, 1.6, 1.5, 1.6, 1.5, 1.3, 1.3, 1.3, 1.2, 1.4, 1.2, 1.0, 1.3, 1.2, 1.3, 1.3, 1.1, 1.3, 2.5, 1.9, 2.1, 1.8, 2.2, 2.1, 1.7, 1.8, 1.8, 2.5, 2.0, 1.9, 2.1, 2.0, 2.4, 2.3, 1.8, 2.2, 2.3, 1.5, 2.3, 2.0, 2.0, 1.8, 2.1, 1.8, 1.8, 1.8, 2.1, 1.6, 1.9, 2.0, 2.2, 1.5, 1.4, 2.3, 2.4, 1.8, 1.8, 2.1, 2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8]
DD
= iris
.data
X
= [x
[0] for x
in DD
]
print (X
)
Y
= [x
[3] for x
in DD
]
print (Y
)
plt
.scatter
(X
[:50], Y
[:50], color
='red', marker
='o', label
='setosa')
plt
.scatter
(X
[50:100], Y
[50:100], color
='blue', marker
='x', label
='versicolor')
plt
.scatter
(X
[100:], Y
[100:],color
='green', marker
='+', label
='Virginica')
plt
.legend
(loc
=2)
plt
.show
()
[5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.8, 4.8, 4.3, 5.8, 5.7, 5.4, 5.1, 5.7, 5.1, 5.4, 5.1, 4.6, 5.1, 4.8, 5.0, 5.0, 5.2, 5.2, 4.7, 4.8, 5.4, 5.2, 5.5, 4.9, 5.0, 5.5, 4.9, 4.4, 5.1, 5.0, 4.5, 4.4, 5.0, 5.1, 4.8, 5.1, 4.6, 5.3, 5.0, 7.0, 6.4, 6.9, 5.5, 6.5, 5.7, 6.3, 4.9, 6.6, 5.2, 5.0, 5.9, 6.0, 6.1, 5.6, 6.7, 5.6, 5.8, 6.2, 5.6, 5.9, 6.1, 6.3, 6.1, 6.4, 6.6, 6.8, 6.7, 6.0, 5.7, 5.5, 5.5, 5.8, 6.0, 5.4, 6.0, 6.7, 6.3, 5.6, 5.5, 5.5, 6.1, 5.8, 5.0, 5.6, 5.7, 5.7, 6.2, 5.1, 5.7, 6.3, 5.8, 7.1, 6.3, 6.5, 7.6, 4.9, 7.3, 6.7, 7.2, 6.5, 6.4, 6.8, 5.7, 5.8, 6.4, 6.5, 7.7, 7.7, 6.0, 6.9, 5.6, 7.7, 6.3, 6.7, 7.2, 6.2, 6.1, 6.4, 7.2, 7.4, 7.9, 6.4, 6.3, 6.1, 7.7, 6.3, 6.4, 6.0, 6.9, 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9]
[0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.4, 0.4, 0.3, 0.3, 0.3, 0.2, 0.4, 0.2, 0.5, 0.2, 0.2, 0.4, 0.2, 0.2, 0.2, 0.2, 0.4, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.2, 0.2, 0.3, 0.3, 0.2, 0.6, 0.4, 0.3, 0.2, 0.2, 0.2, 0.2, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6, 1.0, 1.3, 1.4, 1.0, 1.5, 1.0, 1.4, 1.3, 1.4, 1.5, 1.0, 1.5, 1.1, 1.8, 1.3, 1.5, 1.2, 1.3, 1.4, 1.4, 1.7, 1.5, 1.0, 1.1, 1.0, 1.2, 1.6, 1.5, 1.6, 1.5, 1.3, 1.3, 1.3, 1.2, 1.4, 1.2, 1.0, 1.3, 1.2, 1.3, 1.3, 1.1, 1.3, 2.5, 1.9, 2.1, 1.8, 2.2, 2.1, 1.7, 1.8, 1.8, 2.5, 2.0, 1.9, 2.1, 2.0, 2.4, 2.3, 1.8, 2.2, 2.3, 1.5, 2.3, 2.0, 2.0, 1.8, 2.1, 1.8, 1.8, 1.8, 2.1, 1.6, 1.9, 2.0, 2.2, 1.5, 1.4, 2.3, 2.4, 1.8, 1.8, 2.1, 2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8]
DD
= iris
.data
X
= [x
[0] for x
in DD
]
print (X
)
Y
= [x
[2] for x
in DD
]
print (Y
)
plt
.scatter
(X
[:50], Y
[:50], color
='red', marker
='o', label
='setosa')
plt
.scatter
(X
[50:100], Y
[50:100], color
='blue', marker
='x', label
='versicolor')
plt
.scatter
(X
[100:], Y
[100:],color
='green', marker
='+', label
='Virginica')
plt
.legend
(loc
=2)
plt
.show
()
[5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.8, 4.8, 4.3, 5.8, 5.7, 5.4, 5.1, 5.7, 5.1, 5.4, 5.1, 4.6, 5.1, 4.8, 5.0, 5.0, 5.2, 5.2, 4.7, 4.8, 5.4, 5.2, 5.5, 4.9, 5.0, 5.5, 4.9, 4.4, 5.1, 5.0, 4.5, 4.4, 5.0, 5.1, 4.8, 5.1, 4.6, 5.3, 5.0, 7.0, 6.4, 6.9, 5.5, 6.5, 5.7, 6.3, 4.9, 6.6, 5.2, 5.0, 5.9, 6.0, 6.1, 5.6, 6.7, 5.6, 5.8, 6.2, 5.6, 5.9, 6.1, 6.3, 6.1, 6.4, 6.6, 6.8, 6.7, 6.0, 5.7, 5.5, 5.5, 5.8, 6.0, 5.4, 6.0, 6.7, 6.3, 5.6, 5.5, 5.5, 6.1, 5.8, 5.0, 5.6, 5.7, 5.7, 6.2, 5.1, 5.7, 6.3, 5.8, 7.1, 6.3, 6.5, 7.6, 4.9, 7.3, 6.7, 7.2, 6.5, 6.4, 6.8, 5.7, 5.8, 6.4, 6.5, 7.7, 7.7, 6.0, 6.9, 5.6, 7.7, 6.3, 6.7, 7.2, 6.2, 6.1, 6.4, 7.2, 7.4, 7.9, 6.4, 6.3, 6.1, 7.7, 6.3, 6.4, 6.0, 6.9, 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9]
[1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.6, 1.4, 1.1, 1.2, 1.5, 1.3, 1.4, 1.7, 1.5, 1.7, 1.5, 1.0, 1.7, 1.9, 1.6, 1.6, 1.5, 1.4, 1.6, 1.6, 1.5, 1.5, 1.4, 1.5, 1.2, 1.3, 1.4, 1.3, 1.5, 1.3, 1.3, 1.3, 1.6, 1.9, 1.4, 1.6, 1.4, 1.5, 1.4, 4.7, 4.5, 4.9, 4.0, 4.6, 4.5, 4.7, 3.3, 4.6, 3.9, 3.5, 4.2, 4.0, 4.7, 3.6, 4.4, 4.5, 4.1, 4.5, 3.9, 4.8, 4.0, 4.9, 4.7, 4.3, 4.4, 4.8, 5.0, 4.5, 3.5, 3.8, 3.7, 3.9, 5.1, 4.5, 4.5, 4.7, 4.4, 4.1, 4.0, 4.4, 4.6, 4.0, 3.3, 4.2, 4.2, 4.2, 4.3, 3.0, 4.1, 6.0, 5.1, 5.9, 5.6, 5.8, 6.6, 4.5, 6.3, 5.8, 6.1, 5.1, 5.3, 5.5, 5.0, 5.1, 5.3, 5.5, 6.7, 6.9, 5.0, 5.7, 4.9, 6.7, 4.9, 5.7, 6.0, 4.8, 4.9, 5.6, 5.8, 6.1, 6.4, 5.6, 5.1, 5.6, 6.1, 5.6, 5.5, 4.8, 5.4, 5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1]
DD
= iris
.data
X
= [x
[1] for x
in DD
]
print (X
)
Y
= [x
[2] for x
in DD
]
print (Y
)
plt
.scatter
(X
[:50], Y
[:50], color
='red', marker
='o', label
='setosa')
plt
.scatter
(X
[50:100], Y
[50:100], color
='blue', marker
='x', label
='versicolor')
plt
.scatter
(X
[100:], Y
[100:],color
='green', marker
='+', label
='Virginica')
plt
.legend
(loc
=2)
plt
.show
()
[3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.4, 3.0, 3.0, 4.0, 4.4, 3.9, 3.5, 3.8, 3.8, 3.4, 3.7, 3.6, 3.3, 3.4, 3.0, 3.4, 3.5, 3.4, 3.2, 3.1, 3.4, 4.1, 4.2, 3.1, 3.2, 3.5, 3.6, 3.0, 3.4, 3.5, 2.3, 3.2, 3.5, 3.8, 3.0, 3.8, 3.2, 3.7, 3.3, 3.2, 3.2, 3.1, 2.3, 2.8, 2.8, 3.3, 2.4, 2.9, 2.7, 2.0, 3.0, 2.2, 2.9, 2.9, 3.1, 3.0, 2.7, 2.2, 2.5, 3.2, 2.8, 2.5, 2.8, 2.9, 3.0, 2.8, 3.0, 2.9, 2.6, 2.4, 2.4, 2.7, 2.7, 3.0, 3.4, 3.1, 2.3, 3.0, 2.5, 2.6, 3.0, 2.6, 2.3, 2.7, 3.0, 2.9, 2.9, 2.5, 2.8, 3.3, 2.7, 3.0, 2.9, 3.0, 3.0, 2.5, 2.9, 2.5, 3.6, 3.2, 2.7, 3.0, 2.5, 2.8, 3.2, 3.0, 3.8, 2.6, 2.2, 3.2, 2.8, 2.8, 2.7, 3.3, 3.2, 2.8, 3.0, 2.8, 3.0, 2.8, 3.8, 2.8, 2.8, 2.6, 3.0, 3.4, 3.1, 3.0, 3.1, 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0]
[1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.6, 1.4, 1.1, 1.2, 1.5, 1.3, 1.4, 1.7, 1.5, 1.7, 1.5, 1.0, 1.7, 1.9, 1.6, 1.6, 1.5, 1.4, 1.6, 1.6, 1.5, 1.5, 1.4, 1.5, 1.2, 1.3, 1.4, 1.3, 1.5, 1.3, 1.3, 1.3, 1.6, 1.9, 1.4, 1.6, 1.4, 1.5, 1.4, 4.7, 4.5, 4.9, 4.0, 4.6, 4.5, 4.7, 3.3, 4.6, 3.9, 3.5, 4.2, 4.0, 4.7, 3.6, 4.4, 4.5, 4.1, 4.5, 3.9, 4.8, 4.0, 4.9, 4.7, 4.3, 4.4, 4.8, 5.0, 4.5, 3.5, 3.8, 3.7, 3.9, 5.1, 4.5, 4.5, 4.7, 4.4, 4.1, 4.0, 4.4, 4.6, 4.0, 3.3, 4.2, 4.2, 4.2, 4.3, 3.0, 4.1, 6.0, 5.1, 5.9, 5.6, 5.8, 6.6, 4.5, 6.3, 5.8, 6.1, 5.1, 5.3, 5.5, 5.0, 5.1, 5.3, 5.5, 6.7, 6.9, 5.0, 5.7, 4.9, 6.7, 4.9, 5.7, 6.0, 4.8, 4.9, 5.6, 5.8, 6.1, 6.4, 5.6, 5.1, 5.6, 6.1, 5.6, 5.5, 4.8, 5.4, 5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5.0, 5.2, 5.4, 5.1]
DD
= iris
.data
X
= [x
[1] for x
in DD
]
print (X
)
Y
= [x
[3] for x
in DD
]
print (Y
)
plt
.scatter
(X
[:50], Y
[:50], color
='red', marker
='o', label
='setosa')
plt
.scatter
(X
[50:100], Y
[50:100], color
='blue', marker
='x', label
='versicolor')
plt
.scatter
(X
[100:], Y
[100:],color
='green', marker
='+', label
='Virginica')
plt
.legend
(loc
=2)
plt
.show
()
[3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.4, 3.0, 3.0, 4.0, 4.4, 3.9, 3.5, 3.8, 3.8, 3.4, 3.7, 3.6, 3.3, 3.4, 3.0, 3.4, 3.5, 3.4, 3.2, 3.1, 3.4, 4.1, 4.2, 3.1, 3.2, 3.5, 3.6, 3.0, 3.4, 3.5, 2.3, 3.2, 3.5, 3.8, 3.0, 3.8, 3.2, 3.7, 3.3, 3.2, 3.2, 3.1, 2.3, 2.8, 2.8, 3.3, 2.4, 2.9, 2.7, 2.0, 3.0, 2.2, 2.9, 2.9, 3.1, 3.0, 2.7, 2.2, 2.5, 3.2, 2.8, 2.5, 2.8, 2.9, 3.0, 2.8, 3.0, 2.9, 2.6, 2.4, 2.4, 2.7, 2.7, 3.0, 3.4, 3.1, 2.3, 3.0, 2.5, 2.6, 3.0, 2.6, 2.3, 2.7, 3.0, 2.9, 2.9, 2.5, 2.8, 3.3, 2.7, 3.0, 2.9, 3.0, 3.0, 2.5, 2.9, 2.5, 3.6, 3.2, 2.7, 3.0, 2.5, 2.8, 3.2, 3.0, 3.8, 2.6, 2.2, 3.2, 2.8, 2.8, 2.7, 3.3, 3.2, 2.8, 3.0, 2.8, 3.0, 2.8, 3.8, 2.8, 2.8, 2.6, 3.0, 3.4, 3.1, 3.0, 3.1, 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0]
[0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.4, 0.4, 0.3, 0.3, 0.3, 0.2, 0.4, 0.2, 0.5, 0.2, 0.2, 0.4, 0.2, 0.2, 0.2, 0.2, 0.4, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.2, 0.2, 0.3, 0.3, 0.2, 0.6, 0.4, 0.3, 0.2, 0.2, 0.2, 0.2, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6, 1.0, 1.3, 1.4, 1.0, 1.5, 1.0, 1.4, 1.3, 1.4, 1.5, 1.0, 1.5, 1.1, 1.8, 1.3, 1.5, 1.2, 1.3, 1.4, 1.4, 1.7, 1.5, 1.0, 1.1, 1.0, 1.2, 1.6, 1.5, 1.6, 1.5, 1.3, 1.3, 1.3, 1.2, 1.4, 1.2, 1.0, 1.3, 1.2, 1.3, 1.3, 1.1, 1.3, 2.5, 1.9, 2.1, 1.8, 2.2, 2.1, 1.7, 1.8, 1.8, 2.5, 2.0, 1.9, 2.1, 2.0, 2.4, 2.3, 1.8, 2.2, 2.3, 1.5, 2.3, 2.0, 2.0, 1.8, 2.1, 1.8, 1.8, 1.8, 2.1, 1.6, 1.9, 2.0, 2.2, 1.5, 1.4, 2.3, 2.4, 1.8, 1.8, 2.1, 2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2.0, 2.3, 1.8]
三、逻辑回归分类结果
3.1 选取花萼特征分类
X
= X
= iris
.data
[:, 0:2]
Y
= iris
.target
lr
= LogisticRegression
(C
=1e5)
lr
.fit
(X
,Y
)
h
= .02
x_min
, x_max
= X
[:, 0].min() - .5, X
[:, 0].max() + .5
y_min
, y_max
= X
[:, 1].min() - .5, X
[:, 1].max() + .5
xx
, yy
= np
.meshgrid
(np
.arange
(x_min
, x_max
, h
), np
.arange
(y_min
, y_max
, h
))
Z
= lr
.predict
(np
.c_
[xx
.ravel
(), yy
.ravel
()])
Z
= Z
.reshape
(xx
.shape
)
plt
.figure
(1, figsize
=(8,6))
plt
.pcolormesh
(xx
, yy
, Z
, cmap
=plt
.cm
.Paired
)
plt
.scatter
(X
[:50,0], X
[:50,1], color
='red',marker
='o', label
='setosa')
plt
.scatter
(X
[50:100,0], X
[50:100,1], color
='blue', marker
='x', label
='versicolor')
plt
.scatter
(X
[100:,0], X
[100:,1], color
='green', marker
='s', label
='Virginica')
plt
.xlabel
('Sepal length')
plt
.ylabel
('Sepal width')
plt
.xlim
(xx
.min(), xx
.max())
plt
.ylim
(yy
.min(), yy
.max())
plt
.xticks
(())
plt
.yticks
(())
plt
.legend
(loc
=2)
plt
.show
()
3.2 选取花瓣特征分类
X
= X
= iris
.data
[:, 2:4]
Y
= iris
.target
lr
= LogisticRegression
(C
=1e5)
lr
.fit
(X
,Y
)
h
= .02
x_min
, x_max
= X
[:, 0].min() - .5, X
[:, 0].max() + .5
y_min
, y_max
= X
[:, 1].min() - .5, X
[:, 1].max() + .5
xx
, yy
= np
.meshgrid
(np
.arange
(x_min
, x_max
, h
), np
.arange
(y_min
, y_max
, h
))
Z
= lr
.predict
(np
.c_
[xx
.ravel
(), yy
.ravel
()])
Z
= Z
.reshape
(xx
.shape
)
plt
.figure
(1, figsize
=(8,6))
plt
.pcolormesh
(xx
, yy
, Z
, cmap
=plt
.cm
.Paired
)
plt
.scatter
(X
[:50,0], X
[:50,1], color
='red',marker
='o', label
='setosa')
plt
.scatter
(X
[50:100,0], X
[50:100,1], color
='blue', marker
='x', label
='versicolor')
plt
.scatter
(X
[100:,0], X
[100:,1], color
='green', marker
='s', label
='Virginica')
plt
.xlabel
('Sepal length')
plt
.ylabel
('Sepal width')
plt
.xlim
(xx
.min(), xx
.max())
plt
.ylim
(yy
.min(), yy
.max())
plt
.xticks
(())
plt
.yticks
(())
plt
.legend
(loc
=2)
plt
.show
()
四、数据集切割(训练集+测试集)
x
=iris
.data
y
=iris
.target
x_train
,x_test
,y_train
,y_test
=model_selection
.train_test_split
(x
,y
,random_state
=101,test_size
=0.3)
print("split_train_data 70%:", x_train
.shape
, "split_train_target 70%:",y_train
.shape
, "split_test_data 30%", x_test
.shape
, "split_test_target 30%",y_test
.shape
)
split_train_data 70%: (105, 4) split_train_target 70%: (105,) split_test_data 30% (45, 4) split_test_target 30% (45,)
五、模型训练并用验证集验证
5.1 逻辑回归
model
= LogisticRegression
()
model
.fit
(x_train
, y_train
)
prediction
=model
.predict
(x_test
)
print('The accuracy of the Logistic Regression is: {0}'.format(metrics
.accuracy_score
(prediction
,y_test
)))
The accuracy of the Logistic Regression is: 0.9555555555555556
5.2 决策树
model
=DecisionTreeClassifier
()
model
.fit
(x_train
, y_train
)
prediction
=model
.predict
(x_test
)
print('The accuracy of the DecisionTreeClassifier is: {0}'.format(metrics
.accuracy_score
(prediction
,y_test
)))
The accuracy of the DecisionTreeClassifier is: 0.9555555555555556
5.3 K-邻近
model
=KNeighborsClassifier
(n_neighbors
=3)
model
.fit
(x_train
, y_train
)
prediction
=model
.predict
(x_test
)
print('The accuracy of the K-Nearest Neighbours is: {0}'.format(metrics
.accuracy_score
(prediction
,y_test
)))
The accuracy of the K-Nearest Neighbours is: 1.0
5.4 支持向量机
model
= svm
.SVC
()
model
.fit
(x_train
, y_train
)
prediction
=model
.predict
(x_test
)
print('The accuracy of the SVM is: {0}'.format(metrics
.accuracy_score
(prediction
,y_test
)))
The accuracy of the SVM is: 1.0
六、神经网络
x
=iris
.data
y
=iris
.target
np
.random
.seed
(seed
=7)
y_Label
=LabelBinarizer
().fit_transform
(y
)
x_train
,y_train
,x_test
,y_test
=model_selection
.train_test_split
(x
,y_Label
,test_size
=0.3,random_state
=42)
from keras
.models
import Sequential
from keras
.layers
.core
import Dense
model
= Sequential
()
model
.add
(Dense
(4,activation
='relu',input_shape
=(4,)))
model
.add
(Dense
(6,activation
='relu'))
model
.add
(Dense
(3,activation
='softmax'))
model
.summary
()
Model: "sequential_23"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_68 (Dense) (None, 4) 20
_________________________________________________________________
dense_69 (Dense) (None, 6) 30
_________________________________________________________________
dense_70 (Dense) (None, 3) 21
=================================================================
Total params: 71
Trainable params: 71
Non-trainable params: 0
_________________________________________________________________
model
.compile(loss
='categorical_crossentropy',optimizer
='rmsprop',metrics
=['accuracy'])
step
=25
history
=model
.fit
(x_train
,x_test
,validation_data
=(y_train
,y_test
),batch_size
=10,epochs
=step
)
train_result
=history
.history
Train on 105 samples, validate on 45 samples
Epoch 1/25
105/105 [==============================] - 0s 2ms/step - loss: 0.1160 - accuracy: 0.9429 - val_loss: 0.0453 - val_accuracy: 1.0000
Epoch 2/25
105/105 [==============================] - 0s 123us/step - loss: 0.1121 - accuracy: 0.9524 - val_loss: 0.0486 - val_accuracy: 1.0000
Epoch 3/25
105/105 [==============================] - 0s 114us/step - loss: 0.1116 - accuracy: 0.9429 - val_loss: 0.0496 - val_accuracy: 1.0000
Epoch 4/25
105/105 [==============================] - 0s 114us/step - loss: 0.1129 - accuracy: 0.9524 - val_loss: 0.0479 - val_accuracy: 1.0000
Epoch 5/25
105/105 [==============================] - 0s 123us/step - loss: 0.1136 - accuracy: 0.9524 - val_loss: 0.0483 - val_accuracy: 1.0000
Epoch 6/25
105/105 [==============================] - 0s 114us/step - loss: 0.1112 - accuracy: 0.9524 - val_loss: 0.0518 - val_accuracy: 1.0000
Epoch 7/25
105/105 [==============================] - 0s 114us/step - loss: 0.1110 - accuracy: 0.9429 - val_loss: 0.0505 - val_accuracy: 1.0000
Epoch 8/25
105/105 [==============================] - 0s 123us/step - loss: 0.1099 - accuracy: 0.9524 - val_loss: 0.0467 - val_accuracy: 1.0000
Epoch 9/25
105/105 [==============================] - 0s 114us/step - loss: 0.1109 - accuracy: 0.9524 - val_loss: 0.0554 - val_accuracy: 0.9778
Epoch 10/25
105/105 [==============================] - 0s 123us/step - loss: 0.1097 - accuracy: 0.9524 - val_loss: 0.0535 - val_accuracy: 1.0000
Epoch 11/25
105/105 [==============================] - 0s 114us/step - loss: 0.1091 - accuracy: 0.9524 - val_loss: 0.0452 - val_accuracy: 1.0000
Epoch 12/25
105/105 [==============================] - 0s 123us/step - loss: 0.1040 - accuracy: 0.9619 - val_loss: 0.0742 - val_accuracy: 0.9778
Epoch 13/25
105/105 [==============================] - 0s 123us/step - loss: 0.1124 - accuracy: 0.9524 - val_loss: 0.0670 - val_accuracy: 0.9778
Epoch 14/25
105/105 [==============================] - ETA: 0s - loss: 0.0918 - accuracy: 0.90 - 0s 133us/step - loss: 0.1110 - accuracy: 0.9429 - val_loss: 0.0746 - val_accuracy: 0.9778
Epoch 15/25
105/105 [==============================] - 0s 114us/step - loss: 0.1105 - accuracy: 0.9429 - val_loss: 0.0545 - val_accuracy: 0.9778
Epoch 16/25
105/105 [==============================] - 0s 123us/step - loss: 0.1077 - accuracy: 0.9524 - val_loss: 0.0519 - val_accuracy: 1.0000
Epoch 17/25
105/105 [==============================] - 0s 114us/step - loss: 0.1113 - accuracy: 0.9524 - val_loss: 0.0562 - val_accuracy: 0.9778
Epoch 18/25
105/105 [==============================] - 0s 123us/step - loss: 0.1085 - accuracy: 0.9524 - val_loss: 0.0491 - val_accuracy: 1.0000
Epoch 19/25
105/105 [==============================] - 0s 123us/step - loss: 0.1101 - accuracy: 0.9524 - val_loss: 0.0548 - val_accuracy: 0.9778
Epoch 20/25
105/105 [==============================] - 0s 114us/step - loss: 0.1060 - accuracy: 0.9429 - val_loss: 0.0451 - val_accuracy: 1.0000
Epoch 21/25
105/105 [==============================] - 0s 123us/step - loss: 0.1128 - accuracy: 0.9524 - val_loss: 0.0450 - val_accuracy: 1.0000
Epoch 22/25
105/105 [==============================] - 0s 123us/step - loss: 0.1086 - accuracy: 0.9524 - val_loss: 0.0534 - val_accuracy: 0.9778
Epoch 23/25
105/105 [==============================] - 0s 123us/step - loss: 0.1065 - accuracy: 0.9524 - val_loss: 0.0455 - val_accuracy: 1.0000
Epoch 24/25
105/105 [==============================] - 0s 114us/step - loss: 0.1062 - accuracy: 0.9524 - val_loss: 0.0434 - val_accuracy: 1.0000
Epoch 25/25
105/105 [==============================] - 0s 123us/step - loss: 0.1081 - accuracy: 0.9524 - val_loss: 0.0423 - val_accuracy: 1.0000
基于keras的全连接神经网络模型,仅仅需要很少的训练次数,即可达到100%的准确度,是训练iris数据集时不错的选择。
acc
=train_result
['accuracy']
val_acc
=train_result
['val_accuracy']
epochs
=range(1,step
+1)
plt
.plot
(epochs
,acc
,'b-')
plt
.plot
(epochs
,val_acc
,'r')
plt
.xlabel
('epochs')
plt
.ylabel
('accuracy')
plt
.show
()
t
=model
.predict
(y_train
)
resultsss
=model
.evaluate
(y_train
,y_test
)
resultsss
45/45 [==============================] - 0s 44us/step
[0.5691331187884013, 0.9111111164093018]
七、将花瓣和花萼特征分离+标签 训练
7.1 选取花萼特征训练
DD
= iris
.data
x
=DD
[ :,0:2]
print(x
)
y
=iris
.target
[[5.1 3.5]
[4.9 3. ]
[4.7 3.2]
[4.6 3.1]
[5. 3.6]
[5.4 3.9]
[4.6 3.4]
[5. 3.4]
[4.4 2.9]
[4.9 3.1]
[5.4 3.7]
[4.8 3.4]
[4.8 3. ]
[4.3 3. ]
[5.8 4. ]
[5.7 4.4]
[5.4 3.9]
[5.1 3.5]
[5.7 3.8]
[5.1 3.8]
[5.4 3.4]
[5.1 3.7]
[4.6 3.6]
[5.1 3.3]
[4.8 3.4]
[5. 3. ]
[5. 3.4]
[5.2 3.5]
[5.2 3.4]
[4.7 3.2]
[4.8 3.1]
[5.4 3.4]
[5.2 4.1]
[5.5 4.2]
[4.9 3.1]
[5. 3.2]
[5.5 3.5]
[4.9 3.6]
[4.4 3. ]
[5.1 3.4]
[5. 3.5]
[4.5 2.3]
[4.4 3.2]
[5. 3.5]
[5.1 3.8]
[4.8 3. ]
[5.1 3.8]
[4.6 3.2]
[5.3 3.7]
[5. 3.3]
[7. 3.2]
[6.4 3.2]
[6.9 3.1]
[5.5 2.3]
[6.5 2.8]
[5.7 2.8]
[6.3 3.3]
[4.9 2.4]
[6.6 2.9]
[5.2 2.7]
[5. 2. ]
[5.9 3. ]
[6. 2.2]
[6.1 2.9]
[5.6 2.9]
[6.7 3.1]
[5.6 3. ]
[5.8 2.7]
[6.2 2.2]
[5.6 2.5]
[5.9 3.2]
[6.1 2.8]
[6.3 2.5]
[6.1 2.8]
[6.4 2.9]
[6.6 3. ]
[6.8 2.8]
[6.7 3. ]
[6. 2.9]
[5.7 2.6]
[5.5 2.4]
[5.5 2.4]
[5.8 2.7]
[6. 2.7]
[5.4 3. ]
[6. 3.4]
[6.7 3.1]
[6.3 2.3]
[5.6 3. ]
[5.5 2.5]
[5.5 2.6]
[6.1 3. ]
[5.8 2.6]
[5. 2.3]
[5.6 2.7]
[5.7 3. ]
[5.7 2.9]
[6.2 2.9]
[5.1 2.5]
[5.7 2.8]
[6.3 3.3]
[5.8 2.7]
[7.1 3. ]
[6.3 2.9]
[6.5 3. ]
[7.6 3. ]
[4.9 2.5]
[7.3 2.9]
[6.7 2.5]
[7.2 3.6]
[6.5 3.2]
[6.4 2.7]
[6.8 3. ]
[5.7 2.5]
[5.8 2.8]
[6.4 3.2]
[6.5 3. ]
[7.7 3.8]
[7.7 2.6]
[6. 2.2]
[6.9 3.2]
[5.6 2.8]
[7.7 2.8]
[6.3 2.7]
[6.7 3.3]
[7.2 3.2]
[6.2 2.8]
[6.1 3. ]
[6.4 2.8]
[7.2 3. ]
[7.4 2.8]
[7.9 3.8]
[6.4 2.8]
[6.3 2.8]
[6.1 2.6]
[7.7 3. ]
[6.3 3.4]
[6.4 3.1]
[6. 3. ]
[6.9 3.1]
[6.7 3.1]
[6.9 3.1]
[5.8 2.7]
[6.8 3.2]
[6.7 3.3]
[6.7 3. ]
[6.3 2.5]
[6.5 3. ]
[6.2 3.4]
[5.9 3. ]]
x_train
,x_test
,y_train
,y_test
=model_selection
.train_test_split
(x
,y
,random_state
=101,test_size
=0.3)
print("split_train_data 70%:", x_train
.shape
, "split_train_target 70%:",y_train
.shape
, "split_test_data 30%", x_test
.shape
, "split_test_target 30%",y_test
.shape
)
split_train_data 70%: (105, 2) split_train_target 70%: (105,) split_test_data 30% (45, 2) split_test_target 30% (45,)
model
=KNeighborsClassifier
(n_neighbors
=3)
model
.fit
(x_train
, y_train
)
prediction
=model
.predict
(x_test
)
print('The accuracy of the K-Nearest Neighbours is: {0}'.format(metrics
.accuracy_score
(prediction
,y_test
)))
The accuracy of the K-Nearest Neighbours is: 0.6444444444444445
7.1 选取花瓣特征训练
DD
= iris
.data
x
=DD
[ :,2:4]
print(x
)
[[1.4 0.2]
[1.4 0.2]
[1.3 0.2]
[1.5 0.2]
[1.4 0.2]
[1.7 0.4]
[1.4 0.3]
[1.5 0.2]
[1.4 0.2]
[1.5 0.1]
[1.5 0.2]
[1.6 0.2]
[1.4 0.1]
[1.1 0.1]
[1.2 0.2]
[1.5 0.4]
[1.3 0.4]
[1.4 0.3]
[1.7 0.3]
[1.5 0.3]
[1.7 0.2]
[1.5 0.4]
[1. 0.2]
[1.7 0.5]
[1.9 0.2]
[1.6 0.2]
[1.6 0.4]
[1.5 0.2]
[1.4 0.2]
[1.6 0.2]
[1.6 0.2]
[1.5 0.4]
[1.5 0.1]
[1.4 0.2]
[1.5 0.2]
[1.2 0.2]
[1.3 0.2]
[1.4 0.1]
[1.3 0.2]
[1.5 0.2]
[1.3 0.3]
[1.3 0.3]
[1.3 0.2]
[1.6 0.6]
[1.9 0.4]
[1.4 0.3]
[1.6 0.2]
[1.4 0.2]
[1.5 0.2]
[1.4 0.2]
[4.7 1.4]
[4.5 1.5]
[4.9 1.5]
[4. 1.3]
[4.6 1.5]
[4.5 1.3]
[4.7 1.6]
[3.3 1. ]
[4.6 1.3]
[3.9 1.4]
[3.5 1. ]
[4.2 1.5]
[4. 1. ]
[4.7 1.4]
[3.6 1.3]
[4.4 1.4]
[4.5 1.5]
[4.1 1. ]
[4.5 1.5]
[3.9 1.1]
[4.8 1.8]
[4. 1.3]
[4.9 1.5]
[4.7 1.2]
[4.3 1.3]
[4.4 1.4]
[4.8 1.4]
[5. 1.7]
[4.5 1.5]
[3.5 1. ]
[3.8 1.1]
[3.7 1. ]
[3.9 1.2]
[5.1 1.6]
[4.5 1.5]
[4.5 1.6]
[4.7 1.5]
[4.4 1.3]
[4.1 1.3]
[4. 1.3]
[4.4 1.2]
[4.6 1.4]
[4. 1.2]
[3.3 1. ]
[4.2 1.3]
[4.2 1.2]
[4.2 1.3]
[4.3 1.3]
[3. 1.1]
[4.1 1.3]
[6. 2.5]
[5.1 1.9]
[5.9 2.1]
[5.6 1.8]
[5.8 2.2]
[6.6 2.1]
[4.5 1.7]
[6.3 1.8]
[5.8 1.8]
[6.1 2.5]
[5.1 2. ]
[5.3 1.9]
[5.5 2.1]
[5. 2. ]
[5.1 2.4]
[5.3 2.3]
[5.5 1.8]
[6.7 2.2]
[6.9 2.3]
[5. 1.5]
[5.7 2.3]
[4.9 2. ]
[6.7 2. ]
[4.9 1.8]
[5.7 2.1]
[6. 1.8]
[4.8 1.8]
[4.9 1.8]
[5.6 2.1]
[5.8 1.6]
[6.1 1.9]
[6.4 2. ]
[5.6 2.2]
[5.1 1.5]
[5.6 1.4]
[6.1 2.3]
[5.6 2.4]
[5.5 1.8]
[4.8 1.8]
[5.4 2.1]
[5.6 2.4]
[5.1 2.3]
[5.1 1.9]
[5.9 2.3]
[5.7 2.5]
[5.2 2.3]
[5. 1.9]
[5.2 2. ]
[5.4 2.3]
[5.1 1.8]]
x_train
,x_test
,y_train
,y_test
=model_selection
.train_test_split
(x
,y
,random_state
=101,test_size
=0.3)
print("split_train_data 70%:", x_train
.shape
, "split_train_target 70%:",y_train
.shape
, "split_test_data 30%", x_test
.shape
, "split_test_target 30%",y_test
.shape
)
split_train_data 70%: (105, 2) split_train_target 70%: (105,) split_test_data 30% (45, 2) split_test_target 30% (45,)
model
=KNeighborsClassifier
(n_neighbors
=3)
model
.fit
(x_train
, y_train
)
prediction
=model
.predict
(x_test
)
print('The accuracy of the K-Nearest Neighbours is: {0}'.format(metrics
.accuracy_score
(prediction
,y_test
)))
The accuracy of the K-Nearest Neighbours is: 0.9777777777777777
由此可见,同样选取了准确率很高的KNN模型,选取花萼特征和花瓣特征分别进行训练,得到的预测准确度是不同的。花萼特征相关度较低,花瓣特征相关度较高,花瓣训练结果较为乐观,而四个特征一起加入训练时准确率最高,可以达到100%.