机器学习-这位课代表今后监督大家学习的重任就交给你啦

55次阅读

共计 9284 个字符,预计需要花费 24 分钟才能阅读完成。

在学生时代担任过课代表的童鞋,不妨发个评论示意一下。

帮助老师收作业、发卷子、交流传达与学习有关的各项事宜……课代表的存在确实帮助老师减轻了不小的负担。

当你走上工作岗位,尤其是,当你开始从事与机器学习有关的项目时,面对模型那复杂、繁琐、耗时的训练工作,有没有考虑过请一个课代表来分担分担?毕竟都是为了提高学习质量。呐,这个模型的训练过程存在异常,那个模型的结果数据有质量问题,这时候,如果有个「课代表」能及时告知你,是不是感觉轻松了不少。

可以胜任机器学习监督工作的「课代表」,这就来了!

为什么需要这样的「课代表」?

Amazon SageMaker 模型监控器,这是 Amazon SageMaker 的一项新功能,可以自动监控生产中的机器学习(ML)模型,并在出现数据质量问题时发出警报。

该功能的出现完全是为了帮助我们尽一切可能关注数据的质量。毕竟,如果在训练结束并出现异常后,花费数小时排查问题,最后发现竟然是意外的 NULL 值或不知怎么就到了数据库中的一个外来字符编码导致的,这该让人多郁闷。

由于模型实际上是根据大量数据构建的,因此也就不难理解,为什么 ML 从业人员会花费大量时间来维护数据集。尤其是,需要花费大量时间和精力来保证训练集(用于训练模型)和验证集(用于测量模型的准确性)中数据样本具有相同的统计属性。

但光这样还不够,尽管我们可以完全控制实验数据集,但对于模型将要接收的真实数据就不是那回事了。当然,这些数据将是未经清理的,但更令人担忧的问题是“数据漂移”,即所接收数据的统计性质发生渐变。最小值和最大值、平均值、中位数、方差等,所有这些都是决定模型训练期间做出的假设和决策的关键属性。

直觉告诉我们,这些值的任何重大变化都会影响预测的准确性。设想一下:要是由于输入特征出现漂移甚至缺失,导致一个贷款应用程序预测的金额升高,那多可怕!

然而这些条件非常难以检测:我们需要捕获模型接收的数据,运行各种统计分析,以将这些数据与训练集进行比较,定义规则以检测漂移,并在发生漂移时发出警报…… 更麻烦的是,以后每次更新模型后,上述所有工作都要重新来过。专家级 ML 从业人员当然知道如何构建这些复杂的工具,但却要花费大量时间并耗费大量资源。这不就是眉毛胡子一把抓么……

Amazon SageMaker 模型监控器,大家的「学业」今后就靠你啦!

为了帮助所有客户专注于创造价值,AWS 构建了 Amazon SageMaker 模型监控器。

让我们通过一个例子来展示一下该功能的价值吧。典型的监控会话如下:

首先,我们要从 SageMaker 的终端节点(可以使用现有终端节点,也可专门为了监控目的创建新终端节点)开始。我们可以在任何终端节点上使用 SageMaker 模型监控器,无论模型是从内置算法、内置框架,还是从自己的容器训练而来。
 
借助 SageMaker 开发工具包,我们可以捕获发送到终端节点的部分数据(可配置),也可根据需要捕获预测,并将这些数据存储在 Amazon Simple Storage Service(S3)存储桶中。捕获的数据会附加上元数据(内容类型、时间戳等),我们可以像使用任何 S3 对象一样保护和访问它。

然后,从用于训练在终端节点上部署的模型的数据集建立基线。当然,我们也可以选择使用已有的基线。这将启动 Amazon SageMaker 处理作业,其中 SageMaker 模型监控器将执行以下操作:

  • 推断输入数据的架构,即有关每个特征的类型和完整性的信息。我们应该检查这些内容,并在需要时更新。
  • (仅对于预构建的容器)使用 Deequ(基于由 Amazon 开发并在 Amazon 使用的 Apache Spark 的开源代码工具)来计算特征统计信息(博客文章和研究论文)。这些统计信息包括 KLL 草图,这是一种用于在数据流上计算准确分位数的高级技术,这也是我们最近对 Deequ 做出的一项贡献。

使用这些构件的下一步是启动监控计划,以使 SageMaker 模型监控器检查收集的数据和预测质量。无论使用的是内置容器还是自定义容器,都需要应用许多内置规则,并且报告会定期推送到 S3。这些报告包含在上一个时间段内接收到的数据的统计和架构信息以及检测到的任何违规情况。

最后但同样重要的是,SageMaker 模型监控器会向 Amazon CloudWatch 发出与特征相对应的指标,可用于设置控制面板和警报。CloudWatch 的摘要指标也可以在 Amazon SageMaker Studio 中看到,当然所有统计数据、监控结果和收集的数据都可以在笔记本中查看和进一步分析。

更多信息以及有关如何通过 AWS CloudFormation 使用 SageMaker 模型监控器的示例,请参阅开发人员指南。

接下来,让我们使用经过内置 XGBoost 算法训练的用户流失预测模型进行演示。

启用数据捕获

第一步需要创建终端节点配置以启用数据捕获。在这里,我们决定捕获 100% 的传入数据以及模型输出(即预测)。此外还传递了 CSV 和 JSON 数据的内容类型。

data_capture_configuration = {
    "EnableCapture": True,
    "InitialSamplingPercentage": 100,
    "DestinationS3Uri": s3_capture_upload_path,
    "CaptureOptions": [{ "CaptureMode": "Output"},
        {"CaptureMode": "Input"}
    ],
    "CaptureContentTypeHeader": {"CsvContentTypes": ["text/csv"],
       "JsonContentTypes": ["application/json"]
}

接下来使用常规的 CreateEndpoint API 创建终端节点。

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants=[{
        'InstanceType':'ml.m5.xlarge',
        'InitialInstanceCount':1,
        'InitialVariantWeight':1,
        'ModelName':model_name,
        'VariantName':'AllTrafficVariant'
    }],
    DataCaptureConfig = data_capture_configuration)

对于已有的终端节点,可以使用 UpdateEndpoint API 来无缝更新终端节点配置。

反复调用终端节点后,可以在 S3 中看到一些捕获的数据(为清晰起见,对输出进行了编辑)。

$ aws s3 ls --recursive s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/datacapture/DEMO-xgb-churn-pred-model-monitor-2019-11-22-07-59-33/
AllTrafficVariant/2019/11/22/08/24-40-519-9a9273ca-09c2-45d3-96ab-fc7be2402d43.jsonl
AllTrafficVariant/2019/11/22/08/25-42-243-3e1c653b-8809-4a6b-9d51-69ada40bc809.jsonl

这是其中一个文件中的一行:

 "endpointInput":{
        "observedContentType":"text/csv",
        "mode":"INPUT",
        "data":"132,25,113.2,96,269.9,107,229.1,87,7.1,7,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,1",
        "encoding":"CSV"
     },
     "endpointOutput":{
        "observedContentType":"text/csv; charset=utf-8",
        "mode":"OUTPUT",
        "data":"0.01076381653547287",
        "encoding":"CSV"}
     },
    "eventMetadata":{
        "eventId":"6ece5c74-7497-43f1-a263-4833557ffd63",
        "inferenceTime":"2019-11-22T08:24:40Z"},
        "eventVersion":"0"}

看上去没什么问题了。接下来让我们为此模型创建一个基线。

创建监控基线

这是一个非常简单的步骤:传递基线数据集的位置以及存储结果的位置。

from processingjob_wrapper import ProcessingJob
 
processing_job = ProcessingJob(sm_client, role).
   create(job_name, baseline_data_uri, baseline_results_uri)

完成这项工作后,可以在 S3 中看到两个新对象:一个用于统计信息,一个用于约束。

aws s3 ls s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/baselining/results/
constraints.json
statistics.json

constraints.json 文件告诉了我们推断出的训练数据集的架构(不要忘记检查它的准确性)。每个特征都分了类,此外还可以获得有关某个特征是否始终存在的信息(此处 1.0 表示 100%)。以下是该文件的前几行内容:

{
  "version" : 0.0,
  "features" : [ {
    "name" : "Churn",
    "inferred_type" : "Integral",
    "completeness" : 1.0
  }, {
    "name" : "Account Length",
    "inferred_type" : "Integral",
    "completeness" : 1.0
  }, {
    "name" : "VMail Message",
    "inferred_type" : "Integral",
    "completeness" : 1.0
  }, {
    "name" : "Day Mins",
    "inferred_type" : "Fractional",
    "completeness" : 1.0
  }, {
    "name" : "Day Calls",
    "inferred_type" : "Integral",
    "completeness" : 1.0

在该文件的末尾可以看到 CloudWatch 监控的配置信息:将其打开或关闭、设置漂移阈值等。

"monitoring_config" : {
    "evaluate_constraints" : "Enabled",
    "emit_metrics" : "Enabled",
    "distribution_constraints" : {
      "enable_comparisons" : true,
      "min_domain_mass" : 1.0,
      "comparison_threshold" : 1.0
    }
  }

statistics.json 文件显示了每个特征(平均值、中位数、分位数等)的不同统计信息,以及终端节点接收的唯一值。示例如下:

"name" : "Day Mins",
    "inferred_type" : "Fractional",
    "numerical_statistics" : {
      "common" : {
        "num_present" : 2333,
        "num_missing" : 0
      },
      "mean" : 180.22648949849963,
      "sum" : 420468.3999999996,
      "std_dev" : 53.987178959901556,
      "min" : 0.0,
      "max" : 350.8,
      "distribution" : {
        "kll" : {
          "buckets" : [ {
            "lower_bound" : 0.0,
            "upper_bound" : 35.08,
            "count" : 14.0
          }, {
            "lower_bound" : 35.08,
            "upper_bound" : 70.16,
            "count" : 48.0
          }, {
            "lower_bound" : 70.16,
            "upper_bound" : 105.24000000000001,
            "count" : 130.0
          }, {
            "lower_bound" : 105.24000000000001,
            "upper_bound" : 140.32,
            "count" : 318.0
          }, {
            "lower_bound" : 140.32,
            "upper_bound" : 175.4,
            "count" : 565.0
          }, {
            "lower_bound" : 175.4,
            "upper_bound" : 210.48000000000002,
            "count" : 587.0
          }, {
            "lower_bound" : 210.48000000000002,
            "upper_bound" : 245.56,
            "count" : 423.0
          }, {
            "lower_bound" : 245.56,
            "upper_bound" : 280.64,
            "count" : 180.0
          }, {
            "lower_bound" : 280.64,
            "upper_bound" : 315.72,
            "count" : 58.0
          }, {
            "lower_bound" : 315.72,
            "upper_bound" : 350.8,
            "count" : 10.0
          } ],
          "sketch" : {
            "parameters" : {
              "c" : 0.64,
              "k" : 2048.0
            },
            "data" : [[ 178.1, 160.3, 197.1, 105.2, 283.1, 113.6, 232.1, 212.7, 73.3, 176.9, 161.9, 128.6, 190.5, 223.2, 157.9, 173.1, 273.5, 275.8, 119.2, 174.6, 133.3, 145.0, 150.6, 220.2, 109.7, 155.4, 172.0, 235.6, 218.5, 92.7, 90.7, 162.3, 146.5, 210.1, 214.4, 194.4, 237.3, 255.9, 197.9, 200.2, 120, ...

现在,让我们开始监控终端节点。

监控终端节点

同样的,我们只需要调用一个 API:为此需为终端节点创建一个监控计划,并传递基线数据集的约束和统计文件。如果需要调整数据和预测,还可以传递处理前和处理后函数。

ms = MonitoringSchedule(sm_client, role)
schedule = ms.create(
   mon_schedule_name,
   endpoint_name,
   s3_report_path,
   # record_preprocessor_source_uri=s3_code_preprocessor_uri,
   # post_analytics_source_uri=s3_code_postprocessor_uri,
   baseline_statistics_uri=baseline_results_uri + '/statistics.json',
   baseline_constraints_uri=baseline_results_uri+ '/constraints.json'
)

然后开始将假的数据发送到终端节点,即根据随机值构造的样本,然后等待 SageMaker 模型监控器开始生成报告。

检查报告

很快,可以看到报告已经出现在 S3 中。

mon_executions = sm_client.list_monitoring_executions(MonitoringScheduleName=mon_schedule_name, MaxResults=3)
for execution_summary in mon_executions['MonitoringExecutionSummaries']:
    print("ProcessingJob: {}".format(execution_summary['ProcessingJobArn'].split('/')[1]))
    print('MonitoringExecutionStatus: {} \n'.format(execution_summary['MonitoringExecutionStatus']))
 
ProcessingJob: model-monitoring-201911221050-df2c7fc4
MonitoringExecutionStatus: Completed 
 
ProcessingJob: model-monitoring-201911221040-3a738dd7
MonitoringExecutionStatus: Completed 
 
ProcessingJob: model-monitoring-201911221030-83f15fb9
MonitoringExecutionStatus: Completed 

让我们找到其中一个监控作业的报告。

desc_analytics_job_result=sm_client.describe_processing_job(ProcessingJobName=job_name)
report_uri=desc_analytics_job_result['ProcessingOutputConfig']['Outputs'][0]['S3Output']['S3Uri']
print('Report Uri: {}'.format(report_uri))
 
Report Uri: s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/reports/2019112208-2019112209

然后看看这里面到底有什么:

aws s3 ls s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/reports/2019112208-2019112209/
 
constraint_violations.json
constraints.json
statistics.json

顾名思义,constraints.json 和 statistics.json 包含监控作业处理的数据样本的架构和统计信息。让我们直接打开第三个文件 constraints_violations.json!

violations": [{"feature_name":"State_AL","constraint_check_type":"data_type_check","description":"Value: 0.8 does not meet the constraint requirement! "}, {"feature_name":"Eve Mins","constraint_check_type":"baseline_drift_check","description":"Numerical distance: 0.2711598746081505 exceeds numerical threshold: 0"}, {"feature_name":"CustServ Calls","constraint_check_type":"baseline_drift_check","description":"Numerical distance: 0.6470588235294117 exceeds numerical threshold: 0"}

糟糕!我们好像给整数型特征分配了浮点值,结果可想而知……

一些特征还表现出了漂移现象,这也不是一个让人乐见的情况。也许数据提取过程出了点问题,也许数据的分布实际上已经改变而需要重新训练模型。由于所有这些信息都可以作为 CloudWatch 指标提供,因此我们可以定义阈值,设置警报,甚至自动触发新的训练作业。

Amazon SageMaker 模型监控器易于设置,可帮助我们快速了解 ML 模型中存在的质量问题。

目前,该功能已在提供 Amazon SageMaker 的所有商业区域推出。此功能还集成在了 Amazon SageMaker Studio(AWS 的 ML 项目工作台)中。并且别忘了:所有这些信息都可以在笔记本中查看并进一步分析。

敬请大家亲自尝试!

正文完
 0