人工智慧-LSTM RNN 循環神經網路(分類例子)

    這次我們會使用RNN來進行分類的訓練(分類)。會繼續使用到手寫數字MNIST數據集。讓RNN從每張圖片的第一行變為讀到最後一行,然後再進行分類判斷。首先我們可以先初始化一些變數,如學習率、節點單元數、RNN 層數等:
                                                                                                                       
                                                                                                                                                                               圖1 初始化變數

還需要宣告一下 MNIST 資料生成器:
                                                                                        

                                                                                                                                                                            圖2 MNIST 資料生成器

宣告一下輸入的資料,輸入資料用 x 表示,標註資料用 y_label 表示:
                                                                                                                 

                                                                                                                                                                         圖3 宣告輸入資料與標註資料

這裡輸入的 x 維度是 [None, 784],代表 batch_size 不確定,輸入維度 784,y_label 同理。

接下來我們需要對輸入的 x 進行 reshape 操作,因為我們需要將一張圖分為多個 time_step 來輸入,這樣才能構建一個 RNN 序列,所以這裡直接將 time_step 設成 28,這樣一來 input_size 就變為了 28,batch_size 不變,所以reshape 的結果是一個三維的矩陣:
                                                                                                               

                                                                                                                                                                          圖4 對輸入進行reshape操作

再來我們需要構建一個 RNN 模型了,這裡我們使用的 RNN Cell 是 LSTMCell,而且要搭建一個三層的 RNN,所以這裡還需要用到 MultiRNNCell,它的輸入引數是 LSTMCell 的列表。

所以我們可以先宣告一個方法用於建立 LSTMCell,也加入了 Dropout,來減少訓練過程中的過擬合。方法如下:
                                                                                                          
                                                                                                                                                                              圖5 建構一個RNN模型

我們再利用它來構建多層的 RNN:
                                             

                                                                                                                                                                                   圖6 建立多層RNN



這裡使用了 for 迴圈,每迴圈一次新生成一個 LSTMCell,而不是直接使用乘法來擴充套件列表,因為這樣會導致 LSTMCell 是同一個物件,導致構建完 MultiRNNCell 之後出現維度不匹配的問題。我們需要宣告一個初始狀態:
                                                                                                           

                                                                                                                                                                                          圖7 宣告初始狀態

接下來呼叫 dynamic_rnn() 方法即可完成模型的構建了:
                                                                       
                                                                                                                                                                                          圖8 建構RNN模型

這裡inputs的輸入就是 x 做了 reshape 之後的結果,初始狀態通過 initial_state 傳入,其返回結果有兩個,一個 output 是所有 time_step 的輸出結果,賦值為 output,它是三維的,第一維長度等於 batch_size,第二維長度等於 time_step,第三維長度等於 num_units。另一個 hs 是隱含狀態,是元組形式,長度即 RNN 的層數 3,每一個元素都包含了 c 和 h,即 LSTM 的兩個隱含狀態。
                        
                                                                                                                                                                  圖9 output layer

這裡的 Loss 直接呼叫了 softmax_cross_entropy_with_logits 先計算了 Softmax,然後計算了cross entropy。

 

最後再定義訓練和評估的流程即可,在訓練過程中每隔一定的 step 就輸出 Train Accuracy 和 Test Accuracy:

                                                                                                                                                                    圖10 定義訓練精度流程

                                                                                                                                                   
                                                                                                                                                                             圖11 訓練結果

可以看出來 LSTM 在做 MNIST 字元分類的任務上還是比較有效的。





★博文內容均由個人提供,與平台無關,如有違法或侵權,請與網站管理員聯繫。

★文明上網,請理性發言。內容一周內被舉報5次,發文人進小黑屋喔~

評論