2017-07-31 14 views
0

我使用spark mllib来训练naive-bayes分类器模型,其中我创建了一个管道来索引我的字符串特征,然后规范化并应用PCA降维,之后我训练我的朴素贝叶斯模型。当我运行管道时,我在PCA组件向量中得到负值。在Google上搜索时发现我必须应用NMF(非负矩阵分解)来获得正向量,并且我发现ALS将使用方法.setnonnegative(true)来实现NMF。 ,但我不知道如何将PCA后的ALS整合到我的管道中。任何帮助赞赏。谢谢。如何在我的火花管道中集成ALS来实现非负矩阵分解?

这里是代码

import org.apache.spark.SparkConf; 
import org.apache.spark.SparkContext; 
import org.apache.spark.api.java.JavaSparkContext; 
import org.apache.spark.ml.Pipeline; 
import org.apache.spark.ml.PipelineModel; 
import org.apache.spark.ml.PipelineStage; 
import org.apache.spark.ml.classification.NaiveBayes; 
import org.apache.spark.ml.feature.IndexToString; 
import org.apache.spark.ml.feature.Normalizer; 
import org.apache.spark.ml.feature.PCA; 
import org.apache.spark.ml.feature.StringIndexer; 
import org.apache.spark.ml.feature.StringIndexerModel; 
import org.apache.spark.ml.feature.VectorAssembler; 
import org.apache.spark.ml.recommendation.ALS; 
import org.apache.spark.sql.DataFrame; 
import org.apache.spark.sql.SQLContext; 

public class NBTrainPCA { 
    public static void main(String args[]){ 
     try{ 
      SparkConf conf = new SparkConf().setAppName("NBTrain"); 
      SparkContext scc = new SparkContext(conf); 
      scc.setLogLevel("ERROR"); 
      JavaSparkContext sc = new JavaSparkContext(scc); 
      SQLContext sqlc = new SQLContext(scc); 
      DataFrame traindata = sqlc.read().format("parquet").load(args[0]).filter("user_email!='NA' and user_email!='00' and user_email!='0ed709b5bec77b6bff96ea5b5e334a8e5' and user_email is not null and ip is not null and region_code is not null and city is not null and browser_name is not null and os_name is not null"); 
      traindata.registerTempTable("master"); 
      //DataFrame data = sqlc.sql("select user_email,user_device,ip,country_code,region_code,city,zip_code,time_zone,browser_name,browser_manf,os_name,os_manf from master where user_email!='NA' and user_email is not null and user_device is not null and ip is not null and country_code is not null and region_code is not null and city is not null and browser_name is not null and browser_manf is not null and zip_code is not null and time_zone is not null and os_name is not null and os_manf is not null"); 
      StringIndexerModel emailIndexer = new StringIndexer() 
       .setInputCol("user_email") 
       .setOutputCol("email_index") 
       .setHandleInvalid("skip") 
       .fit(traindata); 
      StringIndexer udevIndexer = new StringIndexer() 
       .setInputCol("user_device") 
       .setOutputCol("udev_index") 
       .setHandleInvalid("skip"); 
      StringIndexer ipIndexer = new StringIndexer() 
       .setInputCol("ip") 
       .setOutputCol("ip_index") 
       .setHandleInvalid("skip"); 
      StringIndexer ccodeIndexer = new StringIndexer() 
       .setInputCol("country_code") 
       .setOutputCol("ccode_index") 
       .setHandleInvalid("skip"); 
      StringIndexer rcodeIndexer = new StringIndexer() 
       .setInputCol("region_code") 
       .setOutputCol("rcode_index") 
       .setHandleInvalid("skip"); 
      StringIndexer cyIndexer = new StringIndexer() 
       .setInputCol("city") 
       .setOutputCol("cy_index") 
       .setHandleInvalid("skip"); 
      StringIndexer zpIndexer = new StringIndexer() 
       .setInputCol("zip_code") 
       .setOutputCol("zp_index") 
       .setHandleInvalid("skip"); 
      StringIndexer tzIndexer = new StringIndexer() 
       .setInputCol("time_zone") 
       .setOutputCol("tz_index") 
       .setHandleInvalid("skip"); 
      StringIndexer bnIndexer = new StringIndexer() 
       .setInputCol("browser_name") 
       .setOutputCol("bn_index") 
       .setHandleInvalid("skip"); 
      StringIndexer bmIndexer = new StringIndexer() 
       .setInputCol("browser_manf") 
       .setOutputCol("bm_index") 
       .setHandleInvalid("skip"); 
      StringIndexer bvIndexer = new StringIndexer() 
       .setInputCol("browser_version") 
       .setOutputCol("bv_index") 
       .setHandleInvalid("skip"); 
      StringIndexer onIndexer = new StringIndexer() 
       .setInputCol("os_name") 
       .setOutputCol("on_index") 
       .setHandleInvalid("skip"); 
      StringIndexer omIndexer = new StringIndexer() 
       .setInputCol("os_manf") 
       .setOutputCol("om_index") 
       .setHandleInvalid("skip"); 
      VectorAssembler assembler = new VectorAssembler() 
       .setInputCols(new String[]{ "udev_index","ip_index","ccode_index","rcode_index","cy_index","zp_index","tz_index","bn_index","bm_index","bv_index","on_index","om_index"}) 
       .setOutputCol("ffeatures"); 
      Normalizer normalizer = new Normalizer() 
       .setInputCol("ffeatures") 
       .setOutputCol("sfeatures") 
       .setP(1.0); 
      PCA pca = new PCA() 
       .setInputCol("sfeatures") 
       .setOutputCol("pcafeatures") 
       .setK(5); 
      NaiveBayes nbcl = new NaiveBayes() 
      .setFeaturesCol("pcafeatures") 
      .setLabelCol("email_index") 
      .setSmoothing(1.0); 
      IndexToString is = new IndexToString() 
      .setInputCol("prediction") 
      .setOutputCol("op") 
      .setLabels(emailIndexer.labels()); 
      Pipeline pipeline = new Pipeline() 
       .setStages(new PipelineStage[] {emailIndexer,udevIndexer,ipIndexer,ccodeIndexer,rcodeIndexer,cyIndexer,zpIndexer,tzIndexer,bnIndexer,bmIndexer,bvIndexer,onIndexer,omIndexer,assembler,normalizer,pca,nbcl,is}); 
      PipelineModel model = pipeline.fit(traindata); 
      //DataFrame chidata = model.transform(data); 
      //chidata.write().format("com.databricks.spark.csv").save(args[1]); 
      model.write().overwrite().save(args[1]); 
      sc.close(); 
      } 
      catch(Exception e){ 

      } 
    } 
} 

回答

0

我建议你阅读一些关于PCA这样你就可以得到它在做什么更好的感觉。这里是一些链接:

https://stats.stackexchange.com/questions/26352/interpreting-positive-and-negative-signs-of-the-elements-of-pca-eigenvectors

https://stats.stackexchange.com/questions/2691/making-sense-of-principal-component-analysis-eigenvectors-eigenvalues

在ALS集成到你的管道好像你只是想对方后插上的一件事。更好地理解他们每个人的行为和用途:ALS和PCA是完全不同的东西。 ALS正在使用AlS进行矩阵分解以实现误差最小化,没有找到任何将数据转换应用到数据或降维的主要组件。

顺便说一句:我没有看到任何问题在PCA组件向量中获得负值。您可以在上面的链接中查看。您正在对数据应用线性转换。所以新的矢量现在是转换的结果。 我希望它有帮助。

+0

在PCA分量矢量中获取负值时存在问题,朴素贝叶斯在特征集中不占用负值。这是确切的问题。 –

+0

引用此链接https://stackoverflow.com/questions/36491852/using-pca-before-bayes-classificition/36491982 –

+0

阅读评论那里:“NMF在Spark中实现,它不考虑正交性时,它分解原始矩阵所以它可能不适合你的应用程序。“ ALS矩阵分解与PCA没有任何关系。 –