以sign函數(shù)為例:
sign函數(shù)可以對數(shù)值進(jìn)行二值化,但在梯度反向傳播是不好處理,一般采用一個近似函數(shù)的梯度作為代替,如上圖的Htanh。在[-1,1]直接梯度為1,其他為0。
#使用修飾器,建立梯度反向傳播函數(shù)。其中op.input包含輸入值、輸出值,grad包含上層傳來的梯度@tf.RegisterGradient("QuantizeGrad")def sign_grad(op, grad): input = op.inputs[0] cond = (input>=-1)&(input<=1) zeros = tf.zeros_like(grad) return tf.where(cond, grad, zeros) #使用with上下文管理器覆蓋原始的sign梯度函數(shù)def binary(input): x = input with tf.get_default_graph().gradient_override_map({"Sign":'QuantizeGrad'}): x = tf.sign(x) return x #使用x = binary(x)
以上這篇tensorflow 實現(xiàn)自定義梯度反向傳播代碼就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持武林站長站。
新聞熱點
疑難解答