🚨 Add training compatibility for Musicgen-like models (#29802)

* first modeling code

* make repository

* still WIP

* update model

* add tests

* add latest change

* clean docstrings and copied from

* update docstrings md and readme

* correct chroma function

* correct copied from and remove unreleated test

* add doc to toctree

* correct imports

* add convert script to notdoctested

* Add suggestion from Sanchit

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* correct get_uncoditional_inputs docstrings

* modify README according to SANCHIT feedback

* add chroma to audio utils

* clean librosa and torchaudio hard dependencies

* fix FE

* refactor audio decoder -> audio encoder for consistency with previous musicgen

* refactor conditional -> encoder

* modify sampling rate logics

* modify license at the beginning

* refactor all_self_attns->all_attentions

* remove ignore copy from causallm generate

* add copied from for from_sub_models

* fix make copies

* add warning if audio is truncated

* add copied from where relevant

* remove artefact

* fix convert script

* fix torchaudio and FE

* modify chroma method according to feedback-> better naming

* refactor input_values->input_features

* refactor input_values->input_features and fix import fe

* add input_features to docstrigs

* correct inputs_embeds logics

* remove dtype conversion

* refactor _prepare_conditional_hidden_states_kwargs_for_generation ->_prepare_encoder_hidden_states_kwargs_for_generation

* change warning for chroma length

* Update src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* change way to save wav, using soundfile

* correct docs and change to soundfile

* fix import

* fix init proj layers

* add draft training

* fix cross entropy

* clean loss computation

* fix labels

* remove line breaks from md

* fix issue with docstrings

* add FE suggestions

* improve is in logics and remove useless imports

* remove custom from_pretrained

* simplify docstring code

* add suggestions for modeling tests

* make style

* update converting script with sanity check

* remove encoder attention mask from conditional generation

* replace musicgen melody checkpoints with official orga

* rename ylacombe->facebook in checkpoints

* fix copies

* remove unecessary warning

* add shape in code docstrings

* add files to slow doc tests

* fix md bug and add md to not_tested

* make fix-copies

* fix hidden states test and batching

* update training code

* add training tests for melody

* add training for o.g musicgen

* fix copied from

* remove final todos

* make style

* fix style

* add suggestions from review

* add ref to the original loss computation code

* rename method + fix labels in tests

* make style

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe 2024-04-25 12:51:19 +02:00 committed by GitHub
parent ce5ae5a434
commit 90cb55bf77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 333 additions and 52 deletions

View File

@ -161,6 +161,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
("mpt", "MptModel"),
("mra", "MraModel"),
("mt5", "MT5Model"),
("musicgen", "MusicgenModel"),
("musicgen_melody", "MusicgenMelodyModel"),
("mvp", "MvpModel"),
("nat", "NatModel"),
("nezha", "NezhaModel"),

View File

@ -104,16 +104,17 @@ class MusicgenUnconditionalInput(ModelOutput):
guidance_scale: float = None
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
# transpose to get (bsz, num_codebooks, seq_len)
input_ids = input_ids.transpose(1, 2)
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids[..., 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
@ -909,6 +910,10 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
@ -1340,15 +1345,18 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Returns:
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (labels is not None) and (input_ids is None and inputs_embeds is None):
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
@ -1370,7 +1378,25 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
loss = None
if labels is not None:
raise NotImplementedError("Training is not implemented for Musicgen.")
# since encoder hidden states have been concatenated to the decoder hidden states,
# we take the last timestamps corresponding to labels
logits = lm_logits[:, :, -labels.shape[1] :]
loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device)
# per codebook cross-entropy
# -100 labels are ignored
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
# per codebook cross-entropy
# ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_labels = labels[..., codebook].contiguous().view(-1)
loss += loss_fct(codebook_logits, codebook_labels)
loss = loss / self.config.num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
@ -2235,7 +2261,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id
)
elif decoder_input_ids is None and decoder_inputs_embeds is None:
@ -2270,23 +2296,15 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
labels=labels,
**kwargs_decoder,
)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=loss,
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
@ -2524,7 +2542,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
@ -2533,6 +2551,22 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
" model.decoder.resize_token_embeddings(...))"
)
def freeze_audio_encoder(self):
"""
Freeze the audio encoder weights.
"""
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.audio_encoder._requires_grad = False
def freeze_text_encoder(self):
"""
Freeze the text encoder weights.
"""
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_encoder._requires_grad = False
def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[torch.Tensor] = None,

View File

@ -116,16 +116,18 @@ class MusicgenMelodyOutputWithPast(ModelOutput):
encoder_hidden_states: Optional[torch.FloatTensor] = None
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
# Copied from transformers.models.musicgen.modeling_musicgen.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
# transpose to get (bsz, num_codebooks, seq_len)
input_ids = input_ids.transpose(1, 2)
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids[..., 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
@ -864,7 +866,7 @@ MUSICGEN_MELODY_INPUTS_DOCSTRING = r"""
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
@ -1269,7 +1271,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple, MusicgenMelodyOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
@ -1278,6 +1280,9 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (labels is not None) and (input_ids is None and inputs_embeds is None):
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
@ -1298,7 +1303,25 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
loss = None
if labels is not None:
raise NotImplementedError("Training is not implemented for MusicgenMelody.")
# since encoder hidden states have been concatenated to the decoder hidden states,
# we take the last timestamps corresponding to labels
logits = lm_logits[:, :, -labels.shape[1] :]
loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device)
# per codebook cross-entropy
# ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
# -100 labels are ignored
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_labels = labels[..., codebook].contiguous().view(-1)
loss += loss_fct(codebook_logits, codebook_labels)
loss = loss / self.config.num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
@ -2156,7 +2179,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id
)
# Decode
@ -2170,23 +2193,15 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
labels=labels,
**kwargs_decoder,
)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + (encoder_hidden_states,)
else:
return decoder_outputs + (encoder_hidden_states,)
return decoder_outputs + (encoder_hidden_states,)
return MusicgenMelodyOutputWithPast(
loss=loss,
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
hidden_states=decoder_outputs.hidden_states,
@ -2397,7 +2412,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
@ -2428,6 +2443,22 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def freeze_audio_encoder(self):
"""
Freeze the audio encoder weights.
"""
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.audio_encoder._requires_grad = False
def freeze_text_encoder(self):
"""
Freeze the text encoder weights.
"""
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_encoder._requires_grad = False
@torch.no_grad()
def generate(
self,

View File

@ -110,8 +110,7 @@ class MusicgenDecoderTester:
parent,
batch_size=4, # need batch_size != num_hidden_layers
seq_length=7,
is_training=False,
use_labels=False,
is_training=True,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
@ -129,7 +128,6 @@ class MusicgenDecoderTester:
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
@ -149,7 +147,9 @@ class MusicgenDecoderTester:
config = self.get_config()
inputs_dict = prepare_musicgen_decoder_inputs_dict(
config, input_ids, encoder_hidden_states=encoder_hidden_states
config,
input_ids,
encoder_hidden_states=encoder_hidden_states,
)
return config, inputs_dict
@ -190,6 +190,45 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
def test_config(self):
self.config_tester.run_common_tests()
# special case for labels
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
if not self.model_tester.is_training:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
model = MusicgenForCausalLM(config)
model.to(torch_device)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model.train()
# Contrarily to the initial method, we don't unfreeze freezed parameters.
# Indeed, sinusoidal position embeddings have frozen weights that should stay frozen.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
inputs = self._prepare_for_class(inputs_dict, MusicgenForCausalLM, return_labels=True)
loss = model(**inputs).loss
loss.backward()
optimizer.step()
for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {MusicgenForCausalLM.__name__} has no gradient!")
# override since we have to compute the input embeddings over codebooks
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -897,6 +936,7 @@ def prepare_musicgen_inputs_dict(
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
labels=None,
):
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.reshape(
@ -923,6 +963,7 @@ def prepare_musicgen_inputs_dict(
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"labels": labels,
}
@ -932,8 +973,7 @@ class MusicgenTester:
parent,
batch_size=4, # need batch_size != num_hidden_layers
seq_length=7,
is_training=False,
use_labels=False,
is_training=True,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
@ -953,7 +993,6 @@ class MusicgenTester:
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
@ -1027,6 +1066,47 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def setUp(self):
self.model_tester = MusicgenTester(self)
# special case for labels
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model.train()
# The audio encoder weights are not used during the forward pass (only during the generate pass)
# So we need to freeze it to be able to train.
model.freeze_audio_encoder()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
optimizer.step()
for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
text_encoder_config = config.text_encoder
decoder_config = config.decoder
@ -1518,6 +1598,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.assertNotIn(config.pad_token_id, output_generate)
@unittest.skip("MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model")
def test_save_load_fast_init_from_base(self):
pass
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@ -2151,6 +2235,27 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.assertTrue(torch.allclose(res_eager, res_sdpa))
def test_requires_grad_with_frozen_encoders(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
model.freeze_audio_encoder()
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
self.assertFalse(all(audio_encoder_grads))
self.assertTrue(all(text_encoder_grads))
model = model_class(config)
model.freeze_text_encoder()
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
self.assertTrue(all(audio_encoder_grads))
self.assertFalse(all(text_encoder_grads))
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
"""Produces a series of 'bip bip' sounds at a given frequency."""

View File

@ -109,8 +109,7 @@ class MusicgenMelodyDecoderTester:
parent,
batch_size=3, # need batch_size != num_hidden_layers because of #29297
seq_length=7,
is_training=False,
use_labels=False,
is_training=True,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
@ -129,7 +128,6 @@ class MusicgenMelodyDecoderTester:
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
@ -151,7 +149,9 @@ class MusicgenMelodyDecoderTester:
config = self.get_config()
inputs_dict = prepare_musicgen_melody_decoder_inputs_dict(
config, input_ids, encoder_hidden_states=encoder_hidden_states
config,
input_ids,
encoder_hidden_states=encoder_hidden_states,
)
return config, inputs_dict
@ -191,6 +191,47 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
def test_config(self):
self.config_tester.run_common_tests()
# special case for labels
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest._prepare_for_class
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.check_training_gradient_checkpointing with Musicgen->MusicgenMelody
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
if not self.model_tester.is_training:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
model = MusicgenMelodyForCausalLM(config)
model.to(torch_device)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model.train()
# Contrarily to the initial method, we don't unfreeze freezed parameters.
# Indeed, sinusoidal position embeddings have frozen weights that should stay frozen.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
inputs = self._prepare_for_class(inputs_dict, MusicgenMelodyForCausalLM, return_labels=True)
loss = model(**inputs).loss
loss.backward()
optimizer.step()
for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {MusicgenMelodyForCausalLM.__name__} has no gradient!")
# override since we have to compute the input embeddings over codebooks
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -896,6 +937,7 @@ def prepare_musicgen_melody_inputs_dict(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
labels=None,
):
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.reshape(
@ -917,6 +959,7 @@ def prepare_musicgen_melody_inputs_dict(
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"labels": labels,
}
@ -926,8 +969,7 @@ class MusicgenMelodyTester:
parent,
batch_size=3, # need batch_size != num_hidden_layers because of #29297
seq_length=7,
is_training=False,
use_labels=False,
is_training=True,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
@ -949,7 +991,6 @@ class MusicgenMelodyTester:
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
@ -1029,6 +1070,47 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def setUp(self):
self.model_tester = MusicgenMelodyTester(self)
# special case for labels
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model.train()
# The audio encoder weights are not used during the forward pass (only during the generate pass)
# So we need to freeze it to be able to train.
model.freeze_audio_encoder()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
optimizer.step()
for k, v in model.named_parameters():
if v.requires_grad:
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
# Ignore copy
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
decoder_config = config.decoder
@ -1500,6 +1582,12 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
self.assertNotIn(config.pad_token_id, output_generate)
@unittest.skip(
"MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
)
def test_save_load_fast_init_from_base(self):
pass
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@ -2133,6 +2221,27 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
self.assertTrue(torch.allclose(res_eager, res_sdpa))
def test_requires_grad_with_frozen_encoders(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
model = model_class(config)
model.freeze_audio_encoder()
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
self.assertFalse(all(audio_encoder_grads))
self.assertTrue(all(text_encoder_grads))
model = model_class(config)
model.freeze_text_encoder()
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
self.assertTrue(all(audio_encoder_grads))
self.assertFalse(all(text_encoder_grads))
# Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):