目的

看到transfomers关于attn_mask给出了一个新的源文件,里面包含了创建4d_causal_attention源码,那是怎么实现的呢,一起来看一下吧。

源码

from dataclasses import dataclass
import torch
from typing import List, Optional, Tuple, Union

dtype = torch.float32

@dataclass
class AttentionMaskConverter:
    """
    一个实用的注意力掩码类,允许你
        - 创建一个causal 4d mask
        - 创建一个带有滑动窗口的causal 4d mask
        - 将2d attention mask(batch_size, query_length)转换成可以与attention scores相乘的4d attention mask(batch_size, 1, query_length,key_value_length)
    input:
        is_causal(bool) 注意力掩码是单向(causal)还是双向
        sliding_window(int) 默认为None,如果这个参数为正整数,则可以创建滑动窗口掩码
    demo:
    ```python
    import torch
    converter = AttentionMaskConverter(True)
    converter
    # AttentionMaskConverter(is_causal=True, sliding_window=None)
    converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
    # tensor([[[[-, -, -, -, -],
    #           [-, -, -, -, -],
    #           [-, -, -, -, -],
    #           [-, -, -,  0, -],
    #           [-, -, -,  0,  0]]]])
    ```
    """
    is_causal: bool
    sliding_window: int
    def __init__(self, is_causal, sliding_window=None):
        self.is_causal = is_causal
        self.sliding_window = sliding_window# 如果有,只能是非负值
    @staticmethod
    def _make_causal_mask(input_ids_shape, past_key_values_length=0, sliding_window=None):
        """
        为双向自注意力制作causal mask
        input:
            input_ids_shape 输入尺寸
            past_key_values_length 已有的key_values长度,默认为0
            sliding_window 滑动窗口的大小,默认为空
        output:
            单向注意力的causal mask
        demo:# 更多的demo在
            converter = AttentionMaskConverter(True)
            input_ids_shape = (1, 4)
            mask = converter._make_causal_mask(input_ids_shape)             
        """
        bsz, tgt_len = input_ids_shape
        mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)# (tgt_len,tgt_len) 无穷小
        mask_cond = torch.arange(mask.size(-1))# tensor([0, 1, 2, 3])
        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
        # mask_cond < (mask_cond + 1).view(mask.size(-1), 1) 建立一个下三角为True,其他地方为False的(tgt_len,tgt_len)矩阵
        # tensor([[ True, False, False, False],
        #         [ True,  True, False, False],
        #         [ True,  True,  True, False],
        #         [ True,  True,  True,  True]])
        # mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 根据mask_cond填充mask,True为0,False为原值无穷小
        # tensor([[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        #         [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
        #         [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
        #         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]])
        mask = mask.to(dtype)
        if past_key_values_length > 0:
            mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
            # tmp = torch.zeros(tgt_len, past_key_values_length, dtype=dtype) 建立一个tgt_len, past_key_values_length的全0矩阵
            # mask 为将[tmp, mask] 在最后一维上拼接,生成的矩阵shape为 tgt_len, past_key_values_length + tgt_len
        # add lower triangular sliding window mask if necessary
        if sliding_window is not None:
            diagonal = past_key_values_length - sliding_window - 1# 对角线长度
            context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
            # torch.ones_like(mask, dtype=torch.bool) 全true的(tgt_len, past_key_values_length + tgt_len)矩阵
            # context_mask 从diagonal开始画对角线
            mask.masked_fill_(context_mask, torch.finfo(dtype).min)
            # 将context_mask中为True的mask内容改写为无穷小
        # 最后返回的是(bsz, 1, tgt_len, tgt_len + past_key_values_length)
        return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
    @staticmethod
    def _expand_mask(mask, tgt_len=None):
        """
        Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
        将attention_mask从[bsz, seq_len]expand到[bsz, 1, tgt_seq_len, src_seq_len],且将0的位置改为负无穷,1的位置改为0
        input:
            mask 待转换的attention_mask
            tgt_len 默认为None,对应mask.size(1)  目标长度,每一行要重复几遍
        demo:
            converter = AttentionMaskConverter(True)
            mask = torch.tensor([[0, 0, 0, 1, 1]])
            inverted_mask = converter._expand_mask(mask)
            # tensor([[[[-, -, -, 0, 0],
            #           [-, -, -, 0, 0],
            #           [-, -, -, 0, 0],
            #           [-, -, -, 0, 0],
            #           [-, -, -, 0, 0]]]])
        """
        bsz, src_len = mask.size()
        tgt_len = tgt_len if tgt_len is not None else src_len
        expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)# (bsz, 1, tgt_len, src_len)
        inverted_mask = 1.0 - expanded_mask
        # 将mask中0的位置改为负无穷,1的位置改为0
        return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)    
    def to_causal_4d(self, batch_size, query_length, key_value_length):
        """
        Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
        bias to upper right hand triangular matrix (causal mask).
        创建一个(bsz, head_dim=1, query_length, key_value_length)的4D causal mask
        上三角为负无穷
        input:
            batch_size int
            query_length int 剩余文本长度
            key_value_length int key长度
        output:
            单向注意力的causal mask
        demo:# 后面有详细的demo
            converter = AttentionMaskConverter(True)
            batch_size = 1
            query_length = 4
            key_value_length = 4
            mask = converter.to_causal_4d(batch_size, query_length, key_value_length)  
        """
        assert self.is_causal == True# 这个函数只支持单向注意力
        input_shape = (batch_size, query_length)# src_seq_len
        past_key_values_length = key_value_length - query_length# tgt_seq_len
        # create causal mask
        # [bsz, query_length] -> [bsz, 1, query_length, key_value_length]
        # 如果query_length=1,则直接返回None
        causal_4d_mask = None
        if input_shape[-1] > 1 or self.sliding_window is not None:
            causal_4d_mask = self._make_causal_mask(
                input_shape,
                past_key_values_length=past_key_values_length,
                sliding_window=self.sliding_window,
            )
        return causal_4d_mask
    def to_4d(self, attention_mask_2d, query_length, key_value_length=None):
        """
        Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
        key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
        causal, a causal mask will be added.
        通过expanding mask将2D attention mask转换为4D attention mask,尺寸为(bsz, head_dim=1, query_length,key_value_length),非连接的位置是负无穷
        将0的位置改为负无穷,1的位置改为0
        如果attention_mask是causal,则会添加一个causal mask        
        input:
            query_length 剩余句子的长度
            key_value_length 句子的长度,只能对应attention_mask_2d.shape[-1],不知道这个值后续是有什么考虑,我是没想明白为什么它有啥用
            attention_mask_2d 文本已经按照最大长度进行mask,对应的mask矩阵二维 且已经将p-tuning V2对应的prefix token拼接好了(如果有)
        output:
            daldjak
        demo:        
            converter = AttentionMaskConverter(True)
            mask = torch.tensor([[0, 0, 0, 1, 1]])
            inverted_mask = converter.to_4d(mask, 5, 5)
        """
        input_shape = (attention_mask_2d.shape[0], query_length)
        # create causal mask
        # [bsz, query_length] -> [bsz, 1, tgt_seq_len, query_length]
        causal_4d_mask = None
        if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
            # 单向注意力,input_shape[-1] > 1对应剩余句子长度大于1,即不是最后一个预测的token
            # 或self.sliding_window is not None 不是滑动窗口
            # key_value_length不能为空
            assert key_value_length is not None
            past_key_values_length = key_value_length - query_length
            causal_4d_mask = self._make_causal_mask(
                input_shape,
                past_key_values_length=past_key_values_length,
                sliding_window=self.sliding_window,
            )# 为单向自注意力创建causal mask 下三角矩阵 # bsz, 1, query_length, key_value_length
        # 滑动窗口只支持单向注意力
        expanded_attn_mask = self._expand_mask(attention_mask_2d, tgt_len=query_length)# 将mask中0的位置改为负无穷,1的位置改为0
        # 因为attention_mask_2d中0的位置代表mask token,所以后面的token不能看到这些token
        if causal_4d_mask is not None:
            # 将下三角矩阵中的mask token对应的位置值改写为负无穷,因为这些位置的token是没有意义的
            expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
        # expanded_attn_mask + causal_4d_mask 直接相加会导致溢出
        expanded_4d_mask = expanded_attn_mask
        return expanded_4d_mask
    @staticmethod
    def _unmask_unattended(expanded_mask, min_dtype):
        # fmt: off
        """
        https://github.com/pytorch/pytorch/issues/110213 pytorch版本变动引起的问题 
        Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
        using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
        例如使用左填充时第一行会是mask token,当使用F.scaled_dot_product_attention内存高效注意力时需要对expanded_mask做个处理
        `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
        input:
            expanded_mask [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]
            min_dtype dtype
        output:
            将前几行是mask token对应对负无穷改为1
        demo:        # 
            converter = AttentionMaskConverter(True)
            attention_mask_2d = torch.tensor([[0, 1, 1, 1], [1, 1, 1, 1]])
            query_length = 4
            causal_mask = converter.to_4d(attention_mask_2d, query_length, query_length)# batch_size, 1, query_length, query_length
            new_causal_mask = converter._unmask_unattended(causal_mask, torch.finfo(dtype).min)        
        ```
        [[[[0, 0, 0],
           [0, 0, 0],
           [0, 0, 1]]],
         [[[1, 0, 0],
           [1, 1, 0],
           [1, 1, 1]]],
         [[[0, 0, 0],
           [0, 1, 0],
           [0, 1, 1]]]]
        ```
        then the modified `expanded_mask` will be
        ```
        [[[[1, 1, 1],   <-- modified
           [1, 1, 1],   <-- modified
           [0, 0, 1]]],
         [[[1, 0, 0],
           [1, 1, 0],
           [1, 1, 1]]],
         [[[1, 1, 1],   <-- modified
           [0, 1, 0],
           [0, 1, 1]]]]
        ```
        """
        return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
    @staticmethod
    def _ignore_causal_mask_sdpa(attention_mask,inputs_embeds,past_key_values_length,sliding_window,is_training,) -> bool:
        """
        Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
        ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
        In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
        `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
        allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
        passed).
        
        检测在使用 PyTorch 的 SDPA 的情况下是否可以忽略可选的用户指定的 Attention_mask 和自动创建的因果掩码,而是依赖于 SDPA 的 `is_causal` 参数。
当`attention_mask` 参数中没有token被屏蔽时,如果 `query_length == 1` 或`key_value_length == query_length`,我们宁愿依赖 SDPA `is_causal` 参数来使用因果/非因果掩码,
允许分派到 flash 注意内核(如果传递了自定义的 `attn_mask`,则无法使用)。
        """
        _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]# batch_size, query_length
        key_value_length = query_length + past_key_values_length# key_value_length和attention_mask.shape(-1)是一致的
        is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
        ignore_causal_mask = False
        if attention_mask is None:
            # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
            # shape, thus SDPA's `is_causal` argument is rightfully updated
            # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
            # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
            # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
            # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
            # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
            #
            # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
            # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
            if (
                (is_training or not is_tracing)
                and (query_length == 1 or key_value_length == query_length)
                and (sliding_window is None or key_value_length < sliding_window)
            ):
                ignore_causal_mask = True
        elif sliding_window is None or key_value_length < sliding_window:
            if len(attention_mask.shape) == 4:
                return False
            elif not is_tracing and torch.all(attention_mask == 1):
                if query_length == 1 or key_value_length == query_length:
                    # For query_length == 1, causal attention and bi-directional attention are the same.
                    ignore_causal_mask = True# 当`attention_mask` 参数中没有token被屏蔽时,如果 `query_length == 1` 或`key_value_length == query_length`
                # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
                # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
                # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
                # Reference: https://github.com/pytorch/pytorch/issues/108108
                # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
        return ignore_causal_mask

调用to_causal_4d

# to_causal_4d
converter = AttentionMaskConverter(True)

text = ['我爱你', '你也爱我']
batch_size = len(text)# batch_size
query_length = max([len(text[i]) for i in range(len(text))])# 当前toekn对应的最大句子长度

# 正常forward(generate)形式,当前token只能看到前面的token,mask是一个下三角阵
key_value_length = query_length
causal_mask = converter.to_causal_4d(batch_size, query_length, key_value_length)# batch_size, 1, query_length, key_value_length
# [[[[0,-,-,-],[0,0,-,-],[0,0,0,-],[0,0,0,0]]],
#  [[[0,-,-,-],[0,0,-,-],[0,0,0,-],[0,0,0,0]]]]

query_length_short = 2# 剩余2个字的情况下,之前的token都能看到
causal_mask_short = converter.to_causal_4d(batch_size, query_length_short, key_value_length)# batch_size, 1, query_length, key_value_length
# [[[[0,0,0,-],[0,0,0,0]]],
#  [[[0,0,0,-],[0,0,0,0]]]]

query_length_short = 1# 剩余1个字的情况下,得到None
causal_mask_short = converter.to_causal_4d(batch_size, query_length_short, key_value_length)# batch_size, 1, query_length, key_value_length
# None

# p_tuning v2的时候,每个token前面都会带着prompts tokens作为输入,这些prompts tokens是可以被后面每个token都看到的
# mask是一个len(prompts tokens)+下三角阵
key_value_length = query_length + 2
causal_mask_prefix = converter.to_causal_4d(batch_size, query_length, key_value_length)# batch_size, 1, query_length, key_value_length
# [[[[0,0,0,-,-,-],[0,0,0,0,-,-],[0,0,0,0,0,-],[0,0,0,0,0,0]]],
#  [[[0,0,0,-,-,-],[0,0,0,0,-,-],[0,0,0,0,0,-],[0,0,0,0,0,0]]]]

# 带有滑动窗口的forward,每个token只能看到前面sliding_window个token
# mask是一个指针阵
sliding_window = 1
key_value_length = query_length + 2
converter = AttentionMaskConverter(True, sliding_window)
causal_mask_sliding_window = converter.to_causal_4d(batch_size, query_length, key_value_length)# batch_size, 1, query_length, key_value_length
# 注意这里是带有prompts tokens的,但因为带有滑动窗口,所以后面的token也是有可能看不到这些prompts tokens的
# [[[[-,0,0,-,-,-],[-,-,0,0,-,-],[-,-,-,0,0,-],[-,-,-,-,0,0]]], 
#  [[[-,0,0,-,-,-],[-,-,0,0,-,-],[-,-,-,0,0,-],[-,-,-,-,0,0]]]] 

调用to_4d

# to_4d
text = ['我爱你', '你也爱我']
batch_size = len(text)
query_length = max([len(text[i]) for i in range(len(text))])# 当前toekn对应的最大句子长度

# 需要padding的文本进行左padding,得到对应的attention_mask矩阵
attention_mask = []
for i in range(len(text)):
    attention_mask.append([0 for j in range(query_length - len(text[i]))] + [1 for j in range(len(text[i])) ])
    if len(text[i]) < query_length:
        text[i] = (query_length - len(text[i])) * '[mask]' + text[i]
        

# 单向注意力
converter = AttentionMaskConverter(True)
attention_mask_2d = torch.tensor(attention_mask)
query_length = attention_mask_2d.shape[-1] # 小于等于attention_mask_2d.shape[-1]即可
key_value_length = attention_mask_2d.shape[-1] # 因为在这个函数里会有两个mask矩阵填充的过程,为了满足两个矩阵大小一致,key_value_length不能取别的值
inverted_mask_ori = converter.to_4d(attention_mask_2d, query_length, key_value_length)# batch_size, 1, query_length, key_value_length
# [[[[-,-,-,-],[-,0,-,-],[-,0,0,-],[-,0,0,0]]],
#  [[[0,-,-,-],[0,0,-,-],[0,0,0,-],[0,0,0,0]]]]
# 注意这里第一行第一个token对应的结果都是负无穷,因为它对应的是mask token


query_length = 1# 特殊情况,这是最后一个预测的token了
inverted_mask_1 = converter.to_4d(attention_mask_2d, query_length, key_value_length)# batch_size, 1, query_length, key_value_length
# [[[[-,0,0,0]]],
#  [[[0,0,0,0]]]]
# 注意这里第一行不能看到第一个token,因为它对应的是mask token


# p-tuning v2
prefix_num = 2
prefix_mask = [[1 for i in range(prefix_num)] + attention_mask[i] for i in range(len(attention_mask))]
prefix_attention_mask_2d = torch.tensor(prefix_mask)
prefix_query_length = prefix_attention_mask_2d.shape[-1] - prefix_num
prefix_key_value_length = prefix_attention_mask_2d.shape[-1]
inverted_mask_prefix = converter.to_4d(prefix_attention_mask_2d, prefix_query_length, prefix_key_value_length)# batch_size, 1, prefix_query_length, prefix_key_value_length
# [[[[0,0,-,-,-,-],[0,0,-,0,-,-],[0,0,-,0,0,-],[0,0,-,0,0,0]]],
#  [[[0,0,0,-,-,-],[0,0,0,0,-,-],[0,0,0,0,0,-],[0,0,0,0,0,0]]]]
# 注意这里inverted_mask_prefix比inverted_mask_ori只多了两列可见的token,这里对应的是prompt token


converter = AttentionMaskConverter(False)# 双向
# 当双向注意力的时候,这个函数与key_value_length没有关系了,传不传没有影响了
query_length = 4
inverted_mask = converter.to_4d(attention_mask_2d, query_length, key_value_length)# batch_size, 1, query_length, key_value_length
# [[[[-,0,0,0],[-,0,0,0],[-,0,0,0],[-,0,0,0]]],
#  [[[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]]]
# 双向注意力,除了mask token看不到以外,其他都是可见的


_create_4d_causal_attention_mask

def _create_4d_causal_attention_mask(input_shape, past_key_values_length=0, sliding_window=None):
    """
    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
    input:
        input_shape (`tuple(int)` or `list(int)` or `torch.Size`): 定义了(batch_size, query_length)
        sliding_window (`int`, *optional*):滑动窗口大小
    """
    attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
    key_value_length = past_key_values_length + input_shape[-1]
    attention_mask = attn_mask_converter.to_causal_4d(
        input_shape[0], input_shape[-1], key_value_length
    )
    return attention_mask
text = ['我爱你', '你也爱我']
batch_size = len(text)# batch_size
query_length = max([len(text[i]) for i in range(len(text))])# 当前toekn对应的最大句子长度

# 正常forward(generate)形式,当前token只能看到前面的token,mask是一个下三角阵
query_length_normal = query_length
past_key_values_length = query_length - query_length_normal
causal_mask = _create_4d_causal_attention_mask((batch_size, query_length_normal), past_key_values_length=past_key_values_length, sliding_window=None)
# [[[[0,-,-,-],[0,0,-,-],[0,0,0,-],[0,0,0,0]]],
#  [[[0,-,-,-],[0,0,-,-],[0,0,0,-],[0,0,0,0]]]]

query_length_short = 2# 剩余2个token的情况下,之前的token都能看到
past_key_values_length = query_length - query_length_short
causal_mask_short = _create_4d_causal_attention_mask((batch_size, query_length_short), past_key_values_length=past_key_values_length, sliding_window=None)
# [[[[0,0,0,-],[0,0,0,0]]],
#  [[[0,0,0,-],[0,0,0,0]]]]


query_length_short = 1# 剩余1个token的情况下,得到None
past_key_values_length = query_length - query_length_short
causal_mask_short = _create_4d_causal_attention_mask((batch_size, query_length_short), past_key_values_length=past_key_values_length, sliding_window=None)
# None

# p_tuning v2的时候,每个token前面都会带着prompts tokens作为输入,这些prompts tokens是可以被后面每个token都看到的
# mask是一个len(prompts tokens)+下三角阵
prefix_num = 2
query_length_normal = query_length
past_key_values_length = query_length + prefix_num - query_length_normal
causal_mask_prefix = _create_4d_causal_attention_mask((batch_size, query_length_normal), past_key_values_length=past_key_values_length, sliding_window=None)
# [[[[0,0,0,-,-,-],[0,0,0,0,-,-],[0,0,0,0,0,-],[0,0,0,0,0,0]]],
#  [[[0,0,0,-,-,-],[0,0,0,0,-,-],[0,0,0,0,0,-],[0,0,0,0,0,0]]]]

# 带有滑动窗口的forward,每个token只能看到前面sliding_window个token
# mask是一个指针阵
sliding_window = 1
prefix_num = 2
query_length_normal = query_length
past_key_values_length = query_length + prefix_num - query_length_normal
causal_mask_sliding_window = _create_4d_causal_attention_mask((batch_size, query_length_normal), past_key_values_length=past_key_values_length, sliding_window=sliding_window)
# 注意这里是带有prompts tokens的,但因为带有滑动窗口,所以后面的token也是有可能看不到这些prompts tokens的
# [[[[-,0,0,-,-,-],[-,-,0,0,-,-],[-,-,-,0,0,-],[-,-,-,-,0,0]]], 
#  [[[-,0,0,-,-,-],[-,-,0,0,-,-],[-,-,-,0,0,-],[-,-,-,-,0,0]]]] 

_prepare_4d_causal_attention_mask

def _prepare_4d_causal_attention_mask(attention_mask, input_shape, past_key_values_length, sliding_window=None,):
    """
    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
    `(batch_size, key_value_length)`
    根据2D attention_mask(batch_size, key_value_length)创建一个4D causal mask(batch_size, 1, query_length, key_value_length)
    input:
        attention_mask:A 2D attention mask of shape `(batch_size, key_value_length)`
        input_shape: The input shape should be a tuple that defines `(batch_size, query_length)`.
        past_key_values_length: The length of the key value cache.
        sliding_window: If the model uses windowed attention, a sliding window should be passed.
    """
    attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
    key_value_length = input_shape[-1] + past_key_values_length # 
    # 4d mask is passed through the layers
    if attention_mask is not None and len(attention_mask.shape) == 2:# 通过to_4d将2D attention mask转换为4D attention mask,尺寸为(bsz, head_dim=1, input_shape[-1],key_value_length)
        # 将0的位置改为负无穷,1的位置改为0
        attention_mask = attn_mask_converter.to_4d(attention_mask, input_shape[-1], key_value_length=key_value_length)
    elif attention_mask is not None and len(attention_mask.shape) == 4:
        expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
        assert tuple(attention_mask.shape) == expected_shape 
        # if the 4D mask has correct shape - invert it and fill with negative infinity
        inverted_mask = 1.0 - attention_mask# _expand_mask操作, 将mask中0的位置改为负无穷,1的位置改为0
        attention_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
    else:# 正常forward(generate)形式,当前token只能看到前面的token,mask是一个下三角阵
        attention_mask = attn_mask_converter.to_causal_4d(input_shape[0], input_shape[-1], key_value_length)
    return attention_mask
text = ['我爱你', '你也爱我']
batch_size = len(text)
query_length = max([len(text[i]) for i in range(len(text))])# 当前toekn对应的最大句子长度

# 需要padding的文本进行左padding,得到对应的attention_mask矩阵
attention_mask = []
for i in range(len(text)):
    attention_mask.append([0 for j in range(query_length - len(text[i]))] + [1 for j in range(len(text[i])) ])
    if len(text[i]) < query_length:
        text[i] = (query_length - len(text[i])) * '[mask]' + text[i]
        


attention_mask_2d = torch.tensor(attention_mask)

# to_4d,结果和inverted_mask_ori是一样的
attention_mask = _prepare_4d_causal_attention_mask(
    attention_mask_2d,
    [2,4],
    torch.randn([1,4]),
    0,
    None,
)

# to_causal_4d,结果和causal_mask是一样的
attention_mask = _prepare_4d_causal_attention_mask(
    None,
    [2,4],
    torch.randn([2,4]),
    0,
    None,
)

# 这是把每一步每个token能看到的token都给准备好了
attention_mask_4d = torch.tensor([[[[0, 0, 0, 0],[0, 1, 0, 0],[0, 1, 1, 0],[0, 1, 1, 1]]],
                                  [[[1, 0, 0, 0],[1, 1, 0, 0],[1, 1, 1, 0],[1, 1, 1, 1]]]])
# to_causal_4d,结果和inverted_mask_ori是一样的
attention_mask = _prepare_4d_causal_attention_mask(
    attention_mask_4d,
    [2,4],
    torch.randn([2,4]),
    0,
    None,
)

_unmask_unattended

# https://github.com/pytorch/pytorch/issues/110213 pytorch版本变动引起的问题 

# 这是把每一步每个token能看到的token都给准备好了
attention_mask_4d = torch.tensor([[[[0, 0, 0, 0],[0, 1, 0, 0],[0, 1, 1, 0],[0, 1, 1, 1]]],
                                  [[[1, 0, 0, 0],[1, 1, 0, 0],[1, 1, 1, 0],[1, 1, 1, 1]]]])
# to_causal_4d,结果和inverted_mask_ori是一样的
attention_mask = _prepare_4d_causal_attention_mask(
    attention_mask_4d,
    [2,4],
    torch.randn([2,4]),
    0,
    None,
)

converter = AttentionMaskConverter(is_causal=True)
attention_mask_new = converter._unmask_unattended(attention_mask, torch.finfo(dtype).min)
# 这个函数将第一行mask token对应的位置改为了0
# [[[[-0,-0,-0,-0],[-,0,-,-],[-,0,0,-],[-,0,0,0]]],
#  [[[0,-,-,-],[0,0,-,-],[0,0,0,-],[0,0,0,0]]]]
import torch
torch.__version__# '2.2.1+cu121'
from torch.nn import functional as F

torch.manual_seed(0)

a = 3
b = 4

q = torch.randn(size=(1, 1, a, b))
k = torch.randn(size=(1, 1, a, b))
v = torch.randn(size=(1, 1, a, b))

def check(q, k, v, device):
    q = q.to(device)
    k = k.to(device)
    v = v.to(device)
    neg_value = torch.finfo(q.dtype).min
    mask = [[neg_value, neg_value, neg_value], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
    mask = torch.tensor([[mask]]).to(device)
    o = F.scaled_dot_product_attention(q, k, v, mask, 0.0, is_causal=False)
    print(o)

check(q, k, v, "cpu")
check(q, k, v, "cuda")
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐