[Pytorch] 使用sampler處理imbalanced data




準備與資料

引入函式庫

import torch
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import TensorDataset, DataLoader

考量一組不平衡資料標籤

imbalanced_data = [0]*5+[1]*10+[2]*20

計算平衡機率

平衡機率 = 1/label_count,讓每個類別被抽中的機率相等(每個類別的標籤機率總合為1)
這邊我們計算每一筆資料被抽中的平衡機率

def balance_prob(all_gold_lablel_ids):
    unique_label_ids = list(set(all_gold_lablel_ids))    
    label_probs = []
    for label_id in range(len(unique_label_ids)):        
        label_id_count = all_gold_lablel_ids.count(label_id)
        label_probs.append(1./label_id_count)
    dataset_element_weights = [] # each element prob
    for label_id in all_gold_lablel_ids:                
        dataset_element_weights.append(label_probs[label_id])
    return dataset_element_weights
balance_prob = balance_prob(imbalanced_data)

使用 Sampler 和 Dataloader

sampler = WeightedRandomSampler(weights=balance_prob,num_samples=len(imbalanced_data),replacement=True)
dataset = TensorDataset(torch.LongTensor(imbalanced_data))
with_sampler_dataloader = DataLoader(dataset, batch_size=4,sampler=sampler)
normal_dataloader = DataLoader(dataset, batch_size=4,shuffle=True)

數據結果

資料平衡

outpus = []
for batch in with_sampler_dataloader:
    sample_data = batch[0].tolist()
    outpus += sample_data

[{label:outpus.count(label)} for label in unique_labels]
# [{0: 10}, {1: 11}, {2: 14}]

不平衡

outpus = []
for batch in normal_dataloader:
    sample_data = batch[0].tolist()
    outpus += sample_data

[{label:outpus.count(label)} for label in unique_labels]
# [{0: 5}, {1: 10}, {2: 20}]

Source Code

Gist / Colab

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *

這個網站採用 Akismet 服務減少垃圾留言。進一步瞭解 Akismet 如何處理網站訪客的留言資料