效果:

将一个张量值value切分成子张量列表。

tf.split(
    value, num_or_size_splits, axis=0, num=None, name='split'
)

如果num_or_size_splits为整数,则将张量value其沿维度axis拆分成大小为num_or_size_splits较小的张量。这就要求 value.shape[axis] 能被 num_or_size_splits整除。

如果num_or_size_splits为一维张量(或列表),则将value其拆分为 len(num_or_size_splits)个元素。第i个元素的形状与value的相同,除了沿维度轴的大小为num_or_size_splitting [i]

参数含义

参数名称 具体含义
value 需要切分的张量
num_or_size_splits 表示沿轴分割数的整数一维整数张量 包含沿轴每个输出张量大小的Python列表。如果是标量,则必须均匀分割value.shape[axis];否则,沿拆分轴的大小之和必须与值的大小之和匹配。
axis 一个整数或int32类型张量。表示切分的维度。必须在[rank(value), rank(value)]范围内。默认值为0。
num 可选,用于指定,当不能从size_split的形状推断输出的数量。
name 操作的名称(可选)。

返回值

如果num_or_size_splitting是标量,则返回num_or_size_splitting个张量对象的列表;如果num_or_size_splitting是一维张量,则返回num_or_size_splitting.get_shape[0]个分割value得到的张量对象。

实例

value = tf.Variable(tf.random.uniform([5, 6], -1, 1))

产生的value的值

#沿着axis=1将value切分为3个张量
s0, s1 = tf.split(value, num_or_size_splits=2, axis=1)

s0的值:
在这里插入图片描述

#将value按照尺寸[1,2,1]切分,在轴axis=1上
split0, split1, split2 = tf.split(value, [1, 2, 1], 1)
split0

split0:
在这里插入图片描述
split1:
在这里插入图片描述

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐