当前位置: 代码迷 >> 综合 >> bert --> 文本分类
  详细解决方案

bert --> 文本分类

热度:59   发布时间:2023-12-06 23:40:31.0

文档地址

https://github.com/google-research/bert

准备工作

## Required parameters## 数据集:包含训练数据集、验证数据集和预测数据集(附件可下载)
flags.DEFINE_string("data_dir", "数据文件/tmp","The input data dir. Should contain the .tsv files (or other data files) ""for the task.")## bert的配置文件(附件可下载)
flags.DEFINE_string("bert_config_file", "数据文件/chinese_L-12_H-768_A-12/bert_config.json","The config json file corresponding to the pre-trained BERT model. ""This specifies the model architecture.")## 启动的任务,最关键的代码块,重写DataProcessor,主要是数据集的处理工作
flags.DEFINE_string("task_name", "csv", "The name of the task to train.")## 字典集
flags.DEFINE_string("vocab_file", "数据文件/chinese_L-12_H-768_A-12/vocab.txt","The vocabulary file that the BERT model was trained on.")## 输出目录
flags.DEFINE_string("output_dir", "output","The output directory where the model checkpoints will be written.")## Other parameters## 基本预训练模型
flags.DEFINE_string("init_checkpoint", "数据文件/chinese_L-12_H-768_A-12/bert_model.ckpt","Initial checkpoint (usually from a pre-trained BERT model).")## 是否对输入的文本小写
flags.DEFINE_bool("do_lower_case", True,"Whether to lower case the input text. Should be True for uncased ""models and False for cased models.")## 最大的序列长度
flags.DEFINE_integer("max_seq_length", 128,"The maximum total input sequence length after WordPiece tokenization. ""Sequences longer than this will be truncated, and sequences shorter ""than this will be padded.")## 是否训练
flags.DEFINE_bool("do_train", True, "Whether to run training.")## 是否验证
flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.")## 是否预测
flags.DEFINE_bool("do_predict", True,"Whether to run the model in inference mode on the test set.")

关键实现代码

bert自带文本分类run_classifier.py,新增一个实现DataProcessor的数据集处理类即可处理自己想要处理的数据,然后加入到处理器即可。

处理器

processors = {
    "cola": ColaProcessor,"mnli": MnliProcessor,"mrpc": MrpcProcessor,"xnli": XnliProcessor,"csv": CsvProcessor}

基本的数据集处理器DataProcessor

class DataProcessor(object):"""Base class for data converters for sequence classification data sets."""def get_train_examples(self, data_dir):"""Gets a collection of `InputExample`s for the train set."""raise NotImplementedError()def get_dev_examples(self, data_dir):"""Gets a collection of `InputExample`s for the dev set."""raise NotImplementedError()def get_test_examples(self, data_dir):"""Gets a collection of `InputExample`s for prediction."""raise NotImplementedError()def get_labels(self):"""Gets the list of labels for this data set."""raise NotImplementedError()@classmethoddef _read_tsv(cls, input_file, quotechar=None):"""Reads a tab separated value file."""with tf.gfile.Open(input_file, "r") as f:reader = csv.reader(f, delimiter="\t", quotechar=quotechar)lines = []for line in reader:lines.append(line)return lines

自定义数据集处理器CsvProcessor

class CsvProcessor(DataProcessor):"""Processor for the CoLA data set (GLUE version)."""## 获取训练数据集def get_train_examples(self, data_dir):"""See base class."""return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.csv")), "train")## 获取验证数据集def get_dev_examples(self, data_dir):"""See base class."""return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.csv")), "dev")## 获取预测数据集def get_test_examples(self, data_dir):"""See base class."""return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.csv")), "test")## 分类的数组## 以train.csv为例子,训练模型的时候主要有"O"和"CS"两大类def get_labels(self):"""See base class."""return ["O", "CS"]## 具体的数据处理逻辑def _create_examples(self, lines, set_type):"""Creates examples for the training and dev sets."""examples = []# i是每一行的索引# line是每一行内容for (i, line) in enumerate(lines):## 获取到第一格的内容并且以逗号分隔字符串data = line[0].split(",")# 第一行是字段名称,忽略处理if i == 0:continue## 每条数据的idguid = "%s-%s" % (set_type, i)## 忽略验证集的空行if set_type == "dev" and data[0] == "":continue## 忽略预测集的空行if set_type == "test" and data[1] == "":continue## 如果是预测集,label指定使用"CS"进行预测if set_type == "test":text_a = tokenization.convert_to_unicode(data[1])label = "CS"## 第一个是文本分类,第二个是文本内容else:text_a = tokenization.convert_to_unicode(data[1])label = tokenization.convert_to_unicode(data[0])## 构建对象并且加进列表examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))return examples

计算PRF

            def metric_fn(per_example_loss, label_ids, logits, is_real_example):predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example)# 计算PRF数值auc = tf.metrics.auc(labels=label_ids, predictions=predictions, weights=is_real_example)# PRF中的Pprecision = tf.metrics.precision(labels=label_ids, predictions=predictions, weights=is_real_example)# PRF中的Rrecall = tf.metrics.recall(labels=label_ids, predictions=predictions, weights=is_real_example)# F的值等于(2 * P * R) / (P + R)loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)return {
    "eval_accuracy": accuracy,"eval_auc": auc,"eval_precision": precision,"eval_recall": recall,"eval_loss": loss,}

源码地址

https://download.csdn.net/download/rainbowBear/16747704
  相关解决方案