除 FlashAttention-2 与模型量化外的大模型显存优化方案
传统的 ReLU + 线性层组合会产生额外显存开销,而Gated Linear Units(GLU)家族中的 GeGLU(Gated Linear Unit with Gaussian Error Linear Units)等结构更高效。将长输入序列划分为多个块进行处理,每次仅加载一个块的信息,处理完后释放显存再加载下一块。通过上述多种方案的组合运用,可进一步优化大模型的显存占用,满足在有限显存硬
1. 激活函数替换:GeGLU 等高效结构
传统的 ReLU + 线性层组合会产生额外显存开销,而Gated Linear Units(GLU)家族中的 GeGLU(Gated Linear Unit with Gaussian Error Linear Units)等结构更高效。GeGLU 将输入分为两部分,通过门控机制控制信息流动,减少中间张量的存储,例如:
python
import torch
import torch.nn as nn
class GeGLU(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.dense_1 = nn.Linear(in_features, out_features * 2)
self.dense_2 = nn.Linear(out_features, out_features)
def forward(self, x):
x = self.dense_1(x)
x1, x2 = x.chunk(2, dim=-1)
return self.dense_2(x1 * torch.sigmoid(x2)) # GeGLU的核心计算
2. 梯度检查点(Gradient Checkpointing)
梯度检查点通过延迟计算非当前层的梯度,仅保存关键中间结果来节省显存。其原理是在反向传播时重新计算前向传播的部分中间结果,而不是全程存储。示例代码(PyTorch 中使用torch.utils.checkpoint):
python
from torch.utils.checkpoint import checkpoint
def checkpointed_forward(model, x):
def custom_forward(*inputs):
return model(*inputs)
return checkpoint(custom_forward, x)
# 在模型前向传播中使用
for layer in model.layers:
x = checkpointed_forward(layer, x)
3. 模型剪枝(Model Pruning)
通过移除模型中不重要的参数来减小模型规模,降低显存占用。例如结构化剪枝(按层、按通道等),示例:
python
import torch.nn.utils.prune as prune
# 对模型中的线性层进行L1剪枝
module = model.linear_layer
prune.l1_unstructured(module, name="weight", amount=0.5) # 剪去50%的不重要权重
prune.remove(module, "weight") # 应用剪枝
4. 动态序列长度处理
对于输入序列长度变化的情况,采用动态截断或动态填充策略。例如,仅处理当前批次中实际有效的序列长度,而非固定为最大可能长度。示例:
python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("your_model")
texts = ["短文本1", "较长文本2"]
encoded = tokenizer(texts, padding='longest', truncation=True, return_tensors='pt')
# 仅处理实际有效长度的部分,减少不必要的显存占用
5. 分块处理输入(Chunked Input Processing)
将长输入序列划分为多个块进行处理,每次仅加载一个块的信息,处理完后释放显存再加载下一块。以长文本为例:
python
def chunked_processing(model, text, chunk_size=512):
tokenized = tokenizer(text, return_tensors='pt')
input_ids = tokenized['input_ids'][0]
chunks = input_ids.split(chunk_size)
outputs = []
for chunk in chunks:
chunk = chunk.unsqueeze(0).cuda() # 加载单个块到显存
with torch.no_grad():
output = model(chunk)
outputs.append(output)
return torch.cat(outputs, dim=1) # 合并各块输出
6. 知识蒸馏(Knowledge Distillation)
利用小模型模仿大模型的输出,训练时仅需存储小模型的显存,而大模型可在训练后卸载。例如:
python
from transformers import AutoModelForCausalLM
import torch.nn as nn
# 加载大模型和小模型
large_model = AutoModelForCausalLM.from_pretrained("large_model")
small_model = AutoModelForCausalLM.from_pretrained("small_model")
# 定义蒸馏损失函数
class DistillationLoss(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, large_output, small_output):
return self.mse(large_output, small_output)
# 训练小模型模仿大模型
optimizer = torch.optim.Adam(small_model.parameters(), lr=1e-4)
for inputs, targets in dataloader:
large_output = large_model(inputs).detach() # 大模型输出,仅保留显存
small_output = small_model(inputs)
loss = DistillationLoss()(large_output, small_output)
optimizer.zero_grad()
loss.backward()
optimizer.step()
7. 稀疏注意力(Sparse Attention)
仅计算部分注意力头或注意力位置的注意力,减少计算量与显存占用。例如局部注意力(只关注当前词附近的若干词):
python
class LocalAttention(nn.Module):
def __init__(self, window_size=128):
super().__init__()
self.window_size = window_size
def forward(self, x):
batch, seq_len, dim = x.size()
attn_output = torch.zeros_like(x)
for i in range(seq_len):
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
window = x[:, start:end, :]
attn = torch.matmul(window, window.transpose(1, 2)) / dim**0.5
attn = torch.softmax(attn, dim=1)
attn_output[:, i, :] = torch.matmul(attn, window).squeeze(1)
return attn_output
8. 内存池化管理(Memory Pooling)
通过自定义内存池,预先分配显存块并重复利用,避免频繁申请和释放显存带来的开销。示例:
python
import torch
class MemoryPool:
def __init__(self, total_size):
self.total_size = total_size
self.free_memory = total_size
self.allocated = []
def allocate(self, size):
if self.free_memory >= size:
self.free_memory -= size
# 模拟分配显存,实际可结合PyTorch内存管理
return torch.empty(size, device='cuda')
else:
raise MemoryError("Insufficient memory")
def free(self, tensor):
self.free_memory += tensor.numel()
tensor = None # 标记为可回收
通过上述多种方案的组合运用,可进一步优化大模型的显存占用,满足在有限显存硬件上运行大规模模型的需求。
更多推荐
所有评论(0)