Bert(三)分类实战

    科技2022-07-10  111

    1. 前言

    所需要的数据 需要编写的4个函数 class.txt 格式: trian.txt 格式: 数据内容与标签以tab符分隔 dev.txt 和 test.txt 分别是验证集和测试集

    2 代码

    2.1 utils.py

    这个函数主要用来加载数据,批量训练数据

    import torch #from tqdm import tqdm # 记录运行时间的函数 import time from datetime import timedelta from torch.utils.data import Dataset, DataLoader from pytorch_pretrained_bert import BertTokenizer CLS = "[CLS]" PAD = "[PAD]" #少量样本的路径,主要测试数据加载代码的正确性 data_path = r"K:\NLP\cs224n_2019\cs224n_2019\pytorch-bert-code\bert分类模型·\data\aaa.txt" # 加载bert词库 bert_vocab_path = r"D:\Bert\bert-base-chinese\vocab.txt" bert_vocab = BertTokenizer(bert_vocab_path, do_lower_case=True) class BertDataset(Dataset): def __init__(self,path, pad_size = 32): self.max_seq_len = pad_size # 为了批量训练,设置max_len self.contents = [] with open(path, "r", encoding="UTF-8") as f: for line in f: # 读取文件的每一行(str) line = line.strip() if not line: # 如果line是空字符,就读下一行 continue content,label = line.split("\t") # 内容与标签的分隔符是制表符 # 分词 tokenized_content = bert_vocab.tokenize(content) # list tokenized_content = [CLS] + tokenized_content + [PAD] # 转化为ids tokenized_content_ids = bert_vocab.convert_tokens_to_ids(tokenized_content) # list # 考虑padding情况 input_seq_len = len(tokenized_content) if self.max_seq_len: if input_seq_len <= self.max_seq_len: # padding tokenized_content_ids += [0] * (self.max_seq_len - input_seq_len) input_mask = [1] * input_seq_len + [0] * (self.max_seq_len - input_seq_len) else: # truncate tokenized_content_ids = tokenized_content_ids[:self.max_seq_len] # 1ist input_mask = [1] * self.max_seq_len # list input_seq_len = self.max_seq_len # int self.contents.append((tokenized_content_ids, int(label), input_seq_len, input_mask)) #print("contents: ", len(self.contents)) 句子个数 def __getitem__(self, item): contents_tensor = self.list_to_tensor(self.contents, item) # (x,input_seq_len, input_mask),y # contents_tensor: tuple # content 是一个列表,列表只有 3 个 元素: [有batch的x, 有batch的 input_seq_len, 有batch的 input_mask] content = contents_tensor[0] label = contents_tensor[1] # y return content,label def __len__(self): return len(self.contents) @ staticmethod def list_to_tensor(contents,idx): ''' 将列表转化为long tensor 形式 ''' x = torch.LongTensor([contents[idx][0]]) #print("x", x.size()) # 1 * max_seq_len y = torch.LongTensor([contents[idx][1]]) input_seq_len = torch.LongTensor([contents[idx][2]]) #print("input_seq_len", input_seq_len) # 如tensor:[18] input_mask = torch.LongTensor([contents[idx][3]]) #print("input_mask", input_mask.size()) # 1 * max_seq_len return (x,input_seq_len, input_mask),y # 使用 DataLoader Bert_dataloader = DataLoader(BertDataset(data_path), batch_size= 2, shuffle= True, drop_last=True) # drop_last = True 表示当最后一批除不尽时就舍弃 if __name__ == "__main__": bertdata = BertDataset(data_path) bertdata.list_to_tensor(bertdata.contents, 0) for key, (content, label) in enumerate(Bert_dataloader): print("key: ", key) print("content: ",content, type(content), len(content)) # content 是list,len(content): 3 print("content[0]: ", content[0].size()) # [2,1,32] # print("content[1]: ", content[1].size()) # [2,1] # print("content[2]: ", content[2].size()) # [2,1,32] print("label: ", label, label.size()) # tensor([3],[2]) 2 x 1 break print(len(bertdata)) # 19 个句子 print(len(Bert_dataloader)) # 9 个批,最后一个样本drop掉了

    2.2 bert.py

    这个函数只需要在预训练模型bert 的基础上,在搭建一个线性分类器就可以实现分类。如bert(一)介绍和使用中说的那样,只需要将 pooled_output,即bert最后一层 [CLS] 位置输出的向量传到线性层输入就行了。

    import torch import torch.nn as nn from pytorch_pretrained_bert import BertTokenizer,BertConfig, BertModel import os path = r"K:\NLP\cs224n_2019\cs224n_2019\pytorch-bert-code\bert分类模型·\data" class Config(): ''' 配置参数 ''' def __init__(self, path): self.bert_config_path = r"D:\Bert\bert-base-chinese\bert_config.json" self.bert_model_path = r"D:\Bert\bert-base-chinese" self.bert_vocab_path = r"D:\Bert\bert-base-chinese\vocab.txt" # 训练数据路径 self.train_data_path = path + r"\train.txt" # 验证集路径 self.dev_data_path = path + r"\dev.txt" # 测试集路径 self.test_data_path = path + r"\test.txt" #类别列表 self.class_list = [x.strip() for x in open( path + r"\class.txt").readlines()] #训练完模型的储存路径 self.trained_bert_path = path + "\\saved_dict\\" + "Bert" + ".ckpt" # 若超过1000batch 效果还没提升,则提取结束训练 self.require_improvement = 1000 #类别数目 self.num_classes = len(self.class_list) # epoch 数目 self.num_epochs = 3 # Batch self.batch_size = 128 # 设置最大句子长度 self.max_seq_len = 32 # 学习率 self.lr = 0.0001 # 隐层神经元个数 self.hidden_size = 768 # 加载中文 bert的配置和字库 self.bert_config = BertConfig.from_json_file(self.bert_config_path) self.bert_vocab = BertTokenizer(self.bert_vocab_path,do_lower_case=True) # 加载bert self.bert_model = BertModel.from_pretrained(self.bert_model_path) #print("bert parm:", self.bert_model.parameter()) # bert_model没有参数 class MyBertModel(nn.Module): ''' 定义bert模型 ''' def __init__(self, config): super().__init__() self.bert = config.bert_model # 使用config类中的bert模型, 继承了nn.Modeule,所以bert有 parameter这个性质 #print("bert1 param: ",self.bert.parameters()) ##<generator object Module.parameters at 0x000000000DEC74C8> # bert.parameters 是生成器,需要迭代才能显示出来 for param in self.bert.parameters(): param.required_grad = True # 需要对 bert 中的参数进行训练 # 重要的线性层 self.fc = nn.Linear(in_features= config.hidden_size , out_features= config.num_classes) ## nn.Module必须有forward函数, 通过 my_bert_model(x)直接调用 forward函数 def forward(self, x): # x -> (tokenized_content_ids,input_seq_len, input_mask) # x 是一个 列表 , 只有 3 维 #print("x.size: ", len(x), type(x)) list, 3 #print("x[0].size: ", x[0].size()) # b x 1 x 32(128 x 1 x 32) input_ids = x[0].squeeze_() input_mask = x[2].squeeze_() all_encoder_layer, pooled_output = self.bert(input_ids, input_mask) # ## all_encoder_layers: 一个包含 12 个 transfomer encoder 层输出的列表list ### 每一层的输出大小是: [batch, seq_lenth, hiddern_size], 对于bert-base 是12, bert-large是24 #print("all_encoder_layer:\n ", type(all_encoder_layer), len(all_encoder_layer), all_encoder_layer[0].size()) # <class 'list'> 12 torch.Size([128, 32, 768]) ## pooled_output: 最后一个 transfomer encoder ,且输入句子的第一个字[CLS]位置上的隐层输出 ### 大小:[batch, hidden_size] ### 这个输出就可以看做是输入句子的语义信息了 #print("pooled_output: \n", type(pooled_output), len(pooled_output), pooled_output.size()) # <class 'torch.Tensor'> 128 torch.Size([128, 768]) # 做分类任务时,只需要 pooled_output 向量 out = self.fc(pooled_output) return out if __name__ == "__main__": config = Config(path) my_bert_model = MyBertModel(config)

    2.3 Train_eval.py

    这个函数主要是训练和验证bert分类模型

    import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from sklearn import metrics import time from transformers.optimization import AdamW # AdamW 比 Adam 效果好,在Transformers中,pytorch没有自带 def get_time(start_time): end_time = time.time() time_df = end_time - start_time return time_df # 权重初始化,默认 xavier def weight_init(model, method = "xavier", exclude = "embedding", seed = 111): #print("model: ", [n for n, p in list(model.named_parameters())]) # 打印bert 网络层的名字,下面打印东西为了解bert的网络层,可以忽略... '''model: ['bert.embeddings.word_embeddings.weight', 'bert.embeddings.position_embeddings.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.embeddings.LayerNorm.weight', 'bert.embeddings.LayerNorm.bias', 'bert.encoder.layer.0.attention.self.query.weight', 'bert.encoder.layer.0.attention.self.query.bias', 'bert.encoder.layer.0.attention.self.key.weight', 'bert.encoder.layer.0.attention.self.key.bias', 'bert.encoder.layer.0.attention.self.value.weight', 'bert.encoder.layer.0.attention.self.value.bias', 'bert.encoder.layer.0.attention.output.dense.weight', 'bert.encoder.layer.0.attention.output.dense.bias', 'bert.encoder.layer.0.attention.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.output.LayerNorm.bias', 'bert.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.layer.0.intermediate.dense.bias', 'bert.encoder.layer.0.output.dense.weight', 'bert.encoder.layer.0.output.dense.bias', 'bert.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.layer.0.output.LayerNorm.bias', 'bert.encoder.layer.1.attention.self.query.weight', 'bert.encoder.layer.1.attention.self.query.bias', 'bert.encoder.layer.1.attention.self.key.weight', 'bert.encoder.layer.1.attention.self.key.bias', 'bert.encoder.layer.1.attention.self.value.weight', 'bert.encoder.layer.1.attention.self.value.bias', 'bert.encoder.layer.1.attention.output.dense.weight', 'bert.encoder.layer.1.attention.output.dense.bias', 'bert.encoder.layer.1.attention.output.LayerNorm.weight', 'bert.encoder.layer.1.attention.output.LayerNorm.bias', 'bert.encoder.layer.1.intermediate.dense.weight', 'bert.encoder.layer.1.intermediate.dense.bias', 'bert.encoder.layer.1.output.dense.weight', 'bert.encoder.layer.1.output.dense.bias', 'bert.encoder.layer.1.output.LayerNorm.weight', 'bert.encoder.layer.1.output.LayerNorm.bias', 'bert.encoder.layer.2.attention.self.query.weight', 'bert.encoder.layer.2.attention.self.query.bias', 'bert.encoder.layer.2.attention.self.key.weight', 'bert.encoder.layer.2.attention.self.key.bias', 'bert.encoder.layer.2.attention.self.value.weight', 'bert.encoder.layer.2.attention.self.value.bias', 'bert.encoder.layer.2.attention.output.dense.weight', 'bert.encoder.layer.2.attention.output.dense.bias', 'bert.encoder.layer.2.attention.output.LayerNorm.weight', 'bert.encoder.layer.2.attention.output.LayerNorm.bias', 'bert.encoder.layer.2.intermediate.dense.weight', 'bert.encoder.layer.2.intermediate.dense.bias', 'bert.encoder.layer.2.output.dense.weight', 'bert.encoder.layer.2.output.dense.bias', 'bert.encoder.layer.2.output.LayerNorm.weight', 'bert.encoder.layer.2.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.3.attention.self.value.bias', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.3.attention.output.dense.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.layer.3.intermediate.dense.bias', 'bert.encoder.layer.3.output.dense.weight', 'bert.encoder.layer.3.output.dense.bias', 'bert.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.layer.3.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.4.attention.self.key.bias', 'bert.encoder.layer.4.attention.self.value.weight', 'bert.encoder.layer.4.attention.self.value.bias', 'bert.encoder.layer.4.attention.output.dense.weight', 'bert.encoder.layer.4.attention.output.dense.bias', 'bert.encoder.layer.4.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.4.intermediate.dense.bias', 'bert.encoder.layer.4.output.dense.weight', 'bert.encoder.layer.4.output.dense.bias', 'bert.encoder.layer.4.output.LayerNorm.weight', 'bert.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.5.output.LayerNorm.weight', 'bert.encoder.layer.5.output.LayerNorm.bias', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.7.output.dense.bias', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.8.attention.self.value.weight', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.9.attention.self.key.weight', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.9.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.11.attention.self.query.bias', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'fc.weight', 'fc.bias'] ''' for name, params in model.named_parameters(): if exclude not in name: if len(params.size()) < 2: # tensor的维度小于2,不做初始化 continue if "weight" in name: if method == "xavier": nn.init.xavier_normal_(params) if method == "Kaiming": nn.init.kaiming_normal_(params) else: nn.init.normal_(params) elif "bias" in name: nn.init.constant_(params,0) # 用 常数0 来初始化偏置 else: pass def train(config, my_bert_model, train_iter, dev_iter, test_iter): start_time = time.time() # 设置model为训练模式 my_bert_model.train() params = list(my_bert_model.named_parameters()) no_decay = ["bias", "LayerNorm", "LayerNorm.weight"] optimizer_grouped_params = [ {"params": [p for n,p in params if not any(nd in n for nd in no_decay)], # 这里返回的 p值对应的n 是都不包含no_decay的 "weight_decay": 0.01}, {"params": [p for n, p in params if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] ## nd in n 返回 True / False, any()中只要一个元素为真,则返回真 ## weight_decay 是 l2 正则化的参数λ # optimizer = torch.optim.Adam(model.parameters(), lr = config.lr) # 可以不定义 optimizer_grouped_params,直接使用pytorch自带的参数优化器 optimzer = AdamW(optimizer_grouped_params, lr = config.lr) # 记录进行到第多少batch total_batch = 0 # 定义最初的验证集损失 dev_best_loss = float("inf") # inf # 记录上次 验证集损失下降 的batch数目 last_batch = 0 # 记录是否很久没有效果提升 flag = False # 开始训练 my_bert_model.train() # 调到训练模式 weight_init(my_bert_model) for epoch in range(config.num_epochs): print("Epoch [{}/{}]".format(epoch+1, config.num_epochs)) for key, (contents, label) in enumerate(train_iter): # contents 是列表 #print("contents", type(contents),len(contents)) outputs = my_bert_model(contents) #print("outputs.size():", outputs.size()) # 128 x 10 类 my_bert_model.zero_grad() # 清空梯度 #print("label: ", label.size()) # 128 x 1 loss_func = nn.CrossEntropyLoss() loss = loss_func(outputs, label.reshape(config.batch_size,).long()) # label: 128 x 1 loss.backward() #反向传播 optimzer.step() # 每多少轮输出在训练集和验证集上的效果 if total_batch % 100 == 0: true = label.data # outputs: 每个类别的概率,shape:[1, num_class] # max(input, dim) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` # where ``values`` is the maximum value of each row of the ###:attr:`input` tensor in the given dimension ###:attr:`dim`. And ``indices`` is the index location of each maximum value predic = torch.max(outputs.data, 1)[1] # 只保留最大值索引 train_acc = metrics.accuracy_score(true,predic) # 验证集正确率和损失,通过编写 evaluate函数实现 dev_acc, dev_loss = evaluate(config,my_bert_model,dev_iter) if dev_loss < dev_best_loss: dev_best_loss = dev_loss # 保存模型训练结果 torch.save(my_bert_model.state_dict(), config.trained_bert_path) improve = "*" last_batch = total_batch else: improve = "" time_dif = get_time(start_time) msg = "iter:{0:>6}, Train Loss:{1:5.2}, Train Acc:{2:>6.2%}, " \ "Val Loss:{3:>5.2}, Val Acc:{4:>6.2%}, Time:{5}{6}" print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)) my_bert_model.train() total_batch += 1 # 验证集 loss 超过 1000 batch 没有下降,就结束训练 if total_batch - last_batch > config.require_improvement: print(" No optimization for a long time, auto-stopping...") flag = True break if flag: break test(config, my_bert_model, test_iter) def test(config, my_bert_model, test_iter): # test my_bert_model.load_state_dict(torch.load(config.trained_bert_path)) # 模型加载分两步load_state_dict和torch.load # 调到 测试模式,不用更新参数 my_bert_model.eval() start_time = time.time() test_acc, test_loss, test_report, test_confusion = evaluate(config, my_bert_model, test_iter, test = True) msg = "Test Loss: {0:>5.2}, Test Acc :{1:>6.2%}" print(msg.format(test_loss, test_acc)) print("Precision, Recall, F1_score") print(test_report) print("Confusion Matrix...") print(test_confusion) time_df = get_time(start_time) print("Time usage...: ", time_df) def evaluate(config, my_bert_model, data_iter, test = False): # 调到 测试模式,不用更新参数 my_bert_model.eval() loss_total = 0 predic_all = np.array([], dtype = int) label_all = np.array([], dtype = int) with torch.no_grad(): for texts, labels in data_iter: outputs = my_bert_model(texts) loss_func = nn.CrossEntropyLoss() loss = loss_func(outputs, labels.reshape(config.batch_size, ).long()) # label: 128 x 1 loss_total += loss labels = labels.data.numpy() # 128 x 1 #print(labels, labels.size, type(labels)) predic = torch.max(outputs.data, 1)[1].numpy() label_all = np.append(label_all, labels) # 一维的np.array,size:128 predic_all = np.append(predic_all, predic) # 一维的np.array,size:128 #print("label_all.size", label_all.size, predic_all.size, type(label_all)) acc = metrics.accuracy_score(label_all, predic_all) if test: report = metrics.classification_report(label_all, predic_all) # y_true : 1d array-like, confusion = metrics.confusion_matrix(label_all, predic_all) # len(data_iter) 返回多少个batch return acc, loss_total/ len(data_iter),report, confusion return acc, loss_total / len(data_iter)

    2.4 run.py

    程序执行入口

    import time import torch import numpy as np from Train_eval import train, weight_init, get_time from torch.utils.data import DataLoader, Dataset import argparse # 用于在控制台打印程序 from utils import BertDataset import bert parser = argparse.ArgumentParser(description=" Chinease Text Classificatin") parser.add_argument("--model", type = str, required = False, help = "choose a model: Bert") args = parser.parse_args() if __name__ == "__main__": # 数据集 path = r"K:\NLP\cs224n_2019\cs224n_2019\pytorch-bert-code\bert分类模型·\data" config = bert.Config(path) start_time = time.time() print("Loading data...") train_iter = DataLoader(BertDataset(config.train_data_path), config.batch_size, shuffle= True, drop_last=True) dev_iter = DataLoader(BertDataset(config.dev_data_path), config.batch_size, shuffle= True, drop_last=True) test_iter = DataLoader(BertDataset(config.test_data_path), config.batch_size, shuffle= True, drop_last=True) time_df = get_time(start_time) print("Time usage...: ", time_df) # Train my_bert_model = bert.MyBertModel(config) train(config,my_bert_model,train_iter, dev_iter,test_iter)

    电脑跑了一轮迭代就快卒了…,所以就不贴出结果了

    Processed: 0.012, SQL: 8