added compiled model support for inference (#25124)

* added compiled model support for inference

* linter

* Fix tests

* linter

* linter

* remove inference mode from pipelines

* Linter

---------

Co-authored-by: amarkov <alexander@inworld.ai>
This commit is contained in:
Alexander Markov 2023-07-28 13:28:04 +01:00 committed by GitHub
parent afa96fffdf
commit 3cbc560d03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 8 deletions

View File

@ -27,8 +27,6 @@ from contextlib import contextmanager
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from packaging import version
from ..dynamic_module_utils import custom_object_save
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..image_processing_utils import BaseImageProcessor
@ -1015,12 +1013,7 @@ class Pipeline(_ScikitCompat):
raise NotImplementedError("postprocess not implemented")
def get_inference_context(self):
inference_context = (
torch.inference_mode
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.9.0")
else torch.no_grad
)
return inference_context
return torch.no_grad
def forward(self, model_inputs, **forward_params):
with self.device_placement():