2017-06-23 64 views
0

描述问题的最佳方法是给出一个输入示例和我想要输出的内容。对Spark中出现的次数进行分类后的数据进行分类

输入

-------------------- 
|id|timestamp |count| 
| 1|2017-06-22| 1 | 
| 1|2017-06-23| 0 | 
| 1|2017-06-24| 1 | 
| 2|2017-06-22| 0 | 
| 2|2017-06-23| 1 | 

逻辑会是这样的,如果(的1 S IN计数总数等于或高于Y最后X天)

code = True 

其他

code = False 

比方说X = 5Y = 2然后输出应该看起来像

输出

--------------------- 
id | code | 
1 | True | 
2 | False | 

输入是一个SparkSQLdataframeorg.apache.spark.sql.DataFrame

不听起来像是一个非常复杂的问题,但我被困在第一步。我只能设法加载数据在dataframe

任何想法?

回答

1

看着你的要求,UDAFaggregation适合最好的。您可以结帐databricksragrawal以获得更好的理解。

我根据你提供指导,我的理解,我希望这是有帮助的

所有你需要定义UDAF第一。在您成功阅读上述链接后,您就可以做到这一点。

private class ManosAggregateFunction(daysToCheck: Int, countsToCheck: Int) extends UserDefinedAggregateFunction { 

    var referenceDate: String = _ 
    def inputSchema: StructType = new StructType().add("timestamp", StringType).add("count", IntegerType) 
    // the aggregation buffer can also have multiple values in general but 
    // this one just has one: the partial sum 
    def bufferSchema: StructType = new StructType().add("timestamp", StringType).add("count", IntegerType).add("days", IntegerType) 
    // returns just a double: the sum 
    def dataType: DataType = BooleanType 
    // always gets the same result 
    def deterministic: Boolean = true 

    def initialize(buffer: MutableAggregationBuffer): Unit = { 
    buffer.update(0, "") 
    buffer.update(1, 0) 
    buffer.update(2, 0) 
    referenceDate = "" 
    } 

    def update(buffer: MutableAggregationBuffer, input: Row): Unit = { 
    val nowDate = input.getString(0) 
    val count = input.getInt(1) 

    buffer.update(0, nowDate) 
    buffer.update(1, count) 
    } 

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 
    val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd") 
    val previousDate = buffer1.getString(0) 
    val nowDate = buffer2.getString(0) 
    if(previousDate != "") { 
     val oldDate = LocalDate.parse(previousDate, formatter) 
     val newDate = LocalDate.parse(nowDate, formatter) 
     buffer1.update(2, buffer1.getInt(2)+(oldDate.toEpochDay() - newDate.toEpochDay()).toInt) 
    } 
    buffer1.update(0, buffer2.getString(0)) 
    if(buffer1.getInt(2) < daysToCheck) { 
     buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1)) 
    } 
    } 

    def evaluate(buffer: Row): Any = { 
    countsToCheck <= buffer.getInt(1) 
    } 
} 

在上面UDAFdaysToCheckcountsToCheck是你的问题XY

您可以拨打定义UDAF如下

val manosAgg = new ManosAggregateFunction(5,2) 
    df.orderBy($"timestamp".desc).groupBy("id").agg(manosAgg(col("timestamp"), col("count")).as("code")).show 

最终输出

+---+-----+ 
| id| code| 
+---+-----+ 
| 1| true| 
| 2|false| 
+---+-----+ 

给定的输入

val df = Seq(
    (1, "2017-06-22", 1), 
    (1, "2017-06-23", 0), 
    (1, "2017-06-24", 1), 
    (2, "2017-06-28", 0), 
    (2, "2017-06-29", 1) 
).toDF("id","timestamp","count") 
+---+----------+-----+ 
|id |timestamp |count| 
+---+----------+-----+ 
|1 |2017-06-22|1 | 
|1 |2017-06-23|0 | 
|1 |2017-06-24|1 | 
|2 |2017-06-28|0 | 
|2 |2017-06-29|1 | 
+---+----------+-----+ 

我希望你已经得到了你的问题的想法。 :)

相关问题