WaveNet torch 复现
WaveNet torch 复现
·
-
torch 复现:
扩张因果卷积:
class CausalConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
super().__init__()
self.pad = (kernel_size - 1) * dilation
self.conv1 = nn.Conv1d(in_channels,
out_channels,
kernel_size,
padding=self.pad,
dilation=dilation)
def forward(self, x):
x = self.conv1(x)
x = x[..., :-self.pad] # 移除右侧填充带来的未来信息
return x
考虑到残差和跳跃连接层会多次出现,使用一个类来实现:
class ResidualLayer(nn.Module):
def __init__(self, residual_size, skip_size, dilation):
super(ResidualLayer, self).__init__()
self.conv_filter = CausalConv1d(residual_size, residual_size,
kernel_size=2, dilation=dilation)
self.conv_gate = CausalConv1d(residual_size, residual_size,
kernel_size=2, dilation=dilation)
self.resconv1_1 = nn.Conv1d(residual_size, residual_size, kernel_size=1)
self.skipconv1_1 = nn.Conv1d(residual_size, skip_size, kernel_size=1)
def forward(self, x):
conv_filter = self.conv_filter(x)
conv_gate = self.conv_gate(x)
fx = F.tanh(conv_filter) * F.sigmoid(conv_gate)
fx = self.resconv1_1(fx)
skip = self.skipconv1_1(fx)
residual = fx + x
#residual=[batch,residual_size,seq_len] skip=[batch,skip_size,seq_len]
return skip, residual
这里采用 DilatedStack 来表示一个 [1, 2, 4, 8] 堆叠层:
class DilatedStack(nn.Module):
def __init__(self, residual_size, skip_size, dilation_depth):
super(DilatedStack, self).__init__()
residual_stack = [ResidualLayer(residual_size, skip_size, 2**layer)
for layer in range(dilation_depth)]
self.residual_stack = nn.ModuleList(residual_stack)
def forward(self, x):
skips = []
for layer in self.residual_stack:
skip, x = layer(x)
skips.append(skip.unsqueeze(0))
#skip =[1,batch,skip_size,seq_len]
return torch.cat(skips, dim=0), x # [layers,batch,skip_size,seq_len]
WaveNet的组装:
在pytorch中,输入时间序列数据纬度为 [ batch_size , seq_len , feature_dim ] [ \text{batch\_size}, \text{seq\_len}, \text{feature\_dim} ] [batch_size,seq_len,feature_dim],为了匹conv1d在最后一个纬度即序列长度方向进行卷积,首先需要交换输入的纬度为 [ batch_size , feature_dim , seq_len ] [ \text{batch\_size}, \text{feature\_dim}, \text{seq\_len} ] [batch_size,feature_dim,seq_len] ,按照waveNet原文一开始就需要一个因果卷积,依次经过两层 [ 1 , 2 , 4 , 8 ] [1, 2, 4, 8] [1,2,4,8] 的卷积,每层的skip都会输出用于后面的计算,最后把skip加和作为结果输出,此后即可见进入decoder层,可以根据需要进行设计。
class WaveNet(nn.Module):
def __init__(self,input_size,out_size, residual_size, skip_size, dilation_cycles, dilation_depth):
super(WaveNet, self).__init__()
self.input_conv = CausalConv1d(in_put_size,residual_szie, kernel_size=2)
self.dilated_stacks = nn.ModuleList(
[DilatedStack(residual_size, skip_size, dilation_depth)
for cycle in range(dilation_cycles)]
)
self.convout_1 = nn.Conv1d(skip_size, out_size, kernel_size=1)
self.convout_2 = nn.Conv1d(skip_size, out_size, kernel_size=1)
def forward(self, x):
x = x.permute(0,2,1)# [batch,input_feature_dim, seq_len]
x = self.input_conv(x) # [batch,residual_size, seq_len]
skip_connections = []
for cycle in self.dilated_stacks:
skips, x = cycle(x)
skip_connections.append(skips)
## skip_connection=[total_layers,batch,skip_size,seq_len]
skip_connections = torch.cat(skip_connections, dim=0)
# gather all output skip connections to generate output, discard last residual output
out = skip_connections.sum(dim=0) # [batch,skip_size,seq_len]
out = F.relu(out)
out = self.convout_1(out) # [batch,out_size,seq_len]
out = F.relu(out)
out=self.convout_2(out)
out=out.permute(0,2,1)
#[bacth,seq_len,out_size]
return out
更多推荐


所有评论(0)