[PyTorch] 使用 collate_fn 處理 Dataset 產生不定長度資料




pytorch Dataset進階用法,我們在使用這個class的時候會覆寫原本的__getitem__方法

但是可能會遭遇到一些問題,例如原本未處理的資料經過處理,可能會變為多筆input
例如BERT model的輸入長度限制為512,倘若單筆輸入的文字量超過512就需要進行切割處理

這時候就無法僅依靠修改Dataset完成,需要再額外定義collate_fn,告訴DataLoader如何取batch

下面簡單示範如何做到這件事情:

import torch

batch_size = 2

class MyDataset(torch.utils.data.Dataset):
    def __init__(self,data):
        super().__init__()
        self.data = data

    def __getitem__(self,index):
        d = self.data[index] # 藉由index取出一筆未處理資料
        d = d.split(',')
        return d[0] if len(d) == 1 else d # 處理後,單筆資料產生多筆資料(如文字滑窗)

    def __len__(self):
        return len(self.data)

def collate_fn(batch=None):
    if not '_flat_batch' in globals(): global _flat_batch; _flat_batch=[]
    _flat_batch = _flat_batch + [item for sublist in batch for item in sublist]
    while(len(_flat_batch)!=0):
        yield _flat_batch[:batch_size]
        _flat_batch = _flat_batch[batch_size:]

if __name__ == "__main__":
    data = ['A,B,C','D','E,F','G,H','I','J','K,L','M'] # 模擬一筆資料進來會產生不定長度的input data
    dataset = MyDataset(data=data)
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,collate_fn=collate_fn)

    for batch_gen in dataloader:
        for batch in batch_gen:
            print(batch)

結果:

['A', 'B']
['C', 'D']
['E', 'F']
['G', 'H']
['I', 'J']
['K', 'L']
['M']

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *

這個網站採用 Akismet 服務減少垃圾留言。進一步瞭解 Akismet 如何處理網站訪客的留言資料