diff --git a/examples/lm_finetuning/finetune_on_pregenerated.py b/examples/lm_finetuning/finetune_on_pregenerated.py index 2a5783c261..8eda2aa5c5 100644 --- a/examples/lm_finetuning/finetune_on_pregenerated.py +++ b/examples/lm_finetuning/finetune_on_pregenerated.py @@ -14,8 +14,8 @@ from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME -from pytorch_pretrained_bert.modeling import BertForPreTraining -from pytorch_pretrained_bert.tokenization import BertTokenizer +from pytorch_pretrained_bert.modeling_bert import BertForPreTraining +from pytorch_pretrained_bert.tokenization_bert import BertTokenizer from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next") diff --git a/examples/lm_finetuning/pregenerate_training_data.py b/examples/lm_finetuning/pregenerate_training_data.py index 8bed1e54d4..c2211c88e6 100644 --- a/examples/lm_finetuning/pregenerate_training_data.py +++ b/examples/lm_finetuning/pregenerate_training_data.py @@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory import shelve from random import random, randrange, randint, shuffle, choice -from pytorch_pretrained_bert.tokenization import BertTokenizer +from pytorch_pretrained_bert.tokenization_bert import BertTokenizer import numpy as np import json import collections diff --git a/examples/lm_finetuning/simple_lm_finetuning.py b/examples/lm_finetuning/simple_lm_finetuning.py index 368d6825c7..bcfd138442 100644 --- a/examples/lm_finetuning/simple_lm_finetuning.py +++ b/examples/lm_finetuning/simple_lm_finetuning.py @@ -30,8 +30,8 @@ from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME -from pytorch_pretrained_bert.modeling import BertForPreTraining -from pytorch_pretrained_bert.tokenization import BertTokenizer +from pytorch_pretrained_bert.modeling_bert import BertForPreTraining +from pytorch_pretrained_bert.tokenization_bert import BertTokenizer from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', diff --git a/examples/run_bert_classifier.py b/examples/run_bert_classifier.py index d987b35321..233a7ee5d1 100644 --- a/examples/run_bert_classifier.py +++ b/examples/run_bert_classifier.py @@ -35,8 +35,8 @@ from torch.nn import CrossEntropyLoss, MSELoss from tensorboardX import SummaryWriter from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME -from pytorch_pretrained_bert.modeling import BertForSequenceClassification -from pytorch_pretrained_bert.tokenization import BertTokenizer +from pytorch_pretrained_bert.modeling_bert import BertForSequenceClassification +from pytorch_pretrained_bert.tokenization_bert import BertTokenizer from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics diff --git a/examples/run_bert_extract_features.py b/examples/run_bert_extract_features.py index 13384a9d69..2a550c431a 100644 --- a/examples/run_bert_extract_features.py +++ b/examples/run_bert_extract_features.py @@ -28,8 +28,8 @@ import torch from torch.utils.data import TensorDataset, DataLoader, SequentialSampler from torch.utils.data.distributed import DistributedSampler -from pytorch_pretrained_bert.tokenization import BertTokenizer -from pytorch_pretrained_bert.modeling import BertModel +from pytorch_pretrained_bert.tokenization_bert import BertTokenizer +from pytorch_pretrained_bert.modeling_bert import BertModel logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', diff --git a/examples/run_bert_squad.py b/examples/run_bert_squad.py index 54eceb36f7..f8eee9c8eb 100644 --- a/examples/run_bert_squad.py +++ b/examples/run_bert_squad.py @@ -34,9 +34,9 @@ from tqdm import tqdm, trange from tensorboardX import SummaryWriter from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME -from pytorch_pretrained_bert.modeling import BertForQuestionAnswering +from pytorch_pretrained_bert.modeling_bert import BertForQuestionAnswering from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule -from pytorch_pretrained_bert.tokenization import BertTokenizer +from pytorch_pretrained_bert.tokenization_bert import BertTokenizer from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions diff --git a/examples/run_bert_swag.py b/examples/run_bert_swag.py index 28fd323c73..3e45225891 100644 --- a/examples/run_bert_swag.py +++ b/examples/run_bert_swag.py @@ -33,9 +33,9 @@ from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME -from pytorch_pretrained_bert.modeling import BertForMultipleChoice, BertConfig +from pytorch_pretrained_bert.modeling_bert import BertForMultipleChoice, BertConfig from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule -from pytorch_pretrained_bert.tokenization import BertTokenizer +from pytorch_pretrained_bert.tokenization_bert import BertTokenizer logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', diff --git a/examples/utils_squad.py b/examples/utils_squad.py index e4e43eff9d..0dfecd202c 100644 --- a/examples/utils_squad.py +++ b/examples/utils_squad.py @@ -24,7 +24,7 @@ import math import collections from io import open -from pytorch_pretrained_bert.tokenization import BasicTokenizer, whitespace_tokenize +from pytorch_pretrained_bert.tokenization_bert import BasicTokenizer, whitespace_tokenize logger = logging.getLogger(__name__) diff --git a/hubconfs/bert_hubconf.py b/hubconfs/bert_hubconf.py index 3769c2567f..94c7a18a30 100644 --- a/hubconfs/bert_hubconf.py +++ b/hubconfs/bert_hubconf.py @@ -1,5 +1,5 @@ -from pytorch_pretrained_bert.tokenization import BertTokenizer -from pytorch_pretrained_bert.modeling import ( +from pytorch_pretrained_bert.tokenization_bert import BertTokenizer +from pytorch_pretrained_bert.modeling_bert import ( BertModel, BertForNextSentencePrediction, BertForMaskedLM, diff --git a/notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb b/notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb index 67c56ead38..ea7271df96 100644 --- a/notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb +++ b/notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb @@ -3997,9 +3997,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n", - "11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling - extracting archive file /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpaqgsm566\n", - "11/16/2018 11:03:08 - INFO - pytorch_pretrained_bert.modeling - Model config {\n", + "11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling_bert - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n", + "11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling_bert - extracting archive file /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpaqgsm566\n", + "11/16/2018 11:03:08 - INFO - pytorch_pretrained_bert.modeling_bert - Model config {\n", " \"attention_probs_dropout_prob\": 0.1,\n", " \"hidden_act\": \"gelu\",\n", " \"hidden_dropout_prob\": 0.1,\n", diff --git a/notebooks/Comparing-TF-and-PT-models.ipynb b/notebooks/Comparing-TF-and-PT-models.ipynb index 5e724a710a..3e438e2f55 100644 --- a/notebooks/Comparing-TF-and-PT-models.ipynb +++ b/notebooks/Comparing-TF-and-PT-models.ipynb @@ -375,8 +375,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling - loading archive file ../../google_models/uncased_L-12_H-768_A-12/\n", - "11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling - Model config {\n", + "11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling_bert - loading archive file ../../google_models/uncased_L-12_H-768_A-12/\n", + "11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling_bert - Model config {\n", " \"attention_probs_dropout_prob\": 0.1,\n", " \"hidden_act\": \"gelu\",\n", " \"hidden_dropout_prob\": 0.1,\n", diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index 7d823a045d..e14b8b27a9 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -1,11 +1,12 @@ __version__ = "0.6.2" -from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer +from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE +from .tokenization_xlm import XLMTokenizer -from .modeling import (BertConfig, BertModel, BertForPreTraining, +from .modeling_bert import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering, @@ -22,6 +23,9 @@ from .modeling_xlnet import (XLNetConfig, XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, load_tf_weights_in_xlnet) +from .modeling_xlm import (XLMConfig, XLMModel, + XLMWithLMHeadModel, XLMForSequenceClassification, + XLMForQuestionAnswering) from .optimization import BertAdam from .optimization_openai import OpenAIAdam diff --git a/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py index 13d96384fd..42f7380969 100755 --- a/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py @@ -25,7 +25,7 @@ import tensorflow as tf import torch import numpy as np -from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert +from pytorch_pretrained_bert.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): # Initialise PyTorch model diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling_bert.py similarity index 100% rename from pytorch_pretrained_bert/modeling.py rename to pytorch_pretrained_bert/modeling_bert.py diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index fef4937400..774ba68509 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -32,7 +32,7 @@ from torch.nn.parameter import Parameter from .file_utils import cached_path from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer -from .modeling import BertLayerNorm as LayerNorm +from .modeling_bert import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index f4fe09110a..7948a070bf 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -32,7 +32,7 @@ from torch.nn.parameter import Parameter from .file_utils import cached_path from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer -from .modeling import BertLayerNorm as LayerNorm +from .modeling_bert import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index 871f699b1a..9a882bce96 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -34,7 +34,7 @@ import torch.nn.functional as F from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter -from .modeling import BertLayerNorm as LayerNorm +from .modeling_bert import BertLayerNorm as LayerNorm from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .file_utils import cached_path from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel diff --git a/pytorch_pretrained_bert/modeling_xlm.py b/pytorch_pretrained_bert/modeling_xlm.py index fa196215a5..66a0b0b1ed 100644 --- a/pytorch_pretrained_bert/modeling_xlm.py +++ b/pytorch_pretrained_bert/modeling_xlm.py @@ -71,7 +71,7 @@ class XLMConfig(PretrainedConfig): pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP def __init__(self, - vocab_size_or_config_json_file, + vocab_size_or_config_json_file=30145, n_special=0, emb_dim=2048, n_layers=12, @@ -80,13 +80,20 @@ class XLMConfig(PretrainedConfig): attention_dropout=0.1, gelu_activation=True, sinusoidal_embeddings=False, + causal=False, asm=False, - id2lang={ 0: "en" }, - lang2id={ "en": 0 }, n_langs=1, - n_words=30145, max_position_embeddings=512, - initializer_range=0.02, + embed_init_std=2048 ** -0.5, + init_std=0.02, + summary_type="last", + use_proj=True, + bos_index=0, + eos_index=1, + pad_index=2, + unk_index=3, + mask_index=5, + is_encoder=True, **kwargs): """Constructs XLMConfig. @@ -148,12 +155,20 @@ class XLMConfig(PretrainedConfig): self.attention_dropout = attention_dropout self.gelu_activation = gelu_activation self.sinusoidal_embeddings = sinusoidal_embeddings + self.causal = causal self.asm = asm - self.id2lang = id2lang - self.lang2id = lang2id self.n_langs = n_langs + self.summary_type = summary_type + self.use_proj = use_proj + self.bos_index = bos_index + self.eos_index = eos_index + self.pad_index = pad_index + self.unk_index = unk_index + self.mask_index = mask_index + self.is_encoder = is_encoder self.max_position_embeddings = max_position_embeddings - self.initializer_range = initializer_range + self.embed_init_std = embed_init_std + self.init_std = init_std else: raise ValueError("First argument must be either a vocabulary size (int)" "or the path to a pretrained model config file (str)") @@ -175,37 +190,21 @@ class XLMConfig(PretrainedConfig): return self.n_layers -try: - from apex.normalization.fused_layer_norm import FusedLayerNorm as XLMLayerNorm -except ImportError: - logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") - class XLMLayerNorm(nn.Module): - def __init__(self, d_model, eps=1e-12): - """Construct a layernorm module in the TF style (epsilon inside the square root). - """ - super(XLMLayerNorm, self).__init__() - self.weight = nn.Parameter(torch.ones(d_model)) - self.bias = nn.Parameter(torch.zeros(d_model)) - self.variance_epsilon = eps - - def forward(self, x): - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.weight * x + self.bias - - -def Embedding(num_embeddings, embedding_dim, padding_idx=None): +def Embedding(num_embeddings, embedding_dim, padding_idx=None, config=None): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) - nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + if config is not None and config.embed_init_std is not None: + nn.init.normal_(m.weight, mean=0, std=config.embed_init_std) if padding_idx is not None: nn.init.constant_(m.weight[padding_idx], 0) return m -def Linear(in_features, out_features, bias=True): +def Linear(in_features, out_features, bias=True, config=None): m = nn.Linear(in_features, out_features, bias) - # nn.init.normal_(m.weight, mean=0, std=1) + if config is not None and config.init_std is not None: + nn.init.normal_(m.weight, mean=0, std=config.init_std) + if bias: + nn.init.constant_(m.bias, 0.) # nn.init.xavier_uniform_(m.weight) # nn.init.constant_(m.bias, 0.) return m @@ -233,14 +232,17 @@ def gelu(x): return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0))) -def get_masks(slen, lengths, causal): +def get_masks(slen, lengths, causal, padding_mask=None): """ Generate hidden states mask, and optionally an attention mask. """ - assert lengths.max().item() <= slen bs = lengths.size(0) - alen = torch.arange(slen, dtype=torch.long, device=lengths.device) - mask = alen < lengths[:, None] + if padding_mask is not None: + mask = padding_mask + else: + assert lengths.max().item() <= slen + alen = torch.arange(slen, dtype=torch.long, device=lengths.device) + mask = alen < lengths[:, None] # attention mask is the same as mask, or triangular inferior attention (causal) if causal: @@ -259,21 +261,21 @@ class MultiHeadAttention(nn.Module): NEW_ID = itertools.count() - def __init__(self, n_heads, dim, dropout, output_attentions=False): + def __init__(self, n_heads, dim, config): super().__init__() self.layer_id = next(MultiHeadAttention.NEW_ID) - self.output_attentions = output_attentions + self.output_attentions = config.output_attentions self.dim = dim self.n_heads = n_heads - self.dropout = dropout + self.dropout = config.attention_dropout assert self.dim % self.n_heads == 0 - self.q_lin = Linear(dim, dim) - self.k_lin = Linear(dim, dim) - self.v_lin = Linear(dim, dim) - self.out_lin = Linear(dim, dim) + self.q_lin = Linear(dim, dim, config=config) + self.k_lin = Linear(dim, dim, config=config) + self.v_lin = Linear(dim, dim, config=config) + self.out_lin = Linear(dim, dim, config=config) - def forward(self, input, mask, kv=None, cache=None): + def forward(self, input, mask, kv=None, cache=None, head_mask=None): """ Self-attention (if kv is None) or attention over source sentence (provided by kv). """ @@ -323,6 +325,11 @@ class MultiHeadAttention(nn.Module): weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) + + # Mask heads if we want to + if head_mask is not None: + weights = weights * head_mask + context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) context = unshape(context) # (bs, qlen, dim) @@ -334,12 +341,12 @@ class MultiHeadAttention(nn.Module): class TransformerFFN(nn.Module): - def __init__(self, in_dim, dim_hidden, out_dim, dropout, gelu_activation): + def __init__(self, in_dim, dim_hidden, out_dim, config): super().__init__() - self.dropout = dropout - self.lin1 = Linear(in_dim, dim_hidden) - self.lin2 = Linear(dim_hidden, out_dim) - self.act = gelu if gelu_activation else F.relu + self.dropout = config.dropout + self.lin1 = Linear(in_dim, dim_hidden, config=config) + self.lin2 = Linear(dim_hidden, out_dim, config=config) + self.act = gelu if config.gelu_activation else F.relu def forward(self, input): x = self.lin1(input) @@ -365,12 +372,9 @@ class XLMPreTrainedModel(PreTrainedModel): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, XLMLayerNorm): + # Weights are initialized in module instantiation (see above) + pass + if isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -439,8 +443,10 @@ class XLMModel(XLMPreTrainedModel): self.output_hidden_states = config.output_hidden_states # encoder / decoder, output layer - # self.is_encoder = is_encoder - # self.is_decoder = not is_encoder + self.is_encoder = config.is_encoder + self.is_decoder = not config.is_encoder + if self.is_decoder: + raise NotImplementedError("Currently XLM can only be used as an encoder") # self.with_output = with_output self.causal = config.causal @@ -450,10 +456,10 @@ class XLMModel(XLMPreTrainedModel): self.eos_index = config.eos_index self.pad_index = config.pad_index # self.dico = dico - self.id2lang = config.id2lang - self.lang2id = config.lang2id + # self.id2lang = config.id2lang + # self.lang2id = config.lang2id # assert len(self.dico) == self.n_words - assert len(self.id2lang) == len(self.lang2id) == self.n_langs + # assert len(self.id2lang) == len(self.lang2id) == self.n_langs # model parameters self.dim = config.emb_dim # 512 by default @@ -465,12 +471,12 @@ class XLMModel(XLMPreTrainedModel): assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads' # embeddings - self.position_embeddings = Embedding(config.max_position_embeddings, self.dim) + self.position_embeddings = Embedding(config.max_position_embeddings, self.dim, config=config) if config.sinusoidal_embeddings: create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) if config.n_langs > 1: - self.lang_embeddings = Embedding(self.n_langs, self.dim) - self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index) + self.lang_embeddings = Embedding(self.n_langs, self.dim, config=config) + self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index, config=config) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12) # transformer layers @@ -478,29 +484,31 @@ class XLMModel(XLMPreTrainedModel): self.layer_norm1 = nn.ModuleList() self.ffns = nn.ModuleList() self.layer_norm2 = nn.ModuleList() - if self.is_decoder: - self.layer_norm15 = nn.ModuleList() - self.encoder_attn = nn.ModuleList() + # if self.is_decoder: + # self.layer_norm15 = nn.ModuleList() + # self.encoder_attn = nn.ModuleList() for _ in range(self.n_layers): - self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) + self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config)) self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12)) - if self.is_decoder: - self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12)) - self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) - self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=config.gelu_activation)) + # if self.is_decoder: + # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12)) + # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) + self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config)) self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12)) - def forward(self, input_ids, lengths, positions=None, langs=None, cache=None, head_mask=None): # src_enc=None, src_len=None, + def forward(self, input_ids, lengths=None, positions=None, langs=None, + token_type_ids=None, attention_mask=None, cache=None, head_mask=None): # src_enc=None, src_len=None, """ Inputs: `input_ids` LongTensor(bs, slen), containing word indices `lengths` LongTensor(bs), containing the length of each sentence - `causal` Boolean, if True, the attention is only done over previous hidden states `positions` LongTensor(bs, slen), containing word positions `langs` LongTensor(bs, slen), containing language IDs + `token_type_ids` LongTensor (bs, slen) same as `langs` used for compatibility """ - # lengths = (input_ids != self.pad_index).float().sum(dim=1) + if lengths is None: + lengths = (input_ids != self.pad_index).float().sum(dim=1) # mask = input_ids != self.pad_index # check inputs @@ -514,7 +522,7 @@ class XLMModel(XLMPreTrainedModel): # assert src_enc.size(0) == bs # generate masks - mask, attn_mask = get_masks(slen, lengths, self.causal) + mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) # if self.is_decoder and src_enc is not None: # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] @@ -527,10 +535,28 @@ class XLMModel(XLMPreTrainedModel): # positions = positions.transpose(0, 1) # langs + assert langs is None or token_type_ids is None, "You can only use one among langs and token_type_ids" + if token_type_ids is not None: + langs = token_type_ids if langs is not None: assert langs.size() == (bs, slen) # (slen, bs) # langs = langs.transpose(0, 1) + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.n_layers + # do not recompute cached elements if cache is not None: _slen = slen - cache['slen'] @@ -696,9 +722,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ``` """ def __init__(self, config): - super(XLMLMHeadModel, self).__init__(config) - self.attn_type = config.attn_type - self.same_length = config.same_length + super(XLMWithLMHeadModel, self).__init__(config) self.transformer = XLMModel(config) self.pred_layer = XLMPredLayer(config) @@ -711,8 +735,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): """ self.pred_layer.proj.weight = self.transformer.embeddings.weight - def forward(self, input_ids, lengths, positions=None, langs=None, cache=None, - labels=None, head_mask=None): + def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None, + attention_mask=None, cache=None, labels=None, head_mask=None): """ Args: inp_k: int32 Tensor in shape [bsz, len], the input token IDs. @@ -739,7 +763,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): summary_type: str, "last", "first", "mean", or "attn". The method to pool the input to get a vector representation. """ - transformer_outputs = self.transformer(input_ids, lengths, positions=positions, langs=langs, cache=cache, head_mask=head_mask) + transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids, + langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) output = transformer_outputs[0] logits = self.pred_layer(output, labels) @@ -759,14 +784,14 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): class XLMSequenceSummary(nn.Module): - def __init__(self, config, summary_type="last", use_proj=True): + def __init__(self, config): super(XLMSequenceSummary, self).__init__() - self.summary_type = summary_type - if use_proj: + self.summary_type = config.summary_type + if config.use_proj: self.summary = nn.Linear(config.d_model, config.d_model) else: self.summary = None - if summary_type == 'attn': + if config.summary_type == 'attn': # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 @@ -859,14 +884,13 @@ class XLMForSequenceClassification(XLMPreTrainedModel): super(XLMForSequenceClassification, self).__init__(config) self.transformer = XLMModel(config) - self.sequence_summary = XLMSequenceSummary(config) - self.logits_proj = nn.Linear(config.d_model, num_labels) + self.logits_proj = nn.Linear(config.d_model, config.num_labels) + self.apply(self.init_weights) - def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, - mems=None, perm_mask=None, target_mapping=None, inp_q=None, - labels=None, head_mask=None): + def forward(self, input_ids, lengths=None, positions=None, langs=None, attention_mask=None, + cache=None, labels=None, head_mask=None): """ Args: inp_k: int32 Tensor in shape [bsz, len], the input token IDs. @@ -894,8 +918,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel): Only used during pretraining for two-stream attention. Set to None during finetuning. """ - transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, - mems, perm_mask, target_mapping, inp_q, head_mask) + transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids, + langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) output = transformer_outputs[0] output = self.sequence_summary(output) @@ -974,7 +998,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): start_logits, end_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, CONFIG_NAME): + def __init__(self, config): super(XLMForQuestionAnswering, self).__init__(config) self.transformer = XLMModel(config) @@ -982,12 +1006,11 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): self.apply(self.init_weights) - def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, - mems=None, perm_mask=None, target_mapping=None, inp_q=None, - start_positions=None, end_positions=None, head_mask=None): + def forward(self, input_ids, lengths=None, positions=None, langs=None, attention_mask=None, cache=None, + labels=None, head_mask=None): - transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, - mems, perm_mask, target_mapping, inp_q, head_mask) + transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids, + langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) output = transformer_outputs[0] logits = self.qa_outputs(output) diff --git a/pytorch_pretrained_bert/tests/model_tests_commons.py b/pytorch_pretrained_bert/tests/model_tests_commons.py index da5d0f8b8a..1179b75368 100644 --- a/pytorch_pretrained_bert/tests/model_tests_commons.py +++ b/pytorch_pretrained_bert/tests/model_tests_commons.py @@ -36,7 +36,9 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict) for model_class in model_classes: model = model_class(config=configs_no_init) for name, param in model.named_parameters(): - tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0], msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) + if param.requires_grad: + tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0], + msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict): configs_no_init = _config_zero_init(config) diff --git a/pytorch_pretrained_bert/tests/model_utils_test.py b/pytorch_pretrained_bert/tests/model_utils_test.py index 76585453c8..59f076fa00 100644 --- a/pytorch_pretrained_bert/tests/model_utils_test.py +++ b/pytorch_pretrained_bert/tests/model_utils_test.py @@ -26,7 +26,7 @@ import pytest import torch from pytorch_pretrained_bert import PretrainedConfig, PreTrainedModel -from pytorch_pretrained_bert.modeling import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP +from pytorch_pretrained_bert.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP class ModelUtilsTest(unittest.TestCase): diff --git a/pytorch_pretrained_bert/tests/modeling_test.py b/pytorch_pretrained_bert/tests/modeling_bert_test.py similarity index 99% rename from pytorch_pretrained_bert/tests/modeling_test.py rename to pytorch_pretrained_bert/tests/modeling_bert_test.py index 2219ee7589..be5c3e090d 100644 --- a/pytorch_pretrained_bert/tests/modeling_test.py +++ b/pytorch_pretrained_bert/tests/modeling_bert_test.py @@ -16,20 +16,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import unittest -import json -import random import shutil import pytest -import torch - from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM, BertForNextSentencePrediction, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertForMultipleChoice) -from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP +from pytorch_pretrained_bert.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) diff --git a/pytorch_pretrained_bert/tests/modeling_xlm_test.py b/pytorch_pretrained_bert/tests/modeling_xlm_test.py new file mode 100644 index 0000000000..d2cf8235d4 --- /dev/null +++ b/pytorch_pretrained_bert/tests/modeling_xlm_test.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import shutil +import pytest + +from pytorch_pretrained_bert import (XLMConfig, XLMModel, XLMForQuestionAnswering, XLMForSequenceClassification) +from pytorch_pretrained_bert.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP + +from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) + + +class XLMModelTest(unittest.TestCase): + class XLMModelTester(object): + + def __init__(self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_lengths=True, + use_token_type_ids=True, + use_labels=True, + gelu_activation=True, + sinusoidal_embeddings=False, + causal=False, + asm=False, + n_langs=2, + vocab_size=99, + n_special=0, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + summary_type="last", + use_proj=True, + scope=None, + all_model_classes = (XLMModel,), # , XLMForSequenceClassification, XLMForTokenClassification), + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_lengths = use_input_lengths + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.gelu_activation = gelu_activation + self.sinusoidal_embeddings = sinusoidal_embeddings + self.asm = asm + self.n_langs = n_langs + self.vocab_size = vocab_size + self.n_special = n_special + self.summary_type = summary_type + self.causal = causal + self.use_proj = use_proj + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.n_langs = n_langs + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.summary_type = summary_type + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.all_model_classes = all_model_classes + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_lengths = None + if self.use_input_lengths: + input_lengths = ids_tensor([self.batch_size], vocab_size=self.seq_length-1) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.n_langs) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = XLMConfig( + vocab_size_or_config_json_file=self.vocab_size, + n_special=self.n_special, + emb_dim=self.hidden_size, + n_layers=self.num_hidden_layers, + n_heads=self.num_attention_heads, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + gelu_activation=self.gelu_activation, + sinusoidal_embeddings=self.sinusoidal_embeddings, + asm=self.asm, + causal=self.causal, + n_langs=self.n_langs, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + summary_type=self.summary_type, + use_proj=self.use_proj) + + return config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels + + def check_loss_output(self, result): + self.parent.assertListEqual( + list(result["loss"].size()), + []) + + def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): + model = XLMModel(config=config) + model.eval() + outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids) + sequence_output = outputs[0] + result = { + "sequence_output": sequence_output, + } + self.parent.assertListEqual( + list(result["sequence_output"].size()), + [self.batch_size, self.seq_length, self.hidden_size]) + + + # def create_and_check_xlm_for_masked_lm(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): + # model = XLMForMaskedLM(config=config) + # model.eval() + # loss, prediction_scores = model(input_ids, token_type_ids, input_lengths, token_labels) + # result = { + # "loss": loss, + # "prediction_scores": prediction_scores, + # } + # self.parent.assertListEqual( + # list(result["prediction_scores"].size()), + # [self.batch_size, self.seq_length, self.vocab_size]) + # self.check_loss_output(result) + + + # def create_and_check_xlm_for_question_answering(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): + # model = XLMForQuestionAnswering(config=config) + # model.eval() + # loss, start_logits, end_logits = model(input_ids, token_type_ids, input_lengths, sequence_labels, sequence_labels) + # result = { + # "loss": loss, + # "start_logits": start_logits, + # "end_logits": end_logits, + # } + # self.parent.assertListEqual( + # list(result["start_logits"].size()), + # [self.batch_size, self.seq_length]) + # self.parent.assertListEqual( + # list(result["end_logits"].size()), + # [self.batch_size, self.seq_length]) + # self.check_loss_output(result) + + + # def create_and_check_xlm_for_sequence_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): + # config.num_labels = self.num_labels + # model = XLMForSequenceClassification(config) + # model.eval() + # loss, logits = model(input_ids, token_type_ids, input_lengths, sequence_labels) + # result = { + # "loss": loss, + # "logits": logits, + # } + # self.parent.assertListEqual( + # list(result["logits"].size()), + # [self.batch_size, self.num_labels]) + # self.check_loss_output(result) + + + # def create_and_check_xlm_for_token_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): + # config.num_labels = self.num_labels + # model = XLMForTokenClassification(config=config) + # model.eval() + # loss, logits = model(input_ids, token_type_ids, input_lengths, token_labels) + # result = { + # "loss": loss, + # "logits": logits, + # } + # self.parent.assertListEqual( + # list(result["logits"].size()), + # [self.batch_size, self.seq_length, self.num_labels]) + # self.check_loss_output(result) + + + # def create_and_check_xlm_for_multiple_choice(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): + # config.num_choices = self.num_choices + # model = XLMForMultipleChoice(config=config) + # model.eval() + # multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + # multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + # multiple_choice_input_lengths = input_lengths.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + # loss, logits = model(multiple_choice_inputs_ids, + # multiple_choice_token_type_ids, + # multiple_choice_input_lengths, + # choice_labels) + # result = { + # "loss": loss, + # "logits": logits, + # } + # self.parent.assertListEqual( + # list(result["logits"].size()), + # [self.batch_size, self.num_choices]) + # self.check_loss_output(result) + + + def create_and_check_xlm_commons(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels): + inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_lengths} + create_and_check_commons(self, config, inputs_dict) + + def test_default(self): + self.run_tester(XLMModelTest.XLMModelTester(self)) + + def test_config(self): + config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37) + config_tester.run_common_tests() + + @pytest.mark.slow + def test_model_from_pretrained(self): + cache_dir = "/tmp/pytorch_pretrained_bert_test/" + for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) + shutil.rmtree(cache_dir) + self.assertIsNotNone(model) + + def run_tester(self, tester): + config_and_inputs = tester.prepare_config_and_inputs() + tester.create_and_check_xlm_model(*config_and_inputs) + + # config_and_inputs = tester.prepare_config_and_inputs() + # tester.create_and_check_xlm_for_masked_lm(*config_and_inputs) + + # config_and_inputs = tester.prepare_config_and_inputs() + # tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs) + + # config_and_inputs = tester.prepare_config_and_inputs() + # tester.create_and_check_xlm_for_question_answering(*config_and_inputs) + + # config_and_inputs = tester.prepare_config_and_inputs() + # tester.create_and_check_xlm_for_sequence_classification(*config_and_inputs) + + # config_and_inputs = tester.prepare_config_and_inputs() + # tester.create_and_check_xlm_for_token_classification(*config_and_inputs) + + config_and_inputs = tester.prepare_config_and_inputs() + tester.create_and_check_xlm_commons(*config_and_inputs) + +if __name__ == "__main__": + unittest.main() diff --git a/pytorch_pretrained_bert/tests/tokenization_test.py b/pytorch_pretrained_bert/tests/tokenization_bert_test.py similarity index 98% rename from pytorch_pretrained_bert/tests/tokenization_test.py rename to pytorch_pretrained_bert/tests/tokenization_bert_test.py index 249f71f984..e00771c1b1 100644 --- a/pytorch_pretrained_bert/tests/tokenization_test.py +++ b/pytorch_pretrained_bert/tests/tokenization_bert_test.py @@ -20,7 +20,7 @@ from io import open import shutil import pytest -from pytorch_pretrained_bert.tokenization import (BasicTokenizer, +from pytorch_pretrained_bert.tokenization_bert import (BasicTokenizer, BertTokenizer, WordpieceTokenizer, _is_control, _is_punctuation, diff --git a/pytorch_pretrained_bert/tests/tokenization_xlm_test.py b/pytorch_pretrained_bert/tests/tokenization_xlm_test.py new file mode 100644 index 0000000000..d288f2fe60 --- /dev/null +++ b/pytorch_pretrained_bert/tests/tokenization_xlm_test.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import unittest +import json +import shutil +import pytest + +from pytorch_pretrained_bert.tokenization_xlm import XLMTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP + + +class XLMTokenizationTest(unittest.TestCase): + + def test_full_tokenizer(self): + """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ + vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", + "w", "r", "t", + "lo", "low", "er", + "low", "lowest", "newer", "wider"] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["l o 123", "lo w 1456", "e r 1789", ""] + with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: + fp.write(json.dumps(vocab_tokens)) + vocab_file = fp.name + with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp: + fp.write("\n".join(merges)) + merges_file = fp.name + + tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["", ""]) + os.remove(vocab_file) + os.remove(merges_file) + + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [""] + input_bpe_tokens = [14, 15, 20] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + + vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/") + tokenizer_2 = XLMTokenizer.from_pretrained("/tmp/") + os.remove(vocab_file) + os.remove(merges_file) + os.remove(special_tokens_file) + + self.assertListEqual( + [tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks, + tokenizer.special_tokens, tokenizer.special_tokens_decoder], + [tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks, + tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder]) + + @pytest.mark.slow + def test_tokenizer_from_pretrained(self): + cache_dir = "/tmp/pytorch_pretrained_bert_test/" + for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: + tokenizer = XLMTokenizer.from_pretrained(model_name, cache_dir=cache_dir) + shutil.rmtree(cache_dir) + self.assertIsNotNone(tokenizer) + + +if __name__ == '__main__': + unittest.main() diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization_bert.py similarity index 100% rename from pytorch_pretrained_bert/tokenization.py rename to pytorch_pretrained_bert/tokenization_bert.py diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index 52d735efa8..5b2bd31cd0 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -26,7 +26,7 @@ from io import open from tqdm import tqdm from .file_utils import cached_path -from .tokenization import BasicTokenizer +from .tokenization_bert import BasicTokenizer logger = logging.getLogger(__name__) diff --git a/pytorch_pretrained_bert/tokenization_xlm.py b/pytorch_pretrained_bert/tokenization_xlm.py index a4c1a61545..d6705954c0 100644 --- a/pytorch_pretrained_bert/tokenization_xlm.py +++ b/pytorch_pretrained_bert/tokenization_xlm.py @@ -26,7 +26,7 @@ from io import open from tqdm import tqdm from .file_utils import cached_path -from .tokenization import BasicTokenizer +from .tokenization_bert import BasicTokenizer logger = logging.getLogger(__name__)