Add support for custom checkpoints in MusicGen (#30011)

* feat: support custom checkpoint

* update: revert changes and add TODO

* update: docs and exception handling

* fix: ah, extra space
This commit is contained in:
Jacky Lee 2024-05-14 23:30:33 -07:00 committed by GitHub
parent 1360801a69
commit 99e16120ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 7 deletions

View File

@ -88,24 +88,24 @@ def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict,
def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
if checkpoint == "small" or checkpoint == "facebook/musicgen-stereo-small":
if checkpoint.endswith("small"):
# default config values
hidden_size = 1024
num_hidden_layers = 24
num_attention_heads = 16
elif checkpoint == "medium" or checkpoint == "facebook/musicgen-stereo-medium":
elif checkpoint.endswith("medium"):
hidden_size = 1536
num_hidden_layers = 48
num_attention_heads = 24
elif checkpoint == "large" or checkpoint == "facebook/musicgen-stereo-large":
elif checkpoint.endswith("large"):
hidden_size = 2048
num_hidden_layers = 48
num_attention_heads = 32
else:
raise ValueError(
"Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, "
"or `['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
f"for the stereo checkpoints, got {checkpoint}."
"`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
f"for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix, got {checkpoint}."
)
if "stereo" in checkpoint:
@ -208,9 +208,9 @@ if __name__ == "__main__":
default="small",
type=str,
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: "
"`['small', 'medium', 'large']` for the mono checkpoints, or "
"`['small', 'medium', 'large']` for the mono checkpoints, "
"`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
"for the stereo checkpoints.",
"for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix.",
)
parser.add_argument(
"--pytorch_dump_folder",