Fix low cpu mem usage tests (#30808)
* Fix tests * fix udop failing test * remove skip * style
This commit is contained in:
parent
be2fb4fb8f
commit
bb17199cd2
|
@ -1297,7 +1297,7 @@ class UdopStack(UdopPreTrainedModel):
|
|||
# get weights from encoder position bias
|
||||
self.relative_bias = self._get_relative_bias(config)
|
||||
|
||||
# tie weights of original position bias of encoder
|
||||
def _tie_weights(self):
|
||||
for bias in self.relative_bias.biases:
|
||||
if isinstance(bias, RelativePositionBias1D):
|
||||
self._tie_or_clone_weights(
|
||||
|
|
|
@ -21,7 +21,6 @@ import os.path
|
|||
import random
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
@ -444,7 +443,6 @@ class ModelTesterMixin:
|
|||
@slow
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@unittest.skip("Need to fix since we have a device mismatch")
|
||||
def test_save_load_low_cpu_mem_usage(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||
|
@ -457,7 +455,6 @@ class ModelTesterMixin:
|
|||
@slow
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@unittest.skip("Need to fix since we have a device mismatch")
|
||||
def test_save_load_low_cpu_mem_usage_checkpoints(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||
|
@ -471,7 +468,6 @@ class ModelTesterMixin:
|
|||
@slow
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@unittest.skip("Need to fix since we have a device mismatch")
|
||||
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
|
||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||
for model_class in self.all_model_classes:
|
||||
|
@ -482,6 +478,8 @@ class ModelTesterMixin:
|
|||
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
|
||||
|
||||
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
|
||||
from accelerate.utils.modeling import named_module_tensors
|
||||
|
||||
# Load the low usage and the normal models.
|
||||
model_low_usage, loading_info = model_class.from_pretrained(
|
||||
saved_model_path,
|
||||
|
@ -496,16 +494,13 @@ class ModelTesterMixin:
|
|||
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
|
||||
# subsequently loaded with the correct values and onto the correct device. We check if there are any
|
||||
# remaining params that were not properly loaded.
|
||||
for name, param in model_low_usage.named_parameters():
|
||||
for name, tensor in named_module_tensors(model_low_usage, recurse=True):
|
||||
self.assertNotEqual(
|
||||
param.device,
|
||||
tensor.device,
|
||||
torch.device("meta"),
|
||||
"Parameter '" + name + "' has not been properly loaded and has device=meta.",
|
||||
"Tensor '" + name + "' has not been properly loaded and has device=meta.",
|
||||
)
|
||||
|
||||
# Tests moving the model to a device other than meta.
|
||||
model_low_usage.to(torch_device)
|
||||
|
||||
# Check that the parameters are equal.
|
||||
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
|
||||
self.assertEquals(p1.data.ne(p2.data).sum(), 0)
|
||||
|
|
Loading…
Reference in New Issue