diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 8ad60a03e2..b80e9bd3a1 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -171,7 +171,10 @@ class Bnb8BitHfQuantizer(HfQuantizer): import bitsandbytes as bnb fp16_statistics_key = param_name.replace("weight", "SCB") + fp16_weights_format_key = param_name.replace("weight", "weight_format") + fp16_statistics = state_dict.get(fp16_statistics_key, None) + fp16_weights_format = state_dict.get(fp16_weights_format_key, None) module, tensor_name = get_module_from_name(model, param_name) if tensor_name not in module._parameters: @@ -210,6 +213,11 @@ class Bnb8BitHfQuantizer(HfQuantizer): if unexpected_keys is not None: unexpected_keys.remove(fp16_statistics_key) + # We just need to pop the `weight_format` keys from the state dict to remove unneeded + # messages. The correct format is correctly retrieved during the first forward pass. + if fp16_weights_format is not None and unexpected_keys is not None: + unexpected_keys.remove(fp16_weights_format_key) + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): model.is_loaded_in_8bit = True model.is_8bit_serializable = self.is_serializable