Skip to content

01-模型评估与优化

本章讲解机器学习模型的评估指标、交叉验证、超参数调优和模型保存。


实际场景

你训练了一个垃圾邮件分类模型,初始准确率只有 70%。你想提高模型的准确率:首先需要知道如何正确评估模型性能(准确率、精确率、召回率哪个更重要?),然后通过调整模型参数(如决策树深度、随机森林的树数量)来提升性能。这就是模型评估与优化的核心任务。

分类评估指标

混淆矩阵

              预测值
          正例 (1)  负例 (0)
真实  正例   TP       FN
值    负例   FP       TN

评估指标

python
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report
)
from numpy.typing import NDArray

y_true: NDArray
y_pred: NDArray
print(f"准确率:{accuracy_score(y_true, y_pred):.4f}")
print(f"精确率:{precision_score(y_true, y_pred):.4f}")
print(f"召回率:{recall_score(y_true, y_pred):.4f}")
print(f"F1 分数:{f1_score(y_true, y_pred):.4f}")

ROC 曲线与 AUC

python
from sklearn.metrics import roc_curve, roc_auc_score
from numpy.typing import NDArray

y_true: NDArray
y_proba: NDArray
fpr: NDArray
tpr: NDArray
thresholds: NDArray
fpr, tpr, thresholds = roc_curve(y_true, y_proba)
auc_score: float = roc_auc_score(y_true, y_proba)

回归评估指标

python
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import numpy as np
from numpy.typing import NDArray

y_true: NDArray
y_pred: NDArray
mse: float = mean_squared_error(y_true, y_pred)
rmse: float = np.sqrt(mse)
mae: float = mean_absolute_error(y_true, y_pred)
r2: float = r2_score(y_true, y_pred)

交叉验证

python
from sklearn.model_selection import cross_val_score
from numpy.typing import NDArray

scores: NDArray = cross_val_score(model, X, y, cv=5)
print(f"平均得分:{scores.mean():.4f}")

超参数调优

网格搜索

python
from sklearn.model_selection import GridSearchCV

param_grid: dict[str, list[int]] = {
    'n_estimators': [50, 100, 200],
    'max_depth': [5, 10, 20]
}

grid_search: GridSearchCV = GridSearchCV(model, param_grid, cv=5)
grid_search.fit(X_train, y_train)

print(f"最佳参数:{grid_search.best_params_}")

随机搜索

python
from sklearn.model_selection import RandomizedSearchCV

random_search: RandomizedSearchCV = RandomizedSearchCV(model, param_distributions, n_iter=50, cv=5)
random_search.fit(X_train, y_train)

模型保存

python
import joblib

joblib.dump(model, 'model.joblib')

loaded_model = joblib.load('model.joblib')

本章小结

┌─────────────────────────────────────────────────────────────┐
│                    模型评估与优化 知识要点                    │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   分类评估:                                                 │
│   ✓ 混淆矩阵、准确率、精确率、召回率、F1                     │
│   ✓ ROC 曲线与 AUC                                          │
│                                                             │
│   回归评估:                                                 │
│   ✓ MSE、RMSE、MAE、R²                                      │
│                                                             │
│   超参数调优:                                               │
│   ✓ 网格搜索:穷举所有组合                                  │
│   ✓ 随机搜索:随机采样、效率高                              │
│                                                             │
│   模型保存:                                                 │
│   ✓ joblib.dump() / joblib.load()                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘