2017-06-03 16 views
0

我想实现一个KNN 1D估计:有超过一个元素的数组的Pyplot真值是不确定

# nearest neighbors estimate 
def nearest_n(x, k, data): 
    # Order dataset 
    #data = np.sort(data, kind='mergesort') 
    nnb = [] 
    # iterate over all data and get k nearest neighbours around x 
    for n in data: 
     if nnb.__len__()<k: 
      nnb.append(n) 
     else: 
      for nb in np.arange(0,k): 
       if np.abs(x-n) < np.abs(x-nnb[nb]): 
        nnb[nb] = n 
        break 

    nnb = np.array(nnb) 
    # get volume(distance) v of k nearest neighbours around x 
    v = nnb.max() - nnb.min() 
    v = k/(data.__len__()*v) 

    return v 

interval = np.arange(-4.0, 8.0, 0.1) 
plt.figure() 
for k in (2,8,35): 
    plt.plot(interval, nearest_n(interval, k,train_data), label=str(o)) 
plt.legend() 
plt.show() 

会抛出:

File "x", line 55, in nearest_n 
    if np.abs(x-n) < np.abs(x-nnb[nb]): 
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() 

我知道错误来自数组输入在plot()中,但我不知道如何在运算符中避免这种情况>/==/<

'data'来自包含浮点数的1D txt文件。

我尝试使用矢量化:

nearest_n = np.vectorize(nearest_n) 

导致:

line 50, in nearest_n 
    for n in data: 
TypeError: 'numpy.float64' object is not iterable 

下面是一个例子,让我们说:

data = [0.5,1.7,2.3,1.2,0.2,2.2] 
k = 2 

nearest_n(1.5)应该然后导致

nbb=[1.2,1.7] 
v = 0.5 

并返回2 /(6 * 0.5)= 2/3

该函数运行例如neares_n(2.0,4,数据),并给出0.0741586011463

+1

你能否包括预期的输出(如果你必须手工完成,你可能需要使用较小的输入) 。 :) – MSeifert

+0

输出将是3个不同的概率密度图(k = 2,8,35),s.th.来自数组[-4,8]的每个值将映射到概率[0,1] –

+0

不,我的意思是调用'nearest_n'的字面结果。例如,'nearest_n(np.arange(-4.0,8.0,0.1),2,np.array([1,2,3]))''应该返回什么?我已经或多或少地选择了这些值,如果需要的话插入更合适的值(如果没有参考实现,则更容易手动计算)。 – MSeifert

回答

0

你在np.arange(-4, 8, .01)传递作为x ,这是一组值。所以x - n是一个长度与x相同的数组,在这种情况下是120个元素,因为减去一个数组和一个标量确实是逐元素减法。与nnb[nb]一样。因此,比较的结果是一个长度为120的数组,其布尔值取决于np.abs(x-n)的每个元素是否小于np.abs(x-nnb[nb])的对应元素。这不能直接用作条件,你需要将这些值合并为一个布尔值(使用all()any()或者只是重新考虑代码)。

+0

嗨,感谢我的回答,请看我的。这只是我期望pyplot工作有点不同 –

0
plt.figure() 
X = np.arange(-4.0,8.0,0.1) 
for k in [2,8,35]: 
    Y = [] 
    for n in X: 
     Y.append(nearest_n(n,k,train_data)) 
    plt.plot(X,Y,label=str(k)) 
plt.show() 

工作正常。我认为pyplot.plot会为我做这件事情,但我想它不会...

+0

这不是'pyplot'的问题,我不知道你为什么认为它可能是?你写了'nearest_n'来为'x'参数取一个标量,所以如果不重写你的代码就不能传入一个向量。在这里,你正在遍历一个向量,并且每次都将一个标量传递给你的函数。 – spruceb

+0

我以为pyplot会像这样处理矢量输入,但我错了 –

+0

我只是想澄清一下,因为我不确定你是否理解了问题的根源。这个错误并没有出现在plt中。plot'功能,是不是因为你的投入'pyplot',误差在'nearest_n'抛出是由于传递给函数的参数。 – spruceb

相关问题