有人可以在TensorFlow中解释我的gradient_override_map
函数吗? 我无法准确理解它的用法。Tensorflow的gradient_override_map函数
我看到的代码用法:
with G.gradient_override_map({"Floor": "Identity"}):
return tf.reduce_mean(SomeVals) * SomeOtherVal
正是这里发生了什么?什么是Identity
?
有人可以在TensorFlow中解释我的gradient_override_map
函数吗? 我无法准确理解它的用法。Tensorflow的gradient_override_map函数
我看到的代码用法:
with G.gradient_override_map({"Floor": "Identity"}):
return tf.reduce_mean(SomeVals) * SomeOtherVal
正是这里发生了什么?什么是Identity
?
尽我所知,gradient_override_map允许你说“在这种情况下,任何时候你会使用X的渐变,而不是使用Y的渐变”。这意味着您仍然需要Y的梯度作为您要使用的渐变。
这是我见过的漂浮在寻找如何工作的一个例子:
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
举:https://stackoverflow.com/a/43948872/1102705
RegisterGradient()
允许你注册你定义一个新的运算的梯度,从而允许你有一个你想要的梯度的运算符,然后你可以在梯度覆盖映射中使用该运算符。它有点笨重 - 你正在定义一个没有前锋的传球。
我不明白的是名称=“身份”是否真的有必要。
两个“楼”和“身份”是操作的类型的字符串,前者对应于tf.floor而后者tf.identity。 所以我猜你的代码的功能是用图G中的tf.floor运算代替tf.identity的BPD计算机制的后向传播梯度(BPG)计算机制,同时通过前向输出tf.reduce_mean。看起来有点奇怪,因为在我找到的gradient_override_map
的所有应用程序中,op_type_map的关键字始终与用于在上下文中生成输出的操作的类型字符串相同。通过这个我的意思是我更熟悉与tf.floor(SomeVals)
返回的场景,而不是tf.reduce_mean(SomeVals)
。
gradient_override_map({op_A_type: op_B_type})
做的是用op_B替换op_A的BPG计算机制,同时保留op_A_type的正向传播计算机制。 lahwran的答案中显示了gradient_override_map的常见应用。
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
通过
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
装饰,tf.RegisterGradient("CustomGrad")
登记由_const_mul_grad(unused_op, grad)
用于定制的运算式定义的梯度函数 - “CustomGrad”,
而
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
保证所有操作的输出(图g),字符串类型为“Id”实体“(tf.identity)与它们相同,而BPG计算机制tf。身份用字符串类型“CustomGrad”替换为BPG计算操作机制。
P.S.
运算的类型字符串对应于OpDef.name
字段定义该操作的原。为了找到一个运算的OpDef.name
,请参照明星的回答下this question
这是没有必要的,因为在tf.identity的ARG“名称”是可选操作申报tf.identity的名称。