1

我为Strings创建了自定义Aggregator[]在多个列上应用自定义Spark聚合器(Spark 2.0)

我想将它应用于DataFrame的所有列,其中所有列都是字符串,但列号是任意的。

我被困在写正确的表达。我想写这样的:

df.agg(df.columns.map(c => myagg(df(c))) : _*) 

这显然是错误的给予各种接口。

我看了一下RelationalGroupedDataset.agg(expr: Column, exprs: Column*)的代码,但是我不熟悉表达式的操作。

有什么想法?

+3

请显示您的聚合器代码。并解释你正在尝试做什么。 –

+0

@AssafMendelson,实际上我们计划为各种数据类型提供各种统计数据的自定义聚合器。我从一个聚合器开始取得最短和最长的字符串:class ShortestLongestAggregator()扩展了Aggregator [String,(String,String),(String,String)]。现在我想为任意数据框的所有列(因为它只有字符串列)拥有所有(最短,最长)对。 – mathieu

回答

5

与在单个字段(列)上操作的UserDefinedAggregateFunctions相反,Aggregtors需要完整的 /值。

如果你想和Aggregator它可以在你的代码段中使用它必须通过列名称参数化并使用作为值类型。

import org.apache.spark.sql.expressions.Aggregator 
import org.apache.spark.sql.{Encoder, Encoders, Row} 

case class Max(col: String) 
    extends Aggregator[Row, Int, Int] with Serializable { 

    def zero = Int.MinValue 
    def reduce(acc: Int, x: Row) = 
    Math.max(acc, Option(x.getAs[Int](col)).getOrElse(zero)) 

    def merge(acc1: Int, acc2: Int) = Math.max(acc1, acc2) 
    def finish(acc: Int) = acc 

    def bufferEncoder: Encoder[Int] = Encoders.scalaInt 
    def outputEncoder: Encoder[Int] = Encoders.scalaInt 
} 

用法示例:

val df = Seq((1, None, 3), (4, Some(5), -6)).toDF("x", "y", "z") 

@transient val exprs = df.columns.map(c => Max(c).toColumn.alias(s"max($c)")) 

df.agg(exprs.head, exprs.tail: _*) 
+------+------+------+ 
|max(x)|max(y)|max(z)| 
+------+------+------+ 
|  4|  5|  3| 
+------+------+------+ 

当结合静态类型DatasetsDataset<Row>按理说Aggregators使更多的意义。

根据您的要求,您也可以使用Seq[_]累加器在单个传递中汇总多个列,并在单个merge调用中处理整个(记录)。