2017-02-26 21 views
2

给出两条曲线的交点我有两个数据集:(X,Y 1)和(x,y2)上。我想找到这两条曲线相互交叉的位置。我们的目标是类似这样的问题:Intersection of two graphs in Python, find the x value:查找(X,Y)数据以高精度在Python

然而,所描述的方法只发现交点到最接近的数据点。我想找到比原始数据间距具有更高精度的曲线的交点。一种选择是简单地重新插值到更精细的网格。这是有效的,但是然后精度由我选择用于重新插值的点的数量决定,这是任意的,并且需要在精度和效率之间进行权衡。

可替换地,我可以使用scipy.optimize.fsolve查找数据集的所述两个花键插补的确切交集。这很好,但它不容易找到多个交点,要求我为交点提供合理的猜测,并且可能不能很好地缩放。 (最后,我想找到的几千套(X,Y1,Y2)的交叉点,所以一个有效的算法将是很好的。)

这是我到目前为止所。任何改进想法?

import numpy as np 
import matplotlib.pyplot as plt 
import scipy.interpolate, scipy.optimize 

x = np.linspace(1, 4, 20) 
y1 = np.sin(x) 
y2 = 0.05*x 

plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1') 
plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2') 

idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0) 

plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method') 

interp1 = scipy.interpolate.InterpolatedUnivariateSpline(x, y1) 
interp2 = scipy.interpolate.InterpolatedUnivariateSpline(x, y2) 

new_x = np.linspace(x.min(), x.max(), 100) 
new_y1 = interp1(new_x) 
new_y2 = interp2(new_x) 
idx = np.argwhere(np.diff(np.sign(new_y1 - new_y2)) != 0) 
plt.plot(new_x[idx], new_y1[idx], 'ro', ms=7, label='Nearest data-point method, with re-interpolated data') 

def difference(x): 
    return np.abs(interp1(x) - interp2(x)) 

x_at_crossing = scipy.optimize.fsolve(difference, x0=3.0) 
plt.plot(x_at_crossing, interp1(x_at_crossing), 'cd', ms=7, label='fsolve method') 

plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left') 

plt.savefig('curve crossing.png', dpi=200) 
plt.show() 

enter image description here

+0

是不是总有精度和效率之间的权衡?您可以继续插入更细的网格,直到您的答案收敛到可容忍的数量范围内。 – Crispin

+1

是不是来自网格交叉点的近似信息正是您设置样条交叉问题所需的信息?我能看到的唯一问题就是如果在单个网格单元中有多个交点。我将运行网格交叉点,然后使用其中的答案来解决样条交集,使用样条线限制到发现的网格交点附近的几个单元格。 – mcdowella

+0

waterboy5281,我想你是对的,在给定相同算法的情况下,通常会在效率和精度之间进行权衡。但是,更好的算法通常既快速又精确。 @mcdowella,我喜欢通过“最近的数据点”法求交点的近似位置,然后使用该信息以使其更容易找到精确交叉点的想法。我会尽力实现这一点。 – DanHickstein

回答

1

最好的(也是最高效的)答案很可能取决于数据集以及它们是如何采样。但是,对于许多数据集来说,一个很好的近似值是它们在数据点之间几乎是线性的。因此,我们可以通过原始文章中显示的“最近的数据点”方法找到交集的大概位置。然后,我们可以使用线性插值来细化最近两个数据点之间的交点位置。

这种方法是非常快,并与2D numpy的数组,你想同时计算多条曲线的交叉工作,万一(我想在我的应用程序执行)。

(我借用“How do I compute the intersection point of two lines in Python?”代码的线性插值。)

from __future__ import division 
import numpy as np 
import matplotlib.pyplot as plt 

def interpolated_intercept(x, y1, y2): 
    """Find the intercept of two curves, given by the same x data""" 

    def intercept(point1, point2, point3, point4): 
     """find the intersection between two lines 
     the first line is defined by the line between point1 and point2 
     the first line is defined by the line between point3 and point4 
     each point is an (x,y) tuple. 

     So, for example, you can find the intersection between 
     intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5) 

     Returns: the intercept, in (x,y) format 
     """  

     def line(p1, p2): 
      A = (p1[1] - p2[1]) 
      B = (p2[0] - p1[0]) 
      C = (p1[0]*p2[1] - p2[0]*p1[1]) 
      return A, B, -C 

     def intersection(L1, L2): 
      D = L1[0] * L2[1] - L1[1] * L2[0] 
      Dx = L1[2] * L2[1] - L1[1] * L2[2] 
      Dy = L1[0] * L2[2] - L1[2] * L2[0] 

      x = Dx/D 
      y = Dy/D 
      return x,y 

     L1 = line([point1[0],point1[1]], [point2[0],point2[1]]) 
     L2 = line([point3[0],point3[1]], [point4[0],point4[1]]) 

     R = intersection(L1, L2) 

     return R 

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0) 
    xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1]))) 
    return xc,yc 

def main(): 
    x = np.linspace(1, 4, 20) 
    y1 = np.sin(x) 
    y2 = 0.05*x 

    plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1') 
    plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2') 

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0) 

    plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method') 

    # new method! 
    xc, yc = interpolated_intercept(x,y1,y2) 
    plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation') 


    plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left') 

    plt.savefig('curve crossing.png', dpi=200) 
    plt.show() 

if __name__ == '__main__': 
    main() 

Curve crossing

相关问题