2017-08-25 54 views
0

我试图在tensorflow中插入一维张量(我实际上想要相当于np.interp)。由于我找不到类似的tensorflow操作,我必须自己执行插值。张量流中的二进制搜索和插值

第一步是在x值的排序列表中搜索y值中的相应索引,即执行二进制搜索。我尝试使用while循环,但我得到了一个神秘的运行时错误。下面是一些代码:

xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis') 
query = tf.placeholder(tf.float32, name='query') 

with tf.name_scope("binsearch"): 
    up = tf.Variable(0, dtype=tf.int32, name='up') 
    mid = tf.Variable(0, dtype=tf.int32, name='mid') 
    down = tf.Variable(0, dtype=tf.int32, name='down') 
    done = tf.Variable(-1, dtype=tf.int32, name='done')   

    def cond(up, down, mid, done): 
     return tf.logical_and(done<0,up-down>1) 

    def body(up, down, mid, done): 
     val = tf.gather(xaxis, mid) 
     done = tf.cond(val>query, 
         tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: -1), 
         tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: -1)) 
     up = tf.cond(val>query, lambda: mid, lambda: up) 
     down = tf.cond(val<query, lambda: mid, lambda: down) 

     with tf.control_dependencies([done, up, down]): 
      return up, down, (up+down)//2, done 

    up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1)) 

这导致

AttributeError: 'int' object has no attribute 'name' 

我使用的是Windows 7中的Python 3.6,并与GPU的支持tensorflow 1.1。任何想法有什么不对? 谢谢。

下面是完整的堆栈跟踪:

AttributeError       Traceback (most recent call last) 
<ipython-input-185-693d3873919c> in <module>() 
    19    return up, down, (up+down)//2, done 
    20 
---> 21  up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1)) 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name) 
    2621  context = WhileContext(parallel_iterations, back_prop, swap_memory, name) 
    2622  ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context) 
-> 2623  result = context.BuildLoop(cond, body, loop_vars, shape_invariants) 
    2624  return result 
    2625 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildLoop(self, pred, body, loop_vars, shape_invariants) 
    2454  self.Enter() 
    2455  original_body_result, exit_vars = self._BuildLoop(
-> 2456   pred, body, original_loop_vars, loop_vars, shape_invariants) 
    2457  finally: 
    2458  self.Exit() 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants) 
    2404   structure=original_loop_vars, 
    2405   flat_sequence=vars_for_body_with_tensor_arrays) 
-> 2406  body_result = body(*packed_vars_for_body) 
    2407  if not nest.is_sequence(body_result): 
    2408  body_result = [body_result] 

<ipython-input-185-693d3873919c> in body(up, down, mid, done) 
    11   val = tf.gather(xaxis, mid) 
    12   done = tf.cond(val>query, 
---> 13      tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: -1), 
    14      tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: -1)) 
    15   up = tf.cond(val>query, lambda: mid, lambda: up) 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in cond(pred, fn1, fn2, name) 
    1746  context_f = CondContext(pred, pivot_2, branch=0) 
    1747  context_f.Enter() 
-> 1748  _, res_f = context_f.BuildCondBranch(fn2) 
    1749  context_f.ExitResult(res_f) 
    1750  context_f.Exit() 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildCondBranch(self, fn) 
    1666    real_v = sparse_tensor.SparseTensor(indices, values, dense_shape) 
    1667   else: 
-> 1668    real_v = self._ProcessOutputTensor(v) 
    1669   result.append(real_v) 
    1670  return original_r, result 

c:\program files\python36\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in _ProcessOutputTensor(self, val) 
    1624  """Process an output tensor of a conditional branch.""" 
    1625  real_val = val 
-> 1626  if val.name not in self._values: 
    1627  # Handle the special case of lambda: x 
    1628  self._values.add(val.name) 

AttributeError: 'int' object has no attribute 'name' 
+1

你可以包含完整的堆栈跟踪吗?这是产生该错误的代码吗?运行你有的代码,我得到'TypeError:true_fn必须是可调用的 – user3080953

回答

1

我不知道你的错误的原因,但我可以告诉你,tf.while_loop很可能会很慢。您可以实现直线插补无环路这样的:

import numpy as np 
import tensorflow as tf 

xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis') 
yaxis = tf.placeholder(tf.float32, shape=100, name='yaxis') 
query = tf.placeholder(tf.float32, name='query') 

# Add additional elements at the beginning and end for extrapolation 
xaxis_pad = tf.concat([[tf.minimum(query - 1, xaxis[0])], xaxis, [tf.maximum(query + 1, xaxis[-1])]], axis=0) 
yaxis_pad = tf.concat([yaxis[:1], yaxis, yaxis[-1:]], axis=0) 

# Find the index of the interval containing query 
cmp = tf.cast(query >= xaxis_pad, dtype=tf.int32) 
diff = cmp[1:] - cmp[:-1] 
idx = tf.argmin(diff) 

# Interpolate 
alpha = (query - xaxis_pad[idx])/(xaxis_pad[idx + 1] - xaxis_pad[idx]) 
res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx] 

# Test with f(x) = 2 * x 
q = 5.4 
x = np.arange(100) 
y = 2 * x 
with tf.Session() as sess: 
    q_interp = sess.run(res, feed_dict={xaxis: x, yaxis: y, query: q}) 
print(q_interp) 
>>> 10.8 

的填充部分只是为了避免麻烦,如果你传递值超出范围,否则它只是一个比较和查找的身在何处的值开始大于query

0

发现问题 - tensorflow不喜欢python整数作为cond的一个参数 - 它需要首先被封装在一个常量中。此代码的工作原理:

with tf.name_scope("binsearch"): 
    m_one = tf.constant(-1, dtype=tf.int32, name='minus_one') 
    up = tf.Variable(0, dtype=tf.int32, name='up') 
    mid = tf.Variable(0, dtype=tf.int32, name='mid') 
    down = tf.Variable(0, dtype=tf.int32, name='down') 
    done = tf.Variable(-1, dtype=tf.int32, name='done') 

    def cond(up, down, mid, done): 
     return tf.logical_and(done<0,up-down>1) 

    def body(up, down, mid, done): 

     def fn1(): 
      return mid, down, (mid+down)//2, tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: m_one) 

     def fn2(): 
      return up, mid, (up+mid)//2, tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: m_one) 

     return tf.cond(tf.gather(xaxis, mid)>query, fn1, fn2) 

    up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1))