* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
2958b55fe5
commit
81a73fa638
|
@ -234,7 +234,7 @@ class OnnxConfig(ABC):
|
|||
if is_torch_available():
|
||||
from transformers.utils import get_torch_version
|
||||
|
||||
return get_torch_version() >= self.torch_onnx_minimum_version
|
||||
return version.parse(get_torch_version()) >= self.torch_onnx_minimum_version
|
||||
else:
|
||||
return False
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from unittest import TestCase
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
|
||||
|
@ -321,7 +322,7 @@ class OnnxExportTestCaseV2(TestCase):
|
|||
if is_torch_available():
|
||||
from transformers.utils import get_torch_version
|
||||
|
||||
if get_torch_version() < onnx_config.torch_onnx_minimum_version:
|
||||
if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version:
|
||||
pytest.skip(
|
||||
"Skipping due to incompatible PyTorch version. Minimum required is"
|
||||
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
|
||||
|
@ -364,7 +365,7 @@ class OnnxExportTestCaseV2(TestCase):
|
|||
if is_torch_available():
|
||||
from transformers.utils import get_torch_version
|
||||
|
||||
if get_torch_version() < onnx_config.torch_onnx_minimum_version:
|
||||
if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version:
|
||||
pytest.skip(
|
||||
"Skipping due to incompatible PyTorch version. Minimum required is"
|
||||
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
|
||||
|
|
Loading…
Reference in New Issue