pytorch Dataset進階用法,我們在使用這個class的時候會覆寫原本的__getitem__
方法
但是可能會遭遇到一些問題,例如原本未處理的資料經過處理,可能會變為多筆input
例如BERT model的輸入長度限制為512,倘若單筆輸入的文字量超過512就需要進行切割處理
這時候就無法僅依靠修改Dataset完成,需要再額外定義collate_fn
,告訴DataLoader如何取batch
下面簡單示範如何做到這件事情:
import torch
batch_size = 2
class MyDataset(torch.utils.data.Dataset):
def __init__(self,data):
super().__init__()
self.data = data
def __getitem__(self,index):
d = self.data[index] # 藉由index取出一筆未處理資料
d = d.split(',')
return d[0] if len(d) == 1 else d # 處理後,單筆資料產生多筆資料(如文字滑窗)
def __len__(self):
return len(self.data)
def collate_fn(batch=None):
if not '_flat_batch' in globals(): global _flat_batch; _flat_batch=[]
_flat_batch = _flat_batch + [item for sublist in batch for item in sublist]
while(len(_flat_batch)!=0):
yield _flat_batch[:batch_size]
_flat_batch = _flat_batch[batch_size:]
if __name__ == "__main__":
data = ['A,B,C','D','E,F','G,H','I','J','K,L','M'] # 模擬一筆資料進來會產生不定長度的input data
dataset = MyDataset(data=data)
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,collate_fn=collate_fn)
for batch_gen in dataloader:
for batch in batch_gen:
print(batch)
結果:
['A', 'B']
['C', 'D']
['E', 'F']
['G', 'H']
['I', 'J']
['K', 'L']
['M']