2017-06-07 87 views
1

我有训练DNN网络的代码。我不想每次都训练这个网络,因为它使用了太多的时间。我该如何保存模型?如何保存张量流的DNN模型

def train_model(filename, validation_ratio=0.): 
    # define model to be trained 
    columns = [tf.contrib.layers.real_valued_column(str(col), 
                dtype=tf.int8) 
       for col in FEATURE_COLS] 
    classifier = tf.contrib.learn.DNNClassifier(
     feature_columns=columns, 
     hidden_units=[100, 100], 
     n_classes=N_LABELS, 
     dropout=0.3) 

    # load and split data 
    print('Loading training data.') 
    data = load_batch(filename) 
    overall_size = data.shape[0] 
    learn_size = int(overall_size * (1 - validation_ratio)) 
    learn, validation = np.array_split(data, [learn_size]) 
    print('Finished loading data. Samples count = {}'.format(overall_size)) 

    # learning 
    print('Training using batch of size {}'.format(learn_size)) 
    classifier.fit(input_fn=lambda: pipeline(learn), 
        steps=learn_size) 

    if validation_ratio > 0: 
     validate_model(classifier, learn, validation) 

    return classifier 

运行此功能后,我得到一个DNNClassifier我想要保存。

+0

没有你得到的答案?你能分享解决方案吗? –

回答