本文提供了一个应用 Hugging Face 🤗 Transformers 在任意多语种语音辨认 (ASR) 数据集上微调 Whisper 的分步指南。同时,咱们还深刻解释了 Whisper 模型、Common Voice 数据集以及微调等理论知识,并提供了数据筹备和微调的相干代码。如果你想要一个全副是代码,仅有大量解释的 Notebook,能够参阅这个 Google Colab。
目录
- 简介
-
在 Google Colab 中微调 Whisper
- 筹备环境
- 加载数据集
- 筹备特征提取器、分词器和数据
- 训练与评估
- 构建演示利用
- 结束语
简介
Whisper 是一系列用于主动语音辨认 (automatic speech recognition,ASR) 的预训练模型,它由来自于 OpenAI 的 Alec Radford 等人于 2022 年 9 月 公布。与 Wav2Vec 2.0 等前作不同,以往的模型都是在未标注的音频数据上预训练的,而 Whisper 是在大量的 已标注 音频转录数据上预训练的。其用于训练的标注音频时长高达 68 万小时,比 Wav2Vec 2.0 应用的未标注训练数据 (6 万小时) 还多一个数量级。更妙的是,该预训练数据中还含有 11.7 万小时的多语种数据。因而,Whisper 训得的 checkpoint 可利用于超过 96 种语言,这其中蕴含不少 数据匮乏 的小语种。
这么多的标注数据使得咱们能够间接在 有监督 语音辨认工作上预训练 Whisper,从标注音频转录数据 ${}^1$ 中间接习得语音到文本的映射。因而,Whisper 简直不须要额定的微调就曾经是高性能的 ASR 模型了。这让 Wav2Vec 2.0 黯然失色,因为 Wav2Vec 2.0 是在 无监督 掩码预测工作上预训练的,所以其训得的模型仅从未标注的纯音频数据中习得了从语音到隐含状态的两头映射。尽管无监督预训练能产生高质量的语音表征,但它 学不到 语音到文本的映射,要学到语音到文本的映射只能靠微调。因而,Wav2Vec 2.0 须要更多的微调能力取得较有竞争力的性能。
在 68 万小时标注数据的加持下,预训练 Whisper 模型体现出了弱小的泛化到多种数据集和畛域的能力。其预训练 checkpoint 体现出了与最先进的 ASR 零碎旗鼓相当的性能: 在 LibriSpeech ASR 的无噪测试子集上的单词错误率 (word error rate,WER) 仅为约 3%,另外它还在 TED-LIUM 上创下了新的记录 – 4.7% 的 WER (详见 Whisper 论文 的表 8)。Whisper 在预训练期间取得的宽泛的多语种 ASR 常识对一些数据匮乏的小语种特地有用。稍稍微调一下,预训练 checkpoint 就能够进一步适配特定的数据集和语种,从而进一步改良在这些语种上的辨认成果。
Whisper 是一个基于 transformer 的编码器 – 解码器模型 (也称为 序列到序列 模型),它将音频的频谱图特色 序列 映射到文本的词 _序列_。首先,通过特征提取器将原始音频输出变换为对数梅尔声谱图 (log-Mel spectrogram)。而后,transformer 编码器对声谱图进行编码,生成一系列编码器隐含状态。最初,解码器基于先前输入的词以及编码器隐含状态,自回归地预测下一个输入词。图 1 是 Whisper 模型的示意图。
图 1: Whisper 模型,该模型是规范的基于 transformer 的编码器 - 解码器架构。首先将对数梅尔声谱图输出到编码器,而后将编码器生成的最终隐含状态通过穿插留神机制输出给解码器。最初,解码器基于编码器隐含状态和先前的输入词,自回归地预测下一个输入词。图源: OpenAI Whisper 博客。
在序列到序列模型中,编码器负责从语音中提取出重要特色,将输出转换为一组隐含状态表征。解码器表演语言模型的角色,解决隐含状态表征并生成对应的文本。咱们把在模型架构 外部 集成语言模型的做法称为 深度交融_。与之绝对的是 _浅交融_,此时,语言模型在 内部 与编码器组合,如 CTC + $n$-gram (_详见 Internal Language Model Estimation 一文)。通过深度交融,能够用同一份训练数据和损失函数对整个零碎进行端到端训练,从而取得更大的灵活性和更优越的性能 ( 详见 ESB Benchmark)。
Whisper 应用穿插熵指标函数进行预训练和微调,穿插熵指标函数是训练序列标注模型的规范指标函数。经过训练,模型能够正确地对指标词进行分类,从而从预约义的词汇表中选出输入词。
Whisper 有五种不同尺寸的 checkpoint。其中,四个小尺寸 checkpoint 又各有两个版本: 英语版和多语种版,而最大的 checkpoint 只有多语种版。所有九个预训练 checkpoints 都能够在 Hugging Face Hub 上找到。下表总结了这些 checkpoint 的信息及其 Hub 链接:
尺寸 | 层数 | 宽 | 多头注意力的头数 | 参数量 | 英语 checkpoint | 多语种 checkpoint |
---|---|---|---|---|---|---|
tiny | 4 | 384 | 6 | 39 M | ✓ | ✓ |
base | 6 | 512 | 8 | 74 M | ✓ | ✓ |
small | 12 | 768 | 12 | 244 M | ✓ | ✓ |
medium | 24 | 1024 | 16 | 769 M | ✓ | ✓ |
large | 32 | 1280 | 20 | 1550 M | x | ✓ |
上面,咱们将以多语种版的 small
checkpoint (参数量 244M (~= 1GB)) 为例,带大家走一遍微调模型的全过程。咱们将应用 Common Voice 数据集里的小语种数据来训练和评估咱们的零碎。通过这个例子,咱们将证实,仅需 8 小时的训练数据就能够微调出一个在该语种上体现弱小的语音辨认模型。
${}^1$ Whisper 的名称来自于“Web-scale Supervised Pre-training for Speech Recognition (网络规模的有监督语音辨认预训练模型)”的首字母缩写“WSPSR”。
在 Google Colab 中微调 Whisper
筹备环境
在微调 Whisper 模型时,咱们会用到几个风行的 Python 包。咱们应用 datasets
来下载和筹备训练数据,应用 transformers
来加载和训练 Whisper 模型。另外,咱们还须要 soundfile
包来预处理音频文件,evaluate
和 jiwer
来评估模型的性能。最初,咱们用 gradio
来为微调后的模型构建一个亮闪闪的演示利用。
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio
咱们强烈建议你间接将训得的模型 checkpoint 上传到 Hugging Face Hub。Hub 提供了以下性能:
- 集成版本控制: 确保在训练期间不会失落任何模型 checkpoint。
- Tensorboard 日志: 跟踪训练过程中的重要指标。
- 模型卡: 记录模型的用法及其利用场景。
- 社区: 轻松与社区进行分享和合作!
将 Python notebook 连上 Hub 非常简单 – 只需依据提醒输出你的 Hub 身份验证令牌即可。你能够在 此处 找到你本人的 Hub 身份验证令牌:
from huggingface_hub import notebook_login
notebook_login()
打印输出:
Login successful
Your token has been saved to /root/.huggingface/token
加载数据集
Common Voice 由一系列众包数据集组成,其中蕴含了用各种语言录制的维基百科文本。本文应用的是最新版本的 Common Voice 数据集 (版本号为 11)。语种上,咱们抉择用 印地语 来微调咱们的模型。印地语是一种在印度北部、中部、东部和西部应用的印度 – 雅利安语。Common Voice 11.0 中有大概 12 小时的标注印地语数据,其中 4 小时是测试数据。
咱们先看下 Hub 上的 Common Voice 数据集页面: mozilla-foundation/common_voice_11_0。如果你是首次查看此页面,零碎会要求你承受其应用条款,批准后就能够拜访数据集了。
一旦身份验证胜利,你就会看到数据集预览。数据集预览展现了数据集的前 100 个样本。更重要的是,它还加载了可供实时收听的音频。咱们能够在下拉菜单抉择 hi
来抉择 Common Voice 的印地语子集 (hi
是印地语的语言标识符代码):
点击第一个音频的播放按钮,你就能够收听音频并看到相应的文本了。你还能够滚动浏览训练集和测试集中的样本,以更好地理解待处理音频和文本数据。从语调和格调能够看出,这些音频是旁白录音。你可能还会留神到录音者和录音品质的微小差别,这是众包数据的一个独特特色。
应用 🤗 Datasets 来下载和筹备数据非常简单。仅需一行代码即可实现 Common Voice 数据集的下载和筹备工作。因为印地语数据十分匮乏,咱们把 训练集
和 验证集
合并成约 8 小时的训练数据,而测试则基于 4 小时的 测试集
:
from datasets import load_dataset, DatasetDict
common_voice = DatasetDict()
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test", use_auth_token=True)
print(common_voice)
打印输出:
DatasetDict({
train: Dataset({features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 6540
})
test: Dataset({features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 2894
})
})
大多数 ASR 数据集仅蕴含输出音频样本 (audio
) 和相应的转录文本 (sentence
)。Common Voice 还蕴含额定的元信息,例如 accent
和 locale
,在 ASR 场景中,咱们能够疏忽这些信息。为了使代码尽可能通用,咱们只思考基于输出音频和转录文本进行微调,而不应用额定的元信息:
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
除了 Common Voice,Hub 上还有不少其余多语种 ASR 数据集可供使用,你能够点击链接: Hub 上的 ASR 数据集 理解更多。
筹备特征提取器、分词器和数据
ASR 的流水线次要蕴含三个模块:
- 对原始音频输出进行预处理的特征提取器
- 执行序列到序列映射的模型
- 将模型输入转换为文本的分词器
在 🤗 Transformers 中,Whisper 模型有本人的特征提取器和分词器,即 WhisperFeatureExtractor 和 WhisperTokenizer。
上面,咱们逐个具体介绍特征提取器和分词器!
加载 WhisperFeatureExtractor
语音可示意为随工夫变动的一维数组,给定时刻的数组值即示意信号在该时刻的 _幅度_,而咱们能够仅从幅度信息重建音频的频谱并复原其所有声学特色。
因为语音是间断的,因而它蕴含无数个幅度值,而计算机只能示意并存储无限个值。因而,咱们须要通过对语音信号进行离散化,即以固定的工夫距离对间断信号进行 采样_。咱们将每秒采样的次数称为 _采样率_,通常以样本数 / 秒或 _赫兹 (Hz) 为单位。高采样率能够更好地迫近间断语音信号,但同时每秒所需的存储量也更大。
须要特地留神的是,输出音频的采样率须要与模型冀望的采样率相匹配,因为不同采样率的音频信号的散布是不同的。解决音频时,须要应用正确的采样率,否则可能会引起意想不到的后果!例如,以 16kHz 的采样率采集音频但以 8kHz 的采样率收听它,会使音频听起来如同是半速的。同样地,向一个须要某一采样率的 ASR 模型馈送一个谬误采样率的音频也会影响模型的性能。Whisper 特征提取器须要采样率为 16kHz 的音频输出,因而输出的采样率要与之相匹配。咱们不想无心中用慢速语音来训练 ASR!
Whisper 特征提取器执行两个操作。首先,填充或截断一批音频样本,将所有样本的输出长度对立至 30 秒。通过在序列开端增加零 (音频信号中的零对应于无信号或静音),将短于 30 秒的样本填充到 30 秒。而对超过 30 秒的样本,间接截断为 30 秒就好了。因为这一批数据中的所有样本都被填充或截断到对立长度 (即 30 s) 了,因而将音频馈送给 Whisper 模型时就不须要注意力掩码了。这是 Whisper 的独门个性,其余大多数音频模型都须要用户提供一个注意力掩码,具体阐明填充地位,这样模型能力在自注意力机制中疏忽填充局部。经过训练的 Whisper 模型能够间接从语音信号中推断出应该疏忽哪些局部,因而无需注意力掩码。
Whisper 特征提取器执行的第二个操作是将第一步所得的音频变换为对数梅尔声谱图。这些频谱图是信号频率的直观示意,相似于傅里叶变换。图 2 展现了一个声谱图的例子,其中 $y$ 轴示意梅尔频段 (Mel channel),对应于特定的频段,$x$ 轴示意工夫,色彩对应于给定时刻该频段的对数强度。Whisper 模型要求输出为对数梅尔声谱图。
梅尔频段是语音解决的规范办法,钻研人员用它来近似示意人类的听觉范围。对于 Whisper 微调这个工作而言,咱们只须要晓得声谱图是语音信号中频率的直观示意。更多无关梅尔频段的详细信息,请参阅 梅尔倒谱 一文。
图 2: 将音频信号变换为对数梅尔声谱图。左图:一维音频离散信号。右图:对应的对数梅尔声谱图。图源:谷歌 SpecAugment 博文.
侥幸的是,🤗 Transformers Whisper 特征提取器仅用一行代码即可执行填充和声谱图变换两个操作!咱们应用以下代码从预训练的 checkpoint 中加载特征提取器,为音频数据处理做好筹备:
from transformers import WhisperFeatureExtractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
加载 WhisperTokenizer
当初咱们加载 Whisper 分词器。Whisper 模型会输入词元,这些词元示意预测文本在词典中的索引。分词器负责将这一系列词元映射为最终的文本字符串 (例如 [1169, 3797, 3332] ->“the cat sat”)。
过来,当应用编码器模型进行 ASR 时,咱们需应用 连贯时序分类法 (Connectionist Temporal Classification,CTC) 进行解码。在应用 CTC 进行解码时,咱们须要为每个数据集训练一个 CTC 分词器。但应用编码器 – 解码器架构的一个劣势是咱们能够间接应用预训练模型的分词器。
Whisper 分词器在 96 种语种数据上预训练而得,因而,其 字节对 (byte-pair) 覆盖面很广,简直蕴含了所有语种。就印地语而言,咱们能够加载分词器并将其间接用于微调。仅需指定一下指标语种和工作,分词器就会依据这些参数将语种和工作标记增加为输入序列的前缀:
from transformers import WhisperTokenizer
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
咱们能够通过对 Common Voice 数据集的第一个样本进行编解码来验证分词器是否正确编码了印地语字符。在对转录文本进行编码时,分词器在序列的结尾和结尾增加“非凡标记”,其中包含文本的开始 / 结尾、语种标记和工作标记 (由上一步中的参数指定)。在解码时,咱们能够抉择“跳过”这些非凡标记,从而保障输入是纯文本模式的:
input_str = common_voice["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)
print(f"Input: {input_str}")
print(f"Decoded w/ special: {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal: {input_str == decoded_str}")
打印输出:
Input: खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
Decoded w/ special: <|startoftranscript|><|hi|><|transcribe|><|notimestamps|>खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई<|endoftext|>
Decoded w/out special: खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
Are equal: True
组装一个 WhisperProcessor
为了简化应用,咱们能够将特征提取器和分词器 包进 到一个 WhisperProcessor
类,该类继承自 WhisperFeatureExtractor
及 WhisperTokenizer
,可依据须要用于音频解决和模型预测。有了它,咱们在训练期间只须要保留两个对象: processor
和 model
就好了。
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
筹备数据
咱们把 Common Voice 数据集的第一个样本打印进去,看看数据长什么样:
print(common_voice["train"][0])
打印输出:
{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3',
'array': array([0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 9.6724887e-07,
1.5334779e-06, 1.0415988e-06], dtype=float32),
'sampling_rate': 48000},
'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}
能够看到,样本含有一个一维音频数组及其对应的转录文本。上文曾经屡次谈及采样率,以及将音频的采样率与 Whisper 模型所需的采样率 (16kHz) 相匹配的重要性。因为当初输出音频的采样率为 48kHz,所以在将其馈送给 Whisper 特征提取器之前,咱们须要将其 _下采样_至 16kHz。
咱们将应用 dataset
的 cast_column
办法将输出音频转换至所需的采样率。该办法仅批示 datasets
让其在首次加载音频时 _即时地_对数据进行重采样,因而并不会扭转原音频数据:
from datasets import Audio
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
从新打印下 Common Voice 数据集中的第一个音频样本,能够看到其已被重采样:
print(common_voice["train"][0])
打印输出:
{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3',
'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-3.4206650e-07, 3.2979898e-07, 1.0042874e-06], dtype=float32),
'sampling_rate': 16000},
'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}
酷!咱们能够看到音频已被下采样到 16kHz 了。数组外面的值也变了,当初的 1 个幅度值大抵对应于之前的 3 个幅度值。
当初咱们编写一个函数来为模型筹备数据:
- 调用
batch["audio"]
加载和重采样音频数据。如上所述,🤗 Datasets 会即时执行任何必要的重采样操作。 - 应用特征提取器将一维音频数组变换为对数梅尔声谱图特色。
- 应用分词器将录音文本编码为 ID。
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch
咱们能够用 dataset
的 .map
办法在所有训练样本上利用上述函数:
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)
好了!训练数据筹备结束!咱们持续看看如何应用这些数据来微调 Whisper。
留神: 目前 datasets
次要应用 torchaudio
和 librosa
来进行音频加载和重采样。如果你本人定制一个数据加载 / 采样函数的话,你齐全能够间接通过 "path"
列获取音频文件门路而不必管 "audio"
列。
训练与评估
至此,数据已筹备结束,能够开始训练了。训练的大部分沉重的工作都会由 🤗 Trainer 来实现。咱们要做的次要有:
- 定义数据整顿器 (data collator): 数据整顿器获取预处理后的数据并将其转换为 PyTorch 张量。
- 评估指标: 咱们应用 单词错误率 (word error rate,WER) 指标来评估模型,因而须要定义一个
compute_metrics
函数来计算它。 - 加载预训练 checkpoint: 咱们须要加载预训练 checkpoint 并正确配置它以进行训练。
- 定义训练参数: 🤗 Trainer 在制订训练打算时须要用到这些参数。
微调完后,咱们须要应用测试数据对其进行评估,以验证最终模型在印地语上的语音辨认成果。
定义数据整顿器
序列到序列语音模型的数据整顿器与其余工作有所不同,因为 input_features
和 labels
的解决办法是不同的: input_features
必须由特征提取器解决,而 labels
由分词器解决。
input_features
曾经填充至 30s 并转换为固定维度的对数梅尔声谱图,咱们所要做的只剩将其转换为 PyTorch 张量。咱们用特征提取器的 .pad
办法来实现这一性能,且将其入参设为 return_tensors=pt
。请留神,这里不须要额定的填充,因为输出维度曾经固定了,所以咱们只须要简略地将 input_features
转换为 PyTorch 张量就好了。
另一方面,labels
数据之前并未填充。所以,咱们首先要应用分词器的 .pad
办法将序列填充至本 batch 的最大长度。而后将填充标记替换为 -100
,这样它们就能够 不 用参加损失的计算了。而后咱们把 SOT
从序列的结尾去掉,稍后训练的时候咱们再把它加回来。
咱们能够利用之前定义的 WhisperProcessor
来执行特征提取和分词操作:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
咱们初始化一下刚刚定义的数据整顿器:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
评估指标
接下来要定义评估指标。咱们将应用词错误率 (WER) 指标,它是评估 ASR 零碎的“规范”指标。无关其详细信息,请参阅 WER 文档。上面,咱们从 🤗 Evaluate 中加载 WER 指标:
import evaluate
metric = evaluate.load("wer")
而后咱们只须要定义一个函数来承受模型输入并返回 WER 指标。这个名为 compute_metrics
的函数首先将 -100
替换为 label_ids
中的 pad_token_id
(以便在计算损失时将其疏忽)。而后,将预测到的 ID 和 label_ids
解码为字符串文本。最初,计算输入文本和实在文本之间的 WER:
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
加载预训练 checkpoint
当初咱们加载预训练 Whisper small
模型的 checkpoint。同样,能够通过应用 🤗 transformers 很轻松地实现这一步!
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
原始 Whisper 模型在自回归生成开始之前强制增加了若干前缀词元 ID (forced_decoder_ids
)。这些词元 ID 次要用于在零样本 ASR 工作中标识语种和工作。因为咱们当初是对已知语种 (印地语) 和工作 (转录) 进行微调,所以咱们要将 forced_decoder_ids
设置为 None
。另外,模型还克制了一些词元 (suppress_tokens
),这些词元的对数概率被强置为 -inf
,以保障它们永远不会被采样到。咱们会用一个空列表笼罩 suppress_tokens
,即咱们不克制任何词元:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
定义训练参数
最初一步是定义与训练相干的所有参数,上面对其中一部分参数进行了解释:
output_dir
: 保留模型权重的本地目录,它也会是 Hugging Face Hub 上的模型存储库名称。generation_max_length
: 评估阶段,自回归生成的最大词元数。save_steps
: 训练期间,每save_steps
步保留一次两头 checkpoint 并异步上传到 Hub。eval_steps
: 训练期间,每eval_steps
步对两头 checkpoint 进行一次评估。report_to
: 训练日志的保留地位,反对azure_ml
、comet_ml
、mlflow
、neptune
、tensorboard
以及wandb
这些平台。你能够依照本人的偏好进行抉择,也能够间接应用缺省的tensorboard
保留至 Hub。
如需更多其余训练参数的详细信息,请参阅 Seq2SeqTrainingArguments 文档。
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-hi", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=4000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
留神: 如果不想将模型 checkpoint 上传到 Hub,你须要设置 push_to_hub=False
。
咱们能够将训练参数以及模型、数据集、数据整顿器和 compute_metrics
函数一起传给 🤗 Trainer:
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
有了这些,就能够开始训练了!
训练
要启动训练,只需执行:
trainer.train()
训练大概须要 5-10 个小时,具体取决于你的 GPU 或 Google Colab 后端的 GPU。依据 GPU 的状况,你可能会在开始训练时遇到 CUDA 内存耗尽
谬误。此时,你能够将 per_device_train_batch_size
逐次缩小 2 倍,同时减少 gradient_accumulation_steps
进行弥补。
打印输出:
步数 | 训练损失 | 轮数 | 验证损失 | WER |
---|---|---|---|---|
1000 | 0.1011 | 2.44 | 0.3075 | 34.63 |
2000 | 0.0264 | 4.89 | 0.3558 | 33.13 |
3000 | 0.0025 | 7.33 | 0.4214 | 32.59 |
4000 | 0.0006 | 9.78 | 0.4519 | 32.01 |
5000 | 0.0002 | 12.22 | 0.4679 | 32.10 |
最佳 WER 是 32.0% —— 对 8 小时的训练数据来说还不错!那与其余 ASR 零碎相比,这个体现到底处于什么程度?为此,咱们能够查看 hf-speech-bench
,这是一个按语种和数据集对模型别离进行 WER 排名的排行榜。
微调后的模型显著进步了 Whisper small
checkpoint 的零样本性能,也突出展现了 Whisper 弱小的迁徙学习能力。
当将训练后果推送到 Hub 时,只需配置适当的关键字参数 (key-word arguments,kwargs) 就能够主动将 checkpoint 提交到排行榜。如需适配本人的数据集、语种和模型名称,仅需对下述代码作出相应的批改即可:
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_11_0",
"dataset": "Common Voice 11.0", # a 'pretty' name for the training dataset
"dataset_args": "config: hi, split: test",
"language": "hi",
"model_name": "Whisper Small Hi - Sanchit Gandhi", # a 'pretty' name for your model
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
"tags": "hf-asr-leaderboard",
}
当初,只需执行 push_to_hub
命令就能够将训练后果上传到 Hub 了:
trainer.push_to_hub(**kwargs)
任何人能够用你的模型的 Hub 链接拜访它。他们还能够应用标识符 "your-username/the-name-you-picked"
加载它,例如:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
model = WhisperForConditionalGeneration.from_pretrained("sanchit-gandhi/whisper-small-hi")
processor = WhisperProcessor.from_pretrained("sanchit-gandhi/whisper-small-hi")
尽管微调后的模型在 Common Voice Hindi 测试数据上的成果还不错,但其成果远算不上最优。本文的目标仅为演示如何在任意多语种 ASR 数据集上微调预训练的 Whisper checkpoint,对成果并未做太多深究。如需晋升成果,你还能够尝试更多技巧,如优化训练超参 (例如 learning rate 和 _dropout_)、应用更大的预训练 checkpoint (medium
或 large
) 等。
构建演示利用
当初模型曾经微调完结,咱们开始构建一个演示利用来展现其 ASR 性能!咱们将应用 🤗 Transformers pipeline
来实现整个 ASR 流水线: 从对音频输出进行预处理始终到对模型输入进行解码。咱们应用 Gradio 来构建咱们的交互式演示。Gradio 提供了最含糊其辞的构建机器学习演示利用的办法,咱们能够用它在几分钟内构建一个演示利用!
运行以下代码会生成一个 Gradio 演示利用,它用计算机的麦克风录制语音并将其馈送给微调后的 Whisper 模型以转录出相应的文本:
from transformers import pipeline
import gradio as gr
pipe = pipeline(model="sanchit-gandhi/whisper-small-hi") # change to "your-username/the-name-you-picked"
def transcribe(audio):
text = pipe(audio)["text"]
return text
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs="text",
title="Whisper Small Hindi",
description="Realtime demo for Hindi speech recognition using a fine-tuned Whisper small model.",
)
iface.launch()
结束语
通过本文,咱们介绍了如何应用 🤗 Datasets、Transformers 和 Hugging Face Hub 一步步为多语种 ASR 微调一个 Whisper 模型。如果你想本人尝试微调一个,请参阅 Google Colab。如果你有趣味针对英语和多语种 ASR 微调一个其它的 Transformers 模型,请务必参考下 examples/pytorch/speech-recognition。
英文原文: https://hf.co/blog/fine-tune-whisper
原文作者: Sanchit Gandhi
译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的利用及大规模模型的训练推理。
审校 / 排版: zhongdongy (阿东)