Merge commit 'ae538a0b' into HEAD
# Conflicts: # src/transformers/models/gemma/modeling_gemma.py
This commit is contained in:
commit
3f2d1b1a23
|
@ -481,6 +481,7 @@ class EosTokenCriteria(StoppingCriteria):
|
|||
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
self.eos_token_id = self.eos_token_id.to(input_ids.device)
|
||||
if input_ids.device.type == "mps":
|
||||
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
|
||||
is_done = (
|
||||
|
@ -492,7 +493,7 @@ class EosTokenCriteria(StoppingCriteria):
|
|||
.squeeze()
|
||||
)
|
||||
else:
|
||||
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
|
||||
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
|
||||
return is_done
|
||||
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
@ -2471,6 +2472,13 @@ class GenerationMixin:
|
|||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
||||
```"""
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
timing = OrderedDict()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s_gen = datetime.datetime.now()
|
||||
s = datetime.datetime.now()
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
@ -2536,10 +2544,65 @@ class GenerationMixin:
|
|||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = -1
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "before while loop", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s_while = datetime.datetime.now()
|
||||
step = 0
|
||||
while True:
|
||||
step += 1
|
||||
if step > 4095:
|
||||
break
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
not_stop = self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 0
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "_has_unfinished_sequences", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
if not not_stop:
|
||||
break
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 0.1
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "if not not_stop", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 1
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "prepare_inputs_for_generation", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
# forward pass to get next token
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
|
@ -2548,6 +2611,17 @@ class GenerationMixin:
|
|||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 2
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "model forward", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
|
@ -2556,6 +2630,17 @@ class GenerationMixin:
|
|||
# pre-process distribution
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 3
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "logits_processor", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
|
@ -2576,6 +2661,17 @@ class GenerationMixin:
|
|||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 4
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "prepare outputs", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
# argmax
|
||||
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
||||
|
||||
|
@ -2589,15 +2685,70 @@ class GenerationMixin:
|
|||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 5
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "next_tokens", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 6
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "_update_model_kwargs_for_generation", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 7
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "stopping_criteria", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 8
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "final part in while", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
e1 = (t - s_gen).total_seconds()
|
||||
e2 = (t - s_while).total_seconds()
|
||||
print(f"generation time: {e1}")
|
||||
print(f"while time: {e2}")
|
||||
print(json.dumps(timing, indent=4))
|
||||
|
||||
import transformers
|
||||
o = transformers.models.gemma.modeling_gemma.timing
|
||||
print(json.dumps(o, indent=4))
|
||||
|
||||
breakpoint()
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
||||
|
|
|
@ -15,6 +15,11 @@
|
|||
# limitations under the License.
|
||||
""" PyTorch Gemma model."""
|
||||
|
||||
import json
|
||||
import datetime
|
||||
|
||||
timing = {}
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
@ -864,23 +869,81 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
)
|
||||
use_cache = False
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 1
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaModel: embed_tokens", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 2
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaModel: past_seen_tokens", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 3
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaModel: cache_position", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 4
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaModel: position_ids", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 5
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaModel: causal_mask", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
|
@ -890,6 +953,17 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 6
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaModel: normalizer", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
@ -929,8 +1003,30 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 7
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaModel: for decoder_layer in self.layers", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 8
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaModel: self.norm", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
@ -942,6 +1038,18 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
if isinstance(next_decoder_cache, DynamicCache)
|
||||
else next_decoder_cache
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 9
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaModel: next_cache", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
|
@ -1114,6 +1222,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
|
@ -1128,6 +1239,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 0
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaForCausalLM: outputs = self.model()", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
|
Loading…
Reference in New Issue