1 要点

题目:Top-down attention-based multiple instance learning for whole slide image analysis

代码:https://github.com/agentdr1/TDA_MIL

研究动机:
目前,MIL是WSI分析中的主流框架,但仍面临两大挑战:

  1. 实例注意力的局限性:传统的实例级注意力只关注单个patch的重要性,难以捕获它们之间的上下文关系;
  2. 自注意力的不足:虽然自注意力能够建模全局依赖,但它学习到的往往是任务无关的普遍特征,而非与特定病理任务(如肿瘤分级或分子状态预测)相关的区域。

研究目的:
临床上,病理医生的诊断通常遵循先整体再聚焦的认知过程,即先观察全局组织形态,再聚焦任务相关区域。本文受到这种自上而下注意力机制的启发,提出一种能够自动执行类似诊断流程的 MIL 框架。

  1. 第一阶段:建模全局上下文信息;
  2. 第二阶段:基于任务信号聚焦于关键区域;

关键技术方法:

  1. 双阶段推理结构
    • 自底向上推理:利用自注意力在所有patch间建立上下文联系,以得到初步的全局表示;
    • 自顶向下推理:基于任务相关性筛选出最关键的patch,并将这些任务相关特征重新注入自注意力层中,实现再聚焦式信息整合;
  2. 特征选择模块
    • 任务相关性计算:计算每个patch与可学习任务toekn T T T的余弦相似度,得到其与任务的相关性分数;
    • 通道重缩放:利用线性变换矩阵 C C C对选中的特征通道加权,强调与任务相关的语义维度。输出的patch权重被用于筛除不相关区域,实现模型的可解释特征聚焦。
  3. 自顶向下注意力注入:将筛选后的任务相关特征加入到自注意力的Value向量中,这样模型在第二次注意力聚合时,重点关注任务特征强化的区域,实现更精准的分类。

数据集

数据集名称 样本数量(N) 描述
CAMELYON17 500(182阳性 / 318阴性) 乳腺癌淋巴结转移检测
TCGA-CRCCPTAC-COADTCGA-STADTCGA-UCEC CRC: 457、COAD: 105、STAD: 361、UCEC: 545 结直肠癌、胃癌、子宫癌MSI状态预测
TCGA-BRCABCNB BRCA: 693、BCNB: 1,058 乳腺癌HER2分子状态分类

注:大概看看就行,作为对比算法即可

2 方法

图1中展示了TDA-MIL的整体流程。该框架由两个主要阶段组成:

  1. 特征压缩阶段;
  2. 特征聚合阶段:结合任务特定的自顶向下注意力策略以及特征选择模块。
图1:TDA-MIL总览

2.1 特征压缩阶段

在标准预处理流程之后,首先对WSI进行背景去除。具体地,利用Otsu阈值法分割出组织区域,然后在20×放大倍率下将切片划分为若干不重叠的小图像块 p i ∈ R 512 × 512 × 3 p_i \in \mathbb{R}^{512 \times 512 \times 3} piR512×512×3。随后,每个图像块都会输入到病理学基础模型中进行编码,并以离线方式提取patch级特征嵌入。经过这一阶段,WSI被压缩为一个可处理的patch特征集合 { x i } i = 1 n ∈ R n × D \{x_i\}_{i=1}^{n} \in \mathbb{R}^{n \times D} {xi}i=1nRn×D,其中 D D D表示基础模型输出的特征维度。

2.2 TDA-MIL 模型结构**

在在线阶段,TDA-MIL通过两个连续的推理步骤来处理这些特征:

  1. 给定特征序列 { x i } i = 1 n \{x_i\}_{i=1}^{n} {xi}i=1n,首先将每个特征从维度 D D D映射到较低的潜在维度 d d d,这一过程通过一个全连接层实现;
  2. 在序列前拼接一个分类token CLS ∈ R 1 × D \text{CLS} \in \mathbb{R}^{1 \times D} CLSR1×D。在后续计算中,CLS 被视为与其他 token一致的输入,为简化符号仍记整个序列长度为 n n n。此外,该输入序列被记为自底向上序列 { x i , B U } i = 1 n \{x_{i,BU}\}_{i=1}^{n} {xi,BU}i=1n

2.2.1 推理步骤I

该阶段使用多层自注意力模块对输入序列进行上下文建模:

  1. 自注意力定义为:
    S A ( Q , K , V ) = softmax ( Q K T d k ) V , (1) SA(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V, \tag{1} SA(Q,K,V)=softmax(dk QKT)V,(1)其中 Q , K , V Q, K, V Q,K,V分别表示查询、键,以及值,并由输入特征 x x x线性映射得到:
    Q = W Q ⋅ x , K = W K ⋅ x , V = W V ⋅ x , (2) Q = W_Q \cdot x, \quad K = W_K \cdot x, \quad V = W_V \cdot x, \tag{2} Q=WQx,K=WKx,V=WVx,(2)其中 W Q , W K ∈ R d × d k , W V ∈ R d × d v W_Q, W_K \in \mathbb{R}^{d \times d_k}, W_V \in \mathbb{R}^{d \times d_v} WQ,WKRd×dk,WVRd×dv为可学习参数。
  2. 多头自注意力 h h h个并行头中执行注意力计算,然后拼接结果并进行线性投影:
    M S A = concat ( head 1 , . . . , head h ) ⋅ W O , MSA = \text{concat}(\text{head}_1, ..., \text{head}_h) \cdot W_O, MSA=concat(head1,...,headh)WO, head j = S A ( Q ( j ) , K ( j ) , V ( j ) ) , j ∈ 1 , . . . , h , \text{head}_j = SA(Q^{(j)}, K^{(j)}, V^{(j)}), \quad j \in {1, ..., h}, headj=SA(Q(j),K(j),V(j)),j1,...,h,其中 W O ∈ R h d v × d W_O \in \mathbb{R}^{hd_v \times d} WORhdv×d也是可学习参数。
  3. 每一层自注意力结构包括以下顺序:
  • 层归一化;
  • 注意力模块;
  • 多层感知机。

2.2.2 特征选择模块

该模块对输出序列${x_i}_{i=1}^{n}$进行任务相关性筛选通道重标定,如图2。该模块的主要作用是:

  1. 在patch维度,即 n n n方向上选择最相关的图像块;
  2. 在通道维度,即 d d d方向上重新加权特征通道。

具体步骤如下:

  1. 计算任务相关性得分:将每个patch特征与一个可学习的任务相关性token T ∈ R d T \in \mathbb{R}^d TRd相乘得到余弦相似度,并以此作为任务相关性权重
    x ^ i , B U = clamp ( sim ( x i , B U , T ) ) , (3) \hat{x}_{i,BU} = \text{clamp}(\text{sim}(x_{i,BU}, T)), \tag{3} x^i,BU=clamp(sim(xi,BU,T)),(3)其中 clamp ( ⋅ ) \text{clamp}(\cdot) clamp()将结果限制在 [ 0 , 1 ] [0, 1] [0,1]区间内;
  1. 特征通道重缩放:对每个patch特征执行线性通道加权:
    x i , T D = C ⋅ x ^ i , B U ⋅ x i , B U , i ∈ 1 , . . . , n , (4) x_{i,TD} = C \cdot \hat{x}_{i,BU} \cdot x_{i,BU}, \quad i \in {1, ..., n}, \tag{4} xi,TD=Cx^i,BUxi,BU,i1,...,n,(4)其中 C ∈ R d × d C \in \mathbb{R}^{d \times d} CRd×d为可学习的通道变换矩阵。如图2所示,参数 T T T充当任务嵌入,以加权方式过滤掉不相关的图像块,矩阵 C C C则在通道级别上进行任务特征强化。
  2. 一个多层感知机对得到的特征序列 x i , T D i = 1 n {x_{i,TD}}_{i=1}^{n} xi,TDi=1n进行解码,并将其输入到下一阶段的推理过程。

2.2.3 推理步骤II

该阶段用于将选择的任务相关patch x i , T D x_{i,TD} xi,TD重新输入到自注意力模块,以执行第二次推理:

  1. 将任务特征输入到自注意力结构的Value向量中:
    V = W V ⋅ ( x B U + x T D ) , V = W_V \cdot (x_{BU} + x_{TD}), V=WV(xBU+xTD),其中 x B U x_{BU} xBU表示第一阶段的自底向上序列。这使模型在第二次注意力计算时,能够聚焦于经任务特征增强的区域,从而更有效地学习任务特定信号。
  2. 完成自注意力操作后,分类token将被输入到一个全连接层中,映射为最终的类别预测结果
    r = W F C ⋅ C L S ∈ R c , r = W_{FC} \cdot CLS \in \mathbb{R}^c, r=WFCCLSRc,其中 c c c为类别数。
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐