1

我正在TensorFlow中训练自主驱动卷积神经网络。这是一个简单的回归网络,可以拍摄图像并输出单个值(转向角度)。训练的张量流模型总是输出零

这是该网络所定义的函数:

def cnn_model_fn(features, labels, mode): 
    conv1 = tf.layers.conv2d(
     inputs=features, 
     filters=32, 
     kernel_size=5, 
     padding="same", 
     activation=tf.nn.relu 
    ) 

    pool1 = tf.layers.max_pooling2d(
     inputs=conv1, 
     pool_size=2, 
     strides=2 
    ) 

    pool1_flat = tf.reshape(pool1, [-1, 2764800]) 

    dense1 = tf.layers.dense(
     inputs=pool1_flat, 
     units=128, 
     activation=tf.nn.relu 
    ) 

    dropout = tf.layers.dropout(
     inputs=dense1, 
     rate=0.4, 
     training=mode == learn.ModeKeys.TRAIN 
    ) 

    dense2 = tf.layers.dense(
     inputs=dropout, 
     units=1, 
     activation=tf.nn.relu 
    ) 

    predictions = tf.reshape(dense2, [-1]) 

    loss = None 
    train_op = None 

    if mode != learn.ModeKeys.INFER: 
     loss = tf.losses.mean_squared_error(
      labels=labels, 
      predictions=predictions 
     ) 

    if mode == learn.ModeKeys.TRAIN: 
     train_op = tf.contrib.layers.optimize_loss(
      loss=loss, 
      global_step=tf.contrib.framework.get_global_step(), 
      learning_rate=0.001, 
      optimizer="SGD" 
     ) 

    return model_fn_lib.ModelFnOps(
     mode=mode, 
     predictions=predictions, 
     loss=loss, 
     train_op=train_op 
    ) 

在亚洲其他节目,我开始分类的培训,像这样:

def main(_): 
    # Gather data 
    images, labels = get_data("./data/labels.csv") 

    # Create the estimator 
    classifier = learn.Estimator(
     model_fn=cnn_model_fn, 
     model_dir="/tmp/network2" 
    ) 

    # Train the model 
    classifier.fit(
     x=images, 
     y=labels, 
     batch_size=10, 
     steps=20 
    ) 

    for v in tf.trainable_variables(): 
     print(v) 

labels是一个简单的一维NumPy的包含训练样例的所有转向角。它们正在从CSV文件中读取。文件中的值非常接近0,平均值为零。

当它们直接从文件中读取或乘以标量时,网络收敛合理,并实现低损耗功能。当我添加一个常量时,它不会收敛或发散。我怀疑网络的所有权重都收敛于零。

有没有人看到我的方法有问题?

回答

0

即辍学转正可能是罪魁祸首:

dropout = tf.layers.dropout(
     inputs=dense1, 
     rate=0.4, 
     training=mode == learn.ModeKeys.TRAIN 
    ) 

你所描述的,未能充分收敛或下降接近零的权重,是高度描述一个高偏问题。删除或降低正规化程度,向您的网络添加更多参数或以其他方式增加差异是解决此类问题的常用方法。

相关问题