HiDDeN论文解读与代码实现
用于在CIFAR-10图像中嵌入和恢复二进制水印消息,并验证其在多种噪声条件下的鲁棒性,以下文件为main.py文件完整程序代码。要求:消息可以从图像中被接收方解码出来,但攻击者很难区分哪些图像包含信息。要求:即便图像经过压缩、裁剪、模糊等破坏,仍能正确恢复水印信息。:近似真实 JPEG 压缩的可微方法,保证训练过程中梯度可传播。将消息向量扩展成与图像相同空间维度的“消息体积”,与特征拼接。通过对
论文:HiDDeN: Hiding Data With Deep Networks
作者:Jiren Zhu, Russell Kaplan, Justin Johnson, Li Fei-Fei
一、研究背景
在图像信息隐藏领域,通常有两类典型的应用场景:
-
隐写 (Steganography)
- 目标:实现秘密通信。
- 要求:消息可以从图像中被接收方解码出来,但攻击者很难区分哪些图像包含信息。
- 关键点:隐蔽性,难以被检测。
-
数字水印 (Digital Watermarking)
- 目标:主要用于版权保护和身份认证。
- 要求:即便图像经过压缩、裁剪、模糊等破坏,仍能正确恢复水印信息。
- 关键点:鲁棒性,保证信息可恢复。
传统方法多依赖人工设计特征,如:
- 修改像素的最低有效位 (LSB);
- 在频域的低频部分嵌入信息。
这些方法在特定场景下有效,但适应性较差。 HiDDeN则提出了全新的思路:利用 端到端可训练的卷积神经网络 替代传统手工特征,实现更强的灵活性和鲁棒性。
二、核心思想
HiDDeN将数据隐藏任务设计为一个可微分的端到端管道,通过深度学习来自动学习嵌入策略:
1. 编码器 (Encoder)
- 使用多层卷积提取封面图的特征。
- 将消息向量扩展成与图像相同空间维度的“消息体积”,与特征拼接。
- 最终生成含密图,保证与封面图在视觉上接近。
2. 噪声层 (Noise Layer)
在训练过程中引入失真,模拟现实场景:
- Dropout / Cropout:随机替换像素或区域。
- Crop:保留图像的一部分,裁剪其余部分。
- Gaussian Blur:模拟图像模糊。
- JPEG Mask / JPEG Drop:近似真实 JPEG 压缩的可微方法,保证训练过程中梯度可传播。
3. 解码器 (Decoder)
- 多层卷积提取失真图像特征。
- 使用全局平均池化获取消息相关信息。
- 最终通过线性层输出消息位。
- 解码器支持输入大小变化,因此对裁剪等操作具有适应性。
4. 对抗判别器 (Adversary)
- 结构类似解码器。
- 输出一个二分类概率:判断图像是否含密。
- 通过对抗训练提升含密图的隐蔽性,降低被检测概率。
这种设计使模型能够在容量(信息嵌入量)、隐蔽性(难以检测)、鲁棒性(抗失真能力) =三方面取得平衡。
三、代码实现
实现了一个简化版 HiDDeN框架在CIFAR-10上的水印嵌入实验,用于在CIFAR-10图像中嵌入和恢复二进制水印消息,并验证其在多种噪声条件下的鲁棒性,以下文件为main.py文件完整程序代码。
# hidden_cifar10_adv.py
# HiDDeN-style watermarking on CIFAR-10 with adversarial loss (L_A)
# Encoder + Decoder + Noise + Discriminator
# Loss = message BCE + lambda_img * MSE(image) + lambda_adv * adv_loss
import argparse, math, os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
# ------------------------------
# Utils
# ------------------------------
def psnr(img1, img2, eps=1e-8):
mse = F.mse_loss(img1, img2, reduction="mean").item()
if mse < eps: return 99.0
return 10.0 * math.log10(1.0 / mse)
def make_gaussian_kernel(ks=5, sigma=1.0, device="cpu"):
ax = torch.arange(ks, dtype=torch.float32) - (ks - 1) / 2.0
xx, yy = torch.meshgrid(ax, ax, indexing="ij")
kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2))
kernel = kernel / kernel.sum()
return kernel.to(device)
# ------------------------------
# Noise layers
# ------------------------------
class NoiseLayer(nn.Module):
def __init__(self, kind="identity", p=0.3, gs_ks=5, gs_sigma=1.0):
super().__init__()
self.kind = kind
self.p = p
self.gs_ks = gs_ks
self.gs_sigma = gs_sigma
self.register_buffer("gs_kernel", torch.empty(0))
def forward(self, ico, ien):
if self.kind == "identity":
return ien
elif self.kind == "dropout":
mask = (torch.rand_like(ien[:, :1, :, :]) < self.p).float()
return mask * ico + (1.0 - mask) * ien
elif self.kind == "gaussian":
if self.gs_kernel.numel() == 0 or self.gs_kernel.device != ien.device:
k = make_gaussian_kernel(self.gs_ks, self.gs_sigma, ien.device)
self.gs_kernel = k[None, None, :, :]
C = ien.size(1)
weight = self.gs_kernel.expand(C, 1, self.gs_ks, self.gs_ks)
return F.conv2d(ien, weight, padding=self.gs_ks//2, groups=C)
elif self.kind == "combined":
choice = torch.randint(0, 3, (1,), device=ien.device).item()
if choice == 0:
return ien
elif choice == 1:
p = float(torch.empty(1).uniform_(self.p*0.8, min(0.95, self.p*1.2)))
mask = (torch.rand_like(ien[:, :1, :, :]) < p).float()
return mask * ico + (1.0 - mask) * ien
else:
sigma = float(torch.empty(1).uniform_(max(0.5, self.gs_sigma*0.6),
self.gs_sigma*1.5))
k = make_gaussian_kernel(self.gs_ks, sigma, ien.device)
weight = k[None, None, :, :].expand(ien.size(1), 1, self.gs_ks, self.gs_ks)
return F.conv2d(ien, weight, padding=self.gs_ks//2, groups=ien.size(1))
else:
return ien
# ------------------------------
# Encoder / Decoder
# ------------------------------
class ConvBNReLU(nn.Module):
def __init__(self, c_in, c_out, k=3, s=1, p=1):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(c_in, c_out, k, s, p, bias=False),
nn.BatchNorm2d(c_out),
nn.ReLU(inplace=True),
)
def forward(self, x): return self.block(x)
class Encoder(nn.Module):
def __init__(self, L, img_ch=3, base=64):
super().__init__()
self.L = L
self.stem = nn.Sequential(
ConvBNReLU(img_ch, base),
ConvBNReLU(base, base),
ConvBNReLU(base, base),
ConvBNReLU(base, base),
)
self.fuse = ConvBNReLU(base + L + img_ch, base)
self.to_img = nn.Conv2d(base, img_ch, 1)
def forward(self, ico, m_bits):
B, C, H, W = ico.size()
feat = self.stem(ico)
m = m_bits.view(B, self.L, 1, 1).float().expand(B, self.L, H, W)
x = torch.cat([feat, m, ico], dim=1)
ien = torch.clamp(self.to_img(self.fuse(x)), 0.0, 1.0)
return ien
class Decoder(nn.Module):
def __init__(self, L, img_ch=3, base=64):
super().__init__()
self.L = L
self.body = nn.Sequential(
ConvBNReLU(img_ch, base),
ConvBNReLU(base, base),
ConvBNReLU(base, base),
ConvBNReLU(base, base),
ConvBNReLU(base, base),
ConvBNReLU(base, base),
)
self.head = ConvBNReLU(base, L)
self.fc = nn.Linear(L, L)
def forward(self, ino):
x = self.body(ino)
x = self.head(x)
x = F.adaptive_avg_pool2d(x, 1).view(x.size(0), self.L)
return self.fc(x)
# ------------------------------
# Discriminator
# ------------------------------
class Discriminator(nn.Module):
def __init__(self, img_ch=3, base=64):
super().__init__()
self.net = nn.Sequential(
ConvBNReLU(img_ch, base),
ConvBNReLU(base, base),
nn.Conv2d(base, base*2, 3, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(base*2, base*4, 3, stride=2, padding=1), nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
)
self.fc = nn.Linear(base*4, 1)
def forward(self, x):
feat = self.net(x).view(x.size(0), -1)
return self.fc(feat)
# ------------------------------
# Training / Evaluation
# ------------------------------
def train(args):
device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
tfm = transforms.ToTensor()
train_set = datasets.CIFAR10(root="./data", train=True, download=True, transform=tfm)
test_set = datasets.CIFAR10(root="./data", train=False, download=True, transform=tfm)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=2)
encoder, decoder = Encoder(args.L).to(device), Decoder(args.L).to(device)
noise = NoiseLayer(args.noise, args.drop_p, args.gs_ks, args.gs_sigma).to(device)
discriminator = Discriminator().to(device)
opt_EG = torch.optim.Adam(list(encoder.parameters())+list(decoder.parameters()), lr=args.lr)
opt_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr)
bce_logits = nn.BCEWithLogitsLoss()
bce = nn.BCEWithLogitsLoss()
mse = nn.MSELoss()
os.makedirs(args.out_dir, exist_ok=True)
for epoch in range(1, args.epochs+1):
encoder.train(); decoder.train(); noise.train(); discriminator.train()
run_acc = 0.0
for imgs, _ in train_loader:
imgs = imgs.to(device)
B = imgs.size(0)
m_bits = torch.randint(0, 2, (B, args.L), device=device)
ien = encoder(imgs, m_bits)
ino = noise(imgs, ien)
logits = decoder(ino)
# 1. Train Discriminator
logits_real = discriminator(imgs)
logits_fake = discriminator(ien.detach())
loss_D = bce_logits(logits_real, torch.zeros_like(logits_real)) + \
bce_logits(logits_fake, torch.ones_like(logits_fake))
opt_D.zero_grad(); loss_D.backward(); opt_D.step()
# 2. Train Encoder+Decoder
msg_loss = bce(logits, m_bits.float())
img_loss = mse(ien, imgs)
logits_fake_for_G = discriminator(ien)
adv_loss = bce_logits(logits_fake_for_G, torch.zeros_like(logits_fake_for_G))
loss = msg_loss + args.lambda_img*img_loss + args.lambda_adv*adv_loss
opt_EG.zero_grad();
loss.backward();
opt_EG.step()
with torch.no_grad():
pred = (torch.sigmoid(logits) > 0.5).long()
run_acc += (pred == m_bits).float().mean().item() * B
print(f"[Epoch {epoch}] bit_acc={run_acc/len(train_loader.dataset):.4f}")
# quick eval on test set + save a visualization
if epoch % args.eval_every == 0:
test_bit_acc, test_psnr = evaluate(encoder, decoder, noise, test_loader, device)
print(f" -> Test bit_acc={test_bit_acc:.4f} PSNR(cover,encoded)={test_psnr:.2f} dB")
dump_examples(encoder, decoder, noise, test_loader, device, args.out_dir, epoch)
torch.save({
"encoder": encoder.state_dict(),
"decoder": decoder.state_dict(),
"discriminator": discriminator.state_dict()
}, os.path.join(args.out_dir, "ckpt.pt"))
print("Training done. Checkpoints & samples saved to:", args.out_dir)
@torch.no_grad()
def evaluate(encoder, decoder, noise, loader, device):
encoder.eval(); decoder.eval(); noise.eval()
acc_sum, psnr_sum, cnt = 0.0, 0.0, 0
for imgs, _ in loader:
imgs = imgs.to(device)
B = imgs.size(0)
m_bits = torch.randint(0, 2, (B, decoder.L), device=device)
ien = encoder(imgs, m_bits)
ino = noise(imgs, ien)
logits = decoder(ino)
pred = (torch.sigmoid(logits) > 0.5).long()
acc_sum += (pred == m_bits).float().mean().item() * B
psnr_sum += psnr(imgs, ien) * B
cnt += B
return acc_sum/cnt, psnr_sum/cnt
@torch.no_grad()
def dump_examples(encoder, decoder, noise, loader, device, out_dir, epoch):
encoder.eval(); decoder.eval(); noise.eval()
imgs, _ = next(iter(loader))
imgs = imgs.to(device)[:8]
B = imgs.size(0)
m_bits = torch.randint(0, 2, (B, decoder.L), device=device)
ien = encoder(imgs, m_bits)
ino = noise(imgs, ien)
utils.save_image(imgs, os.path.join(out_dir, f"epoch{epoch:03d}_cover.png"), nrow=4)
utils.save_image(ien, os.path.join(out_dir, f"epoch{epoch:03d}_encoded.png"), nrow=4)
utils.save_image(ino, os.path.join(out_dir, f"epoch{epoch:03d}_noised.png"), nrow=4)
logits = decoder(ino)
pred = (torch.sigmoid(logits) > 0.5).long()
print("[viz] sample#0 GT bits:", m_bits[0].tolist())
print("[viz] sample#0 PR bits:", pred[0].tolist())
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--L", type=int, default=30)
parser.add_argument("--noise", type=str, default="combined",
choices=["identity","dropout","gaussian","combined"])
parser.add_argument("--drop-p", type=float, default=0.3)
parser.add_argument("--gs-ks", type=int, default=5)
parser.add_argument("--gs-sigma", type=float, default=1.0)
parser.add_argument("--lambda-img", type=float, default=0.7)
parser.add_argument("--lambda-adv", type=float, default=0.03)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--eval-every", type=int, default=1)
parser.add_argument("--out-dir", type=str, default="runs_hidden_cifar10_adv")
parser.add_argument("--cpu", action="store_true")
args = parser.parse_args()
train(args)
if __name__ == "__main__":
main()
目录结构为:
Hidden
├── data
│ └── cifar-10-python.tar.gz
└── main.py
运行命令为:
# 仅保真,不加噪声(容量/保密取向)
python main.py --epochs 5 --noise identity
# 针对像素级 Dropout 的鲁棒水印训练
python main.py --epochs 5 --noise dropout --drop-p 0.3
# 混合扰动(Dropout/高斯 模糊 随机采样),更通用
python main.py --epochs 5 --noise combined
四、关键代码与论文公式对应
解析论文的核心公式,并将其与 main.py
代码实现逐一对应。
1. Encoder 对应公式
论文公式 (1):
Ien=Eθ(Ico,M) I_{en} = E_\theta(I_{co}, M) Ien=Eθ(Ico,M)
-
含义:Encoder EθE_\thetaEθ 接收原始图像 IcoI_{co}Ico 与比特消息 MMM,输出带水印的图像 IenI_{en}Ien.
-
代码对应:
ien = encoder(imgs, m_bits) # (公式1)
2. 噪声层对应公式
论文公式 (2):
Ino=N(Ico,Ien) I_{no} = N(I_{co}, I_{en}) Ino=N(Ico,Ien)
-
含义:在 Encoder 和 Decoder 之间插入噪声层 NNN,输出失真后的图像 InoI_{no}Ino.
-
代码对应:
ino = noise(imgs, ien) # (公式2)
3. Decoder 对应公式
论文公式 (3):
M′=Dϕ(Ino) M' = D_\phi(I_{no}) M′=Dϕ(Ino)
-
含义:Decoder DϕD_\phiDϕ 接收 InoI_{no}Ino,输出预测消息 M′M'M′.
-
代码对应:
logits = decoder(ino) # (公式3)
4. 损失函数对应公式
论文总损失公式(4):
L=λI⋅LI(Ico,Ien)+λM⋅LM(M,M′)+λA⋅LA \mathcal{L} = \lambda_I \cdot \mathcal{L}_I(I_{co}, I_{en}) + \lambda_M \cdot \mathcal{L}_M(M, M') + \lambda_A \cdot \mathcal{L}_A L=λI⋅LI(Ico,Ien)+λM⋅LM(M,M′)+λA⋅LA
- 代码对应:
msg_loss = bce(logits, m_bits.float()) # L_M
img_loss = mse(ien, imgs) # L_I
loss = msg_loss + args.lambda_img * img_loss # 总损失 (λ_A=0)
5. 评价指标对应公式
- Bit Accuracy:
Acc=1L∑i=1L1(Mi=Mi′) Acc = \frac{1}{L} \sum_{i=1}^L 1(M_i = M'_i) Acc=L1i=1∑L1(Mi=Mi′)
代码:
acc = (pred == m_bits).float().mean().item()
- PSNR:
PSNR(Ico,Ien)=10⋅log10(1/MSE(Ico,Ien)) \text{PSNR}(I_{co}, I_{en}) = 10 \cdot \log_{10}(1/\text{MSE}(I_{co}, I_{en})) PSNR(Ico,Ien)=10⋅log10(1/MSE(Ico,Ien))
代码:
def psnr(img1, img2, eps=1e-8):
mse = F.mse_loss(img1, img2).item()
return 10.0 * math.log10(1.0 / mse)
- 公式 (1) Ien=Eθ(Ico,M)I_{en} = E_\theta(I_{co}, M)Ien=Eθ(Ico,M) →
encoder.forward
- 公式 (2) Ino=N(Ico,Ien)I_{no} = N(I_{co}, I_{en})Ino=N(Ico,Ien) →
noise.forward
- 公式 (3) M′=Dϕ(Ino)M' = D_\phi(I_{no})M′=Dϕ(Ino) →
decoder.forward
- 公式 (4) 总损失 →
msg_loss + λ * img_loss
- 指标 Bit Accuracy & PSNR →
acc
,psnr()
更多推荐
所有评论(0)