查看相同栏目文章
Tensorflow2 基础-Tensor索引和切片
1、C语言风格,通过多层下标进行索引。
array= tf.random.uniform([3,4,5,6],maxval=100,dtype=tf.int32)
print(array[0][0])
print(array[0][0][1])
print(array[0][0][1][-1])#负号下标表示逆向索引
tf.Tensor(
[[19 33 53 52 16 71]
[91 18 67 40 47 53]
[82 86 13 10 62 45]
[ 0 98 96 54 76 23]
[51 72 21 16 10 11]], shape=(5, 6), dtype=int32)
tf.Tensor([91 18 67 40 47 53], shape=(6,), dtype=int32)
tf.Tensor(53, shape=(), dtype=int32)
2、numpy风格,通过多层下标索引,写在一个中括号内,使用逗号分隔。
#numpy风格
print(array[0,0])
print(array[0,0,1])
print(array[0,0,1,-1])
tf.Tensor(
[[19 33 53 52 16 71]
[91 18 67 40 47 53]
[82 86 13 10 62 45]
[ 0 98 96 54 76 23]
[51 72 21 16 10 11]], shape=(5, 6), dtype=int32)
tf.Tensor([91 18 67 40 47 53], shape=(6,), dtype=int32)
tf.Tensor(53, shape=(), dtype=int32)
3、selective index
tf.gather(a, axis, indices)
axis表示指定的收集维度,indices表示该维度上收集那些序号。
tf.gather_nd(a, indices)
indices可以是多维的,按照指定维度索引。
tf.boolean_mask(a, mask, axis)
按照布尔型的mask,对为True的对应取索引(支持多层维度)。
a = tf.random.uniform([2, 5,3])
print(a)
print('--------------')
print(tf.gather(a, axis=0,indices=[1]))
print('--------------')
print(tf.gather_nd(a,[[0,1,2],[1,2,0]]))
print('--------------')
print(tf.boolean_mask(a,mask=[True,False],axis=0))
tf.Tensor(
[[[0.21262014 0.6784439 0.7892691 ]
[0.9623302 0.09823263 0.5842774 ]
[0.18734527 0.18486834 0.4955622 ]
[0.5475397 0.86478496 0.71788895]
[0.19401169 0.35921252 0.26015413]]
[[0.18015862 0.4910288 0.01969993]
[0.26937723 0.8827827 0.36195683]
[0.9493141 0.1741085 0.59741175]
[0.614969 0.50290656 0.60260296]
[0.2909646 0.8177142 0.2065022 ]]], shape=(2, 5, 3), dtype=float32)
--------------
tf.Tensor(
[[[0.18015862 0.4910288 0.01969993]
[0.26937723 0.8827827 0.36195683]
[0.9493141 0.1741085 0.59741175]
[0.614969 0.50290656 0.60260296]
[0.2909646 0.8177142 0.2065022 ]]], shape=(1, 5, 3), dtype=float32)
--------------
tf.Tensor([0.5842774 0.9493141], shape=(2,), dtype=float32)
--------------
tf.Tensor(
[[[0.21262014 0.6784439 0.7892691 ]
[0.9623302 0.09823263 0.5842774 ]
[0.18734527 0.18486834 0.4955622 ]
[0.5475397 0.86478496 0.71788895]
[0.19401169 0.35921252 0.26015413]]], shape=(1, 5, 3), dtype=float32)