2016-04-16 77 views
22

我有形状[None, 9, 2](其中None是批量)的张量流的输入。张量流中的拼合批量

要执行进一步的操作(例如matmul),我需要将其转换为[None, 18]形状。怎么做?

回答

2

您可以使用动态整形在运行期间通过获取批量维度的值,并将整个新维度集合计算为tf.reshape。以下是将平面列表重新整形为方形矩阵而不知道列表长度的示例。

tf.reset_default_graph() 
sess = tf.InteractiveSession("") 
a = tf.placeholder(dtype=tf.int32) 
# get [9] 
ashape = tf.shape(a) 
# slice the list from 0th to 1st position 
ashape0 = tf.slice(ashape, [0], [1]) 
# reshape list to scalar, ie from [9] to 9 
ashape0_flat = tf.reshape(ashape0,()) 
# tf.sqrt doesn't support int, so cast to float 
ashape0_flat_float = tf.to_float(ashape0_flat) 
newshape0 = tf.sqrt(ashape0_flat_float) 
# convert [3, 3] Python list into [3, 3] Tensor 
newshape = tf.pack([newshape0, newshape0]) 
# tf.reshape doesn't accept float, so convert back to int 
newshape_int = tf.to_int32(newshape) 
a_reshaped = tf.reshape(a, newshape_int) 
sess.run(a_reshaped, feed_dict={a: np.ones((9))}) 

您应该看到

array([[1, 1, 1], 
     [1, 1, 1], 
     [1, 1, 1]], dtype=int32) 
+0

我没有看到任何方法'tf.batch'在此解决方案或Tensorflow ... – Muneeb

34

你可以用tf.reshape()轻松地做到这一点不知道批量大小。

x = tf.placeholder(tf.float32, shape=[None, 9,2]) 
shape = x.get_shape().as_list()  # a list: [None, 9, 2] 
dim = numpy.prod(shape[1:])   # dim = prod(9,2) = 18 
x2 = tf.reshape(x, [-1, dim])   # -1 means "all" 

在最后一行的-1意味着整个列无论batchsize是在运行什么。你可以在tf.reshape()看到它。


更新:形状= [无,3,无]

由于@kbrose。对于未定义多于1维的情况,我们可以使用tf.shape()tf.reduce_prod()

x = tf.placeholder(tf.float32, shape=[None, 3, None]) 
dim = tf.reduce_prod(tf.shape(x)[1:]) 
x2 = tf.reshape(x, [-1, dim]) 

tf.shape()返回一个可以在运行时评估的形状张量。 tf.get_shape()和tf.shape()之间的区别可以看出in the doc

我也试过tf.contrib.layers.flatten()在另一个。第一种情况最简单,但不能处理第二种情况。

+0

这种运作良好,如果你知道所有的其他尺寸的大小,但不会,如果其他方面有不明尺寸。例如。 'x = tf.placeholder(tf.float32,shape = [None,9,None])' – kbrose

+1

thanks @kbrose。我已经更新了案例的答案。 – weitang114

+0

@ weitang114太棒了! – kbrose

8
flat_inputs = tf.contrib.layers.flatten(inputs)