2017-06-17 21 views
1

我正在尝试使用MXNet中的RNN进行分类。我的数据大致看起来像我创建的矩阵m0和m1。 m 0代表例如一个设备的能耗随着时间的推移,而m1是我的标签,告诉设备如何被分类(例如在这种情况下是二进制)。 我的目标是通过查看一段时间内的能耗来检测设备的类别。 我不断收到关于形状不匹配的错误,并且通过更改输入参数找不到解决方案。您可以在下面看到我的代码和错误消息。 我很感激任何关于如何处理这个问题的建议。在MXNet中使用RNN的形状不匹配 - R

require(mxnet) 

m0 <- matrix(runif(200*100), 100, 200) 
m1 <- matrix(round(runif(1*200)), 1, 200) 

num.round  <- 10 
update.period <- 1 
num.rnn.layer <- 1 
seq.len  <- 100 
num.hidden  <- 1 
num.embed  <- 1 
num.label  <- 1 
batch.size  <- 1 
input.size  <- 1 
learning.rate <- 0.1 

X.train <- list(data = m0, label = m1) 

model <- mx.rnn(train.data = X.train, 
       eval.data = NULL, 
       num.rnn.layer = num.rnn.layer, 
       seq.len = seq.len, 
       num.hidden = num.hidden, 
       num.embed = num.embed, 
       num.label = num.label, 
       batch.size = batch.size, 
       input.size = input.size, 
       ctx = mx.cpu(), 
       num.round = num.round, 
       update.period = update.period, 
       initializer = mx.init.uniform(0.1), 
       learning.rate = learning.rate) 

[16时07分02秒] d:\程序文件 (86)\詹金斯\工作空间\ mxnet \ mxnet \ SRC \操作\ tensor./matrix_op-inl.h:144: 使用target_shape将被弃用。

[16时07分02秒] d:\程序文件 (86)\詹金斯\工作空间\ mxnet \ mxnet \ SRC \操作\ tensor./matrix_op-inl.h:144: 使用target_shape将被弃用。

[16时07分02秒] d:\程序文件 (86)\詹金斯\工作空间\ mxnet \ mxnet \ SRC \操作\ tensor./matrix_op-inl.h:144: 使用target_shape将被弃用。

[16时07分02秒] d:\ Program Files文件 (86)\詹金斯\工作空间\ mxnet \ mxnet \ DMLC-芯\包括\ DMLC/logging.h:304:

[16: 07:02] D:\ Program Files (x86)\ Jenkins \ workspace \ mxnet \ mxnet \ src \ ndarray \ ndarray.cc:299:Check failed:from.shape()== to> shape()operands形状 mismatchfrom.shape =(1,1)= to.shape(1100)中的错误 EXEC $ update.arg.arrays(arg.arrays,match.name,skip.null):

[16:07 :02] D:\ Program Files (x86)\ Jenkins \ workspace \ mxnet \ mxnet \ src \ ndarray \ ndarr ay.cc:299:检查 失败:from.shape()== TO->形状()操作数形状 mismatchfrom.shape =(1,1)= to.shape(1100)

回答

1

原因对于尺寸不匹配的情况,您传递的尺寸与label的尺寸不匹配。一个RNN为序列的每个抽头都有一个输出,所以如果你的长度是100,它将有100个输出,每个时间步一个。您可以通过将m1设置为matrix(round(runif(100*200)), 100, 200)来解决此错误,但是您无法使用简化的mx.rnn()接口完成您想要的操作(即预测整个序列的一个数字)。您需要根据代码here实施您自己的网络。为了实现您正在寻找的单输出,您可以放弃除上一次输出以外的所有输出,并通过Softmax图层运行该输出。