BiLSTM+CRF
import torch
import torch.nn as nn
from modelgraph.BILSTM import BiLSTM
from itertools import zip_longest
class BiLSTM_CRF(nn.Module):
def __init__(self, vocab_size, emb_size, hidden_size, out_size): super(BiLSTM_CRF, self).__init__() self.bilstm = BiLSTM(vocab_size, emb_size, hidden_size, out_size) self.transition = nn.Parameter(torch.ones(out_size, out_size) * 1 / out_size)def forward(self, sents_tensor, lengths): emission = self.bilstm(sents_tensor, lengths) batch_size, max_len, out_size = emission.size() crf_scores = emission.unsqueeze(2).expand(-1, -1, out_size, -1) + self.transition.unsqueeze(0) return crf_scoresdef test(self, test_sents_tensor, lengths, tag2id): start_id = tag2id['<start>'] end_id = tag2id['<end>'] pad = tag2id['<pad>'] tagset_size = len(tag2id) crf_scores =self.forward(test_sents_tensor, lengths) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") B , L , T, _ =crf_scores.size() viterbi = torch.zeros(B, L, T).to(device) backpointer = (torch.zeros(B, L, T).long() * end_id).to(device) lengths = torch.LongTensor(lengths).to(device) for step in range(L): batch_size_t =(lengths > step).sum().item() if step == 0: viterbi[:batch_size_t, step, :] = crf_scores[: batch_size_t, step, start_id, :] backpointer[:batch_size_t, step, :] = start_id else: max_scores, prev_tags = torch.max(viterbi[:batch_size_t, step-1, :].unsqueeze(2) + crf_scores[:batch_size_t, step, :, :], dim=1) viterbi[:batch_size_t, step, :] = max_scores backpointer[:batch_size_t, step, :] = prev_tags backpointer = backpointer.view(B, -1) tagids = [] tags_t = None for step in range(L-1, 0, -1): batch_size_t = (lengths > step).sum().item() if step == L-1: index = torch.ones(batch_size_t).long() * (step * tagset_size) index = index.to(device) index += end_id else: prev_batch_size_t = len(tags_t) new_in_batch = torch.LongTensor([end_id] * (batch_size_t - prev_batch_size_t)).to(device) offset = torch.cat([tags_t, new_in_batch], dim=0) index = torch.ones(batch_size_t).long() * (step *tagset_size) index = index.to(device) index += offset.long() try: tags_t = backpointer[:batch_size_t].gather(dim=1, index=index.unsqueeze(1).long()) except RuntimeError: import pdb pdb.set_trace() tags_t = tags_t.squeeze(1) tagids.append(tags_t.tolist()) tagids = list(zip_longest(*reversed(tagids), fillvalue=pad)) tagids = torch.Tensor(tagids).long() return tagids
def cal_lstm_crf_loss(crf_scores, targets, tag2id):
pad_id = tag2id.get('<pad>')start_id = tag2id.get('<start>')end_id = tag2id.get('<end>')device = torch.device("cuda" if torch.cuda.is_available() else "cpu")batch_size, [货币代码](https://www.gendan5.com/currencycode.html)max_len = targets.size()target_size = len(tag2id)mask = (targets != pad_id)lengths = mask.sum(dim=1)targets = indexed(targets, target_size, start_id)targets = targets.masked_select(mask)flatten_scores = crf_scores.masked_select( mask.view(batch_size, max_len, 1, 1).expand_as(crf_scores)).view(-1, target_size*target_size).contiguous()golden_scores = flatten_scores.gather( dim=1, index=targets.unsqueeze(1)).sum()scores_upto_t = torch.zeros(batch_size, target_size).to(device)for t in range(max_len): batch_size_t = (lengths > t).sum().item() if t == 0: scores_upto_t[:batch_size_t] = crf_scores[:batch_size_t, t, start_id, :] else: scores_upto_t[:batch_size_t] = torch.logsumexp( crf_scores[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t].unsqueeze(2), dim=1 )all_path_scores = scores_upto_t[:, end_id].sum()loss = (all_path_scores - golden_scores) / batch_sizereturn loss
def indexed(targets, tagset_size, start_id):
batch_size, max_len = targets.size()for col in range(max_len-1, 0, -1): targets[:, col] += (targets[:, col-1] * tagset_size)targets[:, 0] += (start_id * tagset_size)return targets