人工智慧-Tensorflow 遷移學習

從零開始訓練一個深度神經網絡有時需要海量的數據才能得到較好的效果。如果你手頭的數據有限,又想採用神經網絡作為解決方案,可以嘗試一下遷移學習。       
       在實際的應用中,預訓練好的模型的輸入輸出可能並不能滿足我們的需求,另外,訓練上百萬甚至上千萬張圖片,可能需要花費好幾天的時間,那有沒有辦法只使用訓練好的模型的一部分呢?訓練好的模型的前幾層對特徵提取有非常好的效果,如果可以直接使用,那就事半功倍了。這種方法被稱之為遷移學習(transfer learning)。       
       深度神經網絡的結構存在層級,對於卷積神經網絡CNN來說,不同層級的捲積層所表現出的特徵提取也呈現層級性。具體來說,底層的捲積層對於低階特徵較為敏感,例如邊緣、團塊等;隨著層級的升高,提取的特徵越來越抽象。這種隨層級變化的特徵提取能力是遷移學習的基礎,它保證了當任務具備相似性時,例如分類1024種不同的自然物體與分類10種不同的零件,已經訓練好的神經網絡的特徵提取層可以“遷移”到新的分類任務中來繼續承擔特徵提取的功能。拿CNN 來舉個例子,我訓練好了一個區分男人和女人的CNN,接著來了個任務,說我下個任務是區分照片中人的年齡。這看似完全不相干的兩個模型,但是我們卻可以運用到遷移學習,讓之前那個CNN 當我們的初始模型因為區分男女的CNN 已經對人類有了理解,基於這個理解開始訓練,總比完全重新開始訓練強。
       遷移一個圖片分類的CNN (VGG),這個VGG 在1000個類別中訓練過,我們提取這個VGG 前面的Conv layers,重新組建後面的fully connected layers,讓它做一個和分類完全不相干的事,在網上下載那1000個分類數據中的貓和老虎的圖片,然後偽造一些貓和老虎長度的數據,最後做到讓遷移後的網絡分辨出貓和老虎的長度(regressor)。為了達到目的,我們不需要下載所有的1000個分類的所有圖片,只要找到自己感興趣的類就好 (老虎和貓)。
                                                                                                             
                                                                                                                                                                   圖1 VGG Code

tf.layers.dense() 建立的 layers 是可以被 train 的,再定义一个 Saver 来保存由 tf.layers.dense() 建立的 parameters。
                                                                                                             
                                                                                                                                                                   圖2 Saver Code


因為有了訓練好了的 VGG16, 你就能將 VGG16 的 Conv 層想像成是一個 feature extractor, 提取或壓縮圖片中的特徵. 和 Autoencoder 中的 encoder 類似. 用這些提取的特徵來訓練後面的 regressor。
                                                                                    
                                                                                                                                                              圖3 Train Code

                                                                                                         
                                                                                                                                                                           圖4 預測結果
輸入了一張貓、一張老虎的圖,這個 VGG 給我預測除了他們的長度,VGG必須懂得區分哪些是貓, 哪些是老虎, 而這個認知, 在原始的 VGG conv 層中就已經學出來了。

★博文內容均由個人提供,與平台無關,如有違法或侵權,請與網站管理員聯繫。

★文明上網,請理性發言。內容一周內被舉報5次,發文人進小黑屋喔~

評論