欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

员工离职预测 逻辑回归

程序员文章站 2024-03-22 09:19:46
...
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder#中文编码为数字
import seaborn as sns#热力图
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split,GridSearchCV#数据集分割,调参
from sklearn.metrics import accuracy_score#评分标准
from sklearn.linear_model import LogisticRegression#逻辑回归模型

导入数据

train=pd.read_csv('./pfm_train.csv')
test=pd.read_csv('./pfm_test.csv')

查看数据

train.head()
Age Attrition BusinessTravel Department DistanceFromHome Education EducationField EmployeeNumber EnvironmentSatisfaction Gender ... RelationshipSatisfaction StandardHours StockOptionLevel TotalWorkingYears TrainingTimesLastYear WorkLifeBalance YearsAtCompany YearsInCurrentRole YearsSinceLastPromotion YearsWithCurrManager
0 37 0 Travel_Rarely Research & Development 1 4 Life Sciences 77 1 Male ... 3 80 1 7 2 4 7 5 0 7
1 54 0 Travel_Frequently Research & Development 1 4 Life Sciences 1245 4 Female ... 1 80 1 33 2 1 5 4 1 4
2 34 1 Travel_Frequently Research & Development 7 3 Life Sciences 147 1 Male ... 4 80 0 9 3 3 9 7 0 6
3 39 0 Travel_Rarely Research & Development 1 1 Life Sciences 1026 4 Female ... 3 80 1 21 3 3 21 6 11 8
4 28 1 Travel_Frequently Research & Development 1 3 Medical 1111 1 Male ... 1 80 2 1 2 3 1 0 0 0

5 rows × 31 columns

LabelEncoder对非数字格式列编码

dtypes_list=train.dtypes.values#取出各列的数据类型
columns__list=train.columns#取出各列列名
#循环遍历每一列找出非数字格式进行编码
for i in  range(len(columns__list)):
    if dtypes_list[i]=='object':#判断类型
        lb=LabelEncoder()#导入LabelEncoder模型
        lb.fit(train[columns__list[i]])#训练
        train[columns__list[i]]=lb.transform(train[columns__list[i]])#编码
        test[columns__list[i]]=lb.transform(test[columns__list[i]])#同样对测试集编码

查看LabelEncoder编码过后的数据也是全数字类型

test.head()
Age BusinessTravel Department DistanceFromHome Education EducationField EmployeeNumber EnvironmentSatisfaction Gender JobInvolvement ... RelationshipSatisfaction StandardHours StockOptionLevel TotalWorkingYears TrainingTimesLastYear WorkLifeBalance YearsAtCompany YearsInCurrentRole YearsSinceLastPromotion YearsWithCurrManager
0 40 0 1 9 4 4 1449 3 1 3 ... 3 80 2 11 2 4 8 7 0 7
1 53 2 1 7 2 3 1201 4 0 3 ... 2 80 1 26 6 3 7 7 4 7
2 42 2 1 2 4 4 477 1 1 2 ... 2 80 0 14 6 3 1 0 0 0
3 34 1 0 11 3 1 1289 3 1 2 ... 4 80 2 14 5 4 10 9 1 8
4 32 2 1 1 1 1 134 4 1 3 ... 4 80 0 1 2 3 1 0 0 0

5 rows × 30 columns

判断特征值是否唯一,唯一删掉

unique_list=[]
#遍历找出唯一的列名
for i in  range(train.shape[1]):
    if train[columns__list[i]].nunique()==1:#nunique()查看每一列是不是唯一的如果唯一返回值为1
        unique_list.append(columns__list[i])
unique_list
['Over18', 'StandardHours']

删除唯一列

train.drop(unique_list,axis=1,inplace=True)
test.drop(unique_list,axis=1,inplace=True)

删除员工号,随机数据影响分类

train.drop('EmployeeNumber',axis=1,inplace=True)
test.drop('EmployeeNumber',axis=1,inplace=True)

查看各列相关性

corr=train.corr()
corr.head()
Age Attrition BusinessTravel Department DistanceFromHome Education EducationField EnvironmentSatisfaction Gender JobInvolvement ... PerformanceRating RelationshipSatisfaction StockOptionLevel TotalWorkingYears TrainingTimesLastYear WorkLifeBalance YearsAtCompany YearsInCurrentRole YearsSinceLastPromotion YearsWithCurrManager
Age 1.000000 -0.175393 0.024270 -0.017185 0.007081 0.198558 -0.010160 0.011803 -0.029794 0.066528 ... -0.029613 0.063489 -0.002413 0.682879 -0.051702 -0.001042 0.328651 0.231842 0.230587 0.212540
Attrition -0.175393 1.000000 0.015483 0.053364 0.088563 -0.046494 0.009994 -0.097003 0.016750 -0.122722 ... 0.046762 -0.051749 -0.138498 -0.187922 -0.043395 -0.048794 -0.143697 -0.163059 -0.071760 -0.158558
BusinessTravel 0.024270 0.015483 1.000000 -0.040937 -0.040339 -0.041300 0.027817 -0.013328 -0.033992 0.019014 ... -0.033435 -0.036396 0.001352 0.031168 0.030902 -0.003512 -0.010005 0.002011 -0.041603 -0.021887
Department -0.017185 0.053364 -0.040937 1.000000 0.013349 -0.006280 0.012247 -0.021766 -0.034053 -0.040219 ... -0.020124 -0.025962 -0.007854 -0.027084 0.051085 0.048179 0.017601 0.059544 0.019319 0.032234
DistanceFromHome 0.007081 0.088563 -0.040339 0.013349 1.000000 0.011437 -0.016882 -0.010308 0.023493 0.012333 ... 0.021042 0.018112 0.050356 0.001287 -0.041208 -0.050950 0.000044 0.019317 -0.002760 0.008852

5 rows × 28 columns

各列相关性热力图

sns.heatmap(corr,xticklabels=corr.columns.values,yticklabels=corr.columns.values)
plt.show()#黑色负相关

员工离职预测 逻辑回归

查看热力图StockOptionLevel与MaritalStatus成负相关,删除其中一列

train.drop('StockOptionLevel',axis=1,inplace=True)
test.drop('StockOptionLevel',axis=1,inplace=True)

分箱离散化MonthlyIncome,分类需要MonthlyIncome

train.MonthlyIncome=pd.cut(train.MonthlyIncome,bins=10)
test.MonthlyIncome=pd.cut(test.MonthlyIncome,bins=10)
label=train.Attrition#提取离职标签
train.drop('Attrition',inplace=True,axis=1)#删除标签

one-hot编码:将离散型特征的每一种取值都看成一种状态

train=pd.get_dummies(train)
test=pd.get_dummies(test)

分割训练集

y_train,y_test,x_train,x_test=train_test_split(label,train)

未调参的模型

log=LogisticRegression()#导入模型
log.fit(x_train,y_train)#训练
accuracy_score(log.predict(x_test),y_test)#评分
0.8327272727272728

调参后的模型

model=LogisticRegression()
C=[0.01,0.1,1,10,100,1000]#逻辑回归正则化强度
penalty=['l1','l2']#逻辑回归惩罚项参数
param_grid=dict(C=C,penalty=penalty)#GridSearchCV需要传入字典类型参数
model_gcv=GridSearchCV(model,param_grid,cv=3,scoring='roc_auc')#调参模型
model_gcv.fit(x_train,y_train)#训练
GridSearchCV(cv=3, error_score='raise',
       estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False),
       fit_params=None, iid=True, n_jobs=1,
       param_grid={'C': [0.01, 0.1, 1, 10, 100, 1000], 'penalty': ['l1', 'l2']},
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring='roc_auc', verbose=0)

查看最好参数

model_gcv.best_params_
{'C': 1, 'penalty': 'l1'}
model_gcv.best_score_
0.8265055902793318
accuracy_score(model_gcv.predict(x_test),y_test)
0.8327272727272728

对测试集预测

predict_test = model_gcv.predict(test)
pd.DataFrame(predict_test, columns=['result']).to_csv('./test_pre.csv',
                                                      index=False)

字段意思

(1)Age:员工年龄
(2)Attrition:员工是否已经离职,1表示已经离职,2表示未离职,这是目标预测值;
(3)BusinessTravel:商务差旅频率,Non-Travel表示不出差,Travel_Rarely表示不经常出差,Travel_Frequently表示经常出差;
(4)Department:员工所在部门,Sales表示销售部,Research & Development表示研发部,Human Resources表示人力资源部;
(5)DistanceFromHome:公司跟家庭住址的距离,从1到29,1表示最近,29表示最远;
(6)Education:员工的教育程度,从1到5,5表示教育程度最高;
(7)EducationField:员工所学习的专业领域,Life Sciences表示生命科学,Medical表示医疗,Marketing表示市场营销,Technical Degree表示技术学位,Human Resources表示人力资源,Other表示其他;
(8)EmployeeNumber:员工号码;
(9)EnvironmentSatisfaction:员工对于工作环境的满意程度,从1到4,1的满意程度最低,4的满意程度最高;
(10)Gender:员工性别,Male表示男性,Female表示女性;
(11)JobInvolvement:员工工作投入度,从1到4,1为投入度最低,4为投入度最高;
(12)JobLevel:职业级别,从1到5,1为最低级别,5为*别;
(13)JobRole:工作角色:Sales Executive是销售主管,Research Scientist是科学研究员,Laboratory Technician实验室技术员,Manufacturing Director是制造总监,Healthcare Representative是医疗代表,Manager是经理,Sales Representative是销售代表,Research Director是研究总监,Human Resources是人力资源;
(14)JobSatisfaction:工作满意度,从1到4,1代表满意程度最低,4代表满意程度最高;
(15)MaritalStatus:员工婚姻状况,Single代表单身,Married代表已婚,Divorced代表离婚;
(16)MonthlyIncome:员工月收入,范围在1009到19999之间;
(17)NumCompaniesWorked:员工曾经工作过的公司数;
(18)Over18:年龄是否超过18岁;
(19)OverTime:是否加班,Yes表示加班,No表示不加班;
(20)PercentSalaryHike:工资提高的百分比;
(21)PerformanceRating:绩效评估;
(22)RelationshipSatisfaction:关系满意度,从1到4,1表示满意度最低,4表示满意度最高;
(23)StandardHours:标准工时;
(24)StockOptionLevel:股票期权水平;
(25)TotalWorkingYears:总工龄;
(26)TrainingTimesLastYear:上一年的培训时长,从0到6,0表示没有培训,6表示培训时间最长;
(27)WorkLifeBalance:工作与生活平衡程度,从1到4,1表示平衡程度最低,4表示平衡程度最高;
(28)YearsAtCompany:在目前公司工作年数;
(29)YearsInCurrentRole:在目前工作职责的工作年数
(30)YearsSinceLastPromotion:距离上次升职时长
(31)YearsWithCurrManager:跟目前的管理者共事年数;

相关标签: 数据挖掘