2016-03-15 150 views
1

我设计了以下功能与任何数值类型的阵列的工作:SparkSQL功能需要类型十进制

def array_sum[T](item:Traversable[T])(implicit n:Numeric[T]) = item.sum 
// Registers a function as a UDF so it can be used in SQL statements. 
sqlContext.udf.register("array_sumD", array_sum(_:Seq[Float])) 

但是想要通过类型的数组浮箱以下错误:

// Now we can use our function directly in SparkSQL. 
sqlContext.sql("SELECT array_sumD(array(5.0,1.0,2.0)) as array_sum").show 

错误:

cannot resolve 'UDF(array(5.0,1.0,2.0))' due to data type mismatch: argument 1 requires array<double> type, however, 'array(5.0,1.0,2.0)' is of array<decimal(2,1)> type; 

回答

1

Spark-SQL中的十进制值的默认数据类型是,十进制。如果您你的文字查询到花车,并使用相同的UDF,它的工作原理:

sqlContext.sql(
    """SELECT array_sumD(array(
    | CAST(5.0 AS FLOAT), 
    | CAST(1.0 AS FLOAT), 
    | CAST(2.0 AS FLOAT) 
    |)) as array_sum""".stripMargin).show 

结果,符合市场预期:

+---------+ 
|array_sum| 
+---------+ 
|  8.0| 
+---------+ 

或者,如果你想要使用小数(避免浮点问题),你会仍然必须使用铸造得到正确的精度,再加上你将不会是abl e使用Scala的不错Numericsum,因为小数被读作java.math.BigDecimal。所以 - 你的代码是:

def array_sum(item:Traversable[java.math.BigDecimal]) = item.reduce((a, b) => a.add(b)) 

// Registers a function as a UDF so it can be used in SQL statements. 
sqlContext.udf.register("array_sumD", array_sum(_:Seq[java.math.BigDecimal])) 

sqlContext.sql(
    """SELECT array_sumD(array(
    | CAST(5.0 AS DECIMAL(38,18)), 
    | CAST(1.0 AS DECIMAL(38,18)), 
    | CAST(2.0 AS DECIMAL(38,18)) 
    |)) as array_sum""".stripMargin).show