2017-08-02 29 views
1

我一直在努力实现的Python生成到使用Keras.js图书馆网站基本Keras模型。现在,我的模型中训练,并出口到model.jsonmodel_weights.bufmodel_metadata.json文件。现在,我基本上从github页面复制并粘贴测试代码,以查看模型是否会在浏览器中加载,但不幸的是我收到错误。这是测试代码。 (编辑:我修正了一些错误,请参阅下面其余的)实施Keras型号为网站与Keras.js

var model = new KerasJS.Model({ 
    filepaths: { 
     model: 'dist/model.json', 
     weights: 'dist/model_weights.buf', 
     metadata: 'dist/model_metadata.json' 
    }, 
    gpu: true 
}); 

    model.ready() 
    .then(function() { 
    console.log("1"); 
    // input data object keyed by names of the input layers 
    // or `input` for Sequential models 
    // values are the flattened Float32Array data 
    // (input tensor shapes are specified in the model config) 
    var inputData = { 
     'input_1': new Float32Array(data) 
    }; 
    console.log("2 " + inputData); 
    // make predictions 
    return model.predict(inputData); 
    }) 
    .then(function(outputData) { 
    // outputData is an object keyed by names of the output layers 
    // or `output` for Sequential models 
    // e.g., 
    // outputData['fc1000'] 
    console.log("3 " + outputData); 
    }) 
    .catch(function(err) { 
    console.log(err); 
    // handle error 
    }); 

编辑:所以我改变了我的计划围绕一点与JS 5对应的(这是我的一个愚蠢的错误),并现在我遇到了一个不同的错误。该错误被捕获并记录。我得到的错误是:Error: predict() must take an object where the keys are the named inputs of the model: input.我相信这个问题是因为我data变量是不正确的格式。我想,如果我的模型参加了号的28x28阵列,然后data也应该是一个28x28阵列,以便能够正确地“预测”正确的输出。但是,我相信我错过了一些东西,这就是错误被抛出的原因。 This问题与我的非常相似,但是它在python中而不是JS。再次,任何帮助将不胜感激。

回答

0

好了,我想通了,为什么这是怎么回事。有两个问题。首先,data数组需要变平,所以我编写了一个快速函数来获取2D输入并将其“变平”为一个长度为784的1D数组。然后,因为我使用了Sequential模型,数据的键名不应该是'input_1',而只是'input'。这摆脱了所有的错误。

现在,要获取输出信息,我们可以将它存储在如下所示的数组中:var out = outputData['output']。因为我用MNIST数据集,out是长度为10的一维阵列,其包含每个数字是所述用户编写的位的概率。从那里,你可以简单地找到具有最高概率的数字,并将其用作模型的预测。