search buffers for dtype (#23159)

This commit is contained in:
cyy 2023-05-06 23:41:08 +08:00 committed by GitHub
parent 312b104ff6
commit ef42c2c487
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 13 deletions

View File

@ -207,21 +207,29 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# if no floating dtype was found return whatever the first dtype is
return last_dtype
else:
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
for t in parameter.buffers():
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype
# fallback to the last dtype
return last_tuple[1].dtype
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
# fallback to the last dtype
return last_tuple[1].dtype
def get_state_dict_float_dtype(state_dict):