This commit is contained in:
ydshieh 2024-05-25 09:52:50 +02:00
parent 0ba60a3df9
commit 8da2af59ca
2 changed files with 21 additions and 0 deletions

View File

@ -2722,6 +2722,11 @@ class GenerationMixin:
print(f"generation time: {e1}")
print(f"while time: {e2}")
print(json.dumps(timing, indent=4))
import transformers
o = transformers.models.gemma.modeling_gemma.timing
print(json.dumps(o, indent=4))
breakpoint()
if streamer is not None:

View File

@ -15,6 +15,11 @@
# limitations under the License.
""" PyTorch Gemma model."""
import json
import datetime
timing = {}
import math
import warnings
from typing import List, Optional, Tuple, Union
@ -1119,6 +1124,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
torch.cuda.synchronize()
s = datetime.datetime.now()
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
@ -1133,6 +1141,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
cache_position=cache_position,
)
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 0
if idx not in timing:
timing[idx] = {"name": "GemmaForCausalLM: outputs = self.model()", "timing": 0.0}
timing[idx]["timing"] += e
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()