🚨 Add training compatibility for Musicgen-like models (#29802)
* first modeling code * make repository * still WIP * update model * add tests * add latest change * clean docstrings and copied from * update docstrings md and readme * correct chroma function * correct copied from and remove unreleated test * add doc to toctree * correct imports * add convert script to notdoctested * Add suggestion from Sanchit Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * correct get_uncoditional_inputs docstrings * modify README according to SANCHIT feedback * add chroma to audio utils * clean librosa and torchaudio hard dependencies * fix FE * refactor audio decoder -> audio encoder for consistency with previous musicgen * refactor conditional -> encoder * modify sampling rate logics * modify license at the beginning * refactor all_self_attns->all_attentions * remove ignore copy from causallm generate * add copied from for from_sub_models * fix make copies * add warning if audio is truncated * add copied from where relevant * remove artefact * fix convert script * fix torchaudio and FE * modify chroma method according to feedback-> better naming * refactor input_values->input_features * refactor input_values->input_features and fix import fe * add input_features to docstrigs * correct inputs_embeds logics * remove dtype conversion * refactor _prepare_conditional_hidden_states_kwargs_for_generation ->_prepare_encoder_hidden_states_kwargs_for_generation * change warning for chroma length * Update src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change way to save wav, using soundfile * correct docs and change to soundfile * fix import * fix init proj layers * add draft training * fix cross entropy * clean loss computation * fix labels * remove line breaks from md * fix issue with docstrings * add FE suggestions * improve is in logics and remove useless imports * remove custom from_pretrained * simplify docstring code * add suggestions for modeling tests * make style * update converting script with sanity check * remove encoder attention mask from conditional generation * replace musicgen melody checkpoints with official orga * rename ylacombe->facebook in checkpoints * fix copies * remove unecessary warning * add shape in code docstrings * add files to slow doc tests * fix md bug and add md to not_tested * make fix-copies * fix hidden states test and batching * update training code * add training tests for melody * add training for o.g musicgen * fix copied from * remove final todos * make style * fix style * add suggestions from review * add ref to the original loss computation code * rename method + fix labels in tests * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
parent
ce5ae5a434
commit
90cb55bf77
|
@ -161,6 +161,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||
("mpt", "MptModel"),
|
||||
("mra", "MraModel"),
|
||||
("mt5", "MT5Model"),
|
||||
("musicgen", "MusicgenModel"),
|
||||
("musicgen_melody", "MusicgenMelodyModel"),
|
||||
("mvp", "MvpModel"),
|
||||
("nat", "NatModel"),
|
||||
("nezha", "NezhaModel"),
|
||||
|
|
|
@ -104,16 +104,17 @@ class MusicgenUnconditionalInput(ModelOutput):
|
|||
guidance_scale: float = None
|
||||
|
||||
|
||||
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
# transpose to get (bsz, num_codebooks, seq_len)
|
||||
input_ids = input_ids.transpose(1, 2)
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
if decoder_start_token_id is None:
|
||||
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
|
||||
shifted_input_ids[:, 0] = decoder_start_token_id
|
||||
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||
|
||||
if pad_token_id is None:
|
||||
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
|
||||
|
@ -909,6 +910,10 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
|
|||
|
||||
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
||||
of `inputs_embeds`.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
|
@ -1340,15 +1345,18 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
|||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
Returns:
|
||||
Returns:
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (labels is not None) and (input_ids is None and inputs_embeds is None):
|
||||
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
@ -1370,7 +1378,25 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
|||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
raise NotImplementedError("Training is not implemented for Musicgen.")
|
||||
# since encoder hidden states have been concatenated to the decoder hidden states,
|
||||
# we take the last timestamps corresponding to labels
|
||||
logits = lm_logits[:, :, -labels.shape[1] :]
|
||||
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = torch.zeros([], device=self.device)
|
||||
|
||||
# per codebook cross-entropy
|
||||
# -100 labels are ignored
|
||||
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
|
||||
|
||||
# per codebook cross-entropy
|
||||
# ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
|
||||
for codebook in range(self.config.num_codebooks):
|
||||
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
|
||||
codebook_labels = labels[..., codebook].contiguous().view(-1)
|
||||
loss += loss_fct(codebook_logits, codebook_labels)
|
||||
|
||||
loss = loss / self.config.num_codebooks
|
||||
|
||||
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
|
||||
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
|
||||
|
@ -2235,7 +2261,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||
|
||||
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id
|
||||
)
|
||||
|
||||
elif decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
|
@ -2270,23 +2296,15 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||
use_cache=use_cache,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
**kwargs_decoder,
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
if loss is not None:
|
||||
return (loss,) + decoder_outputs + encoder_outputs
|
||||
else:
|
||||
return decoder_outputs + encoder_outputs
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
loss=decoder_outputs.loss,
|
||||
logits=decoder_outputs.logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
|
@ -2524,7 +2542,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||
return model_kwargs
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
|
||||
|
||||
def resize_token_embeddings(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
|
@ -2533,6 +2551,22 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||
" model.decoder.resize_token_embeddings(...))"
|
||||
)
|
||||
|
||||
def freeze_audio_encoder(self):
|
||||
"""
|
||||
Freeze the audio encoder weights.
|
||||
"""
|
||||
for param in self.audio_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
self.audio_encoder._requires_grad = False
|
||||
|
||||
def freeze_text_encoder(self):
|
||||
"""
|
||||
Freeze the text encoder weights.
|
||||
"""
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
self.text_encoder._requires_grad = False
|
||||
|
||||
def _maybe_initialize_input_ids_for_generation(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -116,16 +116,18 @@ class MusicgenMelodyOutputWithPast(ModelOutput):
|
|||
encoder_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
|
||||
# Copied from transformers.models.musicgen.modeling_musicgen.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
"""
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
# transpose to get (bsz, num_codebooks, seq_len)
|
||||
input_ids = input_ids.transpose(1, 2)
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
if decoder_start_token_id is None:
|
||||
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
|
||||
shifted_input_ids[:, 0] = decoder_start_token_id
|
||||
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||
|
||||
if pad_token_id is None:
|
||||
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
|
||||
|
@ -864,7 +866,7 @@ MUSICGEN_MELODY_INPUTS_DOCSTRING = r"""
|
|||
|
||||
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
||||
of `inputs_embeds`.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
|
@ -1269,7 +1271,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||
labels: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, MusicgenMelodyOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
|
@ -1278,6 +1280,9 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (labels is not None) and (input_ids is None and inputs_embeds is None):
|
||||
input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
@ -1298,7 +1303,25 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
raise NotImplementedError("Training is not implemented for MusicgenMelody.")
|
||||
# since encoder hidden states have been concatenated to the decoder hidden states,
|
||||
# we take the last timestamps corresponding to labels
|
||||
logits = lm_logits[:, :, -labels.shape[1] :]
|
||||
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = torch.zeros([], device=self.device)
|
||||
|
||||
# per codebook cross-entropy
|
||||
# ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
|
||||
# -100 labels are ignored
|
||||
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
|
||||
|
||||
# per codebook cross-entropy
|
||||
for codebook in range(self.config.num_codebooks):
|
||||
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
|
||||
codebook_labels = labels[..., codebook].contiguous().view(-1)
|
||||
loss += loss_fct(codebook_logits, codebook_labels)
|
||||
|
||||
loss = loss / self.config.num_codebooks
|
||||
|
||||
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
|
||||
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
|
||||
|
@ -2156,7 +2179,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||
|
||||
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id
|
||||
)
|
||||
|
||||
# Decode
|
||||
|
@ -2170,23 +2193,15 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||
use_cache=use_cache,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
**kwargs_decoder,
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
if loss is not None:
|
||||
return (loss,) + decoder_outputs + (encoder_hidden_states,)
|
||||
else:
|
||||
return decoder_outputs + (encoder_hidden_states,)
|
||||
return decoder_outputs + (encoder_hidden_states,)
|
||||
|
||||
return MusicgenMelodyOutputWithPast(
|
||||
loss=loss,
|
||||
loss=decoder_outputs.loss,
|
||||
logits=decoder_outputs.logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
hidden_states=decoder_outputs.hidden_states,
|
||||
|
@ -2397,7 +2412,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||
return model_kwargs
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
|
||||
|
||||
def resize_token_embeddings(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
|
@ -2428,6 +2443,22 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||
break
|
||||
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
|
||||
|
||||
def freeze_audio_encoder(self):
|
||||
"""
|
||||
Freeze the audio encoder weights.
|
||||
"""
|
||||
for param in self.audio_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
self.audio_encoder._requires_grad = False
|
||||
|
||||
def freeze_text_encoder(self):
|
||||
"""
|
||||
Freeze the text encoder weights.
|
||||
"""
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
self.text_encoder._requires_grad = False
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
|
|
|
@ -110,8 +110,7 @@ class MusicgenDecoderTester:
|
|||
parent,
|
||||
batch_size=4, # need batch_size != num_hidden_layers
|
||||
seq_length=7,
|
||||
is_training=False,
|
||||
use_labels=False,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
|
@ -129,7 +128,6 @@ class MusicgenDecoderTester:
|
|||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
@ -149,7 +147,9 @@ class MusicgenDecoderTester:
|
|||
|
||||
config = self.get_config()
|
||||
inputs_dict = prepare_musicgen_decoder_inputs_dict(
|
||||
config, input_ids, encoder_hidden_states=encoder_hidden_states
|
||||
config,
|
||||
input_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
return config, inputs_dict
|
||||
|
||||
|
@ -190,6 +190,45 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# special case for labels
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
model = MusicgenForCausalLM(config)
|
||||
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||
model.train()
|
||||
|
||||
# Contrarily to the initial method, we don't unfreeze freezed parameters.
|
||||
# Indeed, sinusoidal position embeddings have frozen weights that should stay frozen.
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, MusicgenForCausalLM, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
self.assertTrue(v.grad is not None, f"{k} in {MusicgenForCausalLM.__name__} has no gradient!")
|
||||
|
||||
# override since we have to compute the input embeddings over codebooks
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -897,6 +936,7 @@ def prepare_musicgen_inputs_dict(
|
|||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
labels=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.reshape(
|
||||
|
@ -923,6 +963,7 @@ def prepare_musicgen_inputs_dict(
|
|||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
|
@ -932,8 +973,7 @@ class MusicgenTester:
|
|||
parent,
|
||||
batch_size=4, # need batch_size != num_hidden_layers
|
||||
seq_length=7,
|
||||
is_training=False,
|
||||
use_labels=False,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
|
@ -953,7 +993,6 @@ class MusicgenTester:
|
|||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
@ -1027,6 +1066,47 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
def setUp(self):
|
||||
self.model_tester = MusicgenTester(self)
|
||||
|
||||
# special case for labels
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||
model.train()
|
||||
|
||||
# The audio encoder weights are not used during the forward pass (only during the generate pass)
|
||||
# So we need to freeze it to be able to train.
|
||||
model.freeze_audio_encoder()
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
|
||||
|
||||
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
|
||||
text_encoder_config = config.text_encoder
|
||||
decoder_config = config.decoder
|
||||
|
@ -1518,6 +1598,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
@unittest.skip("MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
@ -2151,6 +2235,27 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
def test_requires_grad_with_frozen_encoders(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.freeze_audio_encoder()
|
||||
|
||||
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||
|
||||
self.assertFalse(all(audio_encoder_grads))
|
||||
self.assertTrue(all(text_encoder_grads))
|
||||
|
||||
model = model_class(config)
|
||||
model.freeze_text_encoder()
|
||||
|
||||
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||
|
||||
self.assertTrue(all(audio_encoder_grads))
|
||||
self.assertFalse(all(text_encoder_grads))
|
||||
|
||||
|
||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||
"""Produces a series of 'bip bip' sounds at a given frequency."""
|
||||
|
|
|
@ -109,8 +109,7 @@ class MusicgenMelodyDecoderTester:
|
|||
parent,
|
||||
batch_size=3, # need batch_size != num_hidden_layers because of #29297
|
||||
seq_length=7,
|
||||
is_training=False,
|
||||
use_labels=False,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
|
@ -129,7 +128,6 @@ class MusicgenMelodyDecoderTester:
|
|||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
@ -151,7 +149,9 @@ class MusicgenMelodyDecoderTester:
|
|||
|
||||
config = self.get_config()
|
||||
inputs_dict = prepare_musicgen_melody_decoder_inputs_dict(
|
||||
config, input_ids, encoder_hidden_states=encoder_hidden_states
|
||||
config,
|
||||
input_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
return config, inputs_dict
|
||||
|
||||
|
@ -191,6 +191,47 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# special case for labels
|
||||
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest._prepare_for_class
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.check_training_gradient_checkpointing with Musicgen->MusicgenMelody
|
||||
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
model = MusicgenMelodyForCausalLM(config)
|
||||
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||
model.train()
|
||||
|
||||
# Contrarily to the initial method, we don't unfreeze freezed parameters.
|
||||
# Indeed, sinusoidal position embeddings have frozen weights that should stay frozen.
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, MusicgenMelodyForCausalLM, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
self.assertTrue(v.grad is not None, f"{k} in {MusicgenMelodyForCausalLM.__name__} has no gradient!")
|
||||
|
||||
# override since we have to compute the input embeddings over codebooks
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -896,6 +937,7 @@ def prepare_musicgen_melody_inputs_dict(
|
|||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
labels=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.reshape(
|
||||
|
@ -917,6 +959,7 @@ def prepare_musicgen_melody_inputs_dict(
|
|||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
|
@ -926,8 +969,7 @@ class MusicgenMelodyTester:
|
|||
parent,
|
||||
batch_size=3, # need batch_size != num_hidden_layers because of #29297
|
||||
seq_length=7,
|
||||
is_training=False,
|
||||
use_labels=False,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
|
@ -949,7 +991,6 @@ class MusicgenMelodyTester:
|
|||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
@ -1029,6 +1070,47 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||
def setUp(self):
|
||||
self.model_tester = MusicgenMelodyTester(self)
|
||||
|
||||
# special case for labels
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_codebooks),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||
model.train()
|
||||
|
||||
# The audio encoder weights are not used during the forward pass (only during the generate pass)
|
||||
# So we need to freeze it to be able to train.
|
||||
model.freeze_audio_encoder()
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
|
||||
|
||||
# Ignore copy
|
||||
def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
|
||||
decoder_config = config.decoder
|
||||
|
@ -1500,6 +1582,12 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
@unittest.skip(
|
||||
"MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
|
||||
)
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
@ -2133,6 +2221,27 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
def test_requires_grad_with_frozen_encoders(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.freeze_audio_encoder()
|
||||
|
||||
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||
|
||||
self.assertFalse(all(audio_encoder_grads))
|
||||
self.assertTrue(all(text_encoder_grads))
|
||||
|
||||
model = model_class(config)
|
||||
model.freeze_text_encoder()
|
||||
|
||||
audio_encoder_grads = [param.requires_grad for param in model.audio_encoder.parameters()]
|
||||
text_encoder_grads = [param.requires_grad for param in model.text_encoder.parameters()]
|
||||
|
||||
self.assertTrue(all(audio_encoder_grads))
|
||||
self.assertFalse(all(text_encoder_grads))
|
||||
|
||||
|
||||
# Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip
|
||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||
|
|
Loading…
Reference in New Issue