2017-08-07 15 views
0

我已经构建了一个DL4j项目。一切都很好,如果我使用MNIST数据集如下:DeepLearning4j和DataVec使用标签读取csv文件

DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed); 
    DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed); 

不过,我想切换到我自己的CSV文件格式如下:

A | B | C | X | Y 
------------------------- 
1 | 100 | 5 | 15 | 6 
... 

XY是结果(或标签)。由于我打算执行回归分析,因此XY都是实数。所以,我用下面的代码读取CSV文件:

RecordReader recordReaderTrain = new CSVRecordReader(1, ","); 
    recordReaderTrain.initialize(new FileSplit(new File("src/main/resources/data/Data.csv"))); 
    DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 3, 2); 

3代码意味着index of the labels2意味着number of possible labels。对这两个参数没有太多的解释。我想他们的意思是标签从第四栏开始,有两个标签。

当我运行的代码,它显示了以下异常:

Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 14 

我想是因为dl4j不承认15为标签。

所以我的问题是:如何正确读取csv文件进行回归分析?

非常感谢。

回答

1

权,所以我们有例子的回归: https://github.com/deeplearning4j/dl4j-examples/tree/cc383de91bdf4e28e36859aa2e8749100cd63177/dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/regression

您需要将回归真(它的构造函数的额外部件)与RecordReaderDataSetIterator。

+0

谢谢您的回复。它工作正常。 请问另外一个问题。如何将DataSetIterator中的DataSet转换为'3'图形,如输入,以便我可以使用Convolution网络? –

+0

记住第二个问题的更多细节?谢谢! –