2015-03-31 60 views
12

假设我有一个DataFrame(我从HDFS上的csv读入),我想通过MLlib来训练一些算法。如何将行转换为LabeledPoints或以其他方式将MLlib用于此数据集?与MLlib一起使用DataFrame

+1

你没有提到你列的数据类型,但如果他们是数字(整数,双,等),你可以使用[VectorAssembler(HTTP://spark.apache .org/docs/latest/ml-features.html#vectorassembler)将特征列转换为[Vector](http://spark.apache.org/docs/latest/mllib-data-types.html) 。 – Ben 2016-01-25 18:19:19

回答

5

假设你正在使用的Scala:

比方说,你得到DataFrame如下:

val results : DataFrame = sqlContext.sql(...) 

步骤1:调用results.printSchema() - 这将显示你不仅在列DataFrame和(这很重要)它们的顺序,但也是Spark SQL认为它们的类型。一旦你看到这个输出,事情就会变得不那么神秘。

步骤2:获取一个RDD[Row]DataFrame的:

val rows: RDD[Row] = results.rdd 

步骤3:现在它只是一个拉动的事不管字段,你的兴趣了各行的。为此,您需要知道每个字段的基于0的位置以及它的类型,幸运的是,您已经获得了上述第1步中的所有内容。例如, 假设你做了一个SELECT x, y, z, w FROM ...和打印模式产生

root 
|-- x double (nullable = ...) 
|-- y string (nullable = ...) 
|-- z integer (nullable = ...) 
|-- w binary (nullable = ...) 

而且我们说,你想用xz所有。你可以将它们拉出来为RDD[(Double, Integer)]如下:

rows.map(row => { 
    // x has position 0 and type double 
    // z has position 2 and type integer 
    (row.getDouble(0), row.getInt(2)) 
}) 

从这里,你只需要使用核心星火创建相关MLlib对象。如果您的SQL返回数组类型的列,事情可能会变得更复杂一些,在这种情况下,您必须为该列调用getList(...)

2

假设你正在使用JAVA(火花1.6.2版):

下面是使用数据框机器学习的JAVA代码一个简单的例子。

  • 它加载具有以下结构的JSON,

    [{ “标签”:1, “ATT2”:5.037089672359123 “ATT1”:2.4100883023159456},...]

  • 将数据分成训练和测试,

  • 列车使用列车数据模型,
  • 模型应用到测试数据和
  • STOR结果。

此外根据official documentation“基于DataFrame的API是主要API”为MLlib自2.0.0以来。所以你可以使用DataFrame找到几个例子。

代码:

SparkConf conf = new SparkConf().setAppName("MyApp").setMaster("local[2]"); 
SparkContext sc = new SparkContext(conf); 
String path = "F:\\SparkApp\\test.json"; 
String outputPath = "F:\\SparkApp\\justTest"; 

System.setProperty("hadoop.home.dir", "C:\\winutils\\"); 

SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); 

DataFrame df = sqlContext.read().json(path); 
df.registerTempTable("tmp"); 
DataFrame newDF = df.sqlContext().sql("SELECT att1, att2, label FROM tmp"); 
DataFrame dataFixed = newDF.withColumn("label", newDF.col("label").cast("Double")); 

VectorAssembler assembler = new VectorAssembler().setInputCols(new String[]{"att1", "att2"}).setOutputCol("features"); 
StringIndexer indexer = new StringIndexer().setInputCol("label").setOutputCol("labelIndexed"); 

// Split the data into training and test 
DataFrame[] splits = dataFixed.randomSplit(new double[] {0.7, 0.3}); 
DataFrame trainingData = splits[0]; 
DataFrame testData = splits[1]; 

DecisionTreeClassifier dt = new DecisionTreeClassifier().setLabelCol("labelIndexed").setFeaturesCol("features"); 
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {assembler, indexer, dt}); 
// Train model 
PipelineModel model = pipeline.fit(trainingData); 

// Make predictions 
DataFrame predictions = model.transform(testData); 
predictions.rdd().coalesce(1,true,null).saveAsTextFile("justPlay.txt" +System.currentTimeMillis());