2015-10-05 40 views
4

如何将一些自定义字段(即用户标识)添加到预测结果中?将自定义字段添加到Spark ML LabeldPoint

 List<org.apache.spark.mllib.regression.LabeledPoint> localTesting = ... ;// 
     // I want to add some identifier to each LabeledPoint 

     DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class); 
     DataFrame predictions = model.transform(localTestDF); 
     Row[] collect = predictions.select("label", "probability", "prediction").collect(); 
     for (Row r : collect) { 
      // and want to return identifier here. 
      // so do I save I to database. 
      int userNo = Integer.parseInt(r.get(0).toString()); 
      double prob = Double.parseDouble(r.get(1).toString()); 
      int prediction = Integer.parseInt(r.get(2).toString()); 
      log.debug(userNo + "," + prob + ", " + prediction); 
     } 

但是当我使用这个类localTesting代替LabeledPoint,

class NoLabeledPoint extends LabeledPoint implements Serializable { 
    private static final long serialVersionUID = -2488661810406135403L; 
    int userNo; 
    public NoLabeledPoint(double label, Vector features) { 
     super(label, features); 
    } 

    public int getUserNo() { 
     return userNo; 
    } 

    public void setUserNo(int userNo) { 
     this.userNo = userNo; 
    } 
} 

     List<NoLabeledPoint> localTesting = ... ;// set every user'no to the field userNo 
     // I want to add some identifier to each LabeledPoint 

     DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class); 
     DataFrame predictions = model.transform(localTestDF); 
     Row[] collect = predictions.select("userNo", "probability", "prediction").collect(); 
     for (Row r : collect) { 
      // and want to return identifier here. 
      // so do I save I to database. 
      int userNo = Integer.parseInt(r.get(0).toString()); 
      double prob = Double.parseDouble(r.get(1).toString()); 
      int prediction = Integer.parseInt(r.get(2).toString()); 
      log.debug(userNo + "," + prob + ", " + prediction); 
     } 

的异常抛出

org.apache.spark.sql.AnalysisException: cannot resolve 'userNo' given input columns rawPrediction, probability, features, label, prediction; 
     at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) 
     at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:63) 
     at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:52) 
     at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:286) 
     at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:286) 
     at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:51) 

我的意思是我希望得到的不仅是预测数据(功能,标签,概率..),但也是我想要的一些自定义字段。例如userNo,USER_ID等 从结果:predictions.select( “......”)

更新

解决。一行应固定

  DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class); 

  DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), NoLabeledPoint.class); 
+0

我还没有找到一个可靠的方法来做到这一点。到目前为止,我已经将相关元数据存储在我的验证子集的'label'对象中,通过黑客将它看起来像一个float(在我的例子中,它看起来像'datetime.primary_key',例如1月1日的'150101.12345', 2015年,主键12345)。据我所知,没有内置的系统来存储有关'LabeledPoint'对象的元数据。 –

+0

我们可以尝试RDD中的.zip函数https://spark.apache.org/docs/latest/api/java/org/apache/spark/rdd/RDD.html#zip(org.apache.spark.rdd。 RDD,%20scala.reflect.ClassTag),并将其与userId,actualLabel和predictedLabel进行映射。 http://spark.apache.org/docs/latest/mllib-decision-tree.html另外,Java示例使用平面地图与预测结合。 zip函数假设两个RDD的每个分区具有*相同数量的分区*和*相同数量的元素*(例如,一个是通过另一个地图制作的)。 –

+0

@AnchitChoudhry无法与spark.ml? (RDD使用高级功能,因此不直接与RDDS进行交易。) –

回答

1

既然你不使用低级别MLlib API就没有必要使用LabeledPoint可言。创建DataFrame之后,您只能看到具有某些值的,所有重要*都是与管道中的参数匹配的类型和列名称。

在Scala中,你可以使用任何情况下,类

org.apache.spark.mllib.linalg.Vector; case class 

case class LabeledPointWithMeta(userNo: String, label: Double, features: Vector) 

val rdd: RDD[LabeledPointWithMeta] = ??? 
val df = rdd.toDF 

为了能够从你使用它也许应该补充@BeanInfo注释:

import scala.beans.BeanInfo 

@BeanInfo 
case class LabeledPointWithMeta(...) 

基于一个Spark SQL and DataFrame Guide它看起来像普通的Java你可以这样做**:

​​

,之后:

JavaRDD<LabeledPointWithMeta> myPoints = ...; 

DataFrame df = sqlContext.createDataFrame(myPoints LabeledPointWithMeta.class); 

我想在你的代码的简单变化应该工作以及:

DataFrame localTestDF = jsql.createDataFrame(
    jsc.parallelize(studyData.localTesting), 
    NoLabeledPoint.class 
); 

,如果你想使用MLlib它不会帮你,但是这部分可以易于使用简单的RDD转换,如zip


*一些元数据,但你不会得到来自一个LabeledPoint

**我没有上面的代码测试,因此它可以包含一些错误。

+0

太棒了!真的有帮助。我会根据这些代码尝试一下! :) –

相关问题