Fix bug for T5x to PyTorch convert script with varying encoder and decoder layers (#27448)

* Fix bug in handling varying encoder and decoder layers

This commit resolves an issue where the script failed to convert T5x models to PyTorch models when the number of decoder layers differed from the number of encoder layers.  I've addressed this issue by passing an additional 'num_decoder_layers' parameter to the relevant function.

* Fix bug in handling varying encoder and decoder layers
This commit is contained in:
JiangZhongqing 2023-11-16 04:00:22 +09:00 committed by GitHub
parent 2e72bbab2c
commit b71c38a094
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 3 deletions

View File

@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name):
return params[f"{prefix}/layers_{i}/{layer_name}/scale"] return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only: bool): def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool):
"""Converts the parameters from T5X-Flax to Transformers-PyTorch.""" """Converts the parameters from T5X-Flax to Transformers-PyTorch."""
old = traverse_util.flatten_dict(variables["target"]) old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()} old = {"/".join(k): v for k, v in old.items()}
@ -112,7 +112,7 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only:
if not is_encoder_only: if not is_encoder_only:
# Decoder. # Decoder.
for i in range(num_layers): for i in range(num_decoder_layers):
# Block i, layer 0 (Self Attention). # Block i, layer 0 (Self Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention") k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
@ -177,7 +177,12 @@ def make_state_dict(converted_params, is_encoder_only: bool):
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only): def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
"""Replaces the params in model witht the T5X converted params.""" """Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers, is_encoder_only=is_encoder_only) converted = convert_t5x_to_pytorch(
variables,
num_layers=config.num_layers,
num_decoder_layers=config.num_decoder_layers,
is_encoder_only=is_encoder_only,
)
state_dict = make_state_dict(converted, is_encoder_only) state_dict = make_state_dict(converted, is_encoder_only)
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)