论文136:Top-down attention-based multiple instance learning for whole slide image analysis (MICCAI‘25)
临床上,病理医生的诊断通常遵循先整体再聚焦的认知过程,即先观察全局组织形态,再聚焦任务相关区域。本文受到这种自上而下注意力机制的启发,提出一种能够自动执行类似诊断流程的 MIL 框架。
1 要点
题目:Top-down attention-based multiple instance learning for whole slide image analysis
代码:https://github.com/agentdr1/TDA_MIL
研究动机:
目前,MIL是WSI分析中的主流框架,但仍面临两大挑战:
- 实例注意力的局限性:传统的实例级注意力只关注单个patch的重要性,难以捕获它们之间的上下文关系;
- 自注意力的不足:虽然自注意力能够建模全局依赖,但它学习到的往往是任务无关的普遍特征,而非与特定病理任务(如肿瘤分级或分子状态预测)相关的区域。
研究目的:
临床上,病理医生的诊断通常遵循先整体再聚焦的认知过程,即先观察全局组织形态,再聚焦任务相关区域。本文受到这种自上而下注意力机制的启发,提出一种能够自动执行类似诊断流程的 MIL 框架。
- 第一阶段:建模全局上下文信息;
- 第二阶段:基于任务信号聚焦于关键区域;
关键技术方法:
- 双阶段推理结构:
- 自底向上推理:利用自注意力在所有patch间建立上下文联系,以得到初步的全局表示;
- 自顶向下推理:基于任务相关性筛选出最关键的patch,并将这些任务相关特征重新注入自注意力层中,实现再聚焦式信息整合;
- 特征选择模块:
- 任务相关性计算:计算每个patch与可学习任务toekn T T T的余弦相似度,得到其与任务的相关性分数;
- 通道重缩放:利用线性变换矩阵 C C C对选中的特征通道加权,强调与任务相关的语义维度。输出的patch权重被用于筛除不相关区域,实现模型的可解释特征聚焦。
- 自顶向下注意力注入:将筛选后的任务相关特征加入到自注意力的Value向量中,这样模型在第二次注意力聚合时,重点关注任务特征强化的区域,实现更精准的分类。
数据集:
| 数据集名称 | 样本数量(N) | 描述 |
|---|---|---|
| CAMELYON17 | 500(182阳性 / 318阴性) | 乳腺癌淋巴结转移检测 |
| TCGA-CRC、CPTAC-COAD、TCGA-STAD、TCGA-UCEC | CRC: 457、COAD: 105、STAD: 361、UCEC: 545 | 结直肠癌、胃癌、子宫癌MSI状态预测 |
| TCGA-BRCA、BCNB | BRCA: 693、BCNB: 1,058 | 乳腺癌HER2分子状态分类 |
注:大概看看就行,作为对比算法即可
2 方法
图1中展示了TDA-MIL的整体流程。该框架由两个主要阶段组成:
- 特征压缩阶段;
- 特征聚合阶段:结合任务特定的自顶向下注意力策略以及特征选择模块。

2.1 特征压缩阶段
在标准预处理流程之后,首先对WSI进行背景去除。具体地,利用Otsu阈值法分割出组织区域,然后在20×放大倍率下将切片划分为若干不重叠的小图像块 p i ∈ R 512 × 512 × 3 p_i \in \mathbb{R}^{512 \times 512 \times 3} pi∈R512×512×3。随后,每个图像块都会输入到病理学基础模型中进行编码,并以离线方式提取patch级特征嵌入。经过这一阶段,WSI被压缩为一个可处理的patch特征集合 { x i } i = 1 n ∈ R n × D \{x_i\}_{i=1}^{n} \in \mathbb{R}^{n \times D} {xi}i=1n∈Rn×D,其中 D D D表示基础模型输出的特征维度。
2.2 TDA-MIL 模型结构**
在在线阶段,TDA-MIL通过两个连续的推理步骤来处理这些特征:
- 给定特征序列 { x i } i = 1 n \{x_i\}_{i=1}^{n} {xi}i=1n,首先将每个特征从维度 D D D映射到较低的潜在维度 d d d,这一过程通过一个全连接层实现;
- 在序列前拼接一个分类token CLS ∈ R 1 × D \text{CLS} \in \mathbb{R}^{1 \times D} CLS∈R1×D。在后续计算中,CLS 被视为与其他 token一致的输入,为简化符号仍记整个序列长度为 n n n。此外,该输入序列被记为自底向上序列 { x i , B U } i = 1 n \{x_{i,BU}\}_{i=1}^{n} {xi,BU}i=1n。
2.2.1 推理步骤I
该阶段使用多层自注意力模块对输入序列进行上下文建模:
- 自注意力定义为:
S A ( Q , K , V ) = softmax ( Q K T d k ) V , (1) SA(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V, \tag{1} SA(Q,K,V)=softmax(dkQKT)V,(1)其中 Q , K , V Q, K, V Q,K,V分别表示查询、键,以及值,并由输入特征 x x x线性映射得到:
Q = W Q ⋅ x , K = W K ⋅ x , V = W V ⋅ x , (2) Q = W_Q \cdot x, \quad K = W_K \cdot x, \quad V = W_V \cdot x, \tag{2} Q=WQ⋅x,K=WK⋅x,V=WV⋅x,(2)其中 W Q , W K ∈ R d × d k , W V ∈ R d × d v W_Q, W_K \in \mathbb{R}^{d \times d_k}, W_V \in \mathbb{R}^{d \times d_v} WQ,WK∈Rd×dk,WV∈Rd×dv为可学习参数。 - 多头自注意力在 h h h个并行头中执行注意力计算,然后拼接结果并进行线性投影:
M S A = concat ( head 1 , . . . , head h ) ⋅ W O , MSA = \text{concat}(\text{head}_1, ..., \text{head}_h) \cdot W_O, MSA=concat(head1,...,headh)⋅WO, head j = S A ( Q ( j ) , K ( j ) , V ( j ) ) , j ∈ 1 , . . . , h , \text{head}_j = SA(Q^{(j)}, K^{(j)}, V^{(j)}), \quad j \in {1, ..., h}, headj=SA(Q(j),K(j),V(j)),j∈1,...,h,其中 W O ∈ R h d v × d W_O \in \mathbb{R}^{hd_v \times d} WO∈Rhdv×d也是可学习参数。 - 每一层自注意力结构包括以下顺序:
- 层归一化;
- 注意力模块;
- 多层感知机。
2.2.2 特征选择模块
该模块对输出序列${x_i}_{i=1}^{n}$进行任务相关性筛选和通道重标定,如图2。该模块的主要作用是:
- 在patch维度,即 n n n方向上选择最相关的图像块;
- 在通道维度,即 d d d方向上重新加权特征通道。

具体步骤如下:
- 计算任务相关性得分:将每个patch特征与一个可学习的任务相关性token T ∈ R d T \in \mathbb{R}^d T∈Rd相乘得到余弦相似度,并以此作为任务相关性权重:
x ^ i , B U = clamp ( sim ( x i , B U , T ) ) , (3) \hat{x}_{i,BU} = \text{clamp}(\text{sim}(x_{i,BU}, T)), \tag{3} x^i,BU=clamp(sim(xi,BU,T)),(3)其中 clamp ( ⋅ ) \text{clamp}(\cdot) clamp(⋅)将结果限制在 [ 0 , 1 ] [0, 1] [0,1]区间内;
- 特征通道重缩放:对每个patch特征执行线性通道加权:
x i , T D = C ⋅ x ^ i , B U ⋅ x i , B U , i ∈ 1 , . . . , n , (4) x_{i,TD} = C \cdot \hat{x}_{i,BU} \cdot x_{i,BU}, \quad i \in {1, ..., n}, \tag{4} xi,TD=C⋅x^i,BU⋅xi,BU,i∈1,...,n,(4)其中 C ∈ R d × d C \in \mathbb{R}^{d \times d} C∈Rd×d为可学习的通道变换矩阵。如图2所示,参数 T T T充当任务嵌入,以加权方式过滤掉不相关的图像块,矩阵 C C C则在通道级别上进行任务特征强化。 - 一个多层感知机对得到的特征序列 x i , T D i = 1 n {x_{i,TD}}_{i=1}^{n} xi,TDi=1n进行解码,并将其输入到下一阶段的推理过程。
2.2.3 推理步骤II
该阶段用于将选择的任务相关patch x i , T D x_{i,TD} xi,TD重新输入到自注意力模块,以执行第二次推理:
- 将任务特征输入到自注意力结构的Value向量中:
V = W V ⋅ ( x B U + x T D ) , V = W_V \cdot (x_{BU} + x_{TD}), V=WV⋅(xBU+xTD),其中 x B U x_{BU} xBU表示第一阶段的自底向上序列。这使模型在第二次注意力计算时,能够聚焦于经任务特征增强的区域,从而更有效地学习任务特定信号。 - 完成自注意力操作后,分类token将被输入到一个全连接层中,映射为最终的类别预测结果:
r = W F C ⋅ C L S ∈ R c , r = W_{FC} \cdot CLS \in \mathbb{R}^c, r=WFC⋅CLS∈Rc,其中 c c c为类别数。
更多推荐


所有评论(0)