Biu懂AI:Pytorch架構的CV訓練數據輸入

       Bui~ 新系列博文將專注AI相關領域,想要學習高通藍牙相關知識請查看之前的系列或關注大博主聲波電波就看今朝

       Pytorch是我們常用的訓練框架之一,也是當下主流框架之一。框架提供許多有用的工具,也提供了獲取Coco、Cifar、MNIST等知名數據集的快捷獲取方法。但如果是自己的數據集,方法就不一樣了。下面我們來了解一下,設置自己的數據集需要用到的模塊。

       Pytorch提供的兩個模塊(或者說是python包)torch.utils.data.DataLoader 和torch.utils.data.Dataset。Dataset的作用是將源數據整理打包,方便使用。在這個class中有三個必要函數:

  1. 實例化時傳入必要參數的__init__ :這裡面會傳入數據及標籤存放的文件夾路徑,如果模型對輸入數據有要求,這裡面會傳入轉換數據用的transform,包括標籤的target_transform(都可在transforms 包中找到方法),後續會在DataLoader 調用 __getitem__函數裡面使用。此外還會傳入一些別的設置參數,但都是先保存在本地,供後續使用。這函數裡面主要是為了將數據整理好,數據和標籤一一對應形成列表,方便獲取,也方便__len__函數統計數目。

  2. 例如數據和標籤分別放在多個不同的文件夾下面保存,我們可以用他們的路徑組隊放在列表。__getitem__的時候就根據路徑去提取內容;再例如標籤只保存放在一個文件裡面,我們就需要讀取裡面的內容,並映射到對應的數據文件;或者說數據文件名和標籤文件名不一樣,需要靠第三個索引表去找對應關係,我們也要讀出去將他們匹配好。

  3. 總的來說就是將訓練數據整理好成一個個訓練樣本,以供__getitem__快速提取,加快訓練時間



  4. 獲取數據集總數的__len__:獲取可用的樣本個數



  5. 和獲取一個指定索引樣本的__getitem__:根據索引去找對應的樣本,如果有transform就需要先將數據或標籤轉換成指定的數據格式。例如resize成同樣大小的圖片;或者normalized數據;或者轉tensor格式。標籤可能會轉換一下數據格式、坐標格式、獨熱編碼(one-hot)等等。這些轉換都是為了滿足模型輸入要求或者提供模型性能。



       除了上面說的這些基本功能之外,有些衍生框架會加入很多調節功能,例如不夠訓練樣本的就多添加一些增強數據;將數據或標籤放在ram中,提供訓練數據;數據和標籤沒成對的就拋棄;只取其中個別類型做樣本數據的等等這些個性化的調節功能。

       DataLoader 的作用是管理數據的加載,並提供模型需要的數據格式輸入。為了讓訓練速度加快,DataLoader 會將多個數據樣本打包成一個batch(相當於把數據、標籤合併成對應的大張量),多個數據樣本一次送入模型中訓練。(這裡打包是將多個張量堆疊在一起,獲取時是使用疊代器一一獲取。但如果樣本的張量不一致,例如object detect的label有多個輸出,這時直接堆疊就會出錯。因此DataLoader 提供了collate_fn函數,用來自定義打包功能的,但是要注意堆疊出來的張量要符合模型輸入要求,要能在DataLoader 的疊代器中每次都能返回所需的樣本數據和標籤)



       打包的數量不能設置太小,不然模型很難收斂。但是也不能設太大,每次打包的數量要取決於內存容量,太大的話,有可能訓練中途內存撐不住。(tips:GPU對2的冪次數的batch可以發揮更佳的性能,因此設置成16、32、64、128等數字性能更好,但是大部分數據集沒辦法被這些數字整除,所以常常有剩下的數據湊不夠一個batch,因此你可以選擇不使用剩下這些數據)

       另外,如果數據排序是按分類排的話,直接按序列取也會影響模型收斂速度,所以DataLoader 會提供打亂數據的功能;除了這些基本功能外,還有別的方法加快訓練速度,例如將數據放在頁存儲中,加快讀取速度,多進程運行DataLoader 等等,這些都能直接配置DataLoader 去實現,前提是需要有足夠的資源。

 

        附件上傳了小編自定義dataset和dataloader的一種使用場景例程,有需要可以下載參考下。以上是本期博文的全部內容,如有疑問請在博文下方評論留言,我會儘快解答(o´ω`o)و。謝謝大家瀏覽,我們下期再見。

 

 

簡單是長期努力的結果,而不是起點
                                                 —— 不是我說的

 

 

FAQ 1:Keras能用這種方法嗎?

A1:不行的,Keras是基於TensorFlow的框架

 

FAQ 2:label格式有哪些?

A2:看前一篇博文-Biu懂AI:Object Detection訓練數據的Label格式

 

FAQ 3:可以獲取dataloader的數據出來看嗎?

A3:可以的,但是如果有transform的話,會轉換成對應的張量,這時數據就不能直接顯示出來了,需要進行逆轉換

 

FAQ 4:train和val要用一樣的dataloader嗎?

A4: 不能說完全一樣,因為train和val的任務不一樣,val不需要考慮模型性能等問題,所以在數據transform時,可以不用考慮數據增強問題。

 

FAQ 5:圖像數據用別的格式可以嗎?

A5: 可以的,但是需要transform,變成RBG格式就行了

 

技術文檔

類型標題檔案
硬件ipynb

★博文內容均由個人提供,與平台無關,如有違法或侵權,請與網站管理員聯繫。

★文明上網,請理性發言。內容一周內被舉報5次,發文人進小黑屋喔~

評論