diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 44c040ca6a..dddb08bbaf 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -481,6 +481,7 @@ class EosTokenCriteria(StoppingCriteria): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + self.eos_token_id = self.eos_token_id.to(input_ids.device) if input_ids.device.type == "mps": # https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075 is_done = ( @@ -492,7 +493,7 @@ class EosTokenCriteria(StoppingCriteria): .squeeze() ) else: - is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device)) + is_done = torch.isin(input_ids[:, -1], self.eos_token_id) return is_done diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1633e41021..fe3abed3af 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -16,6 +16,7 @@ import copy import inspect +import json import warnings from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -2471,6 +2472,13 @@ class GenerationMixin: >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" + import datetime + from collections import OrderedDict + timing = OrderedDict() + + torch.cuda.synchronize() + s_gen = datetime.datetime.now() + s = datetime.datetime.now() # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -2536,10 +2544,65 @@ class GenerationMixin: unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = -1 + if idx not in timing: + timing[idx] = {"name": "before while loop", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s_while = datetime.datetime.now() + step = 0 + while True: + step += 1 + if step > 4095: + break + + torch.cuda.synchronize() + s = datetime.datetime.now() + + not_stop = self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device) + + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 0 + if idx not in timing: + timing[idx] = {"name": "_has_unfinished_sequences", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + + if not not_stop: + break + + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 0.1 + if idx not in timing: + timing[idx] = {"name": "if not not_stop", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 1 + if idx not in timing: + timing[idx] = {"name": "prepare_inputs_for_generation", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + # forward pass to get next token outputs = self( **model_inputs, @@ -2548,6 +2611,17 @@ class GenerationMixin: output_hidden_states=output_hidden_states, ) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 2 + if idx not in timing: + timing[idx] = {"name": "model forward", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -2556,6 +2630,17 @@ class GenerationMixin: # pre-process distribution next_tokens_scores = logits_processor(input_ids, next_token_logits) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 3 + if idx not in timing: + timing[idx] = {"name": "logits_processor", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: @@ -2576,6 +2661,17 @@ class GenerationMixin: else (outputs.hidden_states,) ) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 4 + if idx not in timing: + timing[idx] = {"name": "prepare outputs", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + # argmax next_tokens = torch.argmax(next_tokens_scores, dim=-1) @@ -2589,15 +2685,70 @@ class GenerationMixin: input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) + + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 5 + if idx not in timing: + timing[idx] = {"name": "next_tokens", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 6 + if idx not in timing: + timing[idx] = {"name": "_update_model_kwargs_for_generation", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 7 + if idx not in timing: + timing[idx] = {"name": "stopping_criteria", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + this_peer_finished = unfinished_sequences.max() == 0 + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 8 + if idx not in timing: + timing[idx] = {"name": "final part in while", "timing": 0.0} + timing[idx]["timing"] += e + + e1 = (t - s_gen).total_seconds() + e2 = (t - s_while).total_seconds() + print(f"generation time: {e1}") + print(f"while time: {e2}") + print(json.dumps(timing, indent=4)) + + import transformers + o = transformers.models.gemma.modeling_gemma.timing + print(json.dumps(o, indent=4)) + + breakpoint() + if streamer is not None: streamer.end() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index f221e74ddf..98ea73177c 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -15,6 +15,11 @@ # limitations under the License. """ PyTorch Gemma model.""" +import json +import datetime + +timing = {} + import math import warnings from typing import List, Optional, Tuple, Union @@ -864,23 +869,81 @@ class GemmaModel(GemmaPreTrainedModel): ) use_cache = False + torch.cuda.synchronize() + s = datetime.datetime.now() + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 1 + if idx not in timing: + timing[idx] = {"name": "GemmaModel: embed_tokens", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) past_key_values = DynamicCache.from_legacy_cache(past_key_values) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 2 + if idx not in timing: + timing[idx] = {"name": "GemmaModel: past_seen_tokens", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 3 + if idx not in timing: + timing[idx] = {"name": "GemmaModel: cache_position", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + if position_ids is None: position_ids = cache_position.unsqueeze(0) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 4 + if idx not in timing: + timing[idx] = {"name": "GemmaModel: position_ids", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 5 + if idx not in timing: + timing[idx] = {"name": "GemmaModel: causal_mask", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + # embed positions hidden_states = inputs_embeds @@ -890,6 +953,17 @@ class GemmaModel(GemmaPreTrainedModel): normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 6 + if idx not in timing: + timing[idx] = {"name": "GemmaModel: normalizer", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -929,8 +1003,30 @@ class GemmaModel(GemmaPreTrainedModel): if output_attentions: all_self_attns += (layer_outputs[1],) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 7 + if idx not in timing: + timing[idx] = {"name": "GemmaModel: for decoder_layer in self.layers", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + hidden_states = self.norm(hidden_states) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 8 + if idx not in timing: + timing[idx] = {"name": "GemmaModel: self.norm", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -942,6 +1038,18 @@ class GemmaModel(GemmaPreTrainedModel): if isinstance(next_decoder_cache, DynamicCache) else next_decoder_cache ) + + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 9 + if idx not in timing: + timing[idx] = {"name": "GemmaModel: next_cache", "timing": 0.0} + timing[idx]["timing"] += e + + torch.cuda.synchronize() + s = datetime.datetime.now() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1114,6 +1222,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + torch.cuda.synchronize() + s = datetime.datetime.now() + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -1128,6 +1239,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel): cache_position=cache_position, ) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 0 + if idx not in timing: + timing[idx] = {"name": "GemmaForCausalLM: outputs = self.model()", "timing": 0.0} + timing[idx]["timing"] += e + hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float()