我已经构建了一个DL4j项目。一切都很好,如果我使用MNIST数据集如下:DeepLearning4j和DataVec使用标签读取csv文件
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
不过,我想切换到我自己的CSV文件格式如下:
A | B | C | X | Y
-------------------------
1 | 100 | 5 | 15 | 6
...
X
和Y
是结果(或标签)。由于我打算执行回归分析,因此X
和Y
都是实数。所以,我用下面的代码读取CSV文件:
RecordReader recordReaderTrain = new CSVRecordReader(1, ",");
recordReaderTrain.initialize(new FileSplit(new File("src/main/resources/data/Data.csv")));
DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 3, 2);
3
代码意味着index of the labels
和2
意味着number of possible labels
。对这两个参数没有太多的解释。我想他们的意思是标签从第四栏开始,有两个标签。
当我运行的代码,它显示了以下异常:
Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 14
我想是因为dl4j不承认15
为标签。
所以我的问题是:如何正确读取csv文件进行回归分析?
非常感谢。
谢谢您的回复。它工作正常。 请问另外一个问题。如何将DataSetIterator中的DataSet转换为'3'图形,如输入,以便我可以使用Convolution网络? –
记住第二个问题的更多细节?谢谢! –