WIP XLM + refactoring
This commit is contained in:
parent
288be7b7ea
commit
c41f2bad69
|
@ -14,8 +14,8 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
from pytorch_pretrained_bert.modeling_bert import BertForPreTraining
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
||||||
|
|
||||||
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
|
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
|
||||||
|
|
|
@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory
|
||||||
import shelve
|
import shelve
|
||||||
|
|
||||||
from random import random, randrange, randint, shuffle, choice
|
from random import random, randrange, randint, shuffle, choice
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
import collections
|
import collections
|
||||||
|
|
|
@ -30,8 +30,8 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
from pytorch_pretrained_bert.modeling_bert import BertForPreTraining
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
||||||
|
|
||||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
|
|
|
@ -35,8 +35,8 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
from pytorch_pretrained_bert.modeling_bert import BertForSequenceClassification
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
||||||
|
|
||||||
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
|
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
|
||||||
|
|
|
@ -28,8 +28,8 @@ import torch
|
||||||
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
|
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
|
||||||
from pytorch_pretrained_bert.modeling import BertModel
|
from pytorch_pretrained_bert.modeling_bert import BertModel
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
|
|
|
@ -34,9 +34,9 @@ from tqdm import tqdm, trange
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
from pytorch_pretrained_bert.modeling_bert import BertForQuestionAnswering
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
|
||||||
|
|
||||||
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
|
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
|
||||||
|
|
||||||
|
|
|
@ -33,9 +33,9 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import BertForMultipleChoice, BertConfig
|
from pytorch_pretrained_bert.modeling_bert import BertForMultipleChoice, BertConfig
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
|
|
|
@ -24,7 +24,7 @@ import math
|
||||||
import collections
|
import collections
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import BasicTokenizer, whitespace_tokenize
|
from pytorch_pretrained_bert.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization_bert import BertTokenizer
|
||||||
from pytorch_pretrained_bert.modeling import (
|
from pytorch_pretrained_bert.modeling_bert import (
|
||||||
BertModel,
|
BertModel,
|
||||||
BertForNextSentencePrediction,
|
BertForNextSentencePrediction,
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
|
|
|
@ -3997,9 +3997,9 @@
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n",
|
"11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling_bert - loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n",
|
||||||
"11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling - extracting archive file /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpaqgsm566\n",
|
"11/16/2018 11:03:05 - INFO - pytorch_pretrained_bert.modeling_bert - extracting archive file /Users/thomaswolf/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpaqgsm566\n",
|
||||||
"11/16/2018 11:03:08 - INFO - pytorch_pretrained_bert.modeling - Model config {\n",
|
"11/16/2018 11:03:08 - INFO - pytorch_pretrained_bert.modeling_bert - Model config {\n",
|
||||||
" \"attention_probs_dropout_prob\": 0.1,\n",
|
" \"attention_probs_dropout_prob\": 0.1,\n",
|
||||||
" \"hidden_act\": \"gelu\",\n",
|
" \"hidden_act\": \"gelu\",\n",
|
||||||
" \"hidden_dropout_prob\": 0.1,\n",
|
" \"hidden_dropout_prob\": 0.1,\n",
|
||||||
|
|
|
@ -375,8 +375,8 @@
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling - loading archive file ../../google_models/uncased_L-12_H-768_A-12/\n",
|
"11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling_bert - loading archive file ../../google_models/uncased_L-12_H-768_A-12/\n",
|
||||||
"11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling - Model config {\n",
|
"11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling_bert - Model config {\n",
|
||||||
" \"attention_probs_dropout_prob\": 0.1,\n",
|
" \"attention_probs_dropout_prob\": 0.1,\n",
|
||||||
" \"hidden_act\": \"gelu\",\n",
|
" \"hidden_act\": \"gelu\",\n",
|
||||||
" \"hidden_dropout_prob\": 0.1,\n",
|
" \"hidden_dropout_prob\": 0.1,\n",
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
__version__ = "0.6.2"
|
__version__ = "0.6.2"
|
||||||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer
|
from .tokenization_openai import OpenAIGPTTokenizer
|
||||||
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
||||||
from .tokenization_gpt2 import GPT2Tokenizer
|
from .tokenization_gpt2 import GPT2Tokenizer
|
||||||
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
|
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
|
||||||
|
from .tokenization_xlm import XLMTokenizer
|
||||||
|
|
||||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
|
||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
BertForSequenceClassification, BertForMultipleChoice,
|
BertForSequenceClassification, BertForMultipleChoice,
|
||||||
BertForTokenClassification, BertForQuestionAnswering,
|
BertForTokenClassification, BertForQuestionAnswering,
|
||||||
|
@ -22,6 +23,9 @@ from .modeling_xlnet import (XLNetConfig,
|
||||||
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
||||||
XLNetForSequenceClassification, XLNetForQuestionAnswering,
|
XLNetForSequenceClassification, XLNetForQuestionAnswering,
|
||||||
load_tf_weights_in_xlnet)
|
load_tf_weights_in_xlnet)
|
||||||
|
from .modeling_xlm import (XLMConfig, XLMModel,
|
||||||
|
XLMWithLMHeadModel, XLMForSequenceClassification,
|
||||||
|
XLMForQuestionAnswering)
|
||||||
|
|
||||||
from .optimization import BertAdam
|
from .optimization import BertAdam
|
||||||
from .optimization_openai import OpenAIAdam
|
from .optimization_openai import OpenAIAdam
|
||||||
|
|
|
@ -25,7 +25,7 @@ import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
from pytorch_pretrained_bert.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
|
|
|
@ -32,7 +32,7 @@ from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
|
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
|
||||||
from .modeling import BertLayerNorm as LayerNorm
|
from .modeling_bert import BertLayerNorm as LayerNorm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
|
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
|
||||||
from .modeling import BertLayerNorm as LayerNorm
|
from .modeling_bert import BertLayerNorm as LayerNorm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .modeling import BertLayerNorm as LayerNorm
|
from .modeling_bert import BertLayerNorm as LayerNorm
|
||||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
|
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
|
||||||
|
|
|
@ -71,7 +71,7 @@ class XLMConfig(PretrainedConfig):
|
||||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vocab_size_or_config_json_file,
|
vocab_size_or_config_json_file=30145,
|
||||||
n_special=0,
|
n_special=0,
|
||||||
emb_dim=2048,
|
emb_dim=2048,
|
||||||
n_layers=12,
|
n_layers=12,
|
||||||
|
@ -80,13 +80,20 @@ class XLMConfig(PretrainedConfig):
|
||||||
attention_dropout=0.1,
|
attention_dropout=0.1,
|
||||||
gelu_activation=True,
|
gelu_activation=True,
|
||||||
sinusoidal_embeddings=False,
|
sinusoidal_embeddings=False,
|
||||||
|
causal=False,
|
||||||
asm=False,
|
asm=False,
|
||||||
id2lang={ 0: "en" },
|
|
||||||
lang2id={ "en": 0 },
|
|
||||||
n_langs=1,
|
n_langs=1,
|
||||||
n_words=30145,
|
|
||||||
max_position_embeddings=512,
|
max_position_embeddings=512,
|
||||||
initializer_range=0.02,
|
embed_init_std=2048 ** -0.5,
|
||||||
|
init_std=0.02,
|
||||||
|
summary_type="last",
|
||||||
|
use_proj=True,
|
||||||
|
bos_index=0,
|
||||||
|
eos_index=1,
|
||||||
|
pad_index=2,
|
||||||
|
unk_index=3,
|
||||||
|
mask_index=5,
|
||||||
|
is_encoder=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Constructs XLMConfig.
|
"""Constructs XLMConfig.
|
||||||
|
|
||||||
|
@ -148,12 +155,20 @@ class XLMConfig(PretrainedConfig):
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
self.gelu_activation = gelu_activation
|
self.gelu_activation = gelu_activation
|
||||||
self.sinusoidal_embeddings = sinusoidal_embeddings
|
self.sinusoidal_embeddings = sinusoidal_embeddings
|
||||||
|
self.causal = causal
|
||||||
self.asm = asm
|
self.asm = asm
|
||||||
self.id2lang = id2lang
|
|
||||||
self.lang2id = lang2id
|
|
||||||
self.n_langs = n_langs
|
self.n_langs = n_langs
|
||||||
|
self.summary_type = summary_type
|
||||||
|
self.use_proj = use_proj
|
||||||
|
self.bos_index = bos_index
|
||||||
|
self.eos_index = eos_index
|
||||||
|
self.pad_index = pad_index
|
||||||
|
self.unk_index = unk_index
|
||||||
|
self.mask_index = mask_index
|
||||||
|
self.is_encoder = is_encoder
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.initializer_range = initializer_range
|
self.embed_init_std = embed_init_std
|
||||||
|
self.init_std = init_std
|
||||||
else:
|
else:
|
||||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||||
"or the path to a pretrained model config file (str)")
|
"or the path to a pretrained model config file (str)")
|
||||||
|
@ -175,37 +190,21 @@ class XLMConfig(PretrainedConfig):
|
||||||
return self.n_layers
|
return self.n_layers
|
||||||
|
|
||||||
|
|
||||||
try:
|
def Embedding(num_embeddings, embedding_dim, padding_idx=None, config=None):
|
||||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLMLayerNorm
|
|
||||||
except ImportError:
|
|
||||||
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
|
||||||
class XLMLayerNorm(nn.Module):
|
|
||||||
def __init__(self, d_model, eps=1e-12):
|
|
||||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
|
||||||
"""
|
|
||||||
super(XLMLayerNorm, self).__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(d_model))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(d_model))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
u = x.mean(-1, keepdim=True)
|
|
||||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
|
||||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
|
||||||
return self.weight * x + self.bias
|
|
||||||
|
|
||||||
|
|
||||||
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
|
|
||||||
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||||
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
if config is not None and config.embed_init_std is not None:
|
||||||
|
nn.init.normal_(m.weight, mean=0, std=config.embed_init_std)
|
||||||
if padding_idx is not None:
|
if padding_idx is not None:
|
||||||
nn.init.constant_(m.weight[padding_idx], 0)
|
nn.init.constant_(m.weight[padding_idx], 0)
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
def Linear(in_features, out_features, bias=True):
|
def Linear(in_features, out_features, bias=True, config=None):
|
||||||
m = nn.Linear(in_features, out_features, bias)
|
m = nn.Linear(in_features, out_features, bias)
|
||||||
# nn.init.normal_(m.weight, mean=0, std=1)
|
if config is not None and config.init_std is not None:
|
||||||
|
nn.init.normal_(m.weight, mean=0, std=config.init_std)
|
||||||
|
if bias:
|
||||||
|
nn.init.constant_(m.bias, 0.)
|
||||||
# nn.init.xavier_uniform_(m.weight)
|
# nn.init.xavier_uniform_(m.weight)
|
||||||
# nn.init.constant_(m.bias, 0.)
|
# nn.init.constant_(m.bias, 0.)
|
||||||
return m
|
return m
|
||||||
|
@ -233,14 +232,17 @@ def gelu(x):
|
||||||
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
def get_masks(slen, lengths, causal):
|
def get_masks(slen, lengths, causal, padding_mask=None):
|
||||||
"""
|
"""
|
||||||
Generate hidden states mask, and optionally an attention mask.
|
Generate hidden states mask, and optionally an attention mask.
|
||||||
"""
|
"""
|
||||||
assert lengths.max().item() <= slen
|
|
||||||
bs = lengths.size(0)
|
bs = lengths.size(0)
|
||||||
alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
|
if padding_mask is not None:
|
||||||
mask = alen < lengths[:, None]
|
mask = padding_mask
|
||||||
|
else:
|
||||||
|
assert lengths.max().item() <= slen
|
||||||
|
alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
|
||||||
|
mask = alen < lengths[:, None]
|
||||||
|
|
||||||
# attention mask is the same as mask, or triangular inferior attention (causal)
|
# attention mask is the same as mask, or triangular inferior attention (causal)
|
||||||
if causal:
|
if causal:
|
||||||
|
@ -259,21 +261,21 @@ class MultiHeadAttention(nn.Module):
|
||||||
|
|
||||||
NEW_ID = itertools.count()
|
NEW_ID = itertools.count()
|
||||||
|
|
||||||
def __init__(self, n_heads, dim, dropout, output_attentions=False):
|
def __init__(self, n_heads, dim, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_id = next(MultiHeadAttention.NEW_ID)
|
self.layer_id = next(MultiHeadAttention.NEW_ID)
|
||||||
self.output_attentions = output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.dropout = dropout
|
self.dropout = config.attention_dropout
|
||||||
assert self.dim % self.n_heads == 0
|
assert self.dim % self.n_heads == 0
|
||||||
|
|
||||||
self.q_lin = Linear(dim, dim)
|
self.q_lin = Linear(dim, dim, config=config)
|
||||||
self.k_lin = Linear(dim, dim)
|
self.k_lin = Linear(dim, dim, config=config)
|
||||||
self.v_lin = Linear(dim, dim)
|
self.v_lin = Linear(dim, dim, config=config)
|
||||||
self.out_lin = Linear(dim, dim)
|
self.out_lin = Linear(dim, dim, config=config)
|
||||||
|
|
||||||
def forward(self, input, mask, kv=None, cache=None):
|
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||||
"""
|
"""
|
||||||
|
@ -323,6 +325,11 @@ class MultiHeadAttention(nn.Module):
|
||||||
|
|
||||||
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
|
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
|
||||||
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
|
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if head_mask is not None:
|
||||||
|
weights = weights * head_mask
|
||||||
|
|
||||||
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
||||||
context = unshape(context) # (bs, qlen, dim)
|
context = unshape(context) # (bs, qlen, dim)
|
||||||
|
|
||||||
|
@ -334,12 +341,12 @@ class MultiHeadAttention(nn.Module):
|
||||||
|
|
||||||
class TransformerFFN(nn.Module):
|
class TransformerFFN(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_dim, dim_hidden, out_dim, dropout, gelu_activation):
|
def __init__(self, in_dim, dim_hidden, out_dim, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dropout = dropout
|
self.dropout = config.dropout
|
||||||
self.lin1 = Linear(in_dim, dim_hidden)
|
self.lin1 = Linear(in_dim, dim_hidden, config=config)
|
||||||
self.lin2 = Linear(dim_hidden, out_dim)
|
self.lin2 = Linear(dim_hidden, out_dim, config=config)
|
||||||
self.act = gelu if gelu_activation else F.relu
|
self.act = gelu if config.gelu_activation else F.relu
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
x = self.lin1(input)
|
x = self.lin1(input)
|
||||||
|
@ -365,12 +372,9 @@ class XLMPreTrainedModel(PreTrainedModel):
|
||||||
""" Initialize the weights.
|
""" Initialize the weights.
|
||||||
"""
|
"""
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
# Weights are initialized in module instantiation (see above)
|
||||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
pass
|
||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
if isinstance(module, nn.LayerNorm):
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, XLMLayerNorm):
|
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
module.weight.data.fill_(1.0)
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
@ -439,8 +443,10 @@ class XLMModel(XLMPreTrainedModel):
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|
||||||
# encoder / decoder, output layer
|
# encoder / decoder, output layer
|
||||||
# self.is_encoder = is_encoder
|
self.is_encoder = config.is_encoder
|
||||||
# self.is_decoder = not is_encoder
|
self.is_decoder = not config.is_encoder
|
||||||
|
if self.is_decoder:
|
||||||
|
raise NotImplementedError("Currently XLM can only be used as an encoder")
|
||||||
# self.with_output = with_output
|
# self.with_output = with_output
|
||||||
self.causal = config.causal
|
self.causal = config.causal
|
||||||
|
|
||||||
|
@ -450,10 +456,10 @@ class XLMModel(XLMPreTrainedModel):
|
||||||
self.eos_index = config.eos_index
|
self.eos_index = config.eos_index
|
||||||
self.pad_index = config.pad_index
|
self.pad_index = config.pad_index
|
||||||
# self.dico = dico
|
# self.dico = dico
|
||||||
self.id2lang = config.id2lang
|
# self.id2lang = config.id2lang
|
||||||
self.lang2id = config.lang2id
|
# self.lang2id = config.lang2id
|
||||||
# assert len(self.dico) == self.n_words
|
# assert len(self.dico) == self.n_words
|
||||||
assert len(self.id2lang) == len(self.lang2id) == self.n_langs
|
# assert len(self.id2lang) == len(self.lang2id) == self.n_langs
|
||||||
|
|
||||||
# model parameters
|
# model parameters
|
||||||
self.dim = config.emb_dim # 512 by default
|
self.dim = config.emb_dim # 512 by default
|
||||||
|
@ -465,12 +471,12 @@ class XLMModel(XLMPreTrainedModel):
|
||||||
assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'
|
assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
self.position_embeddings = Embedding(config.max_position_embeddings, self.dim)
|
self.position_embeddings = Embedding(config.max_position_embeddings, self.dim, config=config)
|
||||||
if config.sinusoidal_embeddings:
|
if config.sinusoidal_embeddings:
|
||||||
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
|
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
|
||||||
if config.n_langs > 1:
|
if config.n_langs > 1:
|
||||||
self.lang_embeddings = Embedding(self.n_langs, self.dim)
|
self.lang_embeddings = Embedding(self.n_langs, self.dim, config=config)
|
||||||
self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
|
self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index, config=config)
|
||||||
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12)
|
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12)
|
||||||
|
|
||||||
# transformer layers
|
# transformer layers
|
||||||
|
@ -478,29 +484,31 @@ class XLMModel(XLMPreTrainedModel):
|
||||||
self.layer_norm1 = nn.ModuleList()
|
self.layer_norm1 = nn.ModuleList()
|
||||||
self.ffns = nn.ModuleList()
|
self.ffns = nn.ModuleList()
|
||||||
self.layer_norm2 = nn.ModuleList()
|
self.layer_norm2 = nn.ModuleList()
|
||||||
if self.is_decoder:
|
# if self.is_decoder:
|
||||||
self.layer_norm15 = nn.ModuleList()
|
# self.layer_norm15 = nn.ModuleList()
|
||||||
self.encoder_attn = nn.ModuleList()
|
# self.encoder_attn = nn.ModuleList()
|
||||||
|
|
||||||
for _ in range(self.n_layers):
|
for _ in range(self.n_layers):
|
||||||
self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
|
self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config))
|
||||||
self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12))
|
self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12))
|
||||||
if self.is_decoder:
|
# if self.is_decoder:
|
||||||
self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12))
|
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12))
|
||||||
self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
|
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
|
||||||
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=config.gelu_activation))
|
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
|
||||||
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
|
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
|
||||||
|
|
||||||
def forward(self, input_ids, lengths, positions=None, langs=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
|
def forward(self, input_ids, lengths=None, positions=None, langs=None,
|
||||||
|
token_type_ids=None, attention_mask=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
|
||||||
"""
|
"""
|
||||||
Inputs:
|
Inputs:
|
||||||
`input_ids` LongTensor(bs, slen), containing word indices
|
`input_ids` LongTensor(bs, slen), containing word indices
|
||||||
`lengths` LongTensor(bs), containing the length of each sentence
|
`lengths` LongTensor(bs), containing the length of each sentence
|
||||||
`causal` Boolean, if True, the attention is only done over previous hidden states
|
|
||||||
`positions` LongTensor(bs, slen), containing word positions
|
`positions` LongTensor(bs, slen), containing word positions
|
||||||
`langs` LongTensor(bs, slen), containing language IDs
|
`langs` LongTensor(bs, slen), containing language IDs
|
||||||
|
`token_type_ids` LongTensor (bs, slen) same as `langs` used for compatibility
|
||||||
"""
|
"""
|
||||||
# lengths = (input_ids != self.pad_index).float().sum(dim=1)
|
if lengths is None:
|
||||||
|
lengths = (input_ids != self.pad_index).float().sum(dim=1)
|
||||||
# mask = input_ids != self.pad_index
|
# mask = input_ids != self.pad_index
|
||||||
|
|
||||||
# check inputs
|
# check inputs
|
||||||
|
@ -514,7 +522,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||||
# assert src_enc.size(0) == bs
|
# assert src_enc.size(0) == bs
|
||||||
|
|
||||||
# generate masks
|
# generate masks
|
||||||
mask, attn_mask = get_masks(slen, lengths, self.causal)
|
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
|
||||||
# if self.is_decoder and src_enc is not None:
|
# if self.is_decoder and src_enc is not None:
|
||||||
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
||||||
|
|
||||||
|
@ -527,10 +535,28 @@ class XLMModel(XLMPreTrainedModel):
|
||||||
# positions = positions.transpose(0, 1)
|
# positions = positions.transpose(0, 1)
|
||||||
|
|
||||||
# langs
|
# langs
|
||||||
|
assert langs is None or token_type_ids is None, "You can only use one among langs and token_type_ids"
|
||||||
|
if token_type_ids is not None:
|
||||||
|
langs = token_type_ids
|
||||||
if langs is not None:
|
if langs is not None:
|
||||||
assert langs.size() == (bs, slen) # (slen, bs)
|
assert langs.size() == (bs, slen) # (slen, bs)
|
||||||
# langs = langs.transpose(0, 1)
|
# langs = langs.transpose(0, 1)
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
|
||||||
|
if head_mask is not None:
|
||||||
|
if head_mask.dim() == 1:
|
||||||
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1)
|
||||||
|
elif head_mask.dim() == 2:
|
||||||
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||||
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||||
|
else:
|
||||||
|
head_mask = [None] * self.n_layers
|
||||||
|
|
||||||
# do not recompute cached elements
|
# do not recompute cached elements
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
_slen = slen - cache['slen']
|
_slen = slen - cache['slen']
|
||||||
|
@ -696,9 +722,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(XLMLMHeadModel, self).__init__(config)
|
super(XLMWithLMHeadModel, self).__init__(config)
|
||||||
self.attn_type = config.attn_type
|
|
||||||
self.same_length = config.same_length
|
|
||||||
|
|
||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
self.pred_layer = XLMPredLayer(config)
|
self.pred_layer = XLMPredLayer(config)
|
||||||
|
@ -711,8 +735,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||||
"""
|
"""
|
||||||
self.pred_layer.proj.weight = self.transformer.embeddings.weight
|
self.pred_layer.proj.weight = self.transformer.embeddings.weight
|
||||||
|
|
||||||
def forward(self, input_ids, lengths, positions=None, langs=None, cache=None,
|
def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
|
||||||
labels=None, head_mask=None):
|
attention_mask=None, cache=None, labels=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||||
|
@ -739,7 +763,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||||
summary_type: str, "last", "first", "mean", or "attn". The method
|
summary_type: str, "last", "first", "mean", or "attn". The method
|
||||||
to pool the input to get a vector representation.
|
to pool the input to get a vector representation.
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(input_ids, lengths, positions=positions, langs=langs, cache=cache, head_mask=head_mask)
|
transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
|
||||||
|
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||||
|
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
logits = self.pred_layer(output, labels)
|
logits = self.pred_layer(output, labels)
|
||||||
|
@ -759,14 +784,14 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class XLMSequenceSummary(nn.Module):
|
class XLMSequenceSummary(nn.Module):
|
||||||
def __init__(self, config, summary_type="last", use_proj=True):
|
def __init__(self, config):
|
||||||
super(XLMSequenceSummary, self).__init__()
|
super(XLMSequenceSummary, self).__init__()
|
||||||
self.summary_type = summary_type
|
self.summary_type = config.summary_type
|
||||||
if use_proj:
|
if config.use_proj:
|
||||||
self.summary = nn.Linear(config.d_model, config.d_model)
|
self.summary = nn.Linear(config.d_model, config.d_model)
|
||||||
else:
|
else:
|
||||||
self.summary = None
|
self.summary = None
|
||||||
if summary_type == 'attn':
|
if config.summary_type == 'attn':
|
||||||
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||||
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||||
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
|
@ -859,14 +884,13 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
||||||
super(XLMForSequenceClassification, self).__init__(config)
|
super(XLMForSequenceClassification, self).__init__(config)
|
||||||
|
|
||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
|
|
||||||
self.sequence_summary = XLMSequenceSummary(config)
|
self.sequence_summary = XLMSequenceSummary(config)
|
||||||
self.logits_proj = nn.Linear(config.d_model, num_labels)
|
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, lengths=None, positions=None, langs=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
cache=None, labels=None, head_mask=None):
|
||||||
labels=None, head_mask=None):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||||
|
@ -894,8 +918,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
||||||
Only used during pretraining for two-stream attention.
|
Only used during pretraining for two-stream attention.
|
||||||
Set to None during finetuning.
|
Set to None during finetuning.
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
|
||||||
mems, perm_mask, target_mapping, inp_q, head_mask)
|
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||||
|
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
output = self.sequence_summary(output)
|
output = self.sequence_summary(output)
|
||||||
|
@ -974,7 +998,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
||||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, CONFIG_NAME):
|
def __init__(self, config):
|
||||||
super(XLMForQuestionAnswering, self).__init__(config)
|
super(XLMForQuestionAnswering, self).__init__(config)
|
||||||
|
|
||||||
self.transformer = XLMModel(config)
|
self.transformer = XLMModel(config)
|
||||||
|
@ -982,12 +1006,11 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, lengths=None, positions=None, langs=None, attention_mask=None, cache=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
labels=None, head_mask=None):
|
||||||
start_positions=None, end_positions=None, head_mask=None):
|
|
||||||
|
|
||||||
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
|
transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
|
||||||
mems, perm_mask, target_mapping, inp_q, head_mask)
|
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
|
||||||
|
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
logits = self.qa_outputs(output)
|
logits = self.qa_outputs(output)
|
||||||
|
|
|
@ -36,7 +36,9 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict)
|
||||||
for model_class in model_classes:
|
for model_class in model_classes:
|
||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0], msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
if param.requires_grad:
|
||||||
|
tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0],
|
||||||
|
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||||
|
|
||||||
def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict):
|
def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict):
|
||||||
configs_no_init = _config_zero_init(config)
|
configs_no_init = _config_zero_init(config)
|
||||||
|
|
|
@ -26,7 +26,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_pretrained_bert import PretrainedConfig, PreTrainedModel
|
from pytorch_pretrained_bert import PretrainedConfig, PreTrainedModel
|
||||||
from pytorch_pretrained_bert.modeling import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
|
from pytorch_pretrained_bert.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
|
|
||||||
class ModelUtilsTest(unittest.TestCase):
|
class ModelUtilsTest(unittest.TestCase):
|
||||||
|
|
|
@ -16,20 +16,15 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import shutil
|
import shutil
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
|
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
|
||||||
BertForNextSentencePrediction, BertForPreTraining,
|
BertForNextSentencePrediction, BertForPreTraining,
|
||||||
BertForQuestionAnswering, BertForSequenceClassification,
|
BertForQuestionAnswering, BertForSequenceClassification,
|
||||||
BertForTokenClassification, BertForMultipleChoice)
|
BertForTokenClassification, BertForMultipleChoice)
|
||||||
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_pretrained_bert.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
||||||
|
|
|
@ -0,0 +1,276 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import shutil
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert import (XLMConfig, XLMModel, XLMForQuestionAnswering, XLMForSequenceClassification)
|
||||||
|
from pytorch_pretrained_bert.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
|
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class XLMModelTest(unittest.TestCase):
|
||||||
|
class XLMModelTester(object):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_lengths=True,
|
||||||
|
use_token_type_ids=True,
|
||||||
|
use_labels=True,
|
||||||
|
gelu_activation=True,
|
||||||
|
sinusoidal_embeddings=False,
|
||||||
|
causal=False,
|
||||||
|
asm=False,
|
||||||
|
n_langs=2,
|
||||||
|
vocab_size=99,
|
||||||
|
n_special=0,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
summary_type="last",
|
||||||
|
use_proj=True,
|
||||||
|
scope=None,
|
||||||
|
all_model_classes = (XLMModel,), # , XLMForSequenceClassification, XLMForTokenClassification),
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_lengths = use_input_lengths
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.gelu_activation = gelu_activation
|
||||||
|
self.sinusoidal_embeddings = sinusoidal_embeddings
|
||||||
|
self.asm = asm
|
||||||
|
self.n_langs = n_langs
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.n_special = n_special
|
||||||
|
self.summary_type = summary_type
|
||||||
|
self.causal = causal
|
||||||
|
self.use_proj = use_proj
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.n_langs = n_langs
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.summary_type = summary_type
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.scope = scope
|
||||||
|
self.all_model_classes = all_model_classes
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_lengths = None
|
||||||
|
if self.use_input_lengths:
|
||||||
|
input_lengths = ids_tensor([self.batch_size], vocab_size=self.seq_length-1)
|
||||||
|
|
||||||
|
token_type_ids = None
|
||||||
|
if self.use_token_type_ids:
|
||||||
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.n_langs)
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = XLMConfig(
|
||||||
|
vocab_size_or_config_json_file=self.vocab_size,
|
||||||
|
n_special=self.n_special,
|
||||||
|
emb_dim=self.hidden_size,
|
||||||
|
n_layers=self.num_hidden_layers,
|
||||||
|
n_heads=self.num_attention_heads,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
gelu_activation=self.gelu_activation,
|
||||||
|
sinusoidal_embeddings=self.sinusoidal_embeddings,
|
||||||
|
asm=self.asm,
|
||||||
|
causal=self.causal,
|
||||||
|
n_langs=self.n_langs,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
summary_type=self.summary_type,
|
||||||
|
use_proj=self.use_proj)
|
||||||
|
|
||||||
|
return config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def check_loss_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["loss"].size()),
|
||||||
|
[])
|
||||||
|
|
||||||
|
def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
|
||||||
|
model = XLMModel(config=config)
|
||||||
|
model.eval()
|
||||||
|
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
result = {
|
||||||
|
"sequence_output": sequence_output,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["sequence_output"].size()),
|
||||||
|
[self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
|
||||||
|
|
||||||
|
# def create_and_check_xlm_for_masked_lm(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
|
||||||
|
# model = XLMForMaskedLM(config=config)
|
||||||
|
# model.eval()
|
||||||
|
# loss, prediction_scores = model(input_ids, token_type_ids, input_lengths, token_labels)
|
||||||
|
# result = {
|
||||||
|
# "loss": loss,
|
||||||
|
# "prediction_scores": prediction_scores,
|
||||||
|
# }
|
||||||
|
# self.parent.assertListEqual(
|
||||||
|
# list(result["prediction_scores"].size()),
|
||||||
|
# [self.batch_size, self.seq_length, self.vocab_size])
|
||||||
|
# self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
|
# def create_and_check_xlm_for_question_answering(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
|
||||||
|
# model = XLMForQuestionAnswering(config=config)
|
||||||
|
# model.eval()
|
||||||
|
# loss, start_logits, end_logits = model(input_ids, token_type_ids, input_lengths, sequence_labels, sequence_labels)
|
||||||
|
# result = {
|
||||||
|
# "loss": loss,
|
||||||
|
# "start_logits": start_logits,
|
||||||
|
# "end_logits": end_logits,
|
||||||
|
# }
|
||||||
|
# self.parent.assertListEqual(
|
||||||
|
# list(result["start_logits"].size()),
|
||||||
|
# [self.batch_size, self.seq_length])
|
||||||
|
# self.parent.assertListEqual(
|
||||||
|
# list(result["end_logits"].size()),
|
||||||
|
# [self.batch_size, self.seq_length])
|
||||||
|
# self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
|
# def create_and_check_xlm_for_sequence_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
|
||||||
|
# config.num_labels = self.num_labels
|
||||||
|
# model = XLMForSequenceClassification(config)
|
||||||
|
# model.eval()
|
||||||
|
# loss, logits = model(input_ids, token_type_ids, input_lengths, sequence_labels)
|
||||||
|
# result = {
|
||||||
|
# "loss": loss,
|
||||||
|
# "logits": logits,
|
||||||
|
# }
|
||||||
|
# self.parent.assertListEqual(
|
||||||
|
# list(result["logits"].size()),
|
||||||
|
# [self.batch_size, self.num_labels])
|
||||||
|
# self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
|
# def create_and_check_xlm_for_token_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
|
||||||
|
# config.num_labels = self.num_labels
|
||||||
|
# model = XLMForTokenClassification(config=config)
|
||||||
|
# model.eval()
|
||||||
|
# loss, logits = model(input_ids, token_type_ids, input_lengths, token_labels)
|
||||||
|
# result = {
|
||||||
|
# "loss": loss,
|
||||||
|
# "logits": logits,
|
||||||
|
# }
|
||||||
|
# self.parent.assertListEqual(
|
||||||
|
# list(result["logits"].size()),
|
||||||
|
# [self.batch_size, self.seq_length, self.num_labels])
|
||||||
|
# self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
|
# def create_and_check_xlm_for_multiple_choice(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
|
||||||
|
# config.num_choices = self.num_choices
|
||||||
|
# model = XLMForMultipleChoice(config=config)
|
||||||
|
# model.eval()
|
||||||
|
# multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
# multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
# multiple_choice_input_lengths = input_lengths.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
# loss, logits = model(multiple_choice_inputs_ids,
|
||||||
|
# multiple_choice_token_type_ids,
|
||||||
|
# multiple_choice_input_lengths,
|
||||||
|
# choice_labels)
|
||||||
|
# result = {
|
||||||
|
# "loss": loss,
|
||||||
|
# "logits": logits,
|
||||||
|
# }
|
||||||
|
# self.parent.assertListEqual(
|
||||||
|
# list(result["logits"].size()),
|
||||||
|
# [self.batch_size, self.num_choices])
|
||||||
|
# self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_check_xlm_commons(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
|
||||||
|
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_lengths}
|
||||||
|
create_and_check_commons(self, config, inputs_dict)
|
||||||
|
|
||||||
|
def test_default(self):
|
||||||
|
self.run_tester(XLMModelTest.XLMModelTester(self))
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37)
|
||||||
|
config_tester.run_common_tests()
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
||||||
|
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||||
|
shutil.rmtree(cache_dir)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def run_tester(self, tester):
|
||||||
|
config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
|
tester.create_and_check_xlm_model(*config_and_inputs)
|
||||||
|
|
||||||
|
# config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
|
# tester.create_and_check_xlm_for_masked_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
# config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
|
# tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
|
# config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
|
# tester.create_and_check_xlm_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
|
# config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
|
# tester.create_and_check_xlm_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
# config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
|
# tester.create_and_check_xlm_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
|
tester.create_and_check_xlm_commons(*config_and_inputs)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
|
@ -20,7 +20,7 @@ from io import open
|
||||||
import shutil
|
import shutil
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
from pytorch_pretrained_bert.tokenization_bert import (BasicTokenizer,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
WordpieceTokenizer,
|
WordpieceTokenizer,
|
||||||
_is_control, _is_punctuation,
|
_is_control, _is_punctuation,
|
|
@ -0,0 +1,79 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert.tokenization_xlm import XLMTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
|
||||||
|
|
||||||
|
|
||||||
|
class XLMTokenizationTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_full_tokenizer(self):
|
||||||
|
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||||
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
|
"w</w>", "r</w>", "t</w>",
|
||||||
|
"lo", "low", "er</w>",
|
||||||
|
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
|
||||||
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
|
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
||||||
|
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||||
|
fp.write(json.dumps(vocab_tokens))
|
||||||
|
vocab_file = fp.name
|
||||||
|
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||||
|
fp.write("\n".join(merges))
|
||||||
|
merges_file = fp.name
|
||||||
|
|
||||||
|
tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||||
|
os.remove(vocab_file)
|
||||||
|
os.remove(merges_file)
|
||||||
|
|
||||||
|
text = "lower"
|
||||||
|
bpe_tokens = ["low", "er</w>"]
|
||||||
|
tokens = tokenizer.tokenize(text)
|
||||||
|
self.assertListEqual(tokens, bpe_tokens)
|
||||||
|
|
||||||
|
input_tokens = tokens + ["<unk>"]
|
||||||
|
input_bpe_tokens = [14, 15, 20]
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||||
|
tokenizer_2 = XLMTokenizer.from_pretrained("/tmp/")
|
||||||
|
os.remove(vocab_file)
|
||||||
|
os.remove(merges_file)
|
||||||
|
os.remove(special_tokens_file)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
|
||||||
|
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
|
||||||
|
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
|
||||||
|
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_tokenizer_from_pretrained(self):
|
||||||
|
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
||||||
|
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
tokenizer = XLMTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||||
|
shutil.rmtree(cache_dir)
|
||||||
|
self.assertIsNotNone(tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
|
@ -26,7 +26,7 @@ from io import open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
from .tokenization import BasicTokenizer
|
from .tokenization_bert import BasicTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ from io import open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
from .tokenization import BasicTokenizer
|
from .tokenization_bert import BasicTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue