图解TensorFlow op:tf.slice
与tf.strided_slice比,tf.slice相对更简单些,在各维度上切分指定起始点和尺寸的数据。本文用图文的方式来解释TensorFlow中slice算子运算的方式。
田海立@CSDN 2020-11-15
与tf.strided_slice比,tf.slice相对更简单些,仅在各维度上切分指定起始点和尺寸的数据,不支持stride选择。本文用图文的方式来解释TensorFlow中slice算子运算的方式。
一、slice原型
slice在各维度上切分指定起始点和尺寸的数据。
原型如下:
tf.slice(
input_, begin, size, name=None
)
其中:begin/size指定各个维度上的起始与尺寸的vector,长度是rank。
也就是在axis#0上切片[begin[0], begin[0]+size-1];在axis#1上切片[begin[1], begin[1]+size-1];...
二、slice对数据的处理
下面以一个3D Tensor [3, 2, 3]做begin[1, 0, 0], size[1, 2, 3]为例,看slice操作对数据的处理就是:
- axis#0维上,切片[1, 1],也就是切片#1;
- axis#1维上,切片[0, 1],也就是切片#0,#1;
- axis#2维上,切片[0, 2],也就是切片#0,#1,#2;也就是该维上不变。
上述的处理过程,一张图展示就是这样:
再以上一个3D Tensor [3, 2, 3]做begin[1, 0, 0], size[2, 1, 3]为例,看slice操作对数据的处理就是:
- axis#0维上,切片[1, 2],也就是切片#1,#2;
- axis#1维上,切片[0, 0],也就是切片#0;
- axis#2维上,切片[0, 2],也就是切片#0,#1,#2;也就是该维上不变。
上述的处理过程,一张图展示就是这样:
三、slice程序实现
上述过程用程序实现,如下:
定义一个[3, 2, 3]的Tensor:
>>>
>>> t = tf.range(3*2*3)
>>> t = tf.reshape(t, [3, 2, 3])
>>> t
<tf.Tensor: shape=(3, 2, 3), dtype=int32, numpy=
array([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17]]], dtype=int32)>
>>>
执行slice(begin = [1, 0, 0], size = [1, 2, 3]之后:
>>>
>>> t1 = tf.slice(t, [1, 0, 0], [1, 2, 3])
>>> t1
<tf.Tensor: shape=(1, 2, 3), dtype=int32, numpy=
array([[[ 6, 7, 8],
[ 9, 10, 11]]], dtype=int32)>
>>>
执行slice(begin = [1, 0, 0], size = [1, 2, 3]之后:
>>>
>>> t2 = tf.slice(t, [1, 0, 0], [2, 1, 3])
>>> t2
<tf.Tensor: shape=(2, 1, 3), dtype=int32, numpy=
array([[[ 6, 7, 8]],
[[12, 13, 14]]], dtype=int32)>
>>>
总结
本文以图示和程序分析了tf.slices对Tensor的处理,在各个维度上做切片。
更多推荐
所有评论(0)