2017-02-27 85 views
0

我有以下Tensorflow代码:张量数据类型为字符串?

import datetime 
import matplotlib.pyplot as plt 
import numpy as np 
import os 
from PIL import Image 
import tensorflow as tf 

image_width = 202 
image_height = 180 
num_channels = 3 

filenames = tf.train.match_filenames_once("./train/Resized/*.jpg") 

def label(label_string): 
    if label_string == 'cat': label = [1,0] 
    if label_string == 'dog': label = [0,1] 

    return label 

def read_image(filename_queue): 
    image_reader = tf.WholeFileReader() 
    key, image_filename = image_reader.read(filename_queue) 
    image = tf.image.decode_jpeg(image_filename) 
    image.set_shape((image_height, image_width, 3)) 

    name = os.path.basename(image_filename) # example "dog.2148.jpg" 
    s = name.split('.') 
    label_string = s[0] 
    label = label(label_string) 

    return image, label 

def input_pipeline(filenames, batch_size, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True) 
    image, label = read_image(filename_queue) 
    min_after_dequeue = 1000 
    capacity = min_after_dequeue + 3 * batch_size 
    image_batch, label_batch = tf.train.shuffle_batch(
     [image, label], batch_size=batch_size, capacity=capacity, 
     min_after_dequeue=min_after_dequeue) 
    return image_batch, label_batch 

image_batch, label_batch = input_pipeline(filenames, 10) 

的最后一条语句失败,出现以下错误:

--------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
<ipython-input-21-0224ec735c33> in <module>() 
----> 1 image_batch, label_batch = input_pipeline(filenames, 10) 

<ipython-input-20-277e29dc1ae3> in input_pipeline(filenames, batch_size, num_epochs) 
     1 def input_pipeline(filenames, batch_size, num_epochs=None): 
     2  filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True) 
----> 3  image, label = read_image(filename_queue) 
     4  min_after_dequeue = 1000 
     5  capacity = min_after_dequeue + 3 * batch_size 

<ipython-input-19-ffe4ec8c3e25> in read_image(filename_queue) 
     5  image.set_shape((image_height, image_width, 3)) 
     6 
----> 7  name = os.path.basename(image_filename) # example "dog.2148.jpg" 
     8  s = name.split('.') 
     9  label_string = s[0] 

C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in basename(p) 
    230 def basename(p): 
    231  """Returns the final component of a pathname""" 
--> 232  return split(p)[1] 
    233 
    234 

C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in split(p) 
    202 
    203  seps = _get_bothseps(p) 
--> 204  d, p = splitdrive(p) 
    205  # set i to index beyond p's last slash 
    206  i = len(p) 

C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in splitdrive(p) 
    137 
    138  """ 
--> 139  if len(p) >= 2: 
    140   if isinstance(p, bytes): 
    141    sep = b'\\' 

TypeError: object of type 'Tensor' has no len() 

我认为这个问题是关系到张量数据类型与字符串数据类型。我如何正确地向os.path.basename函数表明image_filename是一个字符串?

回答

0

的问题是,match_filenames_once返回

A variable that is initialized to the list of files matching pattern.

(在这里看到:https://www.tensorflow.org/api_docs/python/tf/train/match_filenames_once)。

os.path.basename和string.split是在字符串上工作的函数,不在张量上。

我建议你做的是加载张量管道外的图像,这使得我认为你的标签更容易。

+0

好的,但它也必须可以对Tensorflow管道内的那种变量做一些简单处理吗? – OlavT