{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd #数据分析库,核心是DataFrame对象\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.ensemble import RandomForestClassifier #随机森林\n", "from sklearn.model_selection import train_test_split,cross_val_score,GridSearchCV #训练集和测试集的划分,交叉验证评估模型,网格搜索优化超参数\n", "from sklearn.metrics import accuracy_score,roc_auc_score,roc_curve,auc #模型准确度 \n", "import matplotlib.pyplot as pyt #绘制ROC曲线" ] }, { "cell_type": "code", "execution_count": 3, "id": "59086eb3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS
\n", "
" ], "text/plain": [ " PassengerId Survived Pclass \\\n", "0 1 0 3 \n", "1 2 1 1 \n", "2 3 1 3 \n", "3 4 1 1 \n", "4 5 0 3 \n", "\n", " Name Sex Age SibSp \\\n", "0 Braund, Mr. Owen Harris male 22.0 1 \n", "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", "2 Heikkinen, Miss. Laina female 26.0 0 \n", "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", "4 Allen, Mr. William Henry male 35.0 0 \n", "\n", " Parch Ticket Fare Cabin Embarked \n", "0 0 A/5 21171 7.2500 NaN S \n", "1 0 PC 17599 71.2833 C85 C \n", "2 0 STON/O2. 3101282 7.9250 NaN S \n", "3 0 113803 53.1000 C123 S \n", "4 0 373450 8.0500 NaN S " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df=pd.read_csv(\"C:\\Learning\\MachineLearning\\RandomForest\\\\train.csv\",header=0)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 4, "id": "4b0650bf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\" subset = df[df['Survived'] == 0]\\n\\n# 过采样倍数\\noversampling_factor = 3\\n\\n# 使用 sample 函数对 subset 进行过采样,将采样结果追加到原始数据集中\\noversampled_df = pd.concat([df] + [subset.sample(frac=1, replace=True)] * (oversampling_factor - 1), axis=0)\\n\\n# 输出过采样后的数据集\\nprint(oversampled_df) \"" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\" subset = df[df['Survived'] == 0]\n", "\n", "# 过采样倍数\n", "oversampling_factor = 3\n", "\n", "# 使用 sample 函数对 subset 进行过采样,将采样结果追加到原始数据集中\n", "oversampled_df = pd.concat([df] + [subset.sample(frac=1, replace=True)] * (oversampling_factor - 1), axis=0)\n", "\n", "# 输出过采样后的数据集\n", "print(oversampled_df) \"\"\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "2345dd24", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "891" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[\"Age\"]=df[\"Age\"].fillna(df[\"Age\"].median())\n", "df[\"Cabin\"]=df[\"Cabin\"].bfill() #用相邻后面(back)特征填充前面缺失值\n", "df[\"Embarked\"]=df[\"Embarked\"].fillna(\"S\")\n", "output=df\n", "output.to_csv(\"final.csv\")\n", "len(df)" ] }, { "cell_type": "code", "execution_count": 6, "id": "7d0107fd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Survived\n", "0 549\n", "1 342\n", "Name: count, dtype: int64" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#看样本是否均衡,下采样(从多的样本中抽取跟少的样本中一样的数量),过采样(把数据量少的样本重复多次使之数量与多的差不多)解决样本不均衡的情况\n", "df.Survived.value_counts()" ] }, { "cell_type": "code", "execution_count": 7, "id": "e43fd930", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 0\n", "1 1\n", "2 1\n", "3 1\n", "4 0\n", "Name: Survived, dtype: int64" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y=df[\"Survived\"]\n", "y.head()" ] }, { "cell_type": "code", "execution_count": 8, "id": "f241e0d5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
013Braund, Mr. Owen Harrismale22.010A/5 211717.2500C85S
121Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
233Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250C123S
341Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
453Allen, Mr. William Henrymale35.0003734508.0500E46S
\n", "
" ], "text/plain": [ " PassengerId Pclass Name \\\n", "0 1 3 Braund, Mr. Owen Harris \n", "1 2 1 Cumings, Mrs. John Bradley (Florence Briggs Th... \n", "2 3 3 Heikkinen, Miss. Laina \n", "3 4 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) \n", "4 5 3 Allen, Mr. William Henry \n", "\n", " Sex Age SibSp Parch Ticket Fare Cabin Embarked \n", "0 male 22.0 1 0 A/5 21171 7.2500 C85 S \n", "1 female 38.0 1 0 PC 17599 71.2833 C85 C \n", "2 female 26.0 0 0 STON/O2. 3101282 7.9250 C123 S \n", "3 female 35.0 1 0 113803 53.1000 C123 S \n", "4 male 35.0 0 0 373450 8.0500 E46 S " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#删掉Survived列\n", "x=df.drop(\"Survived\",axis=1)\n", "x.head()" ] }, { "cell_type": "code", "execution_count": 9, "id": "5a377d7f", "metadata": {}, "outputs": [], "source": [ "features=[\"Pclass\",\"Sex\",\"Age\",\"Ticket\",\"Fare\",\"Cabin\"]\n", "from FuncToNumberClass import FuncToNumber\n", "x=FuncToNumber.ToNumber(x.loc[:,features])\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "6dd05460", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Pclass Sex Age Ticket Fare Cabin\n", "887 1 0 24 14 153 30\n", "416 2 0 45 224 161 115\n", "479 3 0 6 251 79 31\n", "134 2 1 33 552 85 107\n", "588 3 1 28 79 43 96\n", ".. ... ... ... ... ... ...\n", "400 3 1 52 663 41 78\n", "118 1 1 31 585 244 36\n", "701 1 1 47 579 141 120\n", "206 3 1 42 247 106 8\n", "867 1 1 41 590 184 6\n", "\n", "[712 rows x 6 columns]\n" ] } ], "source": [ "\n", "xtrain,xtest,ytrain,ytest=train_test_split(x,y,test_size=0.2,random_state=5)\n", "\n", "print(xtrain)" ] }, { "cell_type": "code", "execution_count": 11, "id": "7106567b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Pclass Sex Age Ticket Fare Cabin\n", "126 3 1 36 463 30 138\n", "354 3 1 36 182 16 123\n", "590 3 1 47 656 14 96\n", "509 3 1 34 80 193 121\n", "769 3 1 42 511 48 136\n", ".. ... ... ... ... ... ...\n", "732 2 1 36 137 0 15\n", "42 3 1 36 391 40 101\n", "179 3 1 48 574 0 144\n", "123 2 0 43 219 85 116\n", "890 3 1 42 466 30 147\n", "\n", "[179 rows x 6 columns]\n" ] } ], "source": [ "print(xtest)\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "f62148bd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "126 0\n", "354 0\n", "590 0\n", "509 1\n", "769 0\n", " ..\n", "732 0\n", "42 0\n", "179 0\n", "123 1\n", "890 0\n", "Name: Survived, Length: 179, dtype: int64\n" ] } ], "source": [ "print(ytest)" ] }, { "cell_type": "markdown", "id": "e21ce014", "metadata": {}, "source": [ "n_estimators : 随机森林中树的个数,即学习器的个数。 \n", "max_features : 划分叶子节点,选择的最大特征数目 \n", "n_features:在寻找最佳分割时要考虑的特征数量\n", "max_depth : 树的最大深度,如果选择default=None,树就一致扩展,直到所有的叶子节点都是同一类样本,或者达到最小样本划分(min_samples_split)的数目。\n", "min_samples_split : 最小样本划分的数目,就是样本的数目少于等于这个值,就不能继续划分当前节点了\n", "min_samples_leaf : 叶子节点最少样本数,如果某叶子节点数目小于这个值,就会和兄弟节点一起被剪枝。\n", "min_weight_fraction_leaf:叶子节点最小的样本权重和\n", "max_leaf_nodes: 最大叶子节点数,默认是”None”,即不限制最大的叶子节点数\n", "min_impurity_split:节点划分的最小不纯度,是结束树增长的一个阈值,如果不纯度超过这个阈值,那么该节点就会继续划分,否则不划分,成为一个叶子节点。\n", "min_impurity_decrease : 最小不纯度减少的阈值,如果对该节点进行划分,使得不纯度的减少大于等于这个值,那么该节点就会划分,否则,不划分。\n", "bootstrap :自助采样,又放回的采样,大量采样的结果就是初始样本的63.2%作为训练集。默认选择自助采样法。\n", "oob_score : bool (default=False) \n", "out-of-bag estimate,包外估计;是否选用包外样本(即bootstrap采样剩下的36.8%的样本)作为验证集,对训练结果进行验证,默认不采用。\n", "n_jobs : 并行使用的进程数,默认1个,如果设置为-1,该值为总的核数。\n", "random_state :随机状态,默认由np.numpy生成\n", "verbose:显示输出的一些参数,默认不输出。" ] }, { "cell_type": "code", "execution_count": 13, "id": "7a230210", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RandomForestClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestClassifier()" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rfc=RandomForestClassifier() #实例化\n", "rfc.fit(xtrain,ytrain) " ] }, { "cell_type": "code", "execution_count": 14, "id": "124e234e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1,\n", " 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0,\n", " 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0,\n", " 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1,\n", " 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0,\n", " 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0,\n", " 0, 1, 0], dtype=int64)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predicts=rfc.predict(xtest)\n", "predicts\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "72b345bf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.91, 0.09],\n", " [0.86, 0.14],\n", " [0.68, 0.32],\n", " [0.24, 0.76],\n", " [0.82, 0.18],\n", " [0.93, 0.07],\n", " [0.01, 0.99],\n", " [0.87, 0.13],\n", " [0.04, 0.96],\n", " [0.96, 0.04],\n", " [0.67, 0.33],\n", " [0.6 , 0.4 ],\n", " [0.33, 0.67],\n", " [0.02, 0.98],\n", " [1. , 0. ],\n", " [0.15, 0.85],\n", " [0.97, 0.03],\n", " [0.47, 0.53],\n", " [0.96, 0.04],\n", " [1. , 0. ],\n", " [0.81, 0.19],\n", " [0.41, 0.59],\n", " [0.03, 0.97],\n", " [0.18, 0.82],\n", " [0.85, 0.15],\n", " [0.55, 0.45],\n", " [0.77, 0.23],\n", " [0.67, 0.33],\n", " [0.65, 0.35],\n", " [0.88, 0.12],\n", " [0.93, 0.07],\n", " [0.53, 0.47],\n", " [0.2 , 0.8 ],\n", " [0.98, 0.02],\n", " [0.09, 0.91],\n", " [0.45, 0.55],\n", " [0.97, 0.03],\n", " [0.77, 0.23],\n", " [0.82, 0.18],\n", " [0.86, 0.14],\n", " [0.98, 0.02],\n", " [0.15, 0.85],\n", " [0.91, 0.09],\n", " [0.96, 0.04],\n", " [0.92, 0.08],\n", " [0.98, 0.02],\n", " [0.09, 0.91],\n", " [0.04, 0.96],\n", " [0.68, 0.32],\n", " [0.97, 0.03],\n", " [0.8 , 0.2 ],\n", " [0.03, 0.97],\n", " [0.04, 0.96],\n", " [0.84, 0.16],\n", " [0.94, 0.06],\n", " [0.91, 0.09],\n", " [0.34, 0.66],\n", " [0.98, 0.02],\n", " [0.77, 0.23],\n", " [0.22, 0.78],\n", " [0.9 , 0.1 ],\n", " [0.69, 0.31],\n", " [0.81, 0.19],\n", " [0.91, 0.09],\n", " [0.44, 0.56],\n", " [0.68, 0.32],\n", " [0.99, 0.01],\n", " [0.81, 0.19],\n", " [0.86, 0.14],\n", " [0.55, 0.45],\n", " [0.13, 0.87],\n", " [0.9 , 0.1 ],\n", " [0.1 , 0.9 ],\n", " [0.94, 0.06],\n", " [0.27, 0.73],\n", " [0.8 , 0.2 ],\n", " [0.93, 0.07],\n", " [0.87, 0.13],\n", " [0.03, 0.97],\n", " [0.13, 0.87],\n", " [0.87, 0.13],\n", " [0.95, 0.05],\n", " [0.93, 0.07],\n", " [0.97, 0.03],\n", " [0.05, 0.95],\n", " [0.06, 0.94],\n", " [0.86, 0.14],\n", " [0.16, 0.84],\n", " [0.87, 0.13],\n", " [0.42, 0.58],\n", " [0.88, 0.12],\n", " [1. , 0. ],\n", " [0.01, 0.99],\n", " [0.81, 0.19],\n", " [0.02, 0.98],\n", " [0.9 , 0.1 ],\n", " [0.06, 0.94],\n", " [0.82, 0.18],\n", " [0.92, 0.08],\n", " [0.93, 0.07],\n", " [0.58, 0.42],\n", " [0.11, 0.89],\n", " [0.93, 0.07],\n", " [0.8 , 0.2 ],\n", " [0.21, 0.79],\n", " [0.07, 0.93],\n", " [0.95, 0.05],\n", " [0.77, 0.23],\n", " [0.99, 0.01],\n", " [0.5 , 0.5 ],\n", " [0.97, 0.03],\n", " [0.71, 0.29],\n", " [0.83, 0.17],\n", " [1. , 0. ],\n", " [0.09, 0.91],\n", " [0.44, 0.56],\n", " [0.37, 0.63],\n", " [0.88, 0.12],\n", " [0.05, 0.95],\n", " [0.56, 0.44],\n", " [0.48, 0.52],\n", " [0.72, 0.28],\n", " [0.69, 0.31],\n", " [0.05, 0.95],\n", " [0.84, 0.16],\n", " [0.03, 0.97],\n", " [0.72, 0.28],\n", " [0.89, 0.11],\n", " [0.72, 0.28],\n", " [0.97, 0.03],\n", " [0.94, 0.06],\n", " [0.94, 0.06],\n", " [0.96, 0.04],\n", " [0.81, 0.19],\n", " [0.11, 0.89],\n", " [0.94, 0.06],\n", " [0.98, 0.02],\n", " [0.84, 0.16],\n", " [0.95, 0.05],\n", " [0.9 , 0.1 ],\n", " [0.87, 0.13],\n", " [0.98, 0.02],\n", " [0.96, 0.04],\n", " [0.54, 0.46],\n", " [0.33, 0.67],\n", " [0.06, 0.94],\n", " [0.82, 0.18],\n", " [0.01, 0.99],\n", " [0.86, 0.14],\n", " [0.23, 0.77],\n", " [0.09, 0.91],\n", " [0.79, 0.21],\n", " [0.59, 0.41],\n", " [0.9 , 0.1 ],\n", " [0.86, 0.14],\n", " [0.38, 0.62],\n", " [0.14, 0.86],\n", " [0.82, 0.18],\n", " [0.92, 0.08],\n", " [0.9 , 0.1 ],\n", " [0.89, 0.11],\n", " [0.92, 0.08],\n", " [0.85, 0.15],\n", " [0.65, 0.35],\n", " [0.35, 0.65],\n", " [0.7 , 0.3 ],\n", " [0.12, 0.88],\n", " [0.91, 0.09],\n", " [0.94, 0.06],\n", " [0.95, 0.05],\n", " [0.02, 0.98],\n", " [0.92, 0.08],\n", " [0.32, 0.68],\n", " [0.02, 0.98],\n", " [1. , 0. ],\n", " [0.96, 0.04],\n", " [0.76, 0.24],\n", " [0.1 , 0.9 ],\n", " [0.53, 0.47]])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predicts_proba=rfc.predict_proba(xtest)[:,:]\n", "predicts_proba" ] }, { "cell_type": "code", "execution_count": 16, "id": "660f8572", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8491620111731844" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "row=df.loc[xtest.index]\n", "output = pd.DataFrame({'PassengerId':row.PassengerId, \"Name\":row.Name, 'ActualSurvived': row.Survived, 'PredictSurvived': predicts})\n", "\n", "output.to_csv('result.csv', index=False)\n", "outputresult=output[output.ActualSurvived==output.PredictSurvived]\n", "len(outputresult)\n", "len(outputresult)/len(output)" ] }, { "cell_type": "code", "execution_count": 17, "id": "5a3eaa96", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8579093799682036" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "'''percent = accuracy_score(ytest,predicts)\n", "print(\"准确率:\", round(percent, 3))'''\n", "\n", "roc_auc_score(ytest,rfc.predict_proba(xtest)[:,1])" ] }, { "cell_type": "code", "execution_count": 18, "id": "db10f2f8", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "predict_validation=rfc.predict_proba(xtest)[:,1]\n", "fpr,tpr,_=roc_curve(ytest,predict_validation)\n", "roc_auc=auc(fpr,tpr)\n", "pyt.title(\"Roc validation\")\n", "pyt.plot(fpr,tpr,label=\"AUC=%0.4f\" %roc_auc)\n", "pyt.legend(loc='lower right')\n", "pyt.plot([0,1],[0,1],'r--')\n", "pyt.xlim([0,1])\n", "pyt.ylim([0,1])\n", "pyt.ylabel(\"True positive rate\")\n", "pyt.xlabel(\"False positive rate\")\n", "pyt.show()" ] }, { "cell_type": "code", "execution_count": 19, "id": "bf61bd49", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'bootstrap': True,\n", " 'ccp_alpha': 0.0,\n", " 'class_weight': None,\n", " 'criterion': 'gini',\n", " 'max_depth': None,\n", " 'max_features': 'sqrt',\n", " 'max_leaf_nodes': None,\n", " 'max_samples': None,\n", " 'min_impurity_decrease': 0.0,\n", " 'min_samples_leaf': 1,\n", " 'min_samples_split': 2,\n", " 'min_weight_fraction_leaf': 0.0,\n", " 'n_estimators': 100,\n", " 'n_jobs': None,\n", " 'oob_score': False,\n", " 'random_state': None,\n", " 'verbose': 0,\n", " 'warm_start': False}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "rfc.get_params()\n" ] }, { "cell_type": "code", "execution_count": 20, "id": "fad5d90c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.88688017 0.91869835 0.86290491 0.87920585 0.80944865]\n", "0.8714275856791531\n" ] } ], "source": [ "rfc=RandomForestClassifier() #实例化\n", "\n", "\n", "score=cross_val_score(rfc,xtrain,ytrain,scoring='roc_auc',cv=5)\n", "print(score)\n", "print(score.mean())" ] }, { "cell_type": "code", "execution_count": 21, "id": "959954bb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.86956522 0.88086691 0.8565508 0.76604278]\n", "0.843256426568887\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Python\\Python311\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n" ] } ], "source": [ "rfc1=LogisticRegression() #实例化\n", "score=cross_val_score(rfc1,xtrain,ytrain,scoring='roc_auc',cv=4)\n", "print(score)\n", "print(score.mean())" ] }, { "cell_type": "code", "execution_count": 22, "id": "901c12e5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'n_estimators': 450} 0.8753070939277835\n" ] } ], "source": [ "#调整超参数\n", "params_test1 = {'n_estimators' :range(25,500,25)}\n", "gsearch1=GridSearchCV(estimator=RandomForestClassifier(min_samples_split=2,\n", " min_samples_leaf=1,\n", " max_depth=8,random_state=2),\n", " param_grid=params_test1,\n", " scoring='roc_auc',\n", " cv=5)\n", "gsearch1.fit(xtrain,ytrain)\n", "print(gsearch1.best_params_,gsearch1.best_score_)\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 23, "id": "24704a5e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'max_depth': 11} 0.8767195729499804\n" ] } ], "source": [ "params_test2 = {'max_depth' :range(2,30,3)}\n", "gsearch2=GridSearchCV(estimator=RandomForestClassifier(n_estimators=225,\n", " min_samples_split=2,\n", " min_samples_leaf=1,random_state=2),\n", " param_grid=params_test2,\n", " scoring='roc_auc',\n", " cv=5)\n", "gsearch2.fit(xtrain,ytrain)\n", "print(gsearch2.best_params_,gsearch2.best_score_)" ] }, { "cell_type": "code", "execution_count": 24, "id": "ac90a0e3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RandomForestClassifier(max_depth=11, n_estimators=225, random_state=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestClassifier(max_depth=11, n_estimators=225, random_state=2)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gsearch2.best_estimator_" ] }, { "cell_type": "code", "execution_count": 25, "id": "fd6d6a1b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8657260201377848" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#accuracy_score(ytest,gsearch3.best_estimator_.predict_proba(xtest)[:,1])\n", "roc_auc_score(ytest,gsearch2.best_estimator_.predict_proba(xtest)[:,1])" ] }, { "cell_type": "code", "execution_count": 26, "id": "d3415abc", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
08923Kelly, Mr. Jamesmale34.5003309117.8292NaNQ
18933Wilkes, Mrs. James (Ellen Needs)female47.0103632727.0000NaNS
28942Myles, Mr. Thomas Francismale62.0002402769.6875NaNQ
38953Wirz, Mr. Albertmale27.0003151548.6625NaNS
48963Hirvonen, Mrs. Alexander (Helga E Lindqvist)female22.011310129812.2875NaNS
\n", "
" ], "text/plain": [ " PassengerId Pclass Name Sex \\\n", "0 892 3 Kelly, Mr. James male \n", "1 893 3 Wilkes, Mrs. James (Ellen Needs) female \n", "2 894 2 Myles, Mr. Thomas Francis male \n", "3 895 3 Wirz, Mr. Albert male \n", "4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female \n", "\n", " Age SibSp Parch Ticket Fare Cabin Embarked \n", "0 34.5 0 0 330911 7.8292 NaN Q \n", "1 47.0 1 0 363272 7.0000 NaN S \n", "2 62.0 0 0 240276 9.6875 NaN Q \n", "3 27.0 0 0 315154 8.6625 NaN S \n", "4 22.0 1 1 3101298 12.2875 NaN S " ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df1=pd.read_csv(\"C:\\Learning\\MachineLearning\\RandomForest\\\\test.csv\",header=0)\n", "df1.head()" ] }, { "cell_type": "code", "execution_count": 27, "id": "c283c83f", "metadata": {}, "outputs": [], "source": [ "from FuncToNumberClass import FuncToNumber\n", "xtest2=FuncToNumber.ToNumber(df1.loc[:,features])" ] }, { "cell_type": "code", "execution_count": 28, "id": "371558e8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RandomForestClassifier(max_depth=14, n_estimators=230, random_state=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomForestClassifier(max_depth=14, n_estimators=230, random_state=2)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rfc=RandomForestClassifier(max_depth=14, min_samples_leaf=1, min_samples_split=2,\n", " n_estimators=230, random_state=2) #实例化\n", "rfc.fit(x,y) " ] }, { "cell_type": "code", "execution_count": 29, "id": "76baeb7f", "metadata": {}, "outputs": [], "source": [ "predictresult=rfc.predict(xtest2)\n", "row=df1.loc[xtest2.index]\n", "output = pd.DataFrame({'PassengerId': row.PassengerId,\"Name\":row.Name,'PredictSurvived': predictresult})\n", "output.to_csv(\"ResultPredict.csv\",index=0)" ] }, { "cell_type": "code", "execution_count": 31, "id": "2d404f2e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 1 0 0 1 1 0 0 1 0 1 0 1 0 0 0 0 0 0 0 0 0 1\n", " 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 1 1 0 1 0 0 0 1 0 0 0 1 1 1 1 0 0 1 1 0 1 0\n", " 1 0 0 1 0 1 1 0 0 0 0 0 1 1 0 1 1 0 1 0 0 0 1 0 1 0 1 0 0 0 1 0 0 0 0 0 0\n", " 1 1 1 1 0 0 1 0 1 1 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0\n", " 0 0 1 0 0 0 1 0 1 1 0 0 1 1 1 0 0 1 0 0 1 1 0 0 0 0 0 1 1 0 1 1 0 0 1 0 1\n", " 0 1 0 0 0 0 0 0 0 1 0 1 1 0 0 0 1 0 1 0 0 0 0 1 0 0 0 0 1 1 0 1 0 1 0 1 0\n", " 1 0 1 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 1 0 0 0 0 1 0 1 1 1 0 1 0 0 0 0 0 1\n", " 0 0 0 1 1 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 1 1 0 1 1 0 0 1 0 0 0 1 0 0 0 0\n", " 1 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 0 0\n", " 1 0 0 0 0 0 0 0 0 0 1 0 1 0 1 0 1 1 0 0 0 1 0 1 0 0 0 0 1 1 0 1 0 0 1 1 0\n", " 0 1 0 0 1 1 1 0 0 1 0 0 0 0 0 1 0 0 0 1 1 1 0 0 0 1 0 1 0 0 1 0 1 0 0 0 0\n", " 0 1 1 0 1 1 0 1 0 0 0]\n" ] } ], "source": [ "import pickle\n", "\n", "# 保存模型\n", "with open('model.pkl', 'wb') as f:\n", " pickle.dump(rfc, f)\n", "\n", "# 加载模型\n", "with open('model.pkl', 'rb') as f:\n", " loaded_model = pickle.load(f)\n", "\n", "# 使用加载后的模型进行预测或其他操作\n", "result = loaded_model.predict(xtest2)\n", "print(result)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }