2017-01-26 24 views
1

我刚刚开始学习cntk。但是,我有一个基本问题阻碍了我的进步。我有以下的测试通过:为什么不一致的形状numpy vs cntk?

import numpy as np 
from cntk import input_variable, plus 

def test_simple(self): 

    x_input = np.asarray([[1, 2, 2]], dtype=np.int64) 
    assert (1, 3) == x_input.shape 

    y_input = np.asarray([[5, 3, 3]], dtype=np.int64) 
    assert (1, 3) == y_input.shape 

    x = input_variable(x_input.shape[1]) 
    assert (3,) == x.shape 

    y = input_variable(y_input.shape[1]) 
    assert (3,) == y.shape 

    x_plus_y = plus(x, y) 
    assert (3,) == x_plus_y.shape 

    res = x_plus_y.eval({x: x_input, y: y_input}) 

    assert 6 == res[0, 0, 0] 
    assert 5 == res[0, 0, 1] 
    assert 5 == res[0, 0, 2] 

据我所知,输出的形状为(1,1,3)作为第一和第二轴线是分批和分别缺省的动态轴。

但是,为什么我需要将输入变量的形状设置为(3,)而不是(1,3)。使用(1,3)失败。

为什么图中输入节点的形状与用作该节点输入的numpy数据之间存在不一致?

谢谢 水稻

回答

2

这是为Function.forward解释一点点的“论据”的说明。另一种描述是here。你们混淆的原因可能是CNTK做了一些“有用的”转换。

如果您将输入指定为(1,3),则需要在没有序列轴的小批次或(x,1,3)数组列表的情况下提供(1,3)数组列表如果是具有序列轴的小批次(其中x对于小批次中的每个序列而言可能不同)。同样,如果您将输入指定为(3,),则需要提供(3,)向量列表或(x,3)向量列表。

混乱可能是由于没有提供列表的情况而引起的。在那种情况下,CNTK在提供的张量的引导轴上迭代并且创建这些元素的列表,例如, (5,1,3)张量变成每个具有(1,3)形状的一批5个元素。