[tests] add `torch.use_deterministic_algorithms` for XPU (#30774)
* add xpu check * add marker * add documentation * update doc * fix ci * remove from global init * fix
This commit is contained in:
parent
8366b57241
commit
21339a5213
|
@ -116,6 +116,7 @@ from .utils import (
|
|||
is_torch_bf16_available_on_device,
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_deterministic,
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
|
@ -943,6 +944,15 @@ def require_torch_bf16_cpu(test_case):
|
|||
)(test_case)
|
||||
|
||||
|
||||
def require_deterministic_for_xpu(test_case):
|
||||
if is_torch_xpu_available():
|
||||
return unittest.skipUnless(is_torch_deterministic(), "test requires torch to use deterministic algorithms")(
|
||||
test_case
|
||||
)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_tf32(test_case):
|
||||
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
|
||||
return unittest.skipUnless(
|
||||
|
|
|
@ -188,6 +188,7 @@ from .import_utils import (
|
|||
is_torch_bf16_gpu_available,
|
||||
is_torch_compile_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_deterministic,
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_fx_available,
|
||||
is_torch_fx_proxy,
|
||||
|
|
|
@ -296,6 +296,18 @@ def is_torch_available():
|
|||
return _torch_available
|
||||
|
||||
|
||||
def is_torch_deterministic():
|
||||
"""
|
||||
Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2"
|
||||
"""
|
||||
import torch
|
||||
|
||||
if torch.get_deterministic_debug_mode() == 0:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def is_hqq_available():
|
||||
return _hqq_available
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import tempfile
|
|||
import unittest
|
||||
|
||||
from transformers import is_torch_available, logging
|
||||
from transformers.testing_utils import CaptureLogger, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import CaptureLogger, require_deterministic_for_xpu, require_torch, slow, torch_device
|
||||
|
||||
from ...test_modeling_common import ids_tensor
|
||||
from ..bart.test_modeling_bart import BartStandaloneDecoderModelTester
|
||||
|
@ -639,6 +639,7 @@ class EncoderDecoderMixin:
|
|||
loss.backward()
|
||||
|
||||
@slow
|
||||
@require_deterministic_for_xpu
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
model_2.to(torch_device)
|
||||
|
|
|
@ -18,7 +18,7 @@ import tempfile
|
|||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_deterministic_for_xpu, require_torch, slow, torch_device
|
||||
|
||||
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..bert.test_modeling_bert import BertModelTester
|
||||
|
@ -422,6 +422,7 @@ class EncoderDecoderMixin:
|
|||
loss.backward()
|
||||
|
||||
@slow
|
||||
@require_deterministic_for_xpu
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
model_2.to(torch_device)
|
||||
|
@ -578,6 +579,7 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
def test_save_and_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
@require_deterministic_for_xpu
|
||||
# all published pretrained models are Speech2TextModel != Speech2TextEncoder
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
pass
|
||||
|
|
|
@ -22,6 +22,7 @@ import unittest
|
|||
from transformers import SpeechT5Config, SpeechT5HifiGanConfig
|
||||
from transformers.testing_utils import (
|
||||
is_torch_available,
|
||||
require_deterministic_for_xpu,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
|
@ -1071,6 +1072,7 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
|||
"Shape mismatch between generate_speech and generate methods.",
|
||||
)
|
||||
|
||||
@require_deterministic_for_xpu
|
||||
def test_one_to_many_generation(self):
|
||||
model = self.default_model
|
||||
processor = self.default_processor
|
||||
|
|
Loading…
Reference in New Issue