2017-05-24 17 views
1

我在C++中定义了一个新的Op,它采用tensor类型的单个属性,大致在these instructions之后。操作码的剥离版本如下:TensorFlow接受类型为“张​​量”的Attr的Python类型是什么?

#include "tensorflow/core/framework/op.h" 
#include "tensorflow/core/framework/op_kernel.h" 

using namespace tensorflow; 

REGISTER_OP("DoStuff") 
    .Attr("attr: tensor = { dtype: DT_FLOAT }") 
    .Input("in: float") 
    .Output("out: float"); 

class DoStuffOp : public OpKernel { 
public: 
    explicit DoStuffOp(OpKernelConstruction *context) : OpKernel(context) { 
     OP_REQUIRES_OK(context, context->GetAttr("attr", &attr_)); 
     // ... 
    } 

    void Compute(OpKernelContext *context) override { 
     // ... 
    } 

private: 
    Tensor attr_; 
}; 

REGISTER_KERNEL_BUILDER(Name("DoStuff").Device(DEVICE_CPU), DoStuffOp); 

我可以编译运到.so文件罚款。但是,我无法弄清楚如何成功传入attr的值。当我在Python中运行以下代码:

import tensorflow as tf 
dostufflib = tf.load_op_library('build/do_stuff.so') 
sess = tf.InteractiveSession() 

A = [[1.0, 2.0, 3.0], 
    [4.0, 5.0, 6.0]] 
X = tf.Variable(tf.constant(1.0)) 

Y = dostufflib.do_stuff(X, A) 

我得到TypeError: Don't know how to convert [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] to a TensorProto for argument 'attr'。我所做的任何事似乎都不能满足类型转换:list,numpy array,tf.Tensor,tf.Variable等。如何将Python变量作为张量属性传递到Op中?

回答

2

经过更多的狩猎之后,我找到了tf.contrib.util.make_tensor_proto,这是一个将python标量,python列表,numpy ndarray或numpy标量转换为tf.TensorProto对象的函数。以下作品:

A = tf.contrib.util.make_tensor_proto([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]]) 
X = tf.Variable(tf.constant(1.0)) 

Y = dostufflib.do_stuff(X, A) 
相关问题