2017-10-04 67 views
0

我想用Tensorflow解决人工神经网络模型。目前,我能够将该程序作为一长串文本运行。但是,现在我想将我的代码转换为更易于使用的代码。所以我将我的代码转换为一个类。这是我做的。 (基本上是复制整组代码的一类tensorflow内部类变量与外部变量不同

import os 
import tensorflow as tf 

class NNmodel: 

    def __init__(self, 
       layers, inpShape, outShape, 
       features, 
       learning_rate=0.1, nSteps = 100, 
       saveFolder='models'): 

     self.layers = layers 
     self.features = features 
     self.learning_rate = learning_rate 
     self.saveFolder = saveFolder 
     self.nSteps = 100 

     self.d = tf.placeholder(shape = inpShape, dtype = tf.float32, name='d') # input layer 
     self.dOut = tf.placeholder(shape = outShape, dtype = tf.float32, name='dOut') # output layer 

     self.weights = [] 
     self.biases = [] 
     self.compute = [self.d] 

     layerSizes = [self.features] + [l['size'] for l in self.layers] 

     for i, (v1, v2) in enumerate(zip(layerSizes, layerSizes[1:])): 
      self.weights.append( 
       tf.Variable(np.random.randn(v1, v2)*0.1, dtype = tf.float32, name='W{}'.format(i))) 

      self.biases.append(
       tf.Variable(np.zeros((1,1)), dtype = tf.float32, name='b{}'.format(i))) 

      self.compute.append(tf.matmul( 
       self.compute[-1], self.weights[i]) + self.biases[i]) 

      if self.layers[i]['activation'] == 'tanh': 
       self.compute.append(tf.tanh(self.compute[-1])) 

      if self.layers[i]['activation'] == 'relu': 
       self.compute.append(tf.nn.relu(self.compute[-1])) 

      if self.layers[i]['activation'] == 'sigmoid': 
       self.compute.append(tf.sigmoid (self.compute[-1])) 

     self.result = self.compute[-1] 
     self.delta = self.dOut - self.result 
     self.cost = tf.reduce_mean(self.delta**2) 

     self.optimizer = tf.train.AdamOptimizer(
      learning_rate = self.learning_rate).minimize(self.cost) 
     return 

    def findVal(self, func, inpDict, restorePt=None): 

     saver = tf.train.Saver() 
     sess = tf.Session() 

     init = tf.global_variables_initializer() 
     sess.run(init) 

     if restorePt is not None: 
      try: 
       saver.restore(sess, tf.train.latest_checkpoint(restorePt)) 
       print('Session restored') 
      except Exception as e: 
       print('Unable to restore the session ...') 
       return None 
     else: 
      print('Warning, no restore point selected ...') 

     result = sess.run(func, feed_dict = inpDict) 
     sess.close() 
     return result 

    def optTF(self, inpDict, printSteps=50, modelFile=None): 

     cost = [] 
     saver = tf.train.Saver() 

     sess = tf.Session() 
     init = tf.global_variables_initializer() 
     sess.run(init) 

     print('x'*100) 

     for i in range(self.nSteps): 

      # First run the optimizer ... 
      sess.run(self.optimizer, feed_dict = inpDict) 

      # Save all the data you want to save 
      c = sess.run(self.cost, feed_dict = inpDict) 
      cost.append(c) 

      if (i%printSteps) == 0: 
       print('{:5d}'.format(i)) 

     result = self.run(self.result, feed_dict = inpDict) 

     if modelFile is not None: 
      path = saver.save(sess, os.path.join( 
       self.saveFolder, modelFile)) 
      print('Model saved in: {}'.format(path)) 
     else: 
      print('Warning! model not saved') 
     sess.close() 

     return cost, result 

当我使用这个模型中,我看到有似乎是一个问题:

N  = 500 
features = 2 
nSteps = 1000 

X = [ (np.random.random(N))*np.random.randint(1000, 2000) for i in range(features)] 
X = np.array([np.random.random(N), np.random.random(N)]) 
data = [X.T, X[0].reshape(-1, 1)] 

layers = [ 
    {'name':'6', 'size': 10, 'activation':'tanh'}, 
    {'name':'7', 'size': 1, 'activation':'linear'}, 
] 
m1 = NNmodel(layers, inpShape=np.shape(data[0]), outShape = np.shape(data[1]), 
      features=features, 
      learning_rate=0.1, nSteps = 100, 
      saveFolder='models1') 

d = tf.placeholder(shape = np.shape(data[0]), dtype = tf.float32, name='d_4') 
dOut = tf.placeholder(shape = np.shape(data[1]), dtype = tf.float32, name='dOut') 

m1.findVal(m1.result, {d: data[0], dOut:data[1]}) 

现在看来,存在不匹配我使用ddOut我对外提供形式,占位符,并且已经在模型self.dself.dOut中存在的那些之间。我怎么解决这个问题呢?

回答

1

为什么不只是使用在模型中声明占位符?

m1.findVal(m1.result, {m1.d: data[0], m1.dOut:data[1]}) 
+0

我最终只是发送数据,并在类中创建字典。但是您的信息实际上解释了如何解决问题! – ssm