BERT QA 機器人實戰 – 下篇




上篇稍微帶一下原理,下篇開始帶code

這邊所有用到的程式碼都會放在我的GitHub

https://github.com/p208p2002/taipei-QA-BERT

轉換成BERT輸入格式

首先輸入會被轉換成 WordPiece ids,而出來的預測結果也會是 WrodPiece ids,所以會需要處理文字與id的轉換,這邊分別為Q與A建立一個類別處理這個問題(實際上只會逆轉回去只會用到A)

makeDataset 這個函數是為了創建 PyTorch(需要tensor輸入) 與 DataLoader(接受TensorDataset物件) 對應的資料格式進行的必要轉換,由 DataLoader 幫我們進行批次切割、打亂資料等動作

# core.py
import torch
from torch.utils.data import TensorDataset

def makeDataset(input_ids, input_masks, input_segment_ids, answer_lables):
    all_input_ids = torch.tensor([input_id for input_id in input_ids], dtype=torch.long)
    all_input_masks = torch.tensor([input_mask for input_mask in input_masks], dtype=torch.long)
    all_input_segment_ids = torch.tensor([input_segment_id for input_segment_id in input_segment_ids], dtype=torch.long)
    all_answer_lables = torch.tensor([answer_lable for answer_lable in answer_lables], dtype=torch.long)
    
    full_dataset = TensorDataset(all_input_ids, all_input_masks, all_input_segment_ids, all_answer_lables)
    
    # 切分訓練與測試資料集
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

    return train_dataset,test_dataset
    

class AnsDic(object):
    def __init__(self, answers):
        self.answers = answers #全部答案(含重複)
        self.answers_norepeat = sorted(list(set(answers))) # 不重複
        self.answers_types = len(self.answers_norepeat) # 總共多少類
        self.ans_list = [] # 用於查找id或是text的list
        self._make_dic() # 製作字典
    
    def _make_dic(self):
        for index_a,a in enumerate(self.answers_norepeat):
            if a != None:
                self.ans_list.append((index_a,a))

    def to_id(self,text):
        for ans_id,ans_text in self.ans_list:
            if text == ans_text:
                return ans_id

    def to_text(self,id):
        for ans_id,ans_text in self.ans_list:
            if id == ans_id:
                return ans_text

    @property
    def types(self):
        return self.answers_types
    
    @property
    def data(self):
        return self.answers

    def __len__(self):
        return len(self.answers)

class QuestionDic(AnsDic):
    def __init__(self,questions):
        super().__init__(answers = questions)

輸入轉換成 WordPiece ids 後我們就得到 Token Embeddings ,繼續完成 Segment Embeddings 與 Postition Embeddings

# preprocess_data.py
from core import AnsDic,QuestionDic
from transformers import BertTokenizer
import pickle

def make_ans_dic(answers):
    ansdic = AnsDic(answers)
    print("全部答案:",len(ansdic))
    print("全部答案種類:",ansdic.types)
    
    # 測試轉換
    a_id = ansdic.to_id('臺北市信義區公所')
    a_text = ansdic.to_text(a_id)
    assert a_text == '臺北市信義區公所'
    assert ansdic.to_id(a_text) == a_id

    return ansdic

def make_question_dic(quetsions):
    return QuestionDic(quetsions)
    

def convert_data_to_feature():
    with open('Taipei_QA_new.txt','r',encoding='utf-8') as f:
        data = f.read()
    qa_pairs = data.split("\n")

    questions = []
    answers = []
    for qa_pair in qa_pairs:
        qa_pair = qa_pair.split()
        try:
            a,q = qa_pair
            questions.append(q)
            answers.append(a)
        except:
            continue
    
    assert len(answers) == len(questions)
    
    ans_dic = make_ans_dic(answers)
    question_dic = make_question_dic(questions)
    
    tokenizer = BertTokenizer(vocab_file='bert-base-chinese-vocab.txt')

    q_tokens = []
    max_seq_len = 0
    for q in question_dic.data:
        bert_ids = tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q)))
        if(len(bert_ids)>max_seq_len):
            max_seq_len = len(bert_ids)
        q_tokens.append(bert_ids)
        # print(tokenizer.convert_ids_to_tokens(tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q)))))
    
    print("最長問句長度:",max_seq_len)
    assert max_seq_len <= 512 # 小於BERT-base長度限制

    # 補齊長度
    for q in q_tokens:
        while len(q)<max_seq_len:
            q.append(0)
    
    a_labels = []
    for a in ans_dic.data:
        a_labels.append(ans_dic.to_id(a))
        # print (ans_dic.to_id(a))
    
    # BERT input embedding
    answer_lables = a_labels
    input_ids = q_tokens
    input_masks = [[1]*max_seq_len for i in range(len(question_dic))]
    input_segment_ids = [[0]*max_seq_len for i in range(len(question_dic))]
    assert len(input_ids) == len(question_dic) and len(input_ids) == len(input_masks) and len(input_ids) == len(input_segment_ids)

    data_features = {'input_ids':input_ids,
                    'input_masks':input_masks,
                    'input_segment_ids':input_segment_ids,
                    'answer_lables':answer_lables,
                    'question_dic':question_dic,
                    'answer_dic':ans_dic}
    
    output = open('trained_model/data_features.pkl', 'wb')
    pickle.dump(data_features,output)
    return data_features


if __name__ == "__main__":
    feature = convert_data_to_feature()

input_segment_ids 對應的是 Segment Embeddings

input_masks 對應的是 Postition Embeddings

注意把 data_features 存起來是為了在預測階段能夠重新取回 answer_dic ,進行WordPiece ids 轉換回文字

至此我們已經完成BERT的輸入格式轉換,可以開始丟入BERT進行Fine Tune的輸入格式轉換,可以開始丟入BERT進行 Fine-Tune 了

BERT Fine-Tune

接下來就與一般的PyTorch架構差不多,所以就直接看 code 吧

# train.py
from preprocess_data import convert_data_to_feature
from core import makeDataset
from torch.utils.data import DataLoader
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer, AdamW
import torch

def compute_accuracy(y_pred, y_target):
    _, y_pred_indices = y_pred.max(dim=1)
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100

if __name__ == "__main__":
    bert_config, bert_class, bert_tokenizer = (BertConfig, BertForSequenceClassification, BertTokenizer)
    
    # setting device
    device = torch.device('cuda')

    data_feature = convert_data_to_feature()
    input_ids = data_feature['input_ids']
    input_masks = data_feature['input_masks']
    input_segment_ids = data_feature['input_segment_ids']
    answer_lables = data_feature['answer_lables']
    
    train_dataset, test_dataset = makeDataset(input_ids = input_ids, input_masks = input_masks, input_segment_ids = input_segment_ids, answer_lables = answer_lables)
    train_dataloader = DataLoader(train_dataset,batch_size=16,shuffle=True)
    test_dataloader = DataLoader(test_dataset,batch_size=16,shuffle=True)

    config = bert_config.from_pretrained('bert-base-chinese',num_labels = 149)
    model = bert_class.from_pretrained('bert-base-chinese', from_tf=bool('.ckpt' in 'bert-base-chinese'), config=config)
    model.to(device)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=5e-6, eps=1e-8)
    # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)

    model.zero_grad()
    for epoch in range(30):
        running_loss_val = 0.0
        running_acc = 0.0
        for batch_index, batch_dict in enumerate(train_dataloader):
            model.train()
            batch_dict = tuple(t.to(device) for t in batch_dict)
            outputs = model(
                batch_dict[0],
                # attention_mask=batch_dict[1],
                labels = batch_dict[3]
                )
            loss,logits = outputs[:2]
            loss.sum().backward()
            optimizer.step()
            # scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            
            # compute the loss
            loss_t = loss.item()
            running_loss_val += (loss_t - running_loss_val) / (batch_index + 1)

            # compute the accuracy
            acc_t = compute_accuracy(logits, batch_dict[3])
            running_acc += (acc_t - running_acc) / (batch_index + 1)

            # log
            print("epoch:%2d batch:%4d train_loss:%2.4f train_acc:%3.4f"%(epoch+1, batch_index+1, running_loss_val, running_acc))
        
        running_loss_val = 0.0
        running_acc = 0.0
        for batch_index, batch_dict in enumerate(test_dataloader):
            model.eval()
            batch_dict = tuple(t.to(device) for t in batch_dict)
            outputs = model(
                batch_dict[0],
                # attention_mask=batch_dict[1],
                labels = batch_dict[3]
                )
            loss,logits = outputs[:2]
            
            # compute the loss
            loss_t = loss.item()
            running_loss_val += (loss_t - running_loss_val) / (batch_index + 1)

            # compute the accuracy
            acc_t = compute_accuracy(logits, batch_dict[3])
            running_acc += (acc_t - running_acc) / (batch_index + 1)

            # log
            print("epoch:%2d batch:%4d test_loss:%2.4f test_acc:%3.4f"%(epoch+1, batch_index+1, running_loss_val, running_acc))
    
    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
    model_to_save.save_pretrained('trained_model')

結果預測

將輸入轉換成轉換成BERT輸入,丟進去BERT,最後將預測的id轉換成對應的文字

# predict.py
from transformers import BertTokenizer
import torch
import pickle
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer, AdamW

def toBertIds(q_input):
    return tokenizer.build_inputs_with_special_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q_input)))

if __name__ == "__main__":
    # load and init
    tokenizer = BertTokenizer(vocab_file='bert-base-chinese-vocab.txt')
    pkl_file = open('trained_model/data_features.pkl', 'rb')
    data_features = pickle.load(pkl_file)
    answer_dic = data_features['answer_dic']
    
    bert_config, bert_class, bert_tokenizer = (BertConfig, BertForSequenceClassification, BertTokenizer)
    config = bert_config.from_pretrained('trained_model/config.json')
    model = bert_class.from_pretrained('trained_model/pytorch_model.bin', from_tf=bool('.ckpt' in 'bert-base-chinese'), config=config)
    model.eval()

    #
    q_inputs = ['為何路邊停車格有編號的要收費,無編號的不用收費','債權人可否向稅捐稽徵處申請查調債務人之財產、所得資料','想做大腸癌篩檢,不知如何辨理']
    for q_input in q_inputs:
        bert_ids = toBertIds(q_input)
        assert len(bert_ids) <= 512
        input_ids = torch.LongTensor(bert_ids).unsqueeze(0)

        # predict
        outputs = model(input_ids)
        predicts = outputs[:2]
        predicts = predicts[0]
        max_val = torch.max(predicts)
        label = (predicts == max_val).nonzero().numpy()[0][1]
        ans_label = answer_dic.to_text(label)
        
        print(q_input)
        print(ans_label)
        print()

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *

這個網站採用 Akismet 服務減少垃圾留言。進一步瞭解 Akismet 如何處理網站訪客的留言資料