2017-04-21 32 views
3

我使用ALS预测评级,这是我的代码:为什么Spark ML ALS算法打印RMSE = NaN?

val als = new ALS() 
    .setMaxIter(5) 
    .setRegParam(0.01) 
    .setUserCol("user_id") 
    .setItemCol("business_id") 
    .setRatingCol("stars") 
val model = als.fit(training) 

// Evaluate the model by computing the RMSE on the test data 
val predictions = model.transform(testing) 
predictions.sort("user_id").show(1000) 
val evaluator = new RegressionEvaluator() 
    .setMetricName("rmse") 
    .setLabelCol("stars") 
    .setPredictionCol("prediction") 
val rmse = evaluator.evaluate(predictions) 
println(s"Root-mean-square error = $rmse") 

但得到一些负面的分数和RMSE是楠:

+-------+-----------+---------+------------+ 
|user_id|business_id| stars| prediction| 
+-------+-----------+---------+------------+ 
|  0|  2175|  4.0| 4.0388923| 
|  0|  5753|  3.0| 2.6875196| 
|  0|  9199|  4.0| 4.1753435| 
|  0|  16416|  2.0| -2.710618| 
|  0|  6063|  3.0|   NaN| 
|  0|  23076|  2.0| -0.8930751| 

Root-mean-square error = NaN 

如何取得好成绩?

+0

也许'predictions.where(“prediction!='NaN'”))''有效。 – IceMimosa

回答

3

由于RMSE首先将RMS值平方,因此负值无关紧要。可能你有空的预测值。你可以放弃他们:

predictions.na().drop(["prediction"]) 

虽然这可能有点误导,或者你可以用最低/最高/平均评分填写这些值。

我也建议将x < min_ratingx > max_rating舍入到最低/最高评分,这可以提高您的RMSE。

编辑:

一些额外的信息在这里:https://issues.apache.org/jira/browse/SPARK-14489

+0

我使用'predictions.na.drop()',但是窗台得到NaN。 –

+0

如果您有任何空值或NAs,您可以检查您的数据吗? – jamborta

+0

我检查了我的结果,它仍然contian“NaN”。 –

0

小幅盘整将解决这个问题:

prediction.na.drop()

0

由于星火版本2.2.0你可以将coldStartStrategy参数设置为drop,以便删除DataFrame中包含NaN值的预测中的任何行。然后评估指标将在非NaN数据上计算并且是有效的。

model.setColdStartStrategy("drop");