2017-08-12 171 views
0

喜而与Tensorflow对象检测工作,同时与DAT所陈GitHub上提供的代码打转转Tensorflow物体检测

https://medium.com/towards-data-science/building-a-real-time-object-recognition-app-with-tensorflow-and-opencv-b7a2b4ebdc32

我试图找出如何打印出一个消息到控制台基于对象被划分/检测我试图

if classes == 'Louis': 
     print('Hello Louis') 

但似乎并没有工作,我也尝试做以下

for label in classes: 
      if ('{Name}'.format(**label) == 'louis'): 
       print('Hello Louis') 

但我得到以下错误

File "object_detection_app.py", line 61, in detect_objects 
    if ('{Name}'.format(**label) == 'person'): 
TypeError: format() argument after ** must be a mapping, not numpy.ndarray 

它的工作原理,如果我做一个其他语句与输出你好。无论是否检测到对象,它都会保持打印。

from utils import FPS, WebcamVideoStream 
from multiprocessing import Queue, Pool 
from object_detection.utils import label_map_util 
from object_detection.utils import visualization_utils as vis_util 

CWD_PATH = os.getcwd() 

# Path to frozen detection graph. This is the actual model that is used for the object detection. 
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
PATH_TO_CKPT = os.path.join(CWD_PATH, 'object_detection', MODEL_NAME, 'frozen_inference_graph.pb') 

# List of the strings that is used to add correct label for each box. 
PATH_TO_LABELS = os.path.join(CWD_PATH, 'object_detection', 'data', 'myHousePets_label_map.pbtxt') 

NUM_CLASSES = 90 

# Loading label map 
label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, 
                  use_display_name=True) 
category_index = label_map_util.create_category_index(categories) 


def detect_objects(image_np, sess, detection_graph): 
    # Expand dimensions since the model expects images to have shape: [1, None, None, 3] 
    image_np_expanded = np.expand_dims(image_np, axis=0) 
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 

    # Each box represents a part of the image where a particular object was detected. 
    boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 

    # Each score represent how level of confidence for each of the objects. 
    # Score is shown on the result image, together with the class label. 
    scores = detection_graph.get_tensor_by_name('detection_scores:0') 
    classes = detection_graph.get_tensor_by_name('detection_classes:0') 
    num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
    # Actual detection. 
    (boxes, scores, classes, num_detections) = sess.run(
     [boxes, scores, classes, num_detections], 
     feed_dict ={image_tensor: image_np_expanded}) 

    # Visualization of the results of a detection. 
    vis_util.visualize_boxes_and_labels_on_image_array(
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 

    if classes == 'Louis': 
     print('Hello Louis') 
    return image_np 


def worker(input_q, output_q): 
    # Load a (frozen) Tensorflow model into memory. 
    detection_graph = tf.Graph() 
    with detection_graph.as_default(): 
     od_graph_def = tf.GraphDef() 
     with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
      serialized_graph = fid.read() 
      od_graph_def.ParseFromString(serialized_graph) 
      tf.import_graph_def(od_graph_def, name='') 

     sess = tf.Session(graph=detection_graph) 

    fps = FPS().start() 
    while True: 
     fps.update() 
     frame = input_q.get() 
     output_q.put(detect_objects(frame, sess, detection_graph)) 
    fps.stop() 
    sess.close() 


if __name__ == '__main__': 
    parser = argparse.ArgumentParser() 
    parser.add_argument('-src', '--source', dest='video_source', type=int, 
         default=0, help='Device index of the camera.') 
    parser.add_argument('-wd', '--width', dest='width', type=int, 
         default=480, help='Width of the frames in the video stream.') 
    parser.add_argument('-ht', '--height', dest='height', type=int, 
         default=360, help='Height of the frames in the video stream.') 
    parser.add_argument('-num-w', '--num-workers', dest='num_workers', type=int, 
         default=2, help='Number of workers.') 
    parser.add_argument('-q-size', '--queue-size', dest='queue_size', type=int, 
         default=5, help='Size of the queue.') 
    args = parser.parse_args() 

    logger = multiprocessing.log_to_stderr() 
    logger.setLevel(multiprocessing.SUBDEBUG) 

    input_q = Queue(maxsize=args.queue_size) 
    output_q = Queue(maxsize=args.queue_size) 
    pool = Pool(args.num_workers, worker, (input_q, output_q)) 

    video_capture = WebcamVideoStream(src=args.video_source, 
             width=args.width, 
             height=args.height).start() 
    fps = FPS().start() 

    while True: # fps._numFrames < 120 
     frame = video_capture.read() 
     input_q.put(frame) 

     t = time.time() 

     cv2.imshow('Video', output_q.get()) 
     fps.update() 



     if cv2.waitKey(1) & 0xFF == ord('q'): 
      break 

    fps.stop() 
    print('[INFO] elapsed time (total): {:.2f}'.format(fps.elapsed())) 
    print('[INFO] approx. FPS: {:.2f}'.format(fps.fps())) 
    pool.terminate() 
    video_capture.stop() 
    cv2.destroyAllWindows() 

回答

0

从错误信息,classes是一个多维数组,所以你不能仅仅把它比作不断Louis

我不知道它的结构所以这里是我的建议。

可视化前添加print classes.shape。这将控制台记录数组的行数和列数。 Thisthis将帮助你更好地理解它。

为了进一步了解打印结果,我还建议在形状的正下方添加print classes以便您看到阵列的实际内容并进一步了解其形状。

这样做之后,你可以继续,并相应地遍历它来寻找Louis.nditer()是这样做的有效方式。 Here是一些可能感兴趣的例程。