準備與資料
引入函式庫
import torch
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import TensorDataset, DataLoader
考量一組不平衡資料標籤
imbalanced_data = [0]*5+[1]*10+[2]*20
計算平衡機率
平衡機率 = 1/label_count
,讓每個類別被抽中的機率相等(每個類別的標籤機率總合為1)
這邊我們計算每一筆資料被抽中的平衡機率
def balance_prob(all_gold_lablel_ids):
unique_label_ids = list(set(all_gold_lablel_ids))
label_probs = []
for label_id in range(len(unique_label_ids)):
label_id_count = all_gold_lablel_ids.count(label_id)
label_probs.append(1./label_id_count)
dataset_element_weights = [] # each element prob
for label_id in all_gold_lablel_ids:
dataset_element_weights.append(label_probs[label_id])
return dataset_element_weights
balance_prob = balance_prob(imbalanced_data)
使用 Sampler 和 Dataloader
sampler = WeightedRandomSampler(weights=balance_prob,num_samples=len(imbalanced_data),replacement=True)
dataset = TensorDataset(torch.LongTensor(imbalanced_data))
with_sampler_dataloader = DataLoader(dataset, batch_size=4,sampler=sampler)
normal_dataloader = DataLoader(dataset, batch_size=4,shuffle=True)
數據結果
資料平衡
outpus = []
for batch in with_sampler_dataloader:
sample_data = batch[0].tolist()
outpus += sample_data
[{label:outpus.count(label)} for label in unique_labels]
# [{0: 10}, {1: 11}, {2: 14}]
不平衡
outpus = []
for batch in normal_dataloader:
sample_data = batch[0].tolist()
outpus += sample_data
[{label:outpus.count(label)} for label in unique_labels]
# [{0: 5}, {1: 10}, {2: 20}]