本文基于 iris 数据集介绍 scikit-learn (https://scikit-learn.org/stable/) 的简单用法,并提供 R 语言的等价表示。
scikit-learn 整合各类机器学习算法,以及整套流程(数据归一化、数据集切分、超参数寻优等),提供统一的接口,学习和使用上更加方便。 xgboost 的 R 包接口更接近 Python 编程的使用习惯。 想要对齐 Python 语言和 R 语言的结果是比较困难的,它们内部的实现方式不一致,只能在结果上是做到十分接近。
1 准备
导入必要的功能模块,加载 iris 数据集。
import numpy as np
# 数据集 iris
from sklearn.datasets import load_iris
# 分类模型
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
# 模型评估
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score
# 加载 iris 数据
X, y = load_iris(return_X_y=True)
library(nnet) # 多项对数线性分类模型
library(glmnet) # 允许带惩罚项
library(randomForest) # 随机森林
library(xgboost) # 梯度提升
# pROC::aoc() # 计算 AUC
2 分类模型
2.1 不带惩罚项
R 内置的 glm() 函数不支持三分类及以上的多类分类模型,而 nnet 包支持。其拟合函数 multinom() 内部调用优化函数 optim() ,优化求解器是 BFGS 。拟合结果中,第一个类的回归系数都是 0 。
函数 LogisticRegression() 设置 penalty=None 表示对回归系数不加惩罚,参数估计要求解一个优化问题,solver='lbfgs' 设置优化求解器,C=1 默认值,设置正则项的系数,即惩罚强度,越小强度越大,C=1 表示不加强度。
clf = LogisticRegression(penalty=None, random_state=0, solver='lbfgs', C=1).fit(X, y)
np.hstack((clf.intercept_.reshape(3, 1), clf.coef_))
## array([[ 3.9723211 , 7.34741055, 20.36734847, -30.20952186,
## -14.1184832 ],
## [ 19.28077892, -2.43835607, -6.85009017, 10.39822031,
## -2.07485129],
## [-23.25310002, -4.90905448, -13.5172583 , 19.81130155,
## 16.19333449]])
# R 对数似然 = -150 * Python 对数损失
log_loss(y, clf.predict_proba(X))
## 0.039661957693336546
-150 * log_loss(y, clf.predict_proba(X))
## -5.949293654000482
# 150 个样本预测正确的有 148 个,返回预测的准确度 148/150
clf.score(X, y)
## 0.9866666666666667
# 与 clf.score 一样
accuracy_score(y, clf.predict(X))
## 0.9866666666666667
# 计算 AUC
roc_auc_score(y, clf.predict_proba(X), multi_class='ovr')
## 0.9990666666666668
library(nnet) # 多项逻辑回归
iris_nnet <- multinom(Species ~ ., data = iris, trace = FALSE)
summary(iris_nnet)
## Call:
## multinom(formula = Species ~ ., data = iris, trace = FALSE)
##
## Coefficients:
## (Intercept) Sepal.Length Sepal.Width Petal.Length Petal.Width
## versicolor 18.69 -5.458 -8.707 14.24 -3.098
## virginica -23.84 -7.924 -15.371 23.66 15.135
##
## Std. Errors:
## (Intercept) Sepal.Length Sepal.Width Petal.Length Petal.Width
## versicolor 34.97 89.89 157.0 60.19 45.49
## virginica 35.77 89.91 157.1 60.47 45.93
##
## Residual Deviance: 11.9
## AIC: 31.9
# 对数似然
logLik(iris_nnet)
## 'log Lik.' -5.95 (df=10)
# 预测,返回概率值
iris_nnet_pred <- predict(iris_nnet, iris[, -5], type = "prob")
# 混淆矩阵
table(iris$Species, predict(iris_nnet, iris, type = "class"))
##
## setosa versicolor virginica
## setosa 50 0 0
## versicolor 0 49 1
## virginica 0 1 49
# 计算 AUC
pROC::auc(pROC::multiclass.roc(iris[, 5], iris_nnet_pred))
## Multi-class area under the curve: 0.999
2.2 带 L2 (岭)惩罚
在函数 LogisticRegression() 中设置 penalty="l2" 表示添加 L2 (岭)惩罚。在 R 语言中,带惩罚项的回归模型,常用 glmnet 包来拟合。函数 glmnet() 设置 alpha = 0 表示添加 L2 (岭)惩罚,设置参数 standardize = FALSE 表示在拟合模型前数据不做标准化,这和下面的 Python 代码对齐。Python 函数 LogisticRegression() 中的参数 C=1 的作用相当于 R 函数 glmnet() 中 penalty.factor = 1 。
当设置 penalty="l2" 时,默认的迭代次数(max_iter=100)不够用了,最大迭代次数设置为 1000 。
clf = LogisticRegression(
penalty="l2", random_state=0,
solver='lbfgs', C=1,
max_iter=1000).fit(X, y)
# 截距和系数
np.hstack((clf.intercept_.reshape(3, 1), clf.coef_))
## array([[ 9.85443729, -0.42468213, 0.96714132, -2.51552827,
## -1.08228622],
## [ 2.23270878, 0.53509422, -0.32079311, -0.20713148,
## -0.94316337],
## [-12.08714607, -0.11041209, -0.64634822, 2.72265975,
## 2.02544959]])
# 预测的准确度
clf.score(X, y)
## 0.9733333333333334
iris_glmnet <- glmnet(
x = iris[, -5], y = iris[, 5],
alpha = 0, penalty.factor = rep(1, 4),
standardize = FALSE, family = "multinomial"
)
# 截距和系数
do.call("rbind", lapply(coef(iris_glmnet, s = iris_glmnet$lambda[60]), function(x) t(x)))
## 3 x 5 sparse Matrix of class "dgCMatrix"
## (Intercept) Sepal.Length Sepal.Width Petal.Length Petal.Width
## s=3.162 0.94122 -0.058892 0.031107 -0.17026 -0.070467
## s=3.162 -0.04831 0.004154 -0.027634 0.03725 0.006863
## s=3.162 -0.89291 0.054737 -0.003473 0.13301 0.063604
# 预测
iris_pred_glmnet <- predict(iris_glmnet, newx = as.matrix(iris[, -5]), s = iris_glmnet$lambda[60], type = "class")
# 混淆矩阵
table(iris_pred_glmnet, iris$Species)
##
## iris_pred_glmnet setosa versicolor virginica
## setosa 50 3 0
## versicolor 0 20 1
## virginica 0 27 49
2.3 超参数寻优
常用交叉验证 CV 来调模型中超参数
# CV 选超参数 C
clf = LogisticRegressionCV(
penalty="l2", random_state=0,
solver='lbfgs',
max_iter=1000).fit(X, y)
# 截距和系数
np.hstack((clf.intercept_.reshape(3, 1), clf.coef_))
## array([[ 12.43840368, 0.14282955, 2.6600351 , -4.85048989,
## -2.50437649],
## [ 5.16311988, 0.90443367, 0.02425054, -0.50620054,
## -2.91559719],
## [-17.60152356, -1.04726322, -2.68428563, 5.35669043,
## 5.41997368]])
# 预测的准确度
clf.score(X, y)
## 0.98
CV 选更好的 C 之后,准确度提升了。
# CV 选超参数 lambda
iris_cv_glmnet <- cv.glmnet(
x = as.matrix(iris[, -5]),
y = iris[, 5], alpha = 0, penalty.factor = rep(1, 4),
standardize = FALSE, nlambda = 500,
family = "multinomial"
)
# 最佳的 lambda
best_lambda <- iris_cv_glmnet$lambda.min
best_lambda
## [1] 0.07653
# 最佳的 lambda 代入模型
iris_glmnet2 <- glmnet(
x = iris[, -5], y = iris[, 5],
lambda = best_lambda,
alpha = 0,
standardize = FALSE, family = "multinomial"
)
# 最佳模型的回归系数
do.call("rbind", lapply(coef(iris_glmnet2), function(x) t(x)))
## 3 x 5 sparse Matrix of class "dgCMatrix"
## (Intercept) Sepal.Length Sepal.Width Petal.Length Petal.Width
## s0 5.074 -0.29109 0.322824 -1.11902 -0.4543
## s0 1.511 0.05574 -0.320606 0.08215 -0.2038
## s0 -6.585 0.23535 -0.002218 1.03686 0.6582
# 预测
iris_pred_glmnet2 <- predict(iris_glmnet2, newx = as.matrix(iris[, -5]), type = "class")
# 最佳模型的预测结果
table(iris_pred_glmnet2, iris$Species)
##
## iris_pred_glmnet2 setosa versicolor virginica
## setosa 50 0 0
## versicolor 0 46 2
## virginica 0 4 48
预测准确度 1-6/150 = 0.96 。
2.4 随机森林
数据集比较简单,随机森林和梯度提升算法很容易过拟合的。
clf = RandomForestClassifier(max_depth=2, random_state=0).fit(X, y)
clf.score(X, y)
## 0.9666666666666667
library(randomForest) # 随机森林
iris_rf <- randomForest(
Species ~ ., data = iris,
importance = TRUE, proximity = TRUE
)
# 分类结果
print(iris_rf)
##
## Call:
## randomForest(formula = Species ~ ., data = iris, importance = TRUE, proximity = TRUE)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 4%
## Confusion matrix:
## setosa versicolor virginica class.error
## setosa 50 0 0 0.00
## versicolor 0 47 3 0.06
## virginica 0 3 47 0.06
预测准确度 1-7/150 = 0.9533 。
2.5 梯度提升
梯度提升算法很容易过拟合的。
clf = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0,
max_depth=1, random_state=0).fit(X, y)
clf.score(X, y)
## 1.0
library(xgboost)
iris_xgb <- xgboost(
x = iris[, -5], y = iris[, 5],
objective = "multi:softprob",
max_depth = 2, # 树的深度
learning_rate = 0.3 # 默认值
)
# 预测
iris_pred_xgb <- predict(object = iris_xgb, newdata = iris[, -5])
# 预测结果
head(iris_pred_xgb)
## setosa versicolor virginica
## 1 0.9978 0.001964 0.0001928
## 2 0.9980 0.001688 0.0003081
## 3 0.9981 0.001688 0.0001928
## 4 0.9981 0.001688 0.0002232
## 5 0.9978 0.001964 0.0001928
## 6 0.9967 0.003153 0.0001926
每一个样本划分到各个类的概率。如果以 0.9 为分割点,则概率大于 0.9 的视为样本归属于该类。
apply(iris_pred_xgb, 2, cut, breaks = c(0, 0.9, 1)) |>
xtabs(formula = ~. )
## , , virginica = (0,0.9]
##
## versicolor
## setosa (0,0.9] (0.9,1]
## (0,0.9] 7 47
## (0.9,1] 50 0
##
## , , virginica = (0.9,1]
##
## versicolor
## setosa (0,0.9] (0.9,1]
## (0,0.9] 46 0
## (0.9,1] 0 0
属于 virginica 鸢尾的有 46 个,属于 setosa 鸢尾的有 50 个,属于 versicolor 鸢尾的有 47 个,还有 7 个样本归类错误。
# 预测结果是类别,而不是概率分布
iris_pred = predict(iris_xgb, iris[, -5], type = "class")
table(iris$Species, iris_pred)
## iris_pred
## setosa versicolor virginica
## setosa 50 0 0
## versicolor 0 50 0
## virginica 0 0 50
实际上,XGBoost 可以选择最佳的分割点 0.8,使得分类完全符合实际数据。
注意:函数 xgboost() 目前不支持设置 objective = "multi:softmax" 。函数 xgb.train() 支持,示例见《现代应用统计》分类问题中的 集成学习。
3 运行环境
xfun::session_info(packages = c(
"nnet", "glmnet", "randomForest", "xgboost", "pROC"))
## R version 4.5.2 (2025-10-31)
## Platform: aarch64-apple-darwin20
## Running under: macOS Tahoe 26.0.1
##
## Locale: en_US.UTF-8 / en_US.UTF-8 / en_US.UTF-8 / C / en_US.UTF-8 / en_US.UTF-8
##
## Package version:
## codetools_0.2.20 data.table_1.17.99 foreach_1.5.2
## glmnet_4.1-10 graphics_4.5.2 grDevices_4.5.2
## grid_4.5.2 iterators_1.0.14 jsonlite_2.0.0
## lattice_0.22.7 Matrix_1.7.4 methods_4.5.2
## nnet_7.3-20 pROC_1.19.0.1 randomForest_4.7-1.2
## Rcpp_1.1.0 RcppEigen_0.3.4.0.2 shape_1.4.6.1
## splines_4.5.2 stats_4.5.2 survival_3.8.3
## utils_4.5.2 xgboost_3.1.0.1
reticulate::py_config()
## python: /opt/.virtualenvs/r-tensorflow/bin/python3
## libpython: /opt/homebrew/opt/python@3.13/Frameworks/Python.framework/Versions/3.13/lib/python3.13/config-3.13-darwin/libpython3.13.dylib
## pythonhome: /opt/.virtualenvs/r-tensorflow:/opt/.virtualenvs/r-tensorflow
## virtualenv: /opt/.virtualenvs/r-tensorflow/bin/activate_this.py
## version: 3.13.9 (main, Oct 14 2025, 13:52:31) [Clang 17.0.0 (clang-1700.3.19.1)]
## numpy: /opt/.virtualenvs/r-tensorflow/lib/python3.13/site-packages/numpy
## numpy_version: 2.3.2
##
## NOTE: Python version was forced by RETICULATE_PYTHON
4 参考文献
- sklearn 文档 LogisticRegression