[Tests, GPU, SLOW] fix a bunch of GPU hardcoded tests in Pytorch (#4468)

* fix gpu slow tests in pytorch

* change model to device syntax
This commit is contained in:
Patrick von Platen 2020-05-19 21:35:04 +02:00 committed by GitHub
parent 5856999a9f
commit aa925a52fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 28 additions and 11 deletions

View File

@ -80,7 +80,7 @@ def main():
# Load a pre-trained model
model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
model = model.to(device)
model.to(device)
logger.info(
"Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(

View File

@ -770,7 +770,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
import torch_xla.core.xla_model as xm
model = xm.send_cpu_data_to_device(model, xm.xla_device())
model = model.to(xm.xla_device())
model.to(xm.xla_device())
return model

View File

@ -219,6 +219,7 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_ctrl(self):
model = CTRLLMHeadModel.from_pretrained("ctrl")
model.to(torch_device)
input_ids = torch.tensor(
[[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device
) # Legal the president is

View File

@ -329,5 +329,5 @@ class EncoderDecoderModelTest(unittest.TestCase):
@slow
def test_real_bert_model_from_pretrained(self):
model = EncoderDecoderModel.from_pretrained("bert-base-uncased", "bert-base-uncased")
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
self.assertIsNotNone(model)

View File

@ -343,6 +343,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [
464,
@ -372,6 +373,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_distilgpt2(self):
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
model.to(torch_device)
input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president
expected_output_ids = [
464,

View File

@ -214,32 +214,39 @@ class LongformerModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head(self):
model = LongformerModel.from_pretrained("longformer-base-4096")
model.to(torch_device)
# 'Hello world! ' repeated 1000 times
input_ids = torch.tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]]) # long input
input_ids = torch.tensor(
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
) # long input
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
attention_mask[:, [1, 4, 21]] = 2 # Set global attention on a few random positions
output = model(input_ids, attention_mask=attention_mask)[0]
expected_output_sum = torch.tensor(74585.8594)
expected_output_mean = torch.tensor(0.0243)
expected_output_sum = torch.tensor(74585.8594, device=torch_device)
expected_output_mean = torch.tensor(0.0243, device=torch_device)
self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4))
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
@slow
def test_inference_masked_lm(self):
model = LongformerForMaskedLM.from_pretrained("longformer-base-4096")
model.to(torch_device)
# 'Hello world! ' repeated 1000 times
input_ids = torch.tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]]) # long input
input_ids = torch.tensor(
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
) # long input
loss, prediction_scores = model(input_ids, masked_lm_labels=input_ids)
expected_loss = torch.tensor(0.0620)
expected_prediction_scores_sum = torch.tensor(-6.1599e08)
expected_prediction_scores_mean = torch.tensor(-3.0622)
expected_loss = torch.tensor(0.0620, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device)
expected_prediction_scores_mean = torch.tensor(-3.0622, device=torch_device)
input_ids = input_ids.to(torch_device)
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
self.assertTrue(torch.allclose(prediction_scores.sum(), expected_prediction_scores_sum, atol=1e-4))

View File

@ -227,6 +227,7 @@ class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_openai_gpt(self):
model = OpenAIGPTLMHeadModel.from_pretrained("openai-gpt")
model.to(torch_device)
input_ids = torch.tensor([[481, 4735, 544]], dtype=torch.long, device=torch_device) # the president is
expected_output_ids = [
481,

View File

@ -444,6 +444,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
)
input_ids = tok.encode(model.config.prefix + original_input, return_tensors="pt")
input_ids = input_ids.to(torch_device)
output = model.generate(
input_ids=input_ids,
@ -471,6 +472,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
expected_translation = "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre un « portrait familial » de générations innombrables de étoiles : les plus anciennes sont observées sous forme de pointes bleues, alors que les « nouveau-nés » de couleur rose dans la salle des accouchements doivent être plus difficiles "
input_ids = tok.encode(model.config.prefix + original_input, return_tensors="pt")
input_ids = input_ids.to(torch_device)
output = model.generate(
input_ids=input_ids,
@ -498,6 +500,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022."
input_ids = tok.encode(model.config.prefix + original_input, return_tensors="pt")
input_ids = input_ids.to(torch_device)
output = model.generate(
input_ids=input_ids,

View File

@ -223,6 +223,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_transfo_xl_wt103(self):
model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
model.to(torch_device)
input_ids = torch.tensor(
[
[

View File

@ -434,6 +434,7 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_xlm_mlm_en_2048(self):
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
model.to(torch_device)
input_ids = torch.tensor([[14, 447]], dtype=torch.long, device=torch_device) # the president
expected_output_ids = [
14,
@ -459,4 +460,4 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
] # the president the president the president the president the president the president the president the president the president the president
# TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
self.assertListEqual(output_ids[0].cpu().numpy().tolist(), expected_output_ids)

View File

@ -517,6 +517,7 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_xlnet_base_cased(self):
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
model.to(torch_device)
input_ids = torch.tensor(
[
[