2017-05-09 35 views
1

是否可以加载预训练(二进制)模型进行spark(使用scala)?我试图加载这样的谷歌生成的二进制模型之一:在Spark中加载Word2Vec模型

import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} 


    val model = Word2VecModel.load(sc, "GoogleNews-vectors-negative300.bin") 

但它无法找到元数据目录。我也创建了该文件夹并在其中附加了二进制文件,但无法解析。我没有找到这个问题的任何包装。

回答

0

我写了一个快速的功能在谷歌新闻预训练模型加载到火花word2vec模型。请享用。

def loadBin(file: String) = { 
    def readUntil(inputStream: DataInputStream, term: Char, maxLength: Int = 1024 * 8): String = { 
    var char: Char = inputStream.readByte().toChar 
    val str = new StringBuilder 
    while (!char.equals(term)) { 
     str.append(char) 
     assert(str.size < maxLength) 
     char = inputStream.readByte().toChar 
    } 
    str.toString 
    } 
    val inputStream: DataInputStream = new DataInputStream(new GZIPInputStream(new FileInputStream(file))) 
    try { 
    val header = readUntil(inputStream, '\n') 
    val (records, dimensions) = header.split(" ") match { 
     case Array(records, dimensions) => (records.toInt, dimensions.toInt) 
    } 
    new Word2VecModel((0 until records).toArray.map(recordIndex => { 
     readUntil(inputStream, ' ') -> (0 until dimensions).map(dimensionIndex => { 
     java.lang.Float.intBitsToFloat(java.lang.Integer.reverseBytes(inputStream.readInt())) 
     }).toArray 
    }).toMap) 
    } finally { 
    inputStream.close() 
    } 
}