《Python AI入门》第5章 分类的智慧——逻辑回归与泰坦尼克号生存预测
本章我们将挑战数据科学领域最著名的入门竞赛——泰坦尼克号生存预测。我们将穿越回1912年,通过分析乘客的名单,训练一个AI模型来寻找逃生背后的规律。在这个过程中,你将掌握处理“非数值型数据”的技巧,并学会如何评估一个分类模型的优劣
章节导语
“如果说上一章的线性回归是在教计算机‘算命’(预测房价是多少),那么本章我们要教计算机‘做选择题’(预测是A还是B)。”
在现实世界中,我们面临的很多问题并不是“有多少”,而是“是不是”。
-
这封邮件是垃圾邮件还是正常邮件?
-
这张CT片子是良性还是恶性?
-
这个用户点击广告的概率是高还是低?
这就是机器学习的另一大支柱——分类(Classification)。
本章我们将挑战数据科学领域最著名的入门竞赛——泰坦尼克号生存预测。我们将穿越回1912年,通过分析乘客的名单,训练一个AI模型来寻找逃生背后的规律。在这个过程中,你将掌握处理“非数值型数据”的技巧,并学会如何评估一个分类模型的优劣。
5.1 学习目标
在学完本章后,你将能够:
-
理解分类与回归的区别:明白为什么不能用线性回归来做分类。
-
掌握逻辑回归(Logistic Regression):理解Sigmoid函数如何将输出压缩到0和1之间(概率)。
-
处理类别特征:学会使用独热编码(One-Hot Encoding)将“男/女”这样的文字转化为计算机能懂的数字。
-
工程思维进阶:理解数据清洗中对缺失值(Missing Values)的填充策略。
-
多维评估:不再只看“准确率(Accuracy)”,学会看混淆矩阵(Confusion Matrix)。
-
实战落地:构建一个预测泰坦尼克号乘客生存概率的完整模型。
5.2 从预测数值到预测概率
5.2.1 为什么要用“逻辑回归”?
首先要澄清一个名字上的误会:逻辑回归(Logistic Regression)虽然名字里有“回归”二字,但它实际上是一个标准的“分类算法”。
在第4章中,线性回归输出的是一个具体的数字(比如房价500万)。如果你用它来预测“是否患癌”(0代表健康,1代表患癌),它可能会算出一个0.8,也可能算出一个1.5,甚至负数。这让我们很难办:1.5代表什么?超级患癌?
我们需要一个函数,能把任意大小的输入,强行压缩到 0 到 1 之间。这样,输出结果就变成了“概率”。 比如:输出0.8,代表有80%的概率是1类。
5.2.2 神奇的S型曲线:Sigmoid函数
逻辑回归的核心就在于这个数学魔法——Sigmoid函数。
别被公式吓到。你只需要记住它的形状:像一个拉长的字母 "S"。
-
不管输入
是几千还是几万,输出
永远无限逼近1。
-
不管输入
是负几千,输出
永远无限逼近0。
-
当输入
时,输出
(这就是我们的决策边界)。
机器的决策逻辑:
-
模型算出概率
。
-
如果
,判断为“是”(1)。
-
如果
,判断为“否”(0)。
5.3 实战案例:泰坦尼克号生存预测
这个案例是Kaggle(全球最大的数据科学竞赛平台)的“Hello World”项目。我们的任务是根据乘客的年龄、性别、船舱等级等信息,预测他们是否能在沉船事故中幸存。
5.3.1 第一步:加载数据与初步侦察
为了方便起见,我们直接使用 Seaborn 内置的泰坦尼克数据集(这与 Kaggle 的数据结构一致)。
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
# 1. 加载数据
# 如果下载慢,可配置代理或手动下载 csv 读取
df = sns.load_dataset('titanic')
# 2. 查看前几行
print("--- 乘客名单预览 ---")
print(df.head())
# 3. 检查数据完整性
print("\n--- 数据体检 ---")
print(df.info())
【数据字典解读】
-
survived: 0=遇难, 1=生还(这是我们要预测的标签 y) -
pclass: 船舱等级 (1=头等舱, 2=二等舱, 3=三等舱) -
sex: 性别 -
age: 年龄 -
sibsp: 兄弟姐妹/配偶数量 -
parch: 父母/子女数量 -
fare: 票价 -
embarked: 登船港口 (S, C, Q)
初步发现的问题: 运行 df.info() 后,你会发现 age 列只有约700个非空值(总共891行),这意味着有近200个人的年龄是未知的。此外,deck(甲板号)缺失极其严重。
5.3.2 第二步:探索性数据分析 (EDA)
在做模型前,我们先用统计图来验证一下电影里的桥段:“女士和孩子优先”。
# 设置绘图风格
sns.set_style("whitegrid")
# 1. 看看性别对生存率的影响
plt.figure(figsize=(6, 4))
sns.barplot(data=df, x='sex', y='survived', palette='pastel')
plt.title("性别与生存率的关系")
plt.show()
# 2. 看看船舱等级对生存率的影响
plt.figure(figsize=(6, 4))
sns.barplot(data=df, x='pclass', y='survived', palette='muted')
plt.title("船舱等级与生存率的关系")
plt.show()
洞察:
-
女性的生存率(约74%)远高于男性(约18%)。
-
头等舱的生存率(约60%)远高于三等舱(约24%)。
-
结论:
sex和pclass是极其重要的特征,模型必须包含它们。
5.3.3 第三步:数据清洗与特征工程
这是本章最核心的工程化环节。计算机看不懂 "male", "female" 这种单词,也处理不了空值 NaN。
# --- 1. 处理缺失值 ---
# 策略:年龄缺失,用全船人的平均年龄(或中位数)填充
# 现实中可能用更复杂的方法,但作为基准模型,均值填充足够了
df['age'].fillna(df['age'].median(), inplace=True)
# 策略:登船港口缺失很少(只有2个),直接填充众数(出现最多的港口)
df['embarked'].fillna(df['embarked'].mode()[0], inplace=True)
# 策略:deck(甲板)缺失太多,直接删掉这一列,不参与训练
# embark_town, alive 等列是重复信息,也删掉
drop_columns = ['deck', 'embark_town', 'alive', 'who', 'adult_male', 'class']
df.drop(columns=drop_columns, inplace=True, errors='ignore')
# --- 2. 独热编码 (One-Hot Encoding) ---
# 计算机不懂 "male"/"female"。
# 我们可以把 "sex" 列变成两列:"sex_male" (是男性吗?) 和 "sex_female" (是女性吗?)
# drop_first=True 的意思是:只保留一列。如果是男性(1),那就不是女性(0)。避免冗余。
df = pd.get_dummies(df, columns=['sex', 'embarked'], drop_first=True)
print("\n--- 清洗后的数据 ---")
print(df.head())
【专业提示】为什么不用
0, 1, 2代表港口? 如果如果你把 S港口变成1,C港口变成2,Q港口变成3,模型会误以为 C(2) 比 S(1) “大”,或者 Q(3) 是 S(1) 的三倍。但这三个港口是平等的类别关系,没有大小之分。 独热编码(One-Hot) 就是为了解决这个问题:它把每个类别都变成一个独立的开关。
5.3.4 第四步:构建并训练模型
数据准备好了,剩下的就是标准流程了(Fit-Predict)。
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
# 1. 准备数据
X = df.drop('survived', axis=1) # 特征
y = df['survived'] # 标签
# 2. 切分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. 训练模型
# max_iter=1000 是为了防止数据没归一化导致模型很难收敛(报错提示)
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)
print("模型训练完毕!")
5.3.5 第五步:模型评估——不要被准确率骗了
先看一眼基础的准确率(Accuracy)。
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
# 预测测试集
y_pred = model.predict(X_test)
# 计算准确率
acc = accuracy_score(y_test, y_pred)
print(f"模型准确率: {acc:.2f}")
通常结果在 0.80 (80%) 左右。这是不是意味着模型很完美? 别急。假设船上有100个人,只有1个人幸存。如果我的模型是个“傻子”,闭着眼全部预测“死亡”,它的准确率依然高达99%!但它没能找出那个幸存者。
对于分类问题,我们更关心混淆矩阵(Confusion Matrix)。
# 绘制混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('混淆矩阵')
plt.ylabel('真实标签')
plt.xlabel('预测标签')
plt.show()
如何看图:
-
左上 (TN):真实死亡,预测死亡(预测对了)。
-
右下 (TP):真实存活,预测存活(预测对了)。
-
右上 (FP):真实死亡,却预测存活(误报)。
-
左下 (FN):真实存活,却预测死亡(漏报)。
【小白避坑】 在医疗AI(预测癌症)中,我们最怕FN(漏报)——病人有病你却说没事。 在垃圾邮件识别中,我们最怕FP(误报)——把老板的重要邮件当成垃圾扔了。 根据业务场景的不同,我们要关注不同的指标(Recall 或 Precision)。
最后,打印一份详细的体检报告:
print(classification_report(y_test, y_pred))
5.4 章节小结
本章我们从“预测数值”跨越到了“预测类别”,掌握了分类任务的核心逻辑。
-
逻辑回归:用Sigmoid函数把数值变成概率。
-
数据清洗:对于“脏”数据,我们学会了填充年龄缺失值。
-
特征工程:通过
get_dummies进行独热编码,解决了文本特征无法计算的问题。 -
模型评估:不仅看准确率,更学会了通过混淆矩阵看透模型的“偏科”问题。
工程化的思考: 我们在处理数据时,删除了 Name(姓名)列。但仔细想想,姓名真的没用吗?"Mr.", "Mrs.", "Master"(少爷)这些称谓其实隐藏了年龄和身份地位的信息。一个优秀的AI工程师,会想办法从这些看似无用的文本中提取出黄金特征。
下一章,我们将深入“特征工程”的腹地,学习如何像侦探一样,挖掘出那些让模型准确率从80%提升到90%的秘密线索。
5.5 思考与扩展练习
-
特征重要性分析: 逻辑回归模型也有系数(Coefficients)。请尝试打印
model.coef_,并结合列名X.columns,看看哪个特征的系数最大(正数代表利于生存,负数代表不利)。 提示:你可以用pd.Series(model.coef_[0], index=X.columns).sort_values().plot(kind='barh')画出来。 -
模型对比: 尝试把
LogisticRegression换成DecisionTreeClassifier(决策树分类器)。你会发现代码几乎不用变(除了初始化那一行),这就是 Scikit-learn 接口统一的魅力。看看决策树的效果会不会更好? -
生存预测器: 写一个小程序,让用户输入自己的性别、年龄、票价和舱位,调用训练好的模型,计算一下如果用户在泰坦尼克号上,生存概率是多少?
更多推荐




所有评论(0)