2013-07-12 157 views
0

我试图修改下面的代码把第三点对象参数,但此行:在这种情况下,我应该使用reduceLeft方法吗?

val cumulative = points.reduceLeft((a: Point, b: Point, c: Point) => 

原因此编译时错误:

Multiple markers at this line 
    - type mismatch; found : (scala.algorithms.Point, scala.algorithms.Point, scala.algorithms.Point) => 
    scala.algorithms.Point required: (?, scala.algorithms.Point) => ? 
    - type mismatch; found : (scala.algorithms.Point, scala.algorithms.Point, scala.algorithms.Point) => 
    scala.algorithms.Point required: (?, scala.algorithms.Point) => ? 

整个代码:

package scala.algorithms 

/** 
* Modified from http://garysieling.com/blog/implementing-k-means-in-scala 
* 
*/ 

class Point(val x: Double, val y: Double, val z : Double) { 

    override def toString(): String = { 
    "(" + x + ", " + y + ")" 
    } 

    def dist(p: Point): Double = { 
    x * x + y * y + z * z 
    } 
} 

object kmeans extends App { 

    val NUMBER_OF_CLUSTERS = 5; 

    val k: Int = 2 

    val points: List[Point] = List(
    new Point(0, 0, 1), 
    new Point(1, 0, 1), 
    new Point(0, 1, 0)).sortBy(
     p => (p.x + " " + p.y).hashCode()) 

    def clusterMean(points: List[Point]): Point = { 
    val cumulative = points.reduceLeft((a: Point, b: Point, c: Point) => 
     new Point(a.x + b.x + c.x, a.y + b.y + c.y , a.z + b.z + c.z)) 

    new Point(cumulative.x/points.length, cumulative.y/points.length 
     , cumulative.z/points.length) 
    } 

    def render(points: Map[Int, List[Point]]) { 
    for (clusterNumber <- points.keys.toSeq.sorted) { 
     println(" Cluster " + clusterNumber) 

     val meanPoint = clusterMean(points(clusterNumber)) 
     println(" Mean: " + meanPoint) 

     for (j <- 0 to points(clusterNumber).length - 1) { 
     System.out.println(" " + points(clusterNumber)(j) + ")") 
     } 
    } 
    } 

    val clusters = 
    points.zipWithIndex.groupBy(
     x => x._2 % k) transform (
     (i: Int, p: List[(Point, Int)]) => for (x <- p) yield x._1) 

    println("Initial State: ") 
    render(clusters) 

    def iterate(clusters: Map[Int, List[Point]]): Map[Int, List[Point]] = { 
    val unzippedClusters = 
     (clusters: Iterator[(Point, Int)]) => clusters.map(cluster => cluster._1) 

    // find cluster means 
    val means = 
     (clusters: Map[Int, List[Point]]) => 
     for (clusterIndex <- clusters.keys) 
      yield clusterMean(clusters(clusterIndex)) 

    // find the closest index 
    def closest(p: Point, means: Iterable[Point]): Int = { 
     val distances = for (center <- means) yield p.dist(center) 
     distances.zipWithIndex.min._2 
    } 

    // assignment step 
    val newClusters = 
     points.groupBy(
     (p: Point) => closest(p, means(clusters))) 

    render(newClusters) 

    newClusters 
    } 

    var clusterToTest = clusters 
    for (i <- 0 to NUMBER_OF_CLUSTERS) { 
    System.out.println("Iteration: " + i) 
    clusterToTest = iterate(clusterToTest) 
    } 
} 

阅读从http://www.scala-lang.org/api/current/index.html#index.index-r获取的reduceLeft方法的文档:

Applies a binary operator to all elements of this sequence, going left to right. 

我想我需要改变这里使用的方法吗?

另外,reduceLeft方法有多种性状:

IndexedSeqOptimized LinearSeqOptimized TraversableOnce TraversableProxyLike TraversableForwarder Stream ParIterableLike 

,怎么让我知道哪些特质/ reduceLeft实施正在实施?

+0

是不是通过哈希码排序是没用的? – senia

回答

0

方法reduceLeft接受功能与2个参数作为参数,所以你应该这样使用它:

points.reduce((a, b) => new Point(a.x + b.x, a.y + b.y, a.z + b.z)) 

注意,你会得到空points异常。你可以使用reduceOption或折叠,以避免例外:

points.fold(new Point(0, 0, 0))((a, b) => new Point(a.x + b.x, a.y + b.y, a.z + b.z)) 

你可以使用documentation调查,其中方法的实现:

定义类TraversableOnce

reduceLeft描述。

相关问题