Add support for custom checkpoints in MusicGen (#30011)
* feat: support custom checkpoint * update: revert changes and add TODO * update: docs and exception handling * fix: ah, extra space
This commit is contained in:
parent
1360801a69
commit
99e16120ab
|
@ -88,24 +88,24 @@ def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict,
|
|||
|
||||
|
||||
def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
|
||||
if checkpoint == "small" or checkpoint == "facebook/musicgen-stereo-small":
|
||||
if checkpoint.endswith("small"):
|
||||
# default config values
|
||||
hidden_size = 1024
|
||||
num_hidden_layers = 24
|
||||
num_attention_heads = 16
|
||||
elif checkpoint == "medium" or checkpoint == "facebook/musicgen-stereo-medium":
|
||||
elif checkpoint.endswith("medium"):
|
||||
hidden_size = 1536
|
||||
num_hidden_layers = 48
|
||||
num_attention_heads = 24
|
||||
elif checkpoint == "large" or checkpoint == "facebook/musicgen-stereo-large":
|
||||
elif checkpoint.endswith("large"):
|
||||
hidden_size = 2048
|
||||
num_hidden_layers = 48
|
||||
num_attention_heads = 32
|
||||
else:
|
||||
raise ValueError(
|
||||
"Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, "
|
||||
"or `['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
|
||||
f"for the stereo checkpoints, got {checkpoint}."
|
||||
"`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
|
||||
f"for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix, got {checkpoint}."
|
||||
)
|
||||
|
||||
if "stereo" in checkpoint:
|
||||
|
@ -208,9 +208,9 @@ if __name__ == "__main__":
|
|||
default="small",
|
||||
type=str,
|
||||
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: "
|
||||
"`['small', 'medium', 'large']` for the mono checkpoints, or "
|
||||
"`['small', 'medium', 'large']` for the mono checkpoints, "
|
||||
"`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
|
||||
"for the stereo checkpoints.",
|
||||
"for the stereo checkpoints, or a custom checkpoint with the checkpoint size as a suffix.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder",
|
||||
|
|
Loading…
Reference in New Issue