我们可以在索引方面稍微聪明些,节省大约4倍的成本。
首先让建立正确的形状的一些数据:
seed = np.random.randint(0, 100, (200,206))
data = np.random.randint(0, 100, (4e5,206))
seed[:, 0] = np.arange(200)
data[:, 0] = np.random.randint(0, 200, 4e5)
diam = np.empty(200)
原答复的时间:
%%timeit
for i in range(200):
diam[i] = spd.cdist(seed[np.newaxis, i, 1:], data[data[:, 0]==i][:,1:]).max()
1 loops, best of 3: 1.35 s per loop
moarningsun的回答是:
%%timeit
seed_repeated = seed[data[:,0]]
dist_to_center = np.sqrt(np.sum((data[:,1:]-seed_repeated[:,1:])**2, axis=1))
diam = np.zeros(len(seed))
np.maximum.at(diam, data[:,0], dist_to_center)
1 loops, best of 3: 1.33 s per loop
Divakar的回答是:
%%timeit
data_sorted = data[data[:, 0].argsort()]
seed_ext = np.repeat(seed,np.bincount(data_sorted[:,0]),axis=0)
dists = np.sqrt(((data_sorted[:,1:] - seed_ext[:,1:])**2).sum(1))
shift_idx = np.append(0,np.nonzero(np.diff(data_sorted[:,0]))[0]+1)
diam_out = np.maximum.reduceat(dists,shift_idx)
1 loops, best of 3: 1.65 s per loop
正如我们所看到的,除了更大的内存占用之外,还没有真正获得任何矢量化解决方案。为了避免这种情况,我们需要返回到原来的答案,这是真的做这些事情的正确方法,而是试图减少索引量:
%%timeit
idx = data[:,0].argsort()
bins = np.bincount(data[:,0])
counter = 0
for i in range(200):
data_slice = idx[counter: counter+bins[i]]
diam[i] = spd.cdist(seed[None, i, 1:], data[data_slice, 1:]).max()
counter += bins[i]
1 loops, best of 3: 281 ms per loop
仔细检查答案:
np.allclose(diam, dam_out)
True
这是假设python循环不好的问题。他们往往是,但不是在所有情况下。
这实际上是相当合理的代码。你的for循环相对于'cdist'内完成的计算量相对较小。由于'cdist'是一个相当优化的速度,收益不会很大。 – Daniel
@Ophion - 虽然可以避免重复的线性搜索data [:, 0] == i,从O(n ** 2)到O(n log(n))甚至O(n )。 – 2015-11-06 20:19:23
@moarningsun是的,但是可能的和可用的是两个不同的东西,特别是考虑到O(n * m)而不是O(n^2)和n << m。到目前为止,没有任何解决方案比OP更快,并且所有解决方案都有更多的内存开销。 – Daniel