扩张因果卷积:

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv1 = nn.Conv1d(in_channels, 
                               out_channels, 
                               kernel_size, 
                               padding=self.pad, 
                               dilation=dilation)

    def forward(self, x):
        x = self.conv1(x)
        x = x[..., :-self.pad]   # 移除右侧填充带来的未来信息
        return x

考虑到残差和跳跃连接层会多次出现,使用一个类来实现:

class ResidualLayer(nn.Module):    
    def __init__(self, residual_size, skip_size, dilation):
        super(ResidualLayer, self).__init__()
        self.conv_filter = CausalConv1d(residual_size, residual_size,
                                         kernel_size=2, dilation=dilation)
        self.conv_gate = CausalConv1d(residual_size, residual_size,
                                         kernel_size=2, dilation=dilation)        
        self.resconv1_1 = nn.Conv1d(residual_size, residual_size, kernel_size=1)
        self.skipconv1_1 = nn.Conv1d(residual_size, skip_size, kernel_size=1)
        
   
    def forward(self, x):
        conv_filter = self.conv_filter(x)
        conv_gate = self.conv_gate(x)  
        fx = F.tanh(conv_filter) * F.sigmoid(conv_gate)
        fx = self.resconv1_1(fx) 
        skip = self.skipconv1_1(fx) 
        residual = fx + x  
        #residual=[batch,residual_size,seq_len]  skip=[batch,skip_size,seq_len]
        return skip, residual

这里采用 DilatedStack 来表示一个 [1, 2, 4, 8] 堆叠层:

class DilatedStack(nn.Module):
    def __init__(self, residual_size, skip_size, dilation_depth):
        super(DilatedStack, self).__init__()
        residual_stack = [ResidualLayer(residual_size, skip_size, 2**layer)
                         for layer in range(dilation_depth)]
        self.residual_stack = nn.ModuleList(residual_stack)
        
    def forward(self, x):
        skips = []
        for layer in self.residual_stack:
            skip, x = layer(x)
            skips.append(skip.unsqueeze(0))
            #skip =[1,batch,skip_size,seq_len]
        return torch.cat(skips, dim=0), x  # [layers,batch,skip_size,seq_len]

WaveNet的组装:

在pytorch中,输入时间序列数据纬度为 [ batch_size , seq_len , feature_dim ] [ \text{batch\_size}, \text{seq\_len}, \text{feature\_dim} ] [batch_size,seq_len,feature_dim],为了匹conv1d在最后一个纬度即序列长度方向进行卷积,首先需要交换输入的纬度为 [ batch_size , feature_dim , seq_len ] [ \text{batch\_size}, \text{feature\_dim}, \text{seq\_len} ] [batch_size,feature_dim,seq_len] ,按照waveNet原文一开始就需要一个因果卷积,依次经过两层 [ 1 , 2 , 4 , 8 ] [1, 2, 4, 8] [1,2,4,8] 的卷积,每层的skip都会输出用于后面的计算,最后把skip加和作为结果输出,此后即可见进入decoder层,可以根据需要进行设计。

class WaveNet(nn.Module):

    def __init__(self,input_size,out_size, residual_size, skip_size, dilation_cycles, dilation_depth):

        super(WaveNet, self).__init__()

        self.input_conv = CausalConv1d(in_put_size,residual_szie, kernel_size=2)        

        self.dilated_stacks = nn.ModuleList(

            [DilatedStack(residual_size, skip_size, dilation_depth)

             for cycle in range(dilation_cycles)]

        )

        self.convout_1 = nn.Conv1d(skip_size, out_size, kernel_size=1)

        self.convout_2 = nn.Conv1d(skip_size, out_size, kernel_size=1)

    def forward(self, x):

        x = x.permute(0,2,1)# [batch,input_feature_dim, seq_len]

        x = self.input_conv(x) # [batch,residual_size, seq_len]             

        skip_connections = []

        for cycle in self.dilated_stacks:

            skips, x = cycle(x)             
            skip_connections.append(skips)

        ## skip_connection=[total_layers,batch,skip_size,seq_len]
        skip_connections = torch.cat(skip_connections, dim=0)        

        # gather all output skip connections to generate output, discard last residual output

        out = skip_connections.sum(dim=0) # [batch,skip_size,seq_len]

        out = F.relu(out)

        out = self.convout_1(out) # [batch,out_size,seq_len]
        out = F.relu(out)

        out=self.convout_2(out)

        out=out.permute(0,2,1)
        #[bacth,seq_len,out_size]
        return out
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐