1 sklearn的学习笔记--决策树( 三 )


如下是数据集的展示:
# 输出数据的头部数据 , 大概5个 , 也可以在head中写入需要展示的数据个数taitanic_data.head()# 输出数据的信息taitanic_data.info()# info中的信息更为具体 , 可以看见对应的列数和其格式还有是否为空 , 对后续的数据集操作都有比较好的效果和帮助
在上图中可以发现 , 其中存在一定的缺失值 , age确实的内容较少且age确实会对是否存活有着比较大的影响 。类似于Cabin这种无效的且缺失严重的内容可以将其抛弃 。类似对于是否存活无关的数据还有Name和等信息 。
# 删除一些无用的数值来进行 , 可以提前判断某些内容是否影响对应的结果taitanic_data.drop(['Cabin','Name','Ticket'],inplace=True,axis=1)
如上图就是经过drop操作后的数据 , 除去了一些无用的特征信息来使得数据更为精纯 。
# 处理缺失值 , 对缺失值较多的列进行填补 , 如果有些特征确实缺少太多数据 , 完全可以删除对应的记录taitanic_data['Age'] = taitanic_data['Age'].fillna(taitanic_data['Age'].mean())taitanic_data = http://www.kingceram.com/post/taitanic_data.dropna()# 将分类变量转换为数值型变量# 方法一:主要针对于二分类特征0~1 , 可以定义某一个数值为True , 另一种就是False , 然后通过astype()taitanic_data['Sex'] = (taitanic_data['Sex']=='male').astype('int')# 方法二:将三分类变量转换为数值型变量# 将对应的数据转换成为 , 可以理解为set操作 , 将重复的数值去掉 , 转换成list后续可以用来获取对应的下标用来代替标签labels = taitanic_data['Embarked'].unique().tolist()taitanic_data['Embarked'] = taitanic_data['Embarked'].apply(lambda x: labels.index(x))
执行上述代码的时候可以发现 , 会有的提醒跳出 , 如下所示

1  sklearn的学习笔记--决策树

文章插图
这里提醒我们最好使用 loc()方法来帮助实现刚才对应任务 。
x = taitanic_data.iloc[:,taitanic_data.columns != 'Survived']y = taitanic_data.iloc[:,taitanic_data.columns == 'Survived']# 需要注意一下train_test_split的划分数据集 , 四个数据集的顺序是否正确Xtrain,Xtest,Ytrain,Ytest = train_test_split(x,y,test_size=0.3)# 针对一些比较复杂的数据 , 需要修正对应的索引 , 并且可以帮助后期的查询辅助等任务for i in [Xtrain,Xtest,Ytrain,Ytest]:i.index = range(i.shape[0])Xtrain.head()
# 训练模型clf = DecisionTreeClassifier(random_state=25)clf = clf.fit(Xtrain,Ytrain)score_ = clf.score(Xtest,Ytest)# 0.7715355805243446score = cross_val_socre(clf,x,y,cv=10).mean()# 0.7212391248121231# 效果还是太差了 , 想办法提升一下对应的效果,针对于通道的max_depth进行调整tr = []te = []n = 10for i in range(n):clf = DecisionTreeClassifier(random_state=3,max_depth=i+1,criterion='entropy')clf.fit(Xtrain,Ytrain)score_tr = clf.score(Xtest,Ytest)score_te = cross_val_score(clf,x,y,cv=10).mean()tr.append(score_tr)te.append(score_te)print(max(te))plt.figure()plt.plot(range(1,n+1),tr,color='red',label='train')plt.plot(range(1,n+1),te,color='blue',label='test')plt.xticks(range(1,n+1))plt.legend()plt.show()# 输出的图片如下所示
可以看出 , 在=3的时候 , 效果是最好的 。