2015-10-06 41 views
2

我有一个单一的3D数值数据文件,我从块中读取(因为在块中读取比单个索引快)。例如说有一个MxNx30阵列中“文件”,我会创建一个这样的RDD:在Pyspark的RDD分区中分割数组

def read(ind): 
    f = customFileOpener(file) 
    return f['data'][:,:,ind[0]:ind[-1]+1] 

indices = [[0,9],[10,19],[20,29]] 
rdd = sc.parallelize(indices,3).map(lambda v:read(v)) 
rdd.count() 

所以各3个分区的大小为MxNx10的numpy.ndarray元件。

现在,我想分割每个分区中的每个元素,我有10个元素,每个元素是一个MxN数组。我试着用flatMap()用于此目的,但得到的错误“NoneType对象不是可迭代”:

def splitArr(arr): 
    Nmid = arr.shape[-1] 
    out = [] 
    for i in range(0,Nmid): 
     out.append(arr[...,i]) 
    return out 

rdd2 = rdd.flatMap(lambda v: splitArr(v)) 
rdd2.count() 

什么是做这种正确的方法是什么?关键点是(a)我需要从文件中以块读取数据和(b)拆分数据,因此元素的大小为MxN(最好保留分区结构)。

回答

2

据我了解你的描述是这样的应该做的伎俩:

rdd.flatMap(lambda arr: (x for x in np.rollaxis(arr, 2))) 

或者如果你喜欢一个单独的函数:

def splitArr(arr): 
    for x in np.rollaxis(arr, 2): 
     yield x 

rdd.flatMap(splitArr) 
+0

我明白了,我应该从可迭代猜测错误。将尺寸移动到数组的前面并用rollaxis分割,然后迭代这些元素。正是我想要的,非常感谢。 – Michael