章节导语

“如果说上一章的线性回归是在教计算机‘算命’(预测房价是多少),那么本章我们要教计算机‘做选择题’(预测是A还是B)。”

在现实世界中,我们面临的很多问题并不是“有多少”,而是“是不是”。

  • 这封邮件垃圾邮件还是正常邮件?

  • 这张CT片子良性还是恶性?

  • 这个用户点击广告的概率还是低?

这就是机器学习的另一大支柱——分类(Classification)

本章我们将挑战数据科学领域最著名的入门竞赛——泰坦尼克号生存预测。我们将穿越回1912年,通过分析乘客的名单,训练一个AI模型来寻找逃生背后的规律。在这个过程中,你将掌握处理“非数值型数据”的技巧,并学会如何评估一个分类模型的优劣。


5.1 学习目标

在学完本章后,你将能够:

  1. 理解分类与回归的区别:明白为什么不能用线性回归来做分类。

  2. 掌握逻辑回归(Logistic Regression):理解Sigmoid函数如何将输出压缩到0和1之间(概率)。

  3. 处理类别特征:学会使用独热编码(One-Hot Encoding)将“男/女”这样的文字转化为计算机能懂的数字。

  4. 工程思维进阶:理解数据清洗中对缺失值(Missing Values)的填充策略。

  5. 多维评估:不再只看“准确率(Accuracy)”,学会看混淆矩阵(Confusion Matrix)

  6. 实战落地:构建一个预测泰坦尼克号乘客生存概率的完整模型。


5.2 从预测数值到预测概率

5.2.1 为什么要用“逻辑回归”?

首先要澄清一个名字上的误会:逻辑回归(Logistic Regression)虽然名字里有“回归”二字,但它实际上是一个标准的“分类算法”。

在第4章中,线性回归输出的是一个具体的数字(比如房价500万)。如果你用它来预测“是否患癌”(0代表健康,1代表患癌),它可能会算出一个0.8,也可能算出一个1.5,甚至负数。这让我们很难办:1.5代表什么?超级患癌?

我们需要一个函数,能把任意大小的输入,强行压缩到 01 之间。这样,输出结果就变成了“概率”。 比如:输出0.8,代表有80%的概率是1类。

5.2.2 神奇的S型曲线:Sigmoid函数

逻辑回归的核心就在于这个数学魔法——Sigmoid函数

S(x) = \frac{1}{1 + e^{-x}}

别被公式吓到。你只需要记住它的形状:像一个拉长的字母 "S"

  • 不管输入 x 是几千还是几万,输出 y 永远无限逼近1。

  • 不管输入 x 是负几千,输出 y 永远无限逼近0。

  • 当输入 x=0 时,输出 y=0.5(这就是我们的决策边界)。

机器的决策逻辑:

  1. 模型算出概率p

  2. 如果 p>0.5,判断为“是”(1)。

  3. 如果 p\leq 0.5,判断为“否”(0)。


5.3 实战案例:泰坦尼克号生存预测

这个案例是Kaggle(全球最大的数据科学竞赛平台)的“Hello World”项目。我们的任务是根据乘客的年龄、性别、船舱等级等信息,预测他们是否能在沉船事故中幸存。

5.3.1 第一步:加载数据与初步侦察

为了方便起见,我们直接使用 Seaborn 内置的泰坦尼克数据集(这与 Kaggle 的数据结构一致)。

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

# 1. 加载数据
# 如果下载慢,可配置代理或手动下载 csv 读取
df = sns.load_dataset('titanic')

# 2. 查看前几行
print("--- 乘客名单预览 ---")
print(df.head())

# 3. 检查数据完整性
print("\n--- 数据体检 ---")
print(df.info())

【数据字典解读】

  • survived: 0=遇难, 1=生还(这是我们要预测的标签 y

  • pclass: 船舱等级 (1=头等舱, 2=二等舱, 3=三等舱)

  • sex: 性别

  • age: 年龄

  • sibsp: 兄弟姐妹/配偶数量

  • parch: 父母/子女数量

  • fare: 票价

  • embarked: 登船港口 (S, C, Q)

初步发现的问题: 运行 df.info() 后,你会发现 age 列只有约700个非空值(总共891行),这意味着有近200个人的年龄是未知的。此外,deck(甲板号)缺失极其严重。

5.3.2 第二步:探索性数据分析 (EDA)

在做模型前,我们先用统计图来验证一下电影里的桥段:“女士和孩子优先”。

# 设置绘图风格
sns.set_style("whitegrid")
​
# 1. 看看性别对生存率的影响
plt.figure(figsize=(6, 4))
sns.barplot(data=df, x='sex', y='survived', palette='pastel')
plt.title("性别与生存率的关系")
plt.show()
​
# 2. 看看船舱等级对生存率的影响
plt.figure(figsize=(6, 4))
sns.barplot(data=df, x='pclass', y='survived', palette='muted')
plt.title("船舱等级与生存率的关系")
plt.show()

洞察

  • 女性的生存率(约74%)远高于男性(约18%)。

  • 头等舱的生存率(约60%)远高于三等舱(约24%)。

  • 结论sexpclass 是极其重要的特征,模型必须包含它们。

5.3.3 第三步:数据清洗与特征工程

这是本章最核心的工程化环节。计算机看不懂 "male", "female" 这种单词,也处理不了空值 NaN

# --- 1. 处理缺失值 ---
​
# 策略:年龄缺失,用全船人的平均年龄(或中位数)填充
# 现实中可能用更复杂的方法,但作为基准模型,均值填充足够了
df['age'].fillna(df['age'].median(), inplace=True)
​
# 策略:登船港口缺失很少(只有2个),直接填充众数(出现最多的港口)
df['embarked'].fillna(df['embarked'].mode()[0], inplace=True)
​
# 策略:deck(甲板)缺失太多,直接删掉这一列,不参与训练
# embark_town, alive 等列是重复信息,也删掉
drop_columns = ['deck', 'embark_town', 'alive', 'who', 'adult_male', 'class']
df.drop(columns=drop_columns, inplace=True, errors='ignore')
​
# --- 2. 独热编码 (One-Hot Encoding) ---
​
# 计算机不懂 "male"/"female"。
# 我们可以把 "sex" 列变成两列:"sex_male" (是男性吗?) 和 "sex_female" (是女性吗?)
# drop_first=True 的意思是:只保留一列。如果是男性(1),那就不是女性(0)。避免冗余。
​
df = pd.get_dummies(df, columns=['sex', 'embarked'], drop_first=True)
​
print("\n--- 清洗后的数据 ---")
print(df.head())

【专业提示】为什么不用 0, 1, 2 代表港口? 如果如果你把 S港口变成1,C港口变成2,Q港口变成3,模型会误以为 C(2) 比 S(1) “大”,或者 Q(3) 是 S(1) 的三倍。但这三个港口是平等的类别关系,没有大小之分。 独热编码(One-Hot) 就是为了解决这个问题:它把每个类别都变成一个独立的开关。

5.3.4 第四步:构建并训练模型

数据准备好了,剩下的就是标准流程了(Fit-Predict)。

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
​
# 1. 准备数据
X = df.drop('survived', axis=1) # 特征
y = df['survived']              # 标签
​
# 2. 切分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
​
# 3. 训练模型
# max_iter=1000 是为了防止数据没归一化导致模型很难收敛(报错提示)
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)
​
print("模型训练完毕!")

5.3.5 第五步:模型评估——不要被准确率骗了

先看一眼基础的准确率(Accuracy)。

from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
​
# 预测测试集
y_pred = model.predict(X_test)
​
# 计算准确率
acc = accuracy_score(y_test, y_pred)
print(f"模型准确率: {acc:.2f}")

通常结果在 0.80 (80%) 左右。这是不是意味着模型很完美? 别急。假设船上有100个人,只有1个人幸存。如果我的模型是个“傻子”,闭着眼全部预测“死亡”,它的准确率依然高达99%!但它没能找出那个幸存者。

对于分类问题,我们更关心混淆矩阵(Confusion Matrix)

# 绘制混淆矩阵
cm = confusion_matrix(y_test, y_pred)
​
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('混淆矩阵')
plt.ylabel('真实标签')
plt.xlabel('预测标签')
plt.show()

如何看图:

  • 左上 (TN):真实死亡,预测死亡(预测对了)。

  • 右下 (TP):真实存活,预测存活(预测对了)。

  • 右上 (FP):真实死亡,却预测存活(误报)。

  • 左下 (FN):真实存活,却预测死亡(漏报)。

【小白避坑】 在医疗AI(预测癌症)中,我们最怕FN(漏报)——病人有病你却说没事。 在垃圾邮件识别中,我们最怕FP(误报)——把老板的重要邮件当成垃圾扔了。 根据业务场景的不同,我们要关注不同的指标(Recall 或 Precision)。

最后,打印一份详细的体检报告:

print(classification_report(y_test, y_pred))

5.4 章节小结

本章我们从“预测数值”跨越到了“预测类别”,掌握了分类任务的核心逻辑。

  1. 逻辑回归:用Sigmoid函数把数值变成概率。

  2. 数据清洗:对于“脏”数据,我们学会了填充年龄缺失值。

  3. 特征工程:通过get_dummies进行独热编码,解决了文本特征无法计算的问题。

  4. 模型评估:不仅看准确率,更学会了通过混淆矩阵看透模型的“偏科”问题。

工程化的思考: 我们在处理数据时,删除了 Name(姓名)列。但仔细想想,姓名真的没用吗?"Mr.", "Mrs.", "Master"(少爷)这些称谓其实隐藏了年龄和身份地位的信息。一个优秀的AI工程师,会想办法从这些看似无用的文本中提取出黄金特征。

下一章,我们将深入“特征工程”的腹地,学习如何像侦探一样,挖掘出那些让模型准确率从80%提升到90%的秘密线索。


5.5 思考与扩展练习

  1. 特征重要性分析: 逻辑回归模型也有系数(Coefficients)。请尝试打印 model.coef_,并结合列名 X.columns,看看哪个特征的系数最大(正数代表利于生存,负数代表不利)。 提示:你可以用 pd.Series(model.coef_[0], index=X.columns).sort_values().plot(kind='barh') 画出来。

  2. 模型对比: 尝试把 LogisticRegression 换成 DecisionTreeClassifier(决策树分类器)。你会发现代码几乎不用变(除了初始化那一行),这就是 Scikit-learn 接口统一的魅力。看看决策树的效果会不会更好?

  3. 生存预测器: 写一个小程序,让用户输入自己的性别、年龄、票价和舱位,调用训练好的模型,计算一下如果用户在泰坦尼克号上,生存概率是多少?

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐