上篇稍微帶一下原理,下篇開始帶code
這邊所有用到的程式碼都會放在我的GitHub
https://github.com/p208p2002/taipei-QA-BERT
轉換成BERT輸入格式
首先輸入會被轉換成 WordPiece ids,而出來的預測結果也會是 WrodPiece ids,所以會需要處理文字與id的轉換,這邊分別為Q與A建立一個類別處理這個問題(實際上只會逆轉回去只會用到A)
makeDataset 這個函數是為了創建 PyTorch(需要tensor輸入) 與 DataLoader(接受TensorDataset物件) 對應的資料格式進行的必要轉換,由 DataLoader 幫我們進行批次切割、打亂資料等動作
# core.py import torch from torch.utils.data import TensorDataset def makeDataset(input_ids, input_masks, input_segment_ids, answer_lables): all_input_ids = torch.tensor([input_id for input_id in input_ids], dtype=torch.long) all_input_masks = torch.tensor([input_mask for input_mask in input_masks], dtype=torch.long) all_input_segment_ids = torch.tensor([input_segment_id for input_segment_id in input_segment_ids], dtype=torch.long) all_answer_lables = torch.tensor([answer_lable for answer_lable in answer_lables], dtype=torch.long) full_dataset = TensorDataset(all_input_ids, all_input_masks, all_input_segment_ids, all_answer_lables) # 切分訓練與測試資料集 train_size = int(0.8 * len(full_dataset)) test_size = len(full_dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size]) return train_dataset,test_dataset class AnsDic(object): def __init__(self, answers): self.answers = answers #全部答案(含重複) self.answers_norepeat = sorted(list(set(answers))) # 不重複 self.answers_types = len(self.answers_norepeat) # 總共多少類 self.ans_list = [] # 用於查找id或是text的list self._make_dic() # 製作字典 def _make_dic(self): for index_a,a in enumerate(self.answers_norepeat): if a != None: self.ans_list.append((index_a,a)) def to_id(self,text): for ans_id,ans_text in self.ans_list: if text == ans_text: return ans_id def to_text(self,id): for ans_id,ans_text in self.ans_list: if id == ans_id: return ans_text @property def types(self): return self.answers_types @property def data(self): return self.answers def __len__(self): return len(self.answers) class QuestionDic(AnsDic): def __init__(self,questions): super().__init__(answers = questions)
輸入轉換成 WordPiece ids 後我們就得到 Token Embeddings ,繼續完成 Segment Embeddings 與 Postition Embeddings
# preprocess_data.py from core import AnsDic,QuestionDic from transformers import BertTokenizer import pickle def make_ans_dic(answers): ansdic = AnsDic(answers) print("全部答案:",len(ansdic)) print("全部答案種類:",ansdic.types) # 測試轉換 a_id = ansdic.to_id('臺北市信義區公所') a_text = ansdic.to_text(a_id) assert a_text == '臺北市信義區公所' assert ansdic.to_id(a_text) == a_id return ansdic def make_question_dic(quetsions): return QuestionDic(quetsions) def convert_data_to_feature(): with open('Taipei_QA_new.txt','r',encoding='utf-8') as f: data = f.read() qa_pairs = data.split("\n") questions = [] answers = [] for qa_pair in qa_pairs: qa_pair = qa_pair.split() try: a,q = qa_pair questions.append(q) answers.append(a) except: continue assert len(answers) == len(questions) ans_dic = make_ans_dic(answers) question_dic = make_question_dic(questions) tokenizer = BertTokenizer(vocab_file='bert-base-chinese-vocab.txt') q_tokens = [] max_seq_len = 0 for q in question_dic.data: bert_ids = tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q))) if(len(bert_ids)>max_seq_len): max_seq_len = len(bert_ids) q_tokens.append(bert_ids) # print(tokenizer.convert_ids_to_tokens(tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q))))) print("最長問句長度:",max_seq_len) assert max_seq_len <= 512 # 小於BERT-base長度限制 # 補齊長度 for q in q_tokens: while len(q)<max_seq_len: q.append(0) a_labels = [] for a in ans_dic.data: a_labels.append(ans_dic.to_id(a)) # print (ans_dic.to_id(a)) # BERT input embedding answer_lables = a_labels input_ids = q_tokens input_masks = [[1]*max_seq_len for i in range(len(question_dic))] input_segment_ids = [[0]*max_seq_len for i in range(len(question_dic))] assert len(input_ids) == len(question_dic) and len(input_ids) == len(input_masks) and len(input_ids) == len(input_segment_ids) data_features = {'input_ids':input_ids, 'input_masks':input_masks, 'input_segment_ids':input_segment_ids, 'answer_lables':answer_lables, 'question_dic':question_dic, 'answer_dic':ans_dic} output = open('trained_model/data_features.pkl', 'wb') pickle.dump(data_features,output) return data_features if __name__ == "__main__": feature = convert_data_to_feature()
input_segment_ids 對應的是 Segment Embeddings
input_masks 對應的是 Postition Embeddings
注意把 data_features 存起來是為了在預測階段能夠重新取回 answer_dic ,進行WordPiece ids 轉換回文字
至此我們已經完成BERT的輸入格式轉換,可以開始丟入BERT進行Fine Tune的輸入格式轉換,可以開始丟入BERT進行 Fine-Tune 了
BERT Fine-Tune
接下來就與一般的PyTorch架構差不多,所以就直接看 code 吧
# train.py from preprocess_data import convert_data_to_feature from core import makeDataset from torch.utils.data import DataLoader from transformers import BertConfig, BertForSequenceClassification, BertTokenizer, AdamW import torch def compute_accuracy(y_pred, y_target): _, y_pred_indices = y_pred.max(dim=1) n_correct = torch.eq(y_pred_indices, y_target).sum().item() return n_correct / len(y_pred_indices) * 100 if __name__ == "__main__": bert_config, bert_class, bert_tokenizer = (BertConfig, BertForSequenceClassification, BertTokenizer) # setting device device = torch.device('cuda') data_feature = convert_data_to_feature() input_ids = data_feature['input_ids'] input_masks = data_feature['input_masks'] input_segment_ids = data_feature['input_segment_ids'] answer_lables = data_feature['answer_lables'] train_dataset, test_dataset = makeDataset(input_ids = input_ids, input_masks = input_masks, input_segment_ids = input_segment_ids, answer_lables = answer_lables) train_dataloader = DataLoader(train_dataset,batch_size=16,shuffle=True) test_dataloader = DataLoader(test_dataset,batch_size=16,shuffle=True) config = bert_config.from_pretrained('bert-base-chinese',num_labels = 149) model = bert_class.from_pretrained('bert-base-chinese', from_tf=bool('.ckpt' in 'bert-base-chinese'), config=config) model.to(device) # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=5e-6, eps=1e-8) # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) model.zero_grad() for epoch in range(30): running_loss_val = 0.0 running_acc = 0.0 for batch_index, batch_dict in enumerate(train_dataloader): model.train() batch_dict = tuple(t.to(device) for t in batch_dict) outputs = model( batch_dict[0], # attention_mask=batch_dict[1], labels = batch_dict[3] ) loss,logits = outputs[:2] loss.sum().backward() optimizer.step() # scheduler.step() # Update learning rate schedule model.zero_grad() # compute the loss loss_t = loss.item() running_loss_val += (loss_t - running_loss_val) / (batch_index + 1) # compute the accuracy acc_t = compute_accuracy(logits, batch_dict[3]) running_acc += (acc_t - running_acc) / (batch_index + 1) # log print("epoch:%2d batch:%4d train_loss:%2.4f train_acc:%3.4f"%(epoch+1, batch_index+1, running_loss_val, running_acc)) running_loss_val = 0.0 running_acc = 0.0 for batch_index, batch_dict in enumerate(test_dataloader): model.eval() batch_dict = tuple(t.to(device) for t in batch_dict) outputs = model( batch_dict[0], # attention_mask=batch_dict[1], labels = batch_dict[3] ) loss,logits = outputs[:2] # compute the loss loss_t = loss.item() running_loss_val += (loss_t - running_loss_val) / (batch_index + 1) # compute the accuracy acc_t = compute_accuracy(logits, batch_dict[3]) running_acc += (acc_t - running_acc) / (batch_index + 1) # log print("epoch:%2d batch:%4d test_loss:%2.4f test_acc:%3.4f"%(epoch+1, batch_index+1, running_loss_val, running_acc)) model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save.save_pretrained('trained_model')
結果預測
將輸入轉換成轉換成BERT輸入,丟進去BERT,最後將預測的id轉換成對應的文字
# predict.py from transformers import BertTokenizer import torch import pickle from transformers import BertConfig, BertForSequenceClassification, BertTokenizer, AdamW def toBertIds(q_input): return tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q_input))) if __name__ == "__main__": # load and init tokenizer = BertTokenizer(vocab_file='bert-base-chinese-vocab.txt') pkl_file = open('trained_model/data_features.pkl', 'rb') data_features = pickle.load(pkl_file) answer_dic = data_features['answer_dic'] bert_config, bert_class, bert_tokenizer = (BertConfig, BertForSequenceClassification, BertTokenizer) config = bert_config.from_pretrained('trained_model/config.json') model = bert_class.from_pretrained('trained_model/pytorch_model.bin', from_tf=bool('.ckpt' in 'bert-base-chinese'), config=config) model.eval() # q_inputs = ['為何路邊停車格有編號的要收費,無編號的不用收費','債權人可否向稅捐稽徵處申請查調債務人之財產、所得資料','想做大腸癌篩檢,不知如何辨理'] for q_input in q_inputs: bert_ids = toBertIds(q_input) assert len(bert_ids) <= 512 input_ids = torch.LongTensor(bert_ids).unsqueeze(0) # predict outputs = model(input_ids) predicts = outputs[:2] predicts = predicts[0] max_val = torch.max(predicts) label = (predicts == max_val).nonzero().numpy()[0][1] ans_label = answer_dic.to_text(label) print(q_input) print(ans_label) print()