2023 年 10 月,咱们发表了一篇对于 TimeGPT 的文章,TimeGPT 是工夫序列预测的第一个根底模型之一,具备零样本推理、异样检测和共形预测能力。
尽管 TimeGPT 是一个专有模型,只能通过 API 拜访。然而它还是引发了对工夫序列根底模型的更多钻研。到了 2024 年 2 月,曾经有了一个用于工夫序列预测的开源根底模型:laglllama。
在原论文《Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting》中,模型作为单变量概率预测的通用根底模型提出。它是由来自不同机构的大型团队开发的,这些机构包含 Morgan Stanley, ServiceNow, Université de Montréal, Mila-Quebec, 和 McGill University.
在本文中,咱们将探讨 Lag-Llama 的架构、性能以及训练形式。还会将 lagllama 利用于一个预测我的项目中,并将其与其余深度学习办法 Temporal Fusion Transformer (TFT) 和 DeepAR 进行性能比拟。
Lag-Llama
lagllama 是为单变量概率预测而构建的。它应用不依赖于频率的通用办法来标记工夫序列数据。这样模型能够很好地推广到不可见的频率。
它利用 Transformer 体系结构和散布头来解析输出令牌,并将它们映射到具备置信区间的将来预测。
1、具备滞后特色的标记
laglllama 的标记策略是应用一组指定的滞后来结构序列的滞后特色。
它将从这个列表中为给定的数据集抉择所有适合的频率:
季度、月、周、天、小时、秒
也就是说,如果以每日频率提供数据集,lag – llama 将尝试应用每日滞后 (t-1),每周滞后(t-7),每月滞后(t-30) 等构建特色。
策略如下图所示。
从上图中,咱们还能够看到模型构建了其余动态协变量,例如秒 / 分、小时 / 天等等,直到季度 / 年。尽管这能够很好地推广到所有类型的工夫序列,但它有一个致命的毛病:因为固定的滞后指数列表,输出令牌可能会变得十分大。
例如,查看每小时数据的每月频率须要 730 个工夫步。这意味着除了所有动态协变量之外,输出令牌的长度至多为 730。
2、Lag-Llama 架构
Lag-Llama 是一个基于 transformer 的纯解码器模型,其灵感来自大型语言模型 LLaMA 的体系结构。
从图中能够看到输出标记是滞后工夫步长和动态协变量的拼接。输出序列通过线性投影层将特色映射到解码器外部注意力模块的暗藏维度。另外就是在最初的输入,序列被发送到一个散布头负责输入一个概率分布。
在推理过程中,输出序列生成下一个工夫点的散布。而后通过自回归,模型一一生成残余的预测序列,直到达到设置的长度。
生成预测的自回归过程无效地容许模型为其预测生成不确定性区间。然而这里的问题就是如果序列很长,自回归的形式会将谬误扩充。
3、Lag-Llama 散布头
Lag-Llama 的散布头负责输入概率分布。这样模型就可能生成预测区间。
在模型的迭代中,最初一层应用 Student ‘s t 散布来结构不确定性区间。从实践上讲不同的散布头能够组合在一起,然而论文并没有做这样的试验,可能是想在当前在做吧。
4、Lag-Llama 的训练
作为一个根底模型,Lag-Llama 显然是在大量的工夫序列数据语料库上训练的,因而该模型能够很好地泛化未见过的工夫序列并进行零样本预测。
论文中说:Lag-Llama 在来自不同畛域的 27 个工夫序列数据集上进行了训练,如能源、交通、经济等。
数据蕴含 7965 个单变量工夫序列,总计约 3.52 亿个令牌。
所有数据集都是开源的,包含 ethth, Exchange 和 Weather 等。
Lag-Llama 测试
因为代码曾经开源,所以咱们能够间接测试,咱们首先应用 Lag-Llama 的零样本预测能力,并将其性能与特定数据模型 (如 TFT 和 DeepAR) 进行比拟。
Lag-Llama 的实现是建设在 GluonTS 之上的,所以咱们还须要装置这个库。试验应用了澳大利亚电力需要数据集,该数据集蕴含五个单变量工夫序列,以半小时的频率跟踪能源需求。
这里有个阐明:Lag-Llama 目前的实现是初期阶段。并且存还在踊跃开发中,前面可能还会有很大的调整,因为目前还没退出微调的性能。
1、环境设置
!git clone https://github.com/time-series-foundation-models/lag-llama/
cd lag-llama
pip install -r requirements.txt --quiet
而后须要咱们从 HuggingFace 下载模型的权重。
!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir /content/lag-llama
2、加载数据集
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import torch
from itertools import islice
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from lag_llama.gluon.estimator import LagLlamaEstimator
能够间接从 GluonTS 加载数据集。
dataset = get_dataset("australian_electricity_demand")
backtest_dataset = dataset.test prediction_length = dataset.metadata.prediction_length
context_length = 3 * prediction_length
3、应用 Lag-Llama 预测
简略地初始化模型并应用 LagLlamaEstimator 对象。
ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda:0'))
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
estimator = LagLlamaEstimator( ckpt_path="lag-llama.ckpt",
prediction_length=prediction_length,
context_length=context_length,
input_size=estimator_args["input_size"],
n_layer=estimator_args["n_layer"],
n_embd_per_head=estimator_args["n_embd_per_head"],
n_head=estimator_args["n_head"],
scaling=estimator_args["scaling"],
time_feat=estimator_args["time_feat"])
lightning_module = estimator.create_lightning_module()
transformation = estimator.create_transformation()
predictor = estimator.create_predictor(transformation, lightning_module)
应用 make_evaluation_predictions 函数生成零样本的预测。
forecast_it, ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=predictor)
这个函数返回生成器。咱们须要把它们转换成列表。
forecasts = list(forecast_it)
tss = list(ts_it)
4、评估
GluonTS 能够应用 Evaluator 对象不便地计算不同的性能指标。
evaluator = Evaluator()
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))
RMSE 为 481.57。
咱们还能够随便地将预测可视化。
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})
for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 4):
ax = plt.subplot(2, 2, idx+1)
plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target")
forecast.plot(color='g')
plt.xticks(rotation=60)
ax.xaxis.set_major_formatter(date_formater)
ax.set_title(forecast.item_id)
plt.gcf().tight_layout()
plt.legend()
plt.show()
上图能够看到模型对数据做出了正当的预测,只管它在第四个序列 (图的右下角) 上的确存在问题。
另外因为 Lag-Llama 实现了概率预测,能够失去预测的不确定性区间。
5、与 TFT 和 DeepAR 相比
咱们在数据集上训练 TFT 和 DeepAR 模型,看看它们是否能体现得更好。
为了节省时间,咱们将训练设置为 5 个 epoch。
from gluonts.torch import TemporalFusionTransformerEstimator, DeepAREstimator
tft_estimator = TemporalFusionTransformerEstimator(
prediction_length=prediction_length,
context_length=context_length,
freq="30min",
trainer_kwargs={"max_epochs": 5})
deepar_estimator = DeepAREstimator(
prediction_length=prediction_length,
context_length=context_length,
freq="30min",
trainer_kwargs={"max_epochs": 5})
训练过程。
tft_predictor = tft_estimator.train(dataset.train)
deepar_predictor = deepar_estimator.train(dataset.train)
训练实现后,生成预测并计算 RMSE。
tft_forecast_it, tft_ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=tft_predictor)
deepar_forecast_it, deepar_ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=deepar_predictor)
tft_forecasts = list(tft_forecast_it)
tft_tss = list(tft_ts_it)
deepar_forecasts = list(deepar_forecast_it)
deepar_tss = list(deepar_ts_it)
# Get evaluation metrics
tft_agg_metrics, tft_ts_metrics = evaluator(iter(tft_tss), iter(tft_forecasts))
deepar_agg_metrics, deepar_ts_metrics = evaluator(iter(deepar_tss), iter(deepar_forecasts))
下表突出显示了性能最好的模型。
能够看到 TFT 是目前体现最好的模型,DeepAR 的体现也优于 laglama。
尽管 laglllama 的体现仿佛不尽如人意,但该模型没有通过微调,而且零样本测自身就比拟艰难。
乏味的是,只训练了 5 个 epoch 这两个模型都获得了比 Lag-Llama 更好的后果。尽管样本预测能够节省时间,但训练五个 epoch 在工夫和计算能力方面的要求应该不是很刻薄。所以目前可能零样本学习方面还须要很大的晋升。
总结
在尝试了 TimeGPT 和 Lag-Llama 之后,Lag-Llama 算是构建开源预测模型的第一步,但与 TimeGPT 相比,它在性能方面存在有余。
TimeGPT 能够解决多变量工夫序列、不规则工夫戳,并实现共形预测,与应用 laglama 等固定散布相比,这是一种更持重的量化不确定性的形式。
laglllama 是一个开源的根底模型,只用于单变量概率预测,并且我感觉它训练的数据有点少了。我置信在不久的未来会看到更多的开源预测模型呈现。他们的体现可能会失去改善,这代表了该畛域的一个重大转变。
最初论文地址:
Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting by K. Rasul, A. Ashok, A. Williams, H. Ghonia, R. Bhagwatkar, A. Khorasani, M. Bayazi, G. Adamopoulos, R. Riachi, N. Hassen, M. Bilos, S. Garg, A. Schneider, N. Chapados, A. Drouin, V. Zantedeschi, Y. Nevmyvaka, I. Rish
https://avoid.overfit.cn/post/8a9120d3cf074c1ba0de0a7a247993c9
作者:Marco Peixeiro