2017-10-17 49 views
0

我有以下DataFrame例如:星火斯卡拉:计数连续两个月

Provider Patient Date 
Smith  John  2016-01-23 
Smith  John  2016-02-20 
Smith  John  2016-03-21 
Smith  John  2016-06-25 
Smith  Jill  2016-02-01 
Smith  Jill  2016-03-10 
James  Jill  2017-04-10 
James  Jill  2017-05-11 

我希望以编程方式增加一列,表示有多少个连续月,病人看病。新DataFrame是这样的:

Provider Patient Date   consecutive_id 
Smith  John  2016-01-23 3 
Smith  John  2016-02-20 3 
Smith  John  2016-03-21 3 
Smith  John  2016-06-25 1 
Smith  Jill  2016-02-01 2 
Smith  Jill  2016-03-10 2 
James  Jill  2017-04-10 2 
James  Jill  2017-05-11 2 

我假设有一种方法用Window函数来实现这一点,但我一直没能推算出来呢,我很期待到社区可以提供的洞察力。谢谢。

回答

1

至少有3种方法得到的结果

  1. 在SQL
  2. 使用星火API实现逻辑为窗口函数的最大数量 - .over(windowSpec)
  3. 使用直接.rdd.mapPartitions

Introducing Window Functions in Spark SQL

对于所有解决方案,您可以调用.toDebugString来查看引擎盖下的操作。

SQL溶液低于

val my_df = List(
    ("Smith", "John", "2016-01-23"), 
    ("Smith", "John", "2016-02-20"), 
    ("Smith", "John", "2016-03-21"), 
    ("Smith", "John", "2016-06-25"), 
    ("Smith", "Jill", "2016-02-01"), 
    ("Smith", "Jill", "2016-03-10"), 
    ("James", "Jill", "2017-04-10"), 
    ("James", "Jill", "2017-05-11") 
).toDF(Seq("Provider", "Patient", "Date"): _*) 

my_df.createOrReplaceTempView("tbl") 

val q = """ 
select t2.*, count(*) over (partition by provider, patient, grp) consecutive_id 
    from (select t1.*, sum(x) over (partition by provider, patient order by yyyymm) grp 
      from (select t0.*, 
         case 
          when cast(yyyymm as int) - 
           cast(lag(yyyymm) over (partition by provider, patient order by yyyymm) as int) = 1 
          then 0 
          else 1 
         end x 
        from (select tbl.*, substr(translate(date, '-', ''), 1, 6) yyyymm from tbl) t0) t1) t2 
""" 

sql(q).show 
sql(q).rdd.toDebugString 

输出

scala> sql(q).show 
+--------+-------+----------+------+---+---+--------------+ 
|Provider|Patient|  Date|yyyymm| x|grp|consecutive_id| 
+--------+-------+----------+------+---+---+--------------+ 
| Smith| Jill|2016-02-01|201602| 1| 1|    2| 
| Smith| Jill|2016-03-10|201603| 0| 1|    2| 
| James| Jill|2017-04-10|201704| 1| 1|    2| 
| James| Jill|2017-05-11|201705| 0| 1|    2| 
| Smith| John|2016-01-23|201601| 1| 1|    3| 
| Smith| John|2016-02-20|201602| 0| 1|    3| 
| Smith| John|2016-03-21|201603| 0| 1|    3| 
| Smith| John|2016-06-25|201606| 1| 2|    1| 
+--------+-------+----------+------+---+---+--------------+ 

更新

.mapPartitions的混合+ .over(windowSpec)

import org.apache.spark.sql.Row 
import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType} 

val schema = new StructType().add(
      StructField("provider", StringType, true)).add(
      StructField("patient", StringType, true)).add(
      StructField("date", StringType, true)).add(
      StructField("x", IntegerType, true)).add(
      StructField("grp", IntegerType, true)) 

def f(iter: Iterator[Row]) : Iterator[Row] = { 
    iter.scanLeft(Row("_", "_", "000000", 0, 0)) 
    { 
    case (x1, x2) => 

    val x = 
    if (x2.getString(2).replaceAll("-", "").substring(0, 6).toInt == 
     x1.getString(2).replaceAll("-", "").substring(0, 6).toInt + 1) 
    (0) else (1); 

    val grp = x1.getInt(4) + x; 

    Row(x2.getString(0), x2.getString(1), x2.getString(2), x, grp); 
    }.drop(1) 
} 

val df_mod = spark.createDataFrame(my_df.repartition($"provider", $"patient") 
             .sortWithinPartitions($"date") 
             .rdd.mapPartitions(f, true), schema) 

import org.apache.spark.sql.expressions.Window 
val windowSpec = Window.partitionBy($"provider", $"patient", $"grp") 
df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec) 
    ).orderBy($"provider", $"patient", $"date").show 

输出

scala> df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec) 
    |  ).orderBy($"provider", $"patient", $"date").show 
+--------+-------+----------+---+---+--------------+ 
|provider|patient|  date| x|grp|consecutive_id| 
+--------+-------+----------+---+---+--------------+ 
| James| Jill|2017-04-10| 1| 1|    2| 
| James| Jill|2017-05-11| 0| 1|    2| 
| Smith| Jill|2016-02-01| 1| 1|    2| 
| Smith| Jill|2016-03-10| 0| 1|    2| 
| Smith| John|2016-01-23| 1| 1|    3| 
| Smith| John|2016-02-20| 0| 1|    3| 
| Smith| John|2016-03-21| 0| 1|    3| 
| Smith| John|2016-06-25| 1| 2|    1| 
+--------+-------+----------+---+---+--------------+ 
+0

这适用于我提供的示例数据,这就是为什么我很乐意给出复选标记。我只是试图通过一个'java.lang.ArrayIndexOutOfBoundsException:2'来试图展示最终的'df_mod'转换。 – bshelt141

0

,你可以:

  1. 格式化日期整数(2016-01 = 1, 2016-02 = 2, 2017-01 = 13 ...等)
  2. 把所有的日期到一个数组有一个窗口,collect_list:

    val winSpec = Window.partitionBy("Provider","Patient").orderBy("Date") df.withColumn("Dates", collect_list("Date").over(winSpec))

  3. 将数组传递到@marios的修改版本solutionspark.udf.register一个UDF获得的连续三个月