Pytorch 基礎學習2_Dataset與DateLoader

► 前言

上篇文章「Pytorch 基礎學習1_張量與自動微分」,講解Pytorch中Tensor的用法以及Tensor中自動微分透過PyTorch建構一個簡單的線性回歸網路,本篇主要講解Pytorch中Dataset與DataLoader用法, 透過Dataset類別可以輕鬆地建立一個自定義的Dataset,再通過DataLoader就能夠在訓練模型時載入資料進行模型訓練。

 

► Dataset與DataLoader

Pytorch中提供torch.utils.data.Dataset類別將資料包裝起來,定義每次訓練迭代資料資訊,再透過torch.utils.data.DataLoader定義取樣資訊,將定義好的資料集進行包裝並設置參數調整每次迭代所需的資料數量(batchsize)。

自定義數據集框架如下圖所示,主要透過__init__進行資料定義、__getitem__()進行資料讀取與前處理與__len__來獲取資料長度,進行Dataset定義與操作。

 

 

除了採用自定義數據集以外,常見的Dataset創建方法有torch.utils.data.TensorDataset 根據Tensor創建數據集與torchvision.datasets.ImageFolder根據圖片目錄創建圖片數據集,本篇文章會透過kaggle貓狗數據集進行範例,使用自定義數據集進行資料讀取再藉由torchvision.transfrom來進行影像增強來提高資料的多樣性,最後透過torch.utils.data.DataLoader類別來定義Dataset取樣。更多的相關內容可以參考官方提供的 FaceLandmarksDataset。

首先進行訓練數據集上傳至Colab,再進行套件載入,進行後續的操作。

 








► 小結

透過以上講解,能夠透過Pytorch Dataset類別建立自定義數據集,並透過DataLoader進行數據加載迭代,能夠方便地對於batchsize、 shuffle等參數進行參數設置。

下一章節就要進行模型訓練啦,期待下一篇博文吧!


► 參考資料

Pytorch 官網文檔data_loading

Pytorch 官網文檔data_tutorial

★博文內容均由個人提供,與平台無關,如有違法或侵權,請與網站管理員聯繫。

★文明上網,請理性發言。內容一周內被舉報5次,發文人進小黑屋喔~

評論