Fix bug for T5x to PyTorch convert script with varying encoder and decoder layers (#27448)
* Fix bug in handling varying encoder and decoder layers This commit resolves an issue where the script failed to convert T5x models to PyTorch models when the number of decoder layers differed from the number of encoder layers. I've addressed this issue by passing an additional 'num_decoder_layers' parameter to the relevant function. * Fix bug in handling varying encoder and decoder layers
This commit is contained in:
parent
2e72bbab2c
commit
b71c38a094
|
@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name):
|
|||
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
|
||||
|
||||
|
||||
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only: bool):
|
||||
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool):
|
||||
"""Converts the parameters from T5X-Flax to Transformers-PyTorch."""
|
||||
old = traverse_util.flatten_dict(variables["target"])
|
||||
old = {"/".join(k): v for k, v in old.items()}
|
||||
|
@ -112,7 +112,7 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only:
|
|||
|
||||
if not is_encoder_only:
|
||||
# Decoder.
|
||||
for i in range(num_layers):
|
||||
for i in range(num_decoder_layers):
|
||||
# Block i, layer 0 (Self Attention).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
|
||||
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
|
||||
|
@ -177,7 +177,12 @@ def make_state_dict(converted_params, is_encoder_only: bool):
|
|||
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
|
||||
"""Replaces the params in model witht the T5X converted params."""
|
||||
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
|
||||
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers, is_encoder_only=is_encoder_only)
|
||||
converted = convert_t5x_to_pytorch(
|
||||
variables,
|
||||
num_layers=config.num_layers,
|
||||
num_decoder_layers=config.num_decoder_layers,
|
||||
is_encoder_only=is_encoder_only,
|
||||
)
|
||||
state_dict = make_state_dict(converted, is_encoder_only)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue