Add parallelization support for T5EncoderModel (#9082)
* add model parallelism to T5EncoderModel add model parallelism to T5EncoderModel * remove decoder from T5EncoderModel parallelize * uodate T5EncoderModel docs * Extend T5ModelTest for T5EncoderModel * fix T5Stask using range for get_device_map * fix style Co-authored-by: Ahmed Elnaggar <elnaggar@rostlab.informatik.tu-muenchen.de>
This commit is contained in:
parent
b00eb4fb02
commit
a9c8bff724
|
@ -131,7 +131,7 @@ T5EncoderModel
|
|||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.T5EncoderModel
|
||||
:members: forward
|
||||
:members: forward, parallelize, deparallelize
|
||||
|
||||
TFT5Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
@ -781,7 +781,7 @@ class T5Stack(T5PreTrainedModel):
|
|||
def parallelize(self, device_map=None):
|
||||
# Check validity of device_map
|
||||
self.device_map = (
|
||||
get_device_map(len(self.block), torch.cuda.device_count()) if device_map is None else device_map
|
||||
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
|
||||
)
|
||||
assert_device_map(self.device_map, len(self.block))
|
||||
self.model_parallel = True
|
||||
|
@ -1579,6 +1579,25 @@ class T5EncoderModel(T5PreTrainedModel):
|
|||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
||||
if device_map is None
|
||||
else device_map
|
||||
)
|
||||
assert_device_map(self.device_map, len(self.encoder.block))
|
||||
self.encoder.parallelize(self.device_map)
|
||||
self.model_parallel = True
|
||||
|
||||
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
||||
def deparallelize(self):
|
||||
self.encoder.deparallelize()
|
||||
self.encoder = self.encoder.to("cpu")
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
|
|
|
@ -485,12 +485,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_parallelizable_model_classes = (
|
||||
(
|
||||
T5Model,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
(T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_torchscript = True
|
||||
|
|
Loading…
Reference in New Issue