# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Extract pre-computed feature vectors from a PyTorch BERT model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import collections import logging import json import re import torch from torch.utils.data import TensorDataset, DataLoader, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.modeling import BertModel logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', level = logging.INFO) logger = logging.getLogger(__name__) class InputExample(object): def __init__(self, unique_id, text_a, text_b): self.unique_id = unique_id self.text_a = text_a self.text_b = text_b class InputFeatures(object): """A single set of features of data.""" def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): self.unique_id = unique_id self.tokens = tokens self.input_ids = input_ids self.input_mask = input_mask self.input_type_ids = input_type_ids def convert_examples_to_features(examples, seq_length, tokenizer): """Loads a data file into a list of `InputFeature`s.""" features = [] for (ex_index, example) in enumerate(examples): tokens_a = tokenizer.tokenize(example.text_a) tokens_b = None if example.text_b: tokens_b = tokenizer.tokenize(example.text_b) if tokens_b: # Modifies `tokens_a` and `tokens_b` in place so that the total # length is less than the specified length. # Account for [CLS], [SEP], [SEP] with "- 3" _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) else: # Account for [CLS] and [SEP] with "- 2" if len(tokens_a) > seq_length - 2: tokens_a = tokens_a[0:(seq_length - 2)] # The convention in BERT is: # (a) For sequence pairs: # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 # (b) For single sequences: # tokens: [CLS] the dog is hairy . [SEP] # type_ids: 0 0 0 0 0 0 0 # # Where "type_ids" are used to indicate whether this is the first # sequence or the second sequence. The embedding vectors for `type=0` and # `type=1` were learned during pre-training and are added to the wordpiece # embedding vector (and position vector). This is not *strictly* necessary # since the [SEP] token unambigiously separates the sequences, but it makes # it easier for the model to learn the concept of sequences. # # For classification tasks, the first vector (corresponding to [CLS]) is # used as as the "sentence vector". Note that this only makes sense because # the entire model is fine-tuned. tokens = [] input_type_ids = [] tokens.append("[CLS]") input_type_ids.append(0) for token in tokens_a: tokens.append(token) input_type_ids.append(0) tokens.append("[SEP]") input_type_ids.append(0) if tokens_b: for token in tokens_b: tokens.append(token) input_type_ids.append(1) tokens.append("[SEP]") input_type_ids.append(1) input_ids = tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < seq_length: input_ids.append(0) input_mask.append(0) input_type_ids.append(0) assert len(input_ids) == seq_length assert len(input_mask) == seq_length assert len(input_type_ids) == seq_length if ex_index < 5: logger.info("*** Example ***") logger.info("unique_id: %s" % (example.unique_id)) logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) logger.info( "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) features.append( InputFeatures( unique_id=example.unique_id, tokens=tokens, input_ids=input_ids, input_mask=input_mask, input_type_ids=input_type_ids)) return features def _truncate_seq_pair(tokens_a, tokens_b, max_length): """Truncates a sequence pair in place to the maximum length.""" # This is a simple heuristic which will always truncate the longer sequence # one token at a time. This makes more sense than truncating an equal percent # of tokens from each, since if one sequence is very short then each token # that's truncated likely contains more information than a longer sequence. while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_length: break if len(tokens_a) > len(tokens_b): tokens_a.pop() else: tokens_b.pop() def read_examples(input_file): """Read a list of `InputExample`s from an input file.""" examples = [] unique_id = 0 with open(input_file, "r", encoding='utf-8') as reader: while True: line = reader.readline() if not line: break line = line.strip() text_a = None text_b = None m = re.match(r"^(.*) \|\|\| (.*)$", line) if m is None: text_a = line else: text_a = m.group(1) text_b = m.group(2) examples.append( InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) unique_id += 1 return examples def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument("--input_file", default=None, type=str, required=True) parser.add_argument("--output_file", default=None, type=str, required=True) parser.add_argument("--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") ## Other parameters parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--layers", default="-1,-2,-3,-4", type=str) parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. Sequences longer " "than this will be truncated, and sequences shorter than this will be padded.") parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.") parser.add_argument("--local_rank", type=int, default=-1, help = "local_rank for distributed training on gpus") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") args = parser.parse_args() if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1))) layer_indexes = [int(x) for x in args.layers.split(",")] tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) examples = read_examples(args.input_file) features = convert_examples_to_features( examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer) unique_id_to_feature = {} for feature in features: unique_id_to_feature[feature.unique_id] = feature model = BertModel.from_pretrained(args.bert_model) model.to(device) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) elif n_gpu > 1: model = torch.nn.DataParallel(model) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) if args.local_rank == -1: eval_sampler = SequentialSampler(eval_data) else: eval_sampler = DistributedSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) model.eval() with open(args.output_file, "w", encoding='utf-8') as writer: for input_ids, input_mask, example_indices in eval_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.to(device) all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) all_encoder_layers = all_encoder_layers for b, example_index in enumerate(example_indices): feature = features[example_index.item()] unique_id = int(feature.unique_id) # feature = unique_id_to_feature[unique_id] output_json = collections.OrderedDict() output_json["linex_index"] = unique_id all_out_features = [] for (i, token) in enumerate(feature.tokens): all_layers = [] for (j, layer_index) in enumerate(layer_indexes): layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy() layer_output = layer_output[b] layers = collections.OrderedDict() layers["index"] = layer_index layers["values"] = [ round(x.item(), 6) for x in layer_output[i] ] all_layers.append(layers) out_features = collections.OrderedDict() out_features["token"] = token out_features["layers"] = all_layers all_out_features.append(out_features) output_json["features"] = all_out_features writer.write(json.dumps(output_json) + "\n") if __name__ == "__main__": main()