Ensure input tensor are on device. (#11874)

The feature extractor does not create tensors on the appropriate device,
so we call `ensure_tensor_on_device` before feeding the processed inputs
to the model.
This commit is contained in:
francescorubbo 2021-05-26 01:19:37 -07:00 committed by GitHub
parent a9c797f93d
commit 0b0a598452
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 0 deletions

View File

@ -136,6 +136,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
processed = self.ensure_tensor_on_device(**processed)
name = self.model.__class__.__name__
if name.endswith("ForConditionalGeneration"):