关于自然语言处理:使用-Transformers-在你自己的数据集上训练文本分类模型

3次阅读

共计 1809 个字符,预计需要花费 5 分钟才能阅读完成。

最近切实是有点忙,没啥工夫写博客了。趁着周末水一文,把最近用 huggingface transformers 训练文本分类模型时遇到的一个小问题说下。

背景

之前只闻 transformers 超厉害超好用,然而没有理论用过。之前波及到 bert 类模型都是间接手写或是在他人的根底上批改。但这次因为某些起因,须要疾速训练一个简略的文本分类模型。其实这种场景应该挺多的,例如简略的 POC 或是长期测试某些模型。

我的需要很简略:用咱们 本人的 数据集,疾速 训练一个文本分类模型,验证想法。

我感觉如此简略的一个需要,应该有模板代码。但理论去搜的时候发现,官网文档什么时候变得这么多这么宏大了?还多了个 Trainer API?霎时让我想起了 Pytorch Lightning 那个坑人的同名 API。但可能是工夫起因,找了一圈没找到实用于自定义数据集的代码,都是用的官网、预约义的数据集。

所以弄完后,我决定简略写一个文章,来说下这本来应该极其容易解决的事件。

数据

假如咱们数据的格局如下:

0 第一个句子
1 第二个句子
0 第三个句子

即每一行都是 label sentence 的格局,两头空格分隔。并且咱们已将数据集分成了 train.txtval.txt

代码

加载数据集

首先应用 datasets 加载数据集:

from datasets import load_dataset
dataset = load_dataset('text', data_files={'train': 'data/train_20w.txt', 'test': 'data/val_2w.txt'})

加载后的 dataset 是一个 DatasetDict 对象:

DatasetDict({
    train: Dataset({features: ['text'],
        num_rows: 3
    })
    test: Dataset({features: ['text'],
        num_rows: 3
    })
})

相似 tf.data,尔后咱们须要对其进行 map,对每一个句子进行 tokenize、padding、batch、shuffle:

def tokenize_function(examples):
    labels = []
    texts = []
    for example in examples['text']:
        split = example.split(' ', maxsplit=1)
        labels.append(int(split[0]))
        texts.append(split[1])
    tokenized = tokenizer(texts, padding='max_length', truncation=True, max_length=32)
    tokenized['labels'] = labels
    return tokenized

tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)

依据数据集格局不同,咱们能够在 tokenize_function 中随便自定义处理过程,以失去 text 和 labels。留神 batch_sizemax_length 也是在此处指定。解决完咱们便失去了能够输出给模型的训练集和测试集。

训练

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2, cache_dir='data/pretrained')
training_args = TrainingArguments('ckpts', per_device_train_batch_size=256, num_train_epochs=5)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)
trainer.train()

你能够依据状况批改训练 batchsize per_device_train_batch_size

残缺代码

残缺代码见 GitHub。

END

正文完
 0