2017-06-13 26 views
0

我目前正在学习神经网络背后的理论,并且我想学习如何对这些模型进行编码。所以我开始看TensorFlow。在Python中使用TensorFlow的XOR Neural Netowrk

我发现了一个非常有趣的应用程序,我想编程,但是我目前无法使其工作,并且我不知道为什么!

的例子来自Deep Learning, Goodfellow et al 2016第171 - 177

import tensorflow as tf 

T = 1. 
F = 0. 
train_in = [ 
    [T, T], 
    [T, F], 
    [F, T], 
    [F, F], 
] 
train_out = [ 
    [F], 
    [T], 
    [T], 
    [F], 
] 
w1 = tf.Variable(tf.random_normal([2, 2])) 
b1 = tf.Variable(tf.zeros([2])) 

w2 = tf.Variable(tf.random_normal([2, 1])) 
b2 = tf.Variable(tf.zeros([1])) 

out1 = tf.nn.relu(tf.matmul(train_in, w1) + b1) 
out2 = tf.nn.relu(tf.matmul(out1, w2) + b2) 

error = tf.subtract(train_out, out2) 
mse = tf.reduce_mean(tf.square(error)) 

train = tf.train.GradientDescentOptimizer(0.01).minimize(mse) 

sess = tf.Session() 
tf.global_variables_initializer() 

err = 1.0 
target = 0.01 
epoch = 0 
max_epochs = 1000 

while err > target and epoch < max_epochs: 
    epoch += 1 
    err, _ = sess.run([mse, train]) 

print("epoch:", epoch, "mse:", err) 
print("result: ", out2) 

运行的代码时,我得到Pycharm以下错误信息:Screenshot

回答

0

为了运行初始化运,你应写:

sess.run(tf.global_variables_initializer()) 

而不是:

tf.global_variables_initializer() 

这里是一个工作版本:

import tensorflow as tf 

T = 1. 
F = 0. 
train_in = [ 
    [T, T], 
    [T, F], 
    [F, T], 
    [F, F], 
] 
train_out = [ 
    [F], 
    [T], 
    [T], 
    [F], 
] 
w1 = tf.Variable(tf.random_normal([2, 2])) 
b1 = tf.Variable(tf.zeros([2])) 

w2 = tf.Variable(tf.random_normal([2, 1])) 
b2 = tf.Variable(tf.zeros([1])) 

out1 = tf.nn.relu(tf.matmul(train_in, w1) + b1) 
out2 = tf.nn.relu(tf.matmul(out1, w2) + b2) 

error = tf.subtract(train_out, out2) 
mse = tf.reduce_mean(tf.square(error)) 

train = tf.train.GradientDescentOptimizer(0.01).minimize(mse) 

sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 

err = 1.0 
target = 0.01 
epoch = 0 
max_epochs = 1000 

while err > target and epoch < max_epochs: 
    epoch += 1 
    err, _ = sess.run([mse, train]) 

print("epoch:", epoch, "mse:", err) 
print("result: ", out2) 
相关问题