Tensorflow2 基础-Tensor维度变换

1、tf.reshape(a, shape)将Tensor调整为新的合法shape,不会改变数据,只是改变数据的理解方式。(reshape中维度指定为-1表示自动推导,类似numpy)


a = tf.random.uniform([2,3,4,5])
print(a.shape,'',a.ndim) 
print(tf.reshape(a,[2,3*4,5]).shape) 
print(tf.reshape(a,[2,-1,5]).shape) 
print(tf.reshape(a,[2,3*4*5]).shape) 
print(tf.reshape(a,[2,-1]).shape)
b=tf.random.uniform([2,3])
print(b)
print(tf.reshape(b,[1,-1]))
(2, 3, 4, 5)  4
(2, 12, 5)
(2, 12, 5)
(2, 60)
(2, 60)
tf.Tensor(
[[0.84430206 0.29833543 0.6193876 ]
 [0.27115643 0.21803117 0.28303194]], shape=(2, 3), dtype=float32)
tf.Tensor([[0.84430206 0.29833543 0.6193876  0.27115643 0.21803117 0.28303194]], shape=(1, 6), dtype=float32)

2、tf.transpose(a, perm)将原来Tensor按照perm指定的维度顺序进行转置。


a= tf.random.uniform([16,28,28,3])
print(tf.transpose(a,[0,3,1,2]).shape)
(16, 3, 28, 28)

3、tf.expand_dims(a, axis)在指定维度的前面(axis为正数)或者后面(axis为负数)增加一个新的空维度。


a = tf.random.normal([4,35,8]) 
print(tf.expand_dims(a,axis=0).shape)
print(tf.expand_dims(a,axis=3).shape)
(1, 4, 35, 8)
(4, 35, 8, 1)

4、tf.squeeze(a, axis)消去指定的可以去掉的维度(该维度值为1)。


a= tf.random.normal([1,28,28,1]) 
print(tf.squeeze(a,axis=0).shape) 
print(tf.squeeze(a,axis=3).shape) 
print(tf.squeeze(a,axis=-1).shape)
(28, 28, 1)
(1, 28, 28)
(1, 28, 28)