FlaxGPTJ (#14396)
* add flax gptj * no bias in attention dense * no wpe * fix rotary embeddings * fix rotary embeds * fix rotray embeds * quality * doc and quality * fix equivalence tests
This commit is contained in:
parent
70996a5420
commit
4c0dd199c8
|
@ -425,7 +425,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| GPT-J | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GPT-J | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Hubert | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
|
|
@ -119,3 +119,17 @@ GPTJForSequenceClassification
|
|||
|
||||
.. autoclass:: transformers.GPTJForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
FlaxGPTJModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxGPTJModel
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxGPTJForCausalLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxGPTJForCausalLM
|
||||
:members: __call__
|
||||
|
|
|
@ -1994,6 +1994,7 @@ if is_flax_available():
|
|||
_import_structure["models.gpt_neo"].extend(
|
||||
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
|
||||
_import_structure["models.marian"].extend(
|
||||
[
|
||||
"FlaxMarianModel",
|
||||
|
@ -3691,6 +3692,7 @@ if TYPE_CHECKING:
|
|||
from .models.encoder_decoder import FlaxEncoderDecoderModel
|
||||
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
|
||||
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
|
||||
from .models.gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
|
||||
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
|
||||
from .models.mbart import (
|
||||
FlaxMBartForConditionalGeneration,
|
||||
|
|
|
@ -39,6 +39,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
|||
("bart", "FlaxBartModel"),
|
||||
("gpt2", "FlaxGPT2Model"),
|
||||
("gpt_neo", "FlaxGPTNeoModel"),
|
||||
("gptj", "FlaxGPTJModel"),
|
||||
("electra", "FlaxElectraModel"),
|
||||
("clip", "FlaxCLIPModel"),
|
||||
("vit", "FlaxViTModel"),
|
||||
|
@ -114,6 +115,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||
# Model for Causal LM mapping
|
||||
("gpt2", "FlaxGPT2LMHeadModel"),
|
||||
("gpt_neo", "FlaxGPTNeoForCausalLM"),
|
||||
("gptj", "FlaxGPTJForCausalLM"),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_torch_available
|
||||
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
@ -33,6 +33,13 @@ if is_torch_available():
|
|||
"GPTJPreTrainedModel",
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_gptj"] = [
|
||||
"FlaxGPTJForCausalLM",
|
||||
"FlaxGPTJModel",
|
||||
"FlaxGPTJPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
|
||||
|
@ -46,6 +53,9 @@ if TYPE_CHECKING:
|
|||
GPTJPreTrainedModel,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
|
|
@ -0,0 +1,714 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2021 The EleutherAI and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
||||
from ...utils import logging
|
||||
from .configuration_gptj import GPTJConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "gptj"
|
||||
_CONFIG_FOR_DOC = "GPTJConfig"
|
||||
_TOKENIZER_FOR_DOC = "GPTJTokenizer"
|
||||
|
||||
|
||||
GPTJ_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||
generic methods the library implements for all its model (such as downloading or saving, resizing the input
|
||||
embeddings, pruning heads etc.)
|
||||
|
||||
This model is also a Flax Linen `flax.nn.Module
|
||||
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
|
||||
Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
|
||||
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
|
||||
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
|
||||
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
|
||||
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.GPTJConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
GPTJ_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, input_ids_length)`):
|
||||
:obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.GPTJTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||
config.max_position_embeddings - 1]``.
|
||||
past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):
|
||||
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
||||
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
def create_sinusoidal_positions(num_pos, dim):
|
||||
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
|
||||
sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
|
||||
sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp)
|
||||
|
||||
sentinel = dim // 2 + dim % 2
|
||||
out = np.zeros((num_pos, dim))
|
||||
out[:, 0:sentinel] = sin
|
||||
out[:, sentinel:] = cos
|
||||
|
||||
return jnp.array(out)
|
||||
|
||||
|
||||
def rotate_every_two(tensor):
|
||||
rotate_half_tensor = jnp.stack((tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)
|
||||
rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,))
|
||||
return rotate_half_tensor
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(tensor, sincos):
|
||||
sin_pos, cos_pos = sincos
|
||||
sin_pos = sin_pos[:, :, None, :].repeat(2, 3)
|
||||
cos_pos = cos_pos[:, :, None, :].repeat(2, 3)
|
||||
return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos)
|
||||
|
||||
|
||||
class FlaxGPTJAttention(nn.Module):
|
||||
config: GPTJConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
causal: bool = True
|
||||
is_cross_attention: bool = False
|
||||
|
||||
def setup(self):
|
||||
config = self.config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.rotary_dim = config.rotary_dim
|
||||
|
||||
dense = partial(
|
||||
nn.Dense,
|
||||
self.embed_dim,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
||||
self.out_proj = dense()
|
||||
|
||||
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
|
||||
|
||||
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
|
||||
|
||||
pos_embd_dim = self.rotary_dim or self.embed_dim
|
||||
self.embed_positions = create_sinusoidal_positions(config.max_position_embeddings, pos_embd_dim)
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
||||
|
||||
@nn.compact
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
is_initialized = self.has_variable("cache", "cached_key")
|
||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
||||
|
||||
if is_initialized:
|
||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
||||
# update key, value caches with our new 1d spatial slices
|
||||
cur_index = cache_index.value
|
||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
||||
cached_key.value = key
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
)
|
||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
||||
return key, value, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
|
||||
query = self.q_proj(hidden_states)
|
||||
key = self.k_proj(hidden_states)
|
||||
value = self.v_proj(hidden_states)
|
||||
|
||||
query = self._split_heads(query)
|
||||
key = self._split_heads(key)
|
||||
value = self._split_heads(value)
|
||||
|
||||
sincos = jnp.take(self.embed_positions, position_ids, axis=0)
|
||||
sincos = jnp.split(sincos, 2, axis=-1)
|
||||
if self.rotary_dim is not None:
|
||||
k_rot = key[:, :, :, : self.rotary_dim]
|
||||
k_pass = key[:, :, :, self.rotary_dim :]
|
||||
|
||||
q_rot = query[:, :, :, : self.rotary_dim]
|
||||
q_pass = query[:, :, :, self.rotary_dim :]
|
||||
|
||||
k_rot = apply_rotary_pos_emb(k_rot, sincos)
|
||||
q_rot = apply_rotary_pos_emb(q_rot, sincos)
|
||||
|
||||
key = jnp.concatenate([k_rot, k_pass], axis=-1)
|
||||
query = jnp.concatenate([q_rot, q_pass], axis=-1)
|
||||
else:
|
||||
key = apply_rotary_pos_emb(key, sincos)
|
||||
query = apply_rotary_pos_emb(query, sincos)
|
||||
|
||||
query_length, key_length = query.shape[1], key.shape[1]
|
||||
|
||||
if self.has_variable("cache", "cached_key"):
|
||||
mask_shift = self.variables["cache"]["cache_index"]
|
||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
||||
causal_mask = lax.dynamic_slice(
|
||||
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
||||
)
|
||||
else:
|
||||
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
||||
|
||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
||||
attention_mask = combine_masks(attention_mask, causal_mask)
|
||||
|
||||
dropout_rng = None
|
||||
if not deterministic and self.config.attn_pdrop > 0.0:
|
||||
dropout_rng = self.make_rng("dropout")
|
||||
|
||||
# During fast autoregressive decoding, we feed one position at a time,
|
||||
# and cache the keys and values step by step.
|
||||
if self.has_variable("cache", "cached_key") or init_cache:
|
||||
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
|
||||
|
||||
# transform boolean mask into float mask
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
||||
)
|
||||
|
||||
# usual dot product attention
|
||||
attn_weights = dot_product_attention_weights(
|
||||
query,
|
||||
key,
|
||||
bias=attention_bias,
|
||||
dropout_rng=dropout_rng,
|
||||
dropout_rate=self.config.attn_pdrop,
|
||||
deterministic=deterministic,
|
||||
dtype=self.dtype,
|
||||
precision=None,
|
||||
)
|
||||
|
||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
|
||||
attn_output = self._merge_heads(attn_output)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
||||
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxGPTJMLP(nn.Module):
|
||||
config: GPTJConfig
|
||||
intermediate_size: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
embed_dim = self.config.hidden_size
|
||||
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
|
||||
|
||||
self.fc_in = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
|
||||
self.fc_out = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)
|
||||
|
||||
self.act = ACT2FN[self.config.activation_function]
|
||||
self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
|
||||
|
||||
def __call__(self, hidden_states, deterministic: bool = True):
|
||||
hidden_states = self.fc_in(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.fc_out(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxGPTJBlock(nn.Module):
|
||||
config: GPTJConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
hidden_size = self.config.hidden_size
|
||||
inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
||||
self.attn = FlaxGPTJAttention(self.config, dtype=self.dtype)
|
||||
|
||||
self.mlp = FlaxGPTJMLP(self.config, inner_dim, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_outputs = self.attn(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
|
||||
feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
|
||||
# residual connection
|
||||
hidden_states = attn_output + feed_forward_hidden_states + residual
|
||||
|
||||
return (hidden_states,) + attn_outputs[1:]
|
||||
|
||||
|
||||
class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = GPTJConfig
|
||||
base_model_prefix = "transformer"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTJConfig,
|
||||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs,
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
if self.config.add_cross_attention:
|
||||
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
|
||||
encoder_attention_mask = attention_mask
|
||||
module_init_outputs = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
||||
|
||||
return module_init_outputs["params"]
|
||||
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (:obj:`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (:obj:`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
input_ids = jnp.ones((batch_size, max_length))
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return init_variables["cache"]
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
past_key_values: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
|
||||
if position_ids is None:
|
||||
if past_key_values is not None:
|
||||
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
|
||||
|
||||
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones((batch_size, sequence_length))
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
inputs = {"params": params or self.params}
|
||||
|
||||
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module
|
||||
if past_key_values:
|
||||
inputs["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
False,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxGPTJBlockCollection(nn.Module):
|
||||
config: GPTJConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.blocks = [
|
||||
FlaxGPTJBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
for block in self.blocks:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids=position_ids,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
# this contains possible `None` values - `FlaxGPTJModule` will filter them out
|
||||
outputs = (hidden_states, all_hidden_states, all_attentions)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxGPTJModule(nn.Module):
|
||||
config: GPTJConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.embed_dim = self.config.hidden_size
|
||||
|
||||
self.wte = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
|
||||
self.h = FlaxGPTJBlockCollection(self.config, dtype=self.dtype)
|
||||
self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
deterministic=True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
input_embeds = self.wte(input_ids.astype("i4"))
|
||||
|
||||
hidden_states = self.dropout(input_embeds, deterministic=deterministic)
|
||||
|
||||
outputs = self.h(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids=position_ids,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = outputs[1] + (hidden_states,)
|
||||
outputs = (hidden_states, all_hidden_states) + outputs[2:]
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=outputs[1],
|
||||
attentions=outputs[-1],
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare GPTJ Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
GPTJ_START_DOCSTRING,
|
||||
)
|
||||
class FlaxGPTJModel(FlaxGPTJPreTrainedModel):
|
||||
module_class = FlaxGPTJModule
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxGPTJModel,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxCausalLMOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
class FlaxGPTJForCausalLMModule(nn.Module):
|
||||
config: GPTJConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.transformer = FlaxGPTJModule(self.config, dtype=self.dtype)
|
||||
self.lm_head = nn.Dense(
|
||||
self.config.vocab_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
|
||||
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
|
||||
else:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (lm_logits,) + outputs[1:]
|
||||
|
||||
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPTJ Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
embeddings).
|
||||
""",
|
||||
GPTJ_START_DOCSTRING,
|
||||
)
|
||||
class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
|
||||
module_class = FlaxGPTJForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length)
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
||||
# But since GPTJ uses a causal mask, those positions are masked anyways.
|
||||
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": extended_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
||||
return model_kwargs
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxGPTJForCausalLM,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxCausalLMOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
|
@ -1004,6 +1004,42 @@ class FlaxGPTNeoPreTrainedModel:
|
|||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxGPTJForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxGPTJModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxGPTJPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxMarianModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
|
|
@ -0,0 +1,328 @@
|
|||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, tooslow
|
||||
|
||||
from .test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.models.gptj.modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FlaxGPTJModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
rotary_dim=4,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = None
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
self.pad_token_id = vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
config = GPTJConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
n_embd=self.hidden_size,
|
||||
n_layer=self.num_hidden_layers,
|
||||
n_head=self.num_attention_heads,
|
||||
n_positions=self.max_position_embeddings,
|
||||
use_cache=False,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
rotary_dim=self.rotary_dim,
|
||||
)
|
||||
|
||||
return (config, input_ids, input_mask)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
|
||||
max_decoder_length = 20
|
||||
model = model_class_name(config)
|
||||
|
||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
|
||||
|
||||
position_ids = jnp.broadcast_to(
|
||||
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
|
||||
)
|
||||
outputs_cache = model(
|
||||
input_ids[:, :-1],
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
|
||||
outputs_cache_next = model(
|
||||
input_ids[:, -1:],
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=outputs_cache.past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
outputs = model(input_ids)
|
||||
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
|
||||
max_decoder_length = 20
|
||||
model = model_class_name(config)
|
||||
|
||||
attention_mask_cache = jnp.concatenate(
|
||||
[attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||
position_ids = jnp.broadcast_to(
|
||||
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
|
||||
)
|
||||
|
||||
outputs_cache = model(
|
||||
input_ids[:, :-1],
|
||||
attention_mask=attention_mask_cache,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
|
||||
outputs_cache_next = model(
|
||||
input_ids[:, -1:],
|
||||
past_key_values=outputs_cache.past_key_values,
|
||||
attention_mask=attention_mask_cache,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
outputs = model(input_ids, attention_mask=attention_mask)
|
||||
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGPTJModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxGPTJModel, FlaxGPTJForCausalLM) if is_flax_available() else ()
|
||||
all_generative_model_classes = (FlaxGPTJForCausalLM,) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxGPTJModelTester(self)
|
||||
|
||||
def test_use_cache_forward(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)
|
||||
|
||||
def test_use_cache_forward_with_attn_mask(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_use_cache_forward_with_attn_mask(
|
||||
model_class_name, config, input_ids, attention_mask
|
||||
)
|
||||
|
||||
@tooslow
|
||||
def test_batch_generation(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left")
|
||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
|
||||
|
||||
model = FlaxGPTJForCausalLM.from_pretrained("EleutherAI/gptj-6B")
|
||||
model.do_sample = False
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
jit_generate = jax.jit(model.generate)
|
||||
|
||||
output_sequences = jit_generate(
|
||||
inputs["input_ids"], attention_mask=inputs["attention_mask"], pad_token_id=tokenizer.pad_token_id
|
||||
).sequences
|
||||
|
||||
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
|
||||
|
||||
expected_string = [
|
||||
"Hello this is a long string of text.\n\nI'm trying to get the text of the",
|
||||
"Hey, I'm a little late to the party. I'm going to",
|
||||
]
|
||||
|
||||
self.assertListEqual(output_string, expected_string)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
pt_model = pt_model_class(config).eval()
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
batch_size, seq_length = pt_inputs["input_ids"].shape
|
||||
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
|
||||
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
|
||||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
|
||||
|
||||
@tooslow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("EleutherAI/gptj-6B")
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
|
@ -50,6 +50,7 @@ class GPTJModelTester:
|
|||
use_mc_token_ids=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
rotary_dim=4,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
|
@ -73,6 +74,7 @@ class GPTJModelTester:
|
|||
self.use_mc_token_ids = use_mc_token_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
|
@ -149,6 +151,7 @@ class GPTJModelTester:
|
|||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
rotary_dim=self.rotary_dim,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
|
|
Loading…
Reference in New Issue