2017-01-05 137 views
4

我工作的这个教程:如何从小批次获得标签?

https://github.com/Microsoft/CNTK/blob/master/Tutorials/CNTK_201B_CIFAR-10_ImageHandsOn.ipynb

测试/火车数据文件包含图像文件名和正确的标签像这样简单的制表符分隔文本文件:

...\data\CIFAR-10\test\00000.png 3 
...\data\CIFAR-10\test\00001.png 8 
...\data\CIFAR-10\test\00002.png 8 

我怎样才能从minibatch中提取原始标签?

我曾尝试使用此代码:

reader_test = MinibatchSource(ImageDeserializer('test_map.txt', StreamDefs(
    features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image' 
    labels = StreamDef(field='label', shape=num_classes)  # and second as 'label' 
))) 

test_minibatch = reader_test.next_minibatch(10) 
labels_stream_info = reader_test['labels'] 
orig_label = test_minibatch[labels_stream_info].value 
print(orig_label) 

<cntk.cntk_py.Value; proxy of <Swig Object of type 'CNTK::ValuePtr *' at 0x0000000007A32C00> > 

但是,正如你看到上面的结果并不与标签的数组。

什么是正确的代码去标签?

此代码有效,但它使用不同的文件格式,而不是ImageDeserializer。

文件格式:

|labels 0 0 1 0 0 0 |features 0 
|labels 1 0 0 0 0 0 |features 457 

工作代码:

mb_source = text_format_minibatch_source('test_map2.txt', [ 
    StreamConfiguration('features', 1), 
    StreamConfiguration('labels', num_classes)]) 

test_minibatch = mb_source.next_minibatch(2) 

labels_stream_info = mb_source['labels'] 
orig_label = test_minibatch[labels_stream_info].value 
print(orig_label) 

[[[ 0. 0. 1. 0. 0. 0.]] 
[[ 1. 0. 0. 0. 0. 0.]]] 

我怎样才能使用ImageDeserializer当输入的标签?

回答

2

你可以尝试使用:

orig_label = test_minibatch[labels_stream_info].value 
+0

我试过你的建议,但我仍然得到相同的结果。 – OlavT

1

我只是想重播 - 我想这里有一些奇怪的bug。我的直觉是,事实上labels对象不会作为有效的numpy数组返回。我插入下面的调试输出到train_and_evaluate功能教程CNTK_201B

for epoch in range(max_epochs):  # loop over epochs 
    sample_count = 0 
    while sample_count < epoch_size: # loop over minibatches in the epoch 
     data = reader_train.next_minibatch(min(minibatch_size, epoch_size - sample_count), input_map=input_map) # fetch minibatch. 
     print("Features:") 
     print(data[input_var].shape) 
     print(data[input_var].value.shape) 
     print("Labels:") 
     print(data[label_var].shape) 
     print(data[label_var].value.shape) 

输出:

Training 116906 parameters in 10 parameter tensors. 
Features: 
(64, 1, 3, 32, 32) 
(64, 1, 3, 32, 32) 
Labels: 
(64, 1, 10) 
() 

标签出来的似乎是一个numpy.ndarray,但它不具有有效shape

我会打电话给一个错误。