Add parallelization support for T5EncoderModel (#9082)

* add model parallelism to T5EncoderModel

add model parallelism to T5EncoderModel

* remove decoder from T5EncoderModel parallelize

* uodate T5EncoderModel docs

* Extend T5ModelTest for T5EncoderModel

* fix T5Stask using range for get_device_map

* fix style

Co-authored-by: Ahmed Elnaggar <elnaggar@rostlab.informatik.tu-muenchen.de>
This commit is contained in:
Ahmed Elnaggar 2020-12-14 18:00:45 +01:00 committed by GitHub
parent b00eb4fb02
commit a9c8bff724
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 8 deletions

View File

@ -131,7 +131,7 @@ T5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.T5EncoderModel
:members: forward
:members: forward, parallelize, deparallelize
TFT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -781,7 +781,7 @@ class T5Stack(T5PreTrainedModel):
def parallelize(self, device_map=None):
# Check validity of device_map
self.device_map = (
get_device_map(len(self.block), torch.cuda.device_count()) if device_map is None else device_map
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
)
assert_device_map(self.device_map, len(self.block))
self.model_parallel = True
@ -1579,6 +1579,25 @@ class T5EncoderModel(T5PreTrainedModel):
self.init_weights()
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.shared

View File

@ -485,12 +485,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
all_parallelizable_model_classes = (
(
T5Model,
T5ForConditionalGeneration,
)
if is_torch_available()
else ()
(T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else ()
)
test_pruning = False
test_torchscript = True