2016-03-27 77 views
5

我正在重构我的代码以充分利用DataFrames, Estimators, and Pipelines。我最初在RDD[LabeledPoint]上使用MLlib Multiclass LogisticRegressionWithLBFGS。我很享受学习和使用新的API,但我不确定如何保存我的新模型并将其应用于新数据。Spark ML - 保存OneVsRestModel

目前,LogisticRegression的ML实现仅支持二进制分类。我,而不是使用OneVsRest像这样:

val lr = new LogisticRegression().setFitIntercept(true) 
val ovr = new OneVsRest() 
ovr.setClassifier(lr) 
val ovrModel = ovr.fit(training) 

现在,我想救我OneVsRestModel,但这似乎并没有受到API的支持。我曾尝试过:

ovrModel.save("my-ovr") // Cannot resolve symbol save 
ovrModel.models.foreach(_.save("model-" + _.uid)) // Cannot resolve symbol save 

有没有办法保存这个,所以我可以加载它在一个新的应用程序进行新的预测?

回答

5

星火2.0.0

OneVsRestModel实现MLWritable所以应该是可以直接保存。下面显示的方法对单独保存各个模型仍然有用。

火花< 2.0.0

这里的问题是,models返回ClassificationModel[_, _]]不是LogisticRegressionModel(或MLWritable)的Array一个Array。为了使它工作,你就必须要具体说明的类型:

import org.apache.spark.ml.classification.LogisticRegressionModel 

ovrModel.models.zipWithIndex.foreach { 
    case (model: LogisticRegressionModel, i: Int) => 
    model.save(s"model-${model.uid}-$i") 
} 

或更通用的:

import org.apache.spark.ml.util.MLWritable 

ovrModel.models.zipWithIndex.foreach { 
    case (model: MLWritable, i: Int) => 
    model.save(s"model-${model.uid}-$i") 
} 

不幸的是,作为现在(星火1.6)OneVsRestModel没有实现MLWritable所以它不能单独保存。

注意

所有型号的int OneVsRest似乎使用相同的uid因此,我们需要一个明确的指标。稍后确定模型也很有用。

+1

我希望我能+2这个。这不仅仅是我所需要的,它使计算原始概率的工作变得更容易。我以为我将不得不自定义src。谢谢! –

+0

@ zero323是否有您的答案的pyspark版本?试图找到一种方法来保存pyspark.ml模型 – ajkl

+0

@AjinkyaKale在1.6? – zero323