2017-08-12 19 views
0

我有一个N体模拟,可以生成一个粒子位置列表,用于模拟中的多个时间步长。对于给定的帧,我想要生成粒子索引对的列表(i, j),例如dist(p[i], p[j]) < masking_radius。基本上我创建了一个“交互”对的列表,其中这些对在彼此的一定距离内。我目前的实现看起来是这样的:因为大量颗粒的有效的粒子对相互作用计算

interaction_pairs = [] 

# going through each unique pair (order doesn't matter) 
for i in range(num_particles): 
    for j in range(i + 1, num_particles): 
     if dist(p[i], p[j]) < masking_radius: 
      interaction_pairs.append((i,j)) 

,这个过程需要很长的时间(每个测试> 1个小时),并且它严重地限制什么,我需要用做数据。我想知道是否有更有效的方式来构造数据,以便计算这些对将更有效,而不是比较每个可能的粒子组合。我正在研究KDTrees,但我无法想出一种方法来利用它们来更有效地计算这个值。任何帮助表示赞赏,谢谢!

+0

2维或3维的点?如果2你可以尝试一下Cormen等人的*算法简介*中给出的“寻找最接近的一对点”算法的变体。 –

+0

你的算法的时间复杂度为'O(n ** 2)',而Cormen的'O(n * log(n))'有。你也可以尝试加速Python代码的[pypy](http://pypy.org/)。 –

回答

0

既然你使用python,sklearn有最近的邻居多个实现发现: http://scikit-learn.org/stable/modules/neighbors.html

有KDTree和Balltree提供。

至于KDTree的主要观点是将你拥有的所有粒子推入KDTree,然后对每个粒子询问查询:“给我范围X内的所有粒子”。 KDtree通常比bruteforce搜索更快。 你可以阅读更多的例子在这里:https://www.cs.cmu.edu/~ckingsf/bioinfo-lectures/kdtrees.pdf

如果使用2D或3D空间,那么另一个选择是刚刚晋级的空间为大网格(屏蔽半径其单元尺寸),每个粒子分配到一个网格细胞。然后,您可以通过检查相邻单元来找到可能的相互作用的候选对象(但您也必须执行距离检查,但对于更少的粒子对)。

0

这是一个使用普通Python的相当简单的技术,可以减少所需的比较次数。

我们首先按X,Y或Z轴(在下面的代码中由axis选择)对点进行排序。假设我们选择X轴。然后我们循环点对,就像你的代码一样,但是当我们发现一对距离大于masking_radius的对时,我们测试它们的X坐标差是否也大于masking_radius。如果是,那么我们可以从内部j循环中退出,因为具有更大的j的所有点都具有更大的X坐标。

我的dist2函数计算平方距离。这比计算实际距离要快,因为计算平方根相对较慢。

我还包括了代码,它们的行为与您的代码相似,即测试每对点,以进行速度比较;它也用于检查快速代码是否正确。)

from random import seed, uniform 
from operator import itemgetter 

seed(42) 

# Make some fake data 
def make_point(hi=10.0): 
    return [uniform(-hi, hi) for _ in range(3)] 

psize = 1000 
points = [make_point() for _ in range(psize)] 

masking_radius = 4.0 
masking_radius2 = masking_radius ** 2 

def dist2(p, q): 
    return (p[0] - q[0])**2 + (p[1] - q[1])**2 + (p[2] - q[2])**2 

pair_count = 0 
test_count = 0 

do_fast = 1 
if do_fast: 
    # Sort the points on one axis 
    axis = 0 
    points.sort(key=itemgetter(axis)) 

    # Fast 
    for i, p in enumerate(points): 
     left, right = i - 1, i + 1 
     for j in range(i + 1, psize): 
      test_count += 1 
      q = points[j] 
      if dist2(p, q) < masking_radius2: 
       #interaction_pairs.append((i, j)) 
       pair_count += 1 
      elif q[axis] - p[axis] >= masking_radius: 
       break 

     if i % 100 == 0: 
      print('\r {:3} '.format(i), flush=True, end='') 

    total_pairs = psize * (psize - 1) // 2 
    print('\r {}/{} tests'.format(test_count, total_pairs)) 

else: 
    # Slow 
    for i, p in enumerate(points): 
     for j in range(i+1, psize): 
      q = points[j] 
      if dist2(p, q) < masking_radius2: 
       #interaction_pairs.append((i, j)) 
       pair_count += 1 

     if i % 100 == 0: 
      print('\r {:3} '.format(i), flush=True, end='') 

print('\n', pair_count, 'pairs') 

输出do_fast = 1

181937/499500 tests 

13295 pairs 

输出do_fast = 0

13295 pairs 

当然,如果大部分的点对在彼此的masking_radius,有在使用这种技术方面没有太大的好处神游。排序点增加了一点时间,但是Python的TimSort效率很高,尤其是在数据已经部分排序的情况下,所以如果数据量足够小,您应该看到速度的显着提高。