1. 激活函数替换:GeGLU 等高效结构

传统的 ReLU + 线性层组合会产生额外显存开销,而Gated Linear Units(GLU)家族中的 GeGLU(Gated Linear Unit with Gaussian Error Linear Units)等结构更高效。GeGLU 将输入分为两部分,通过门控机制控制信息流动,减少中间张量的存储,例如:

python

import torch

import torch.nn as nn

class GeGLU(nn.Module):

def __init__(self, in_features, out_features):

super().__init__()

self.dense_1 = nn.Linear(in_features, out_features * 2)

self.dense_2 = nn.Linear(out_features, out_features)

def forward(self, x):

x = self.dense_1(x)

x1, x2 = x.chunk(2, dim=-1)

return self.dense_2(x1 * torch.sigmoid(x2)) # GeGLU的核心计算

2. 梯度检查点(Gradient Checkpointing)

梯度检查点通过延迟计算非当前层的梯度,仅保存关键中间结果来节省显存。其原理是在反向传播时重新计算前向传播的部分中间结果,而不是全程存储。示例代码(PyTorch 中使用torch.utils.checkpoint):

python

from torch.utils.checkpoint import checkpoint

def checkpointed_forward(model, x):

def custom_forward(*inputs):

return model(*inputs)

return checkpoint(custom_forward, x)

# 在模型前向传播中使用

for layer in model.layers:

x = checkpointed_forward(layer, x)

3. 模型剪枝(Model Pruning)

通过移除模型中不重要的参数来减小模型规模,降低显存占用。例如结构化剪枝(按层、按通道等),示例:

python

import torch.nn.utils.prune as prune

# 对模型中的线性层进行L1剪枝

module = model.linear_layer

prune.l1_unstructured(module, name="weight", amount=0.5) # 剪去50%的不重要权重

prune.remove(module, "weight") # 应用剪枝

4. 动态序列长度处理

对于输入序列长度变化的情况,采用动态截断或动态填充策略。例如,仅处理当前批次中实际有效的序列长度,而非固定为最大可能长度。示例:

python

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("your_model")

texts = ["短文本1", "较长文本2"]

encoded = tokenizer(texts, padding='longest', truncation=True, return_tensors='pt')

# 仅处理实际有效长度的部分,减少不必要的显存占用

5. 分块处理输入(Chunked Input Processing)

将长输入序列划分为多个块进行处理,每次仅加载一个块的信息,处理完后释放显存再加载下一块。以长文本为例:

python

def chunked_processing(model, text, chunk_size=512):

tokenized = tokenizer(text, return_tensors='pt')

input_ids = tokenized['input_ids'][0]

chunks = input_ids.split(chunk_size)

outputs = []

for chunk in chunks:

chunk = chunk.unsqueeze(0).cuda() # 加载单个块到显存

with torch.no_grad():

output = model(chunk)

outputs.append(output)

return torch.cat(outputs, dim=1) # 合并各块输出

6. 知识蒸馏(Knowledge Distillation)

利用小模型模仿大模型的输出,训练时仅需存储小模型的显存,而大模型可在训练后卸载。例如:

python

from transformers import AutoModelForCausalLM

import torch.nn as nn

# 加载大模型和小模型

large_model = AutoModelForCausalLM.from_pretrained("large_model")

small_model = AutoModelForCausalLM.from_pretrained("small_model")

# 定义蒸馏损失函数

class DistillationLoss(nn.Module):

def __init__(self):

super().__init__()

self.mse = nn.MSELoss()

def forward(self, large_output, small_output):

return self.mse(large_output, small_output)

# 训练小模型模仿大模型

optimizer = torch.optim.Adam(small_model.parameters(), lr=1e-4)

for inputs, targets in dataloader:

large_output = large_model(inputs).detach() # 大模型输出,仅保留显存

small_output = small_model(inputs)

loss = DistillationLoss()(large_output, small_output)

optimizer.zero_grad()

loss.backward()

optimizer.step()

7. 稀疏注意力(Sparse Attention)

仅计算部分注意力头或注意力位置的注意力,减少计算量与显存占用。例如局部注意力(只关注当前词附近的若干词):

python

class LocalAttention(nn.Module):

def __init__(self, window_size=128):

super().__init__()

self.window_size = window_size

def forward(self, x):

batch, seq_len, dim = x.size()

attn_output = torch.zeros_like(x)

for i in range(seq_len):

start = max(0, i - self.window_size // 2)

end = min(seq_len, i + self.window_size // 2 + 1)

window = x[:, start:end, :]

attn = torch.matmul(window, window.transpose(1, 2)) / dim**0.5

attn = torch.softmax(attn, dim=1)

attn_output[:, i, :] = torch.matmul(attn, window).squeeze(1)

return attn_output

8. 内存池化管理(Memory Pooling)

通过自定义内存池,预先分配显存块并重复利用,避免频繁申请和释放显存带来的开销。示例:

python

import torch

class MemoryPool:

def __init__(self, total_size):

self.total_size = total_size

self.free_memory = total_size

self.allocated = []

def allocate(self, size):

if self.free_memory >= size:

self.free_memory -= size

# 模拟分配显存,实际可结合PyTorch内存管理

return torch.empty(size, device='cuda')

else:

raise MemoryError("Insufficient memory")

def free(self, tensor):

self.free_memory += tensor.numel()

tensor = None # 标记为可回收

通过上述多种方案的组合运用,可进一步优化大模型的显存占用,满足在有限显存硬件上运行大规模模型的需求。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐