关于机器学习:利用-BERT-模型解析电子病历

54次阅读

共计 16925 个字符,预计需要花费 43 分钟才能阅读完成。

我的项目原始地址

我的项目地址
本我的项目改编自此 Github 我的项目,鸣谢作者。

问题形容

咱们心愿能从患者住院期间的临床记录来预测该患者将来 30 天内是否会再次入院,该预测能够辅助医生更好的抉择医治计划并对手术危险进行评估。在临床中医治伎俩常见而预后状况难以管制治理的状况不足为奇。比方关节置换手术作为医治老年骨性关节炎等疾病的最终办法在临床中获得了极大胜利,然而与手术相干的并发症以及由此导致的再入院状况也并不少见。患者的本身因素如心脏病、糖尿病、瘦削等状况也会减少关节置换术后的再入院危险。当承受关节置换手术的人群的年龄越来越大,健康状况越来越差的状况下,会呈现更多的并发症并且减少再次入院危险。
通过电子病历的相干记录,察看到对于某些疾病或者手术来说,30 天内再次入院的患者各方面的危险都明显增加。因而对与前次住院起因雷同,且前次入院与下次入院距离未超过 30 天的再一次住院视为同一次住院的状况进行了筛选标注,训练模型来尝试解决这个问题。

数据选取与数据荡涤

选取于 Medical Information Mart for Intensive Care III 数据集,也称 MIMIC-III,是在 NIH 赞助下,由 MIT、哈佛医学院 BID 医学中心、飞利浦医疗联合开发保护的多参数重症监护数据库。该数据集收费向钻研人员凋谢,然而须要进行申请。咱们在进行试验的时候将数据部署在 Postgre SQL 中。首先从 admission 表中取出所有数据,针对每一条记录计算同个 subject_id 下一次呈现时的工夫距离,若小于 30 天则给该条记录增加标签 Label=1,否则 Label=0。而后再计算该次住院的时长(入院日期 - 入院日期),并抽取其中住院时长 >2 的样本。将上述抽出的所有样本的 HADM_ID 依照 0.8:0.1:0.1 的比例随机调配造成训练集、验证集和测试集。之后再从 noteevents 表中依照之前调配好的 HADM_ID 获取各个数据集的文本内容(即表 noteevents 中的 TEXT 列)。整顿好的训练集、验证集和测试集均含有三列,别离为 TEXT(文本内容),ID(即 HADM_ID),Label(0 或 1)。

预训练模型

原我的项目应用的预训练模型。基于 BERT 训练。在 NLP(自然语言解决)畛域 BERT 模型有着里程碑式的意义。2018 年的 10 月 11 日,Google 公布的论文《Pre-training of Deep Bidirectional Transformers for Language Understanding》,胜利在 11 项 NLP 工作中获得 state of the art 的后果,博得自然语言解决学界的一片赞美之声。BERT 模型在文本分类、文本预测等多个畛域都获得了很好的成果。
更多对于 BERT 模型的内容可参考链接

BERT 算法的原理次要由两局部组成:

  • 第一步,通过对大量未标注的语料进行非监督的预训练,来学习其中的表达法。
  • 其次,应用大量标记的训练数据以监督形式微调(fine tuning)预训练模型以进行各种监督工作。

ClinicalBERT 模型依据含有标记的临床记录对 BERT 模型进行微调,从而失去一个能够用于医疗畛域文本剖析的模型。细节请参考原我的项目链接

环境装置

!pip install -U pytorch-pretrained-bert -i https://pypi.tuna.tsinghua.edu.cn/simple
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'

数据查看

让咱们来看看被预测的数据是什么格局

import pandas as pd
sample = pd.read_csv('/home/input/MIMIC_note3519/BERT/sample.csv')
sample
TEXT ID Label
0 Nursing Progress Note 1900-0700 hours:\n** Ful… 176088 1
1 Nursing Progress Note 1900-0700 hours:\n** Ful… 135568 1
2 NPN:\n\nNeuro: Alert and oriented X2-3, Sleepi… 188180 0
3 RESPIRATORY CARE:\n\n35 yo m adm from osh for … 110655 0
4 NEURO: A+OX3 pleasant, mae, following commands… 139362 0
5 Nursing Note\nSee Flowsheet\n\nNeuro: Propofol… 176981 0

能够看到在 TEXT 字段下寄存了几条非构造的文本数据,让咱们来取出一条看看在说什么。

text = sample['TEXT'][0]
print(text)
Nursing Progress Note 1900-0700 hours:
** Full code

** allergy: nkda

** access: #18 piv to right FA, #18 piv to right FA.

** diagnosis: angioedema

In Brief: Pt is a 51yo F with pmh significant for: COPD, HTN, diabetes insipidus, hypothyroidism, OSA (on bipap at home), restrictive lung disease, pulm artery hypertension attributed to COPD/OSA, ASD with shunt, down syndrome, CHF with LVEF >60%. Also, 45pk-yr smoker (quit in [**2112**]).

Pt brought to [**Hospital1 2**] by EMS after family found with decreased LOC.  Pt presented with facial swelling and mental status changes. In [**Name (NI) **], pt with enlarged lips and with sats 99% on 2-4l.  Her pupils were pinpoint so given narcan.  She c/o LLQ abd pain and also developed a severe HA.  ABG with profound resp acidosis 7.18/108/71.  Given benadryl, nebs, solumedrol. Difficult intubation-req'd being taken to OR to have fiberoptic used.  Also found to have ARF.  On admit to ICU-denied pain in abdomen, denied HA.  Denied any pain. Pt understands basic english but also used [**Name (NI) **] interpretor to determine these findings. Head CT on [**Name6 (MD) **] [**Name8 (MD) 20**] md as pt was able to nod yes and no and follow commands.

NEURO: pt is sedate on fent at 50mcg/hr and versed at 0.5mg/hr-able to arouse on this level of sedation.  PEARL 2mm/brisk. Able to move all ext's, nod yes and no to questions.  Occasional cough.

CARDIAC: sb-nsr with hr high 50's to 70's.  Ace inhibitors  (pt takes at home) on hold right now as unclear as to what meds or other cause of angioedema.  no ectopy.  SBP >100 with MAPs > 60.

RESP: nasally intubated. #6.0 tube which is sutured in place.  Confirmed by xray for proper placement (5cm above carina). ** some resp events overnight: on 3 occasions thus far, pt noted to have vent alarm 'apnea' though on AC mode and then alarms 'pressure limited/not constant'.  At that time-pt appears comfortably sedate (not bucking vent) but dropping TV's into 100's (from 400's), MV to 3.0 and then desats to 60's and 70's with no chest rise and fall noted. Given 100% 02 first two times with immediate elevation of o2 sat to >92%.  The third time RT ambubagged to see if it was difficult-also climbed right back up to sat >93%.   Suctioned for scant sputum only.  ? as to whether tube was kinking off in trachea or occluding somehow.  RT also swapped out the vent for a new one in case [**Last Name **] problem.  Issue did occur again with new vent (so ruled out a [**Last Name **] problem). Several ABGs overnight (see carevue) which last abg stable. Current settings: 50%/ tv 400/ac 22/p5. Lungs with some rhonchi-received MDI's/nebs overnight. IVF infusing (some risk for chf) Sats have been >93% except for above events. cont to assess.

GI/GU: abd soft, distended, obese. two small bm's this shift-brown, soft, loose. Pt without FT and unlikely to have one placed [**3-3**] edema.  IVF started for ARF and [**3-3**] without nutrition. Foley in place draining clear, yellow 25-80cc/hr.

ID: initial wbc of 12. Pt spiked temp overnight to 102.1-given tylenol supp (last temp 101.3) and pan cx'd.  no abx at this time.

[**Month/Day (2) **]: fs wnl

文本内容

能够看到是一段 ICU 的护理日记,是一个 51 岁的女性,有慢性阻塞性肺疾病,高血压,甲减,唐氏综合征,先心房缺,慢性心衰,肺动脉低压,睡眠呼吸暂停综合症等多种疾病。被家人发现昏迷后送医,是重大的过敏反应,急性血管水肿。处于慌张状态有轻微意识。她在医治过的过程中产生过好凝,做过溶拴还产生过急性肾衰竭。

模型推理

批改当前工作门路

import os
os.chdir('/home/work/clinicalBERT')

根底类定义

每个类的阐明见正文

import csv
import pandas as pd


class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""
    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines

    @classmethod
    def _read_csv(cls, input_file):
        """Reads a comma separated value file."""
        file = pd.read_csv(input_file)
        lines = zip(file.ID, file.TEXT, file.Label)
        return lines

定义数据读取与解决类

继承自基类

def create_examples(lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
        guid = "%s-%s" % (set_type, i)
        text_a = line[1]
        label = str(int(line[2]))
        examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
    return examples


class ReadmissionProcessor(DataProcessor):
    def get_test_examples(self, data_dir):
        return create_examples(self._read_csv(os.path.join(data_dir, "sample.csv")), "test")

    def get_labels(self):
        return ["0", "1"]

定义脚手架函数

  • truncate_seq_pair
  • convert_examples_to_features
  • vote_score
  • pr_curve_plot
  • vote_pr_curve
def truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()
# 将文件载入,并且转换为张量
import logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
    """Loads a data file into a list of `InputBatch`s."""

    label_map = {}
    for (i, label) in enumerate(l        label_id = label_map[example.label]
        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
…                              label_id=label_id))
    return featuresabel_list):
        label_map[label] = i

    features = []
    for (ex_index, example) in enumerate(examples):
        tokens_a = tokenizer.tokenize(example.text_a)

        tokens_b = None
        if example.text_b:
            tokens_b = tokenizer.tokenize(example.text_b)

        if tokens_b:
            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
        else:
            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[0:(max_seq_length - 2)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0   0   0   0  0     0 0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambigiously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append("[SEP]")
        segment_ids.append(0)

        if tokens_b:
            for token in tokens_b:
                tokens.append(token)
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        #print (example.label)
        label_id = label_map[example.label]
        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
            logger.info("label: %s (id = %d)" % (example.label, label_id))

        features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_id=label_id))
    return features
# 准确率曲线与绘图
import numpy as np
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt


def vote_score(df, score, ax):
    df['pred_score'] = score
    df_sort = df.sort_values(by=['ID'])
    # score
    temp = (df_sort.groupby(['ID'])['pred_score'].agg(max) + df_sort.groupby(['ID'])['pred_score'].agg(sum) / 2) / (1 + df_sort.groupby(['ID'])['pred_score'].agg(len) / 2)
    x = df_sort.groupby(['ID'])['Label'].agg(np.min).values
    df_out = pd.DataFrame({'logits': temp.values, 'ID': x})

    fpr, tpr, thresholds = roc_curve(x, temp.values)
    auc_score = auc(fpr, tpr)

    ax.plot([0, 1], [0, 1], 'k--')
    ax.plot(fpr, tpr, label='Val (area = {:.3f})'.format(auc_score))
    ax.set_xlabel('False positive rate')
    ax.set_ylabel('True positive rate')
    ax.set_title('ROC curve')
    ax.legend(loc='best')
    return fpr, tpr, df_out
from sklearn.metrics import precision_recall_curve
from funcsigs import signature


def pr_curve_plot(y, y_score, ax):
    precision, recall, _ = precision_recall_curve(y, y_score)
    area = auc(recall, precision)
    step_kwargs = ({'step': 'post'}
                   if 'step' in signature(plt.fill_between).parameters
                   else {})

    ax.step(recall, precision, color='b', alpha=0.2,
             where='post')
    ax.fill_between(recall, precision, alpha=0.2, color='b', **step_kwargs)
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.set_ylim([0.0, 1.05])
    ax.set_xlim([0.0, 1.0])
    ax.set_title('Precision-Recall curve: AUC={0:0.2f}'.format(area))

def vote_pr_curve(df, score, ax):
    df['pred_score'] = score
    df_sort = df.sort_values(by=['ID'])
    # score
    temp = (df_sort.groupby(['ID'])['pred_score'].agg(max) + df_sort.groupby(['ID'])['pred_score'].agg(sum) / 2) / (1 + df_sort.groupby(['ID'])['pred_score'].agg(len) / 2)
    y = df_sort.groupby(['ID'])['Label'].agg(np.min).values

    precision, recall, thres = precision_recall_curve(y, temp)
    pr_thres = pd.DataFrame(data=list(zip(precision, recall, thres)), columns=['prec', 'recall', 'thres'])

    pr_curve_plot(y, temp, ax)

    temp = pr_thres[pr_thres.prec > 0.799999].reset_index()

    rp80 = 0
    if temp.size == 0:
        print('Test Sample too small or RP80=0')
    else:
        rp80 = temp.iloc[0].recall
        print(f'Recall at Precision of 80 is {rp80}')

    return rp80

配置推理参数

  • output_dir: 输入文件的目录
  • task_name: 工作名称
  • bert_model: 模型目录
  • data_dir: 数据目录,默认文件名称为 sample.csv
  • max_seq_length: 最大字符串序列长度
  • eval_batch_size: 推理批的大小,越大占内存越大
config = {
    "local_rank": -1,
    "no_cuda": False,
    "seed": 42,
    "output_dir": './result',
    "task_name": 'readmission',
    "bert_model": '/home/input/MIMIC_note3519/BERT/early_readmission',
    "fp16": False,
    "data_dir": '/home/input/MIMIC_note3519/BERT',
    "max_seq_length": 512,
    "eval_batch_size": 2,
}

执行推理

推理过程会产生大量日志,能够通过抉择以后 cell(抉择后 cell 左侧会变为蓝色),按下键盘上的“O”键来暗藏日志输入

import random
from tqdm import tqdm
from pytorch_pretrained_bert.tokenization import BertTokenizer
from modeling_readmission import BertForSequenceClassification
from torch.utils.data import TensorDataset, SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch


processors = {"readmission": ReadmissionProcessor}

if config['local_rank'] == -1 or config['no_cuda']:
    device = torch.device("cuda" if torch.cuda.is_available() and not config['no_cuda'] else "cpu")
    n_gpu = torch.cuda.device_count()
else:
    device = torch.device("cuda", config['local_rank'])
    n_gpu = 1
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend='nccl')
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(config['local_rank'] != -1))


random.seed(config['seed'])
np.random.seed(config['seed'])
torch.manual_seed(config['seed'])
if n_gpu > 0:
    torch.cuda.manual_seed_all(config['seed'])


if os.path.exists(config['output_dir']):
    pass
else:
    os.makedirs(config['output_dir'], exist_ok=True)

task_name = config['task_name'].lower()

if task_name not in processors:
    raise ValueError(f"Task not found: {task_name}")

processor = processors[task_name]()
label_list = processor.get_labels()

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Prepare model
model = BertForSequenceClassification.from_pretrained(config['bert_model'], 1)
if config['fp16']:
    model.half()
model.to(device)
if config['local_rank'] != -1:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config['local_rank']],
                                                      output_device=config['local_rank'])
elif n_gpu > 1:
    model = torch.nn.DataParallel(model)

eval_examples = processor.get_test_examples(config['data_dir'])
eval_features = convert_examples_to_features(eval_examples, label_list, config['max_seq_length'], tokenizer)
logger.info("***** Running evaluation *****")
logger.info("Num examples = %d", len(eval_examples))
logger.info("Batch size = %d", config['eval_batch_size'])
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if config['local_rank'] == -1:
    eval_sampler = SequentialSampler(eval_data)
else:
    eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=config['eval_batch_size'])
model.eval()
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
true_labels = []
pred_labels = []
logits_history = []
m = torch.nn.Sigmoid()
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader):
    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    segment_ids = segment_ids.to(device)
    label_ids = label_ids.to(device)
    with torch.no_grad():
        tmp_eval_loss, temp_logits = model(input_ids, segment_ids, input_mask, label_ids)
        logits = model(input_ids, segment_ids, input_mask)

    logits = torch.squeeze(m(logits)).detach().cpu().numpy()
    label_ids = label_ids.to('cpu').numpy()

    outputs = np.asarray([1 if i else 0 for i in (logits.flatten() >= 0.5)])
    tmp_eval_accuracy = np.sum(outputs == label_ids)

    true_labels = true_labels + label_ids.flatten().tolist()
    pred_labels = pred_labels + outputs.flatten().tolist()
    logits_history = logits_history + logits.flatten().tolist()

    eval_loss += tmp_eval_loss.mean().item()
    eval_accuracy += tmp_eval_accuracy

    nb_eval_examples += input_ids.size(0)
    nb_eval_steps += 1

### 绘制精度评估曲线

df = pd.DataFrame({'logits': logits_history, 'pred_label': pred_labels, 'label': true_labels})
df_test = pd.read_csv(os.path.join(config['data_dir'], "sample.csv"))

fig = plt.figure(1)
ax1 = fig.add_subplot(1,2,1)
ax2 = fig.add_subplot(1,2,2)
fpr, tpr, df_out = vote_score(df_test, logits_history, ax1)
rp80 = vote_pr_curve(df_test, logits_history, ax2)

output_eval_file = os.path.join(config['output_dir'], "eval_results.txt")
plt.tight_layout()
plt.show()

将推理信息保留至输入目录

eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples
result = {'eval_loss': eval_loss,
          'eval_accuracy': eval_accuracy,
          'RP80': rp80}
with open(output_eval_file, "w") as writer:
    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("%s = %s", key, str(result[key]))
      writer.write("%s = %s\n" % (key, str(result[key])))

小结

通过 ICU 的医疗日记,能够晓得患者的丰盛的体征、病史等信息。通过这个模型能够无效预测该患者是否还会住院.

代码已提交至 Github
更多内容请关注我的集体博客

正文完
 0