一. valid卷積的梯度
我們分兩種不同的情況討論valid卷積的梯度:第一種情況,在已知卷積核的情況下,對未知張量求導(即對張量中每一個變量求導);第二種情況,在已知張量的情況下,對未知卷積核求導(即對卷積核中每一個變量求導)
1.已知卷積核,對未知張量求導
我們用一個簡單的例子理解valid卷積的梯度反向傳播。假設有一個3x3的未知張量x,以及已知的2x2的卷積核K
Tensorflow提供函數tf.nn.conv2d_backprop_input實現了valid卷積中對未知變量的求導,以上示例對應的代碼如下:
import tensorflow as tf# 卷積核kernel=tf.constant( [ [[[3]],[[4]]], [[[5]],[[6]]] ] ,tf.float32)# 某一函數針對sigma的導數out=tf.constant( [ [ [[-1],[1]], [[2],[-2]] ] ] ,tf.float32)# 針對未知變量的導數的方向計算inputValue=tf.nn.conv2d_backprop_input((1,3,3,1),kernel,out,[1,1,1,1],'VALID')session=tf.Session()print(session.run(inputValue))[[[[ -3.] [ -1.] [ 4.]] [[ 1.] [ 1.] [ -2.]] [[ 10.] [ 2.] [-12.]]]]
2.已知輸入張量,對未知卷積核求導
假設已知3行3列的張量x和未知的2行2列的卷積核K
Tensorflow提供函數tf.nn.conv2d_backprop_filter實現valid卷積對未知卷積核的求導,以上示例的代碼如下:
import tensorflow as tf# 輸入張量x=tf.constant( [ [ [[1],[2],[3]], [[4],[5],[6]], [[7],[8],[9]] ] ] ,tf.float32)# 某一個函數F對sigma的導數partial_sigma=tf.constant( [ [ [[-1],[-2]], [[-3],[-4]] ] ] ,tf.float32)# 某一個函數F對卷積核k的導數partial_sigma_k=tf.nn.conv2d_backprop_filter(x,(2,2,1,1),partial_sigma,[1,1,1,1],'VALID')session=tf.Session()print(session.run(partial_sigma_k))[[[[-37.]] [[-47.]]] [[[-67.]] [[-77.]]]]
二. same卷積的梯度
1.已知卷積核,對輸入張量求導
假設有3行3列的已知張量x,2行2列的未知卷積核K
import tensorflow as tf# 卷積核kernel=tf.constant( [ [[[3]],[[4]]], [[[5]],[[6]]] ] ,tf.float32)# 某一函數針對sigma的導數partial_sigma=tf.constant( [ [ [[-1],[1],[3]], [[2],[-2],[-4]], [[-3],[4],[1]] ] ] ,tf.float32)# 針對未知變量的導數的方向計算partial_x=tf.nn.conv2d_backprop_input((1,3,3,1),kernel,partial_sigma,[1,1,1,1],'SAME')session=tf.Session()print(session.run(inputValue))[[[[ -3.] [ -1.] [ 4.]] [[ 1.] [ 1.] [ -2.]] [[ 10.] [ 2.] [-12.]]]]
2.已知輸入張量,對未知卷積核求導
假設已知3行3列的張量x和未知的2行2列的卷積核K
import tensorflow as tf# 卷積核x=tf.constant( [ [ [[1],[2],[3]], [[4],[5],[6]], [[7],[8],[9]] ] ] ,tf.float32)# 某一函數針對sigma的導數partial_sigma=tf.constant( [ [ [[-1],[-2],[1]], [[-3],[-4],[2]], [[-2],[1],[3]] ] ] ,tf.float32)# 針對未知變量的導數的方向計算partial_sigma_k=tf.nn.conv2d_backprop_filter(x,(2,2,1,1),partial_sigma,[1,1,1,1],'SAME')session=tf.Session()print(session.run(partial_sigma_k))[[[[ -1.]] [[-54.]]] [[[-43.]] [[-77.]]]]
新聞熱點
疑難解答