2013-10-21 127 views
1

我想分析很多让我的程序运行缓慢的数据。 我正在读取从.txt文件到单元格数组的数据集。 我正在使用单元格数组来分类我的数据,这是两个属性的形式,我需要这些类是字符。提高嵌套循环的性能MATLAB

我想使用最近的均值分类器找到重新排列错误。 我有一个主要的外部循环,它遍历我的数据集的每一行(数以万计)。依次移除每一行,每次迭代一行。在删除线条的每次迭代中重新计算两个属性的平均值。主挂点似乎成为下一个部分,在那里我需要计算在我的数据集的每一行:

  • 的数据之间上线(2个属性值)和 的欧几里得距离各自的均值我类。
  • 然后我想记录其属性平均值最接近的类,这将是它的分配类。
  • 最后,我想检查这个分配的类是否是正确的 类。

目前这个循环看起来像这样。

errorCount = 0; 
for l = 1:20000 
    closest = 100; 
    class = 0; 
    attribute1 = d{2}(l); 
    attribute2 = d{3}(l); 
    for m = 1:numel(classes) 
     dist = sqrt((attribute1-meansattr1(m))*(attribute1-meansattr1(m)) + (attribute2-meansattr2(m))*(attribute2-meansattr2(m))); 
     if dist < closest 
      closest = dist; 
      class = m; 
     end 
    end 

    if strcmp(d{1}(l),classes(class)) 
     %correct 
    else 
     errorCount = errorCount + 1; 
    end 
end 

d是我的细胞阵列,其中d{2}是保持我的属性1值的列。我通过d{1}(1)获取了该列中第一行的这些值。

classes是我的数据集中的独特类,所以对于我的每个类,我计算它的欧几里得距离。

meansattr1meansattr2是包含我的每个属性的平均值的数组。当线被移除时,这些更新在外部循环的每次迭代中。

希望能帮助您理解我拥有的代码。非常感谢在优化和加速这些计算方面的任何帮助。

+2

最简单的速度改进是删除'sqrt'调用。查找最近距离的平方与最近距离完全相同。 – paddy

回答

1

您基本上正在优化k-means算法的迭代部分,因此您可以参考my previous solution来获得向量化的方法。但是,这里是你如何处理你的问题和数据格式。

采取随机数据集,像下面,

numClasses = 5; 
numPoints = 20e3; 
numDims = 2; 

classes = strsplit(num2str(1:numClasses)); 

% generate random data (expected error rate of (numClasses-1)/numClasses) 
d{1} = classes(randi(numClasses,numPoints,1)); 
d{2} = rand(numPoints,1); 
d{3} = rand(numPoints,1); 

% random initial class centers 
meansattr1 = rand(5,1); 
meansattr2 = rand(5,1); 

您的代码,压缩和存储每个点的最接近的类ID及该类的距离:

closestDistance = zeros(numPoints,1); nearestCluster = zeros(numPoints,1); 
errorCount = 0; 
for l = 1:numPoints 
    closest = 100; iclass = 0; 
    attribute1 = d{2}(l); attribute2 = d{3}(l); 

    for m = 1:numel(classes) 
     dist = sqrt((attribute1-meansattr1(m))*(attribute1-meansattr1(m)) + ... 
      (attribute2-meansattr2(m))*(attribute2-meansattr2(m))); 
     if dist < closest 
      closest = dist; closestDistance(l) = closest; 
      iclass = m; nearestCluster(l) = iclass; 
     end 
    end 

    if ~strcmp(d{1}(l),classes(iclass)) 
     errorCount = errorCount + 1; 
    end 
end 

的矢量版本那么上面是:

data = [d{2}(:) d{3}(:)]; 
meansattr = [meansattr1(:) meansattr2(:)]; 

kdiffs = bsxfun(@minus,data,permute(meansattr,[3 2 1])); 

allDistances = sqrt(sum(kdiffs.^2,2)); % no need to do sqrt 
allDistances = squeeze(allDistances); % Nx1xK => NxK 

[closestDistance,nearestCluster] = min(allDistances,[],2); % Nx1 

correctClassIds = str2num(char(d{1}(:))); 
errorCount = nnz(nearestCluster ~= correctClassIds); 

结果在errorCountclosestDistancenearestCluster等同于以前的解决方案。如代码注释所示,您可以删除sqrt并在errorCountnearestCluster中获得相同的结果。

说你想要做的更新meansattr1meansattr2下一步:

% Calculate the NEW cluster centers (mean the data) 
meansattr_new = zeros(numClasses,numDims); 
clustersizes = zeros(numClasses,1); 
for ii=1:numClasses, 
    indk = nearestCluster==ii; 
    clustersizes(ii) = nnz(indk); 
    meansattr_new(ii,:) = mean(data(indk,:))'; 
end 

meansattr1_next = meansattr_new(:,1); 
meansattr2_next = meansattr_new(:,2); 

把这一切都在while errorCount>THRESHfor jj = 1:MAXITER,你应该有你所追求的。

+1

谢谢,这样做会有明显的性能提升。我将不得不看看我的程序的其他方面,看看我可以做些类似的改进。 –

2

最简单的速度改进是删除sqrt呼叫。查找最近距离的平方与最近距离完全相同。

接下来,您可以矢量化内部循环。自从我做了任何MatLab以来,这已经很长时间了,所以我可能会弄错下面的代码,但是想法是将这两个属性变成长度为numel(classes)的向量。然后,您可以直接计算差异并对其进行平方。

事情是这样的:

d1 = attribute1 - meansattr1; 
d2 = attribute2 - meansattr2; 
[closest, class] = min(d1 .* d1 + d2 .* d2); 

顺便说一句,这不是用class作为变量一个伟大的想法(如果你甚至可以)。这是一个保留字。

+0

'距离最近的= strt(最近的)'在使用距离时丢失,变量包含平方距离。 – Daniel

+0

当然,但原始代码并未显示使用的地方。在所有循环之后很容易取得'sqrt' *。 – paddy

0

我开始用稻谷的解决方案,简单的更换变量名称:

[closest, cl] = min((d{2}(m) - meansattr1).^2 +(d{3}(m) - meansattr2).^ 
2); 

因此,我们有一个线环,共同的策略:做它的一个功能,并把它变成arrayfun:

[email protected](x)min((d{2}(x) - meansattr1).^2 +(d{3}(x) - meansattr2).^2) 
[sqclosest,cl]=arrayfun(f,1:numel(d{2})); 

%If necessary real distances could be calculated: 
%closest=sqrt(sqclosest) 

errorCount=sum(arrayfun(@(x,c)(1-strcmp(x,classes(c))),d{1},cl)) 

注意:请勿将“class”或任何其他保留字用于其他目的。