这是一个使用普通Python的相当简单的技术,可以减少所需的比较次数。
我们首先按X,Y或Z轴(在下面的代码中由axis
选择)对点进行排序。假设我们选择X轴。然后我们循环点对,就像你的代码一样,但是当我们发现一对距离大于masking_radius
的对时,我们测试它们的X坐标差是否也大于masking_radius
。如果是,那么我们可以从内部j
循环中退出,因为具有更大的j
的所有点都具有更大的X坐标。
我的dist2
函数计算平方距离。这比计算实际距离要快,因为计算平方根相对较慢。
我还包括了代码,它们的行为与您的代码相似,即测试每对点,以进行速度比较;它也用于检查快速代码是否正确。)
from random import seed, uniform
from operator import itemgetter
seed(42)
# Make some fake data
def make_point(hi=10.0):
return [uniform(-hi, hi) for _ in range(3)]
psize = 1000
points = [make_point() for _ in range(psize)]
masking_radius = 4.0
masking_radius2 = masking_radius ** 2
def dist2(p, q):
return (p[0] - q[0])**2 + (p[1] - q[1])**2 + (p[2] - q[2])**2
pair_count = 0
test_count = 0
do_fast = 1
if do_fast:
# Sort the points on one axis
axis = 0
points.sort(key=itemgetter(axis))
# Fast
for i, p in enumerate(points):
left, right = i - 1, i + 1
for j in range(i + 1, psize):
test_count += 1
q = points[j]
if dist2(p, q) < masking_radius2:
#interaction_pairs.append((i, j))
pair_count += 1
elif q[axis] - p[axis] >= masking_radius:
break
if i % 100 == 0:
print('\r {:3} '.format(i), flush=True, end='')
total_pairs = psize * (psize - 1) // 2
print('\r {}/{} tests'.format(test_count, total_pairs))
else:
# Slow
for i, p in enumerate(points):
for j in range(i+1, psize):
q = points[j]
if dist2(p, q) < masking_radius2:
#interaction_pairs.append((i, j))
pair_count += 1
if i % 100 == 0:
print('\r {:3} '.format(i), flush=True, end='')
print('\n', pair_count, 'pairs')
输出与do_fast = 1
181937/499500 tests
13295 pairs
输出与do_fast = 0
13295 pairs
当然,如果大部分的点对在彼此的masking_radius
,有在使用这种技术方面没有太大的好处神游。排序点增加了一点时间,但是Python的TimSort效率很高,尤其是在数据已经部分排序的情况下,所以如果数据量足够小,您应该看到速度的显着提高。
2维或3维的点?如果2你可以尝试一下Cormen等人的*算法简介*中给出的“寻找最接近的一对点”算法的变体。 –
你的算法的时间复杂度为'O(n ** 2)',而Cormen的'O(n * log(n))'有。你也可以尝试加速Python代码的[pypy](http://pypy.org/)。 –