Fix low cpu mem usage tests (#30808)

* Fix tests

* fix udop failing test

* remove skip

* style
This commit is contained in:
Marc Sun 2024-05-22 14:09:01 +02:00 committed by Ita Zaporozhets
parent be2fb4fb8f
commit bb17199cd2
2 changed files with 6 additions and 11 deletions

View File

@ -1297,7 +1297,7 @@ class UdopStack(UdopPreTrainedModel):
# get weights from encoder position bias
self.relative_bias = self._get_relative_bias(config)
# tie weights of original position bias of encoder
def _tie_weights(self):
for bias in self.relative_bias.biases:
if isinstance(bias, RelativePositionBias1D):
self._tie_or_clone_weights(

View File

@ -21,7 +21,6 @@ import os.path
import random
import re
import tempfile
import unittest
import warnings
from collections import defaultdict
from typing import Dict, List, Tuple
@ -444,7 +443,6 @@ class ModelTesterMixin:
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
@ -457,7 +455,6 @@ class ModelTesterMixin:
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
@ -471,7 +468,6 @@ class ModelTesterMixin:
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes:
@ -482,6 +478,8 @@ class ModelTesterMixin:
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
from accelerate.utils.modeling import named_module_tensors
# Load the low usage and the normal models.
model_low_usage, loading_info = model_class.from_pretrained(
saved_model_path,
@ -496,16 +494,13 @@ class ModelTesterMixin:
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
# subsequently loaded with the correct values and onto the correct device. We check if there are any
# remaining params that were not properly loaded.
for name, param in model_low_usage.named_parameters():
for name, tensor in named_module_tensors(model_low_usage, recurse=True):
self.assertNotEqual(
param.device,
tensor.device,
torch.device("meta"),
"Parameter '" + name + "' has not been properly loaded and has device=meta.",
"Tensor '" + name + "' has not been properly loaded and has device=meta.",
)
# Tests moving the model to a device other than meta.
model_low_usage.to(torch_device)
# Check that the parameters are equal.
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
self.assertEquals(p1.data.ne(p2.data).sum(), 0)