2016-07-16 30 views
4

我正在尝试使用UDF和输入类型数组结构。 我有下面的数据结构,这是一个更大的结构只有相关部分具有复杂输入参数的Spark Sql UDF

|--investments: array (nullable = true) 
    | |-- element: struct (containsNull = true) 
    | | |-- funding_round: struct (nullable = true) 
    | | | |-- company: struct (nullable = true) 
    | | | | |-- name: string (nullable = true) 
    | | | | |-- permalink: string (nullable = true) 
    | | | |-- funded_day: long (nullable = true) 
    | | | |-- funded_month: long (nullable = true) 
    | | | |-- funded_year: long (nullable = true) 
    | | | |-- raised_amount: long (nullable = true) 
    | | | |-- raised_currency_code: string (nullable = true) 
    | | | |-- round_code: string (nullable = true) 
    | | | |-- source_description: string (nullable = true) 
    | | | |-- source_url: string (nullable = true) 

我宣布case类:

case class Company(name: String, permalink: String) 
case class FundingRound(company: Company, funded_day: Long, funded_month: Long, funded_year: Long, raised_amount: Long, raised_currency_code: String, round_code: String, source_description: String, source_url: String) 
case class Investments(funding_round: FundingRound) 

UDF声明:

sqlContext.udf.register("total_funding", (investments:Seq[Investments]) => { 
    val totals = investments.map(r => r.funding_round.raised_amount) 
    totals.sum 
}) 

当我执行以下转换,结果如预期

scala> sqlContext.sql("""select total_funding(investments) from companies""") 
res11: org.apache.spark.sql.DataFrame = [_c0: bigint] 

但是,当一个动作就像收集执行我有一个错误:

Executor: Exception in task 0.0 in stage 4.0 (TID 10) 
java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to $line33.$read$$iwC$$iwC$Investments 

感谢您的任何帮助。

回答

10

你看到的错误应该是非常明显的。在催化剂/ SQL类型和Scala类型之间存在严格的映射,可以在the Spark SQL, DataFrames and Datasets Guidethe relevant section中找到。

特别是struct类型转换为o.a.s.sql.Row(在您的特定情况下数据将显示为Seq[Row])。

有不同的方法可用于以暴露数据作为特定类型:

只有前一种方法可适用于此特定情况。

如果你想使用UDF,你需要这样的访问investments.funding_round.raised_amount

val getRaisedAmount = udf((investments: Seq[Row]) => scala.util.Try(
    investments.map(_.getAs[Row]("funding_round").getAs[Long]("raised_amount")) 
).toOption) 

,但简单select应该是更安全和更清洁:

df.select($"investments.funding_round.raised_amount")