[Flax Generation] Correct inconsistencies PyTorch/Flax (#12662)

* fix_torch_device_generate_test

* remove @

* correct greedy search

* save intertmed

* add final logits bias

* correct

* up

* add more tests

* fix another bug

* finish tests

* finish marian tests

* up

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen 2021-07-13 18:53:30 +01:00 committed by GitHub
parent 7a22a02a70
commit cee2d2135f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 113 additions and 44 deletions

View File

@ -39,25 +39,6 @@ Implementation Notes
- Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``.
- This model was contributed by `sshleifer <https://huggingface.co/sshleifer>`__.
Tips
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- In Flax, it is highly advised to pass `early_stopping=True` to `generate`. *E.g.*:
::
>>> from transformers import MarianTokenizer, FlaxMarianMTModel
>>> model = FlaxMarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> text = "My friends are cool but they eat too many carbs."
>>> input_ids = tokenizer(text, max_length=64, return_tensors='jax').input_ids
>>> # Marian has to make use of early_stopping=True
>>> sequences = model.generate(input_ids, early_stopping=True, max_length=64, num_beams=2).sequences
Naming
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -15,6 +15,7 @@
# limitations under the License.
from functools import partial
from typing import Dict, Optional
import numpy as np
@ -447,8 +448,8 @@ class FlaxGenerationMixin:
next_token = jnp.argmax(logits, axis=-1)
next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
@ -687,7 +688,7 @@ class FlaxGenerationMixin:
return not_max_length_yet & still_open_beam & improvement_still_possible
def beam_search_body_fn(state):
def beam_search_body_fn(state, input_ids_length=1):
"""beam search state update fn."""
# 1. Forward current tokens
# Collect the current position slice along length to feed the fast
@ -696,10 +697,15 @@ class FlaxGenerationMixin:
# unflatten beam dimension
# Unflatten beam dimension in attention cache arrays
input_token = flatten_beam_dim(
lax.dynamic_slice(state.running_sequences, (0, 0, state.cur_len - 1), (batch_size, num_beams, 1))
lax.dynamic_slice(
state.running_sequences,
(0, 0, state.cur_len - input_ids_length),
(batch_size, num_beams, input_ids_length),
)
)
model_outputs = model(input_token, params=params, **state.model_kwargs)
logits = unflatten_beam_dim(model_outputs.logits[:, 0], batch_size, num_beams)
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
cache = jax.tree_map(
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
)
@ -747,14 +753,13 @@ class FlaxGenerationMixin:
# set of active beam search sequences, set their log probs to a very large
# negative value.
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
# 5. Get running sequences scores for next
# Determine the top k beam indices (from top 2*k beams) from log probs
# and gather top k beams (from top 2*k beams).
next_topk_indices = jnp.flip(lax.top_k(topk_log_probs, k=num_beams)[1], axis=1)
next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
next_running_sequences, next_running_scores = gather_beams(
[topk_sequences, topk_log_probs], next_topk_indices, batch_size, num_beams
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
)
# 6. Process topk logits
@ -801,7 +806,8 @@ class FlaxGenerationMixin:
)
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
state = beam_search_body_fn(state)
if input_ids.shape[-1] > 1:
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
if not trace:
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)

View File

@ -1799,7 +1799,6 @@ class GenerationMixin:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)

View File

@ -1213,6 +1213,7 @@ class FlaxMarianMTModule(nn.Module):
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
def _get_encoder_module(self):
return self.model.encoder
@ -1254,6 +1255,8 @@ class FlaxMarianMTModule(nn.Module):
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias
if not return_dict:
output = (lm_logits,) + outputs[1:]
return output
@ -1367,6 +1370,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias
return lm_logits, outputs
@ -1465,8 +1469,7 @@ FLAX_MARIAN_MT_DOCSTRING = """
>>> text = "My friends are cool but they eat too many carbs."
>>> input_ids = tokenizer(text, max_length=64, return_tensors='jax').input_ids
>>> # Marian has to make use of early_stopping=True
>>> sequences = model.generate(input_ids, early_stopping=True, max_length=64, num_beams=2).sequences
>>> sequences = model.generate(input_ids, max_length=64, num_beams=2).sequences
>>> outputs = tokenizer.batch_decode(sequences, skip_special_tokens=True)
>>> # should give `Meine Freunde sind cool, aber sie essen zu viele Kohlenhydrate.`

View File

@ -16,8 +16,9 @@ import random
import numpy as np
from transformers import is_flax_available
from transformers.testing_utils import require_flax
import transformers
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
if is_flax_available():
@ -26,10 +27,15 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from jax import jit
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
if is_torch_available():
import torch
def ids_tensor(shape, vocab_size, rng=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
@ -78,6 +84,29 @@ class FlaxGenerationTesterMixin:
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, max_length
@is_pt_flax_cross_test
def test_greedy_generate_pt_fx(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
config.decoder_start_token_id = 0
for model_class in self.all_generative_model_classes:
flax_model = model_class(config)
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, flax_model.params)
flax_generation_outputs = flax_model.generate(input_ids).sequences
pt_generation_outputs = pt_model.generate(torch.tensor(input_ids, dtype=torch.long))
if flax_generation_outputs.shape[-1] > pt_generation_outputs.shape[-1]:
flax_generation_outputs = flax_generation_outputs[:, : pt_generation_outputs.shape[-1]]
self.assertListEqual(pt_generation_outputs.numpy().tolist(), flax_generation_outputs.tolist())
def test_greedy_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False

View File

@ -304,7 +304,7 @@ class FlaxMarianModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGeneratio
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("Helsinki-NLP/opus-mt-en-de", from_pt=True)
model = model_class_name.from_pretrained("Helsinki-NLP/opus-mt-en-de")
# FlaxMarianForSequenceClassification expects eos token in input_ids
input_ids = np.ones((1, 1)) * model.config.eos_token_id
outputs = model(input_ids)
@ -333,7 +333,7 @@ class MarianIntegrationTest(unittest.TestCase):
@cached_property
def model(self):
model: FlaxMarianMTModel = FlaxMarianMTModel.from_pretrained(self.model_name, from_pt=True)
model: FlaxMarianMTModel = FlaxMarianMTModel.from_pretrained(self.model_name)
self.assertEqual(model.config.decoder_start_token_id, model.config.pad_token_id)
return model
@ -348,7 +348,6 @@ class MarianIntegrationTest(unittest.TestCase):
attention_mask=model_inputs.attention_mask,
num_beams=2,
max_length=128,
early_stopping=True,
).sequences
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return generated_words
@ -385,7 +384,7 @@ class TestMarian_FR_EN(MarianIntegrationTest):
"Tom et Mary étaient assis à une table.", # Accents
]
expected_text = [
"Give me the microphone, please.",
"Give me the microphone.",
"Tom and Mary were sitting at a table.",
]
@ -402,8 +401,8 @@ class TestMarian_MT_EN(MarianIntegrationTest):
src = "mt"
tgt = "en"
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil"]
expected_text = ["By touching him gently, Jesus healed a man who was affected by"]
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]
@slow
def test_batch_generation_mt_en(self):
@ -425,14 +424,66 @@ class TestMarian_EN_DE(MarianIntegrationTest):
"Turn around and close your eyes.",
]
expected_text = [
"- Ich bin ein kleiner Frosch.",
"Ich bin ein kleiner Frosch.",
"Jetzt kann ich die 100 Wörter des Deutschen vergessen, die ich kenne.",
"Tom fragte seinen Lehrer um Rat.",
"Tom bat seinen Lehrer um Rat.",
"So würde ich das machen.",
"Tom bewunderte Marias Mut wirklich.",
"Dreh dich um und schließe deine Augen.",
"Drehen Sie sich um und schließen Sie die Augen.",
]
@slow
def test_batch_generation_en_de(self):
self._assert_generated_batch_equal_expected()
@require_flax
@require_sentencepiece
@require_tokenizers
class TestMarian_en_zh(MarianIntegrationTest):
src = "en"
tgt = "zh"
src_text = ["My name is Wolfgang and I live in Berlin"]
expected_text = ["我叫沃尔夫冈 我住在柏林"]
@slow
def test_batch_generation_eng_zho(self):
self._assert_generated_batch_equal_expected()
@require_flax
@require_sentencepiece
@require_tokenizers
class TestMarian_RU_FR(MarianIntegrationTest):
src = "ru"
tgt = "fr"
src_text = ["Он показал мне рукопись своей новой пьесы."]
expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."]
@slow
def test_batch_generation_ru_fr(self):
self._assert_generated_batch_equal_expected()
@require_flax
@require_sentencepiece
@require_tokenizers
class TestMarian_en_ROMANCE(MarianIntegrationTest):
"""Multilingual on target side."""
src = "en"
tgt = "ROMANCE"
src_text = [
">>fr<< Don't spend so much time watching TV.",
">>pt<< Your message has been sent.",
">>es<< He's two years older than me.",
]
expected_text = [
"Ne passez pas autant de temps à regarder la télé.",
"A sua mensagem foi enviada.",
"Es dos años más viejo que yo.",
]
@slow
def test_batch_generation_en_ROMANCE_multi(self):
self._assert_generated_batch_equal_expected()

View File

@ -482,8 +482,8 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
expected_summaries = [
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," the magazine says . all 150 on board the germanwings flight were killed .',
"the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .",
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says . all 150 on board were killed when germanwings flight 9525 crashed .',
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . he says the new framework would reduce Iran's low-enriched uranium stockpile and cut centrifuges . miller: if it had been, there would have been no Iranian team at the table .",
'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
]