From 8da2af59ca61c2b9faeb6827845322e51c4892c1 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Sat, 25 May 2024 09:52:50 +0200 Subject: [PATCH] sync --- src/transformers/generation/utils.py | 5 +++++ src/transformers/models/gemma/modeling_gemma.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ced10b53cd..3dd487014d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2722,6 +2722,11 @@ class GenerationMixin: print(f"generation time: {e1}") print(f"while time: {e2}") print(json.dumps(timing, indent=4)) + + import transformers + o = transformers.models.gemma.modeling_gemma.timing + print(json.dumps(o, indent=4)) + breakpoint() if streamer is not None: diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 97e4e5d49f..833a3798b6 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -15,6 +15,11 @@ # limitations under the License. """ PyTorch Gemma model.""" +import json +import datetime + +timing = {} + import math import warnings from typing import List, Optional, Tuple, Union @@ -1119,6 +1124,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + torch.cuda.synchronize() + s = datetime.datetime.now() + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -1133,6 +1141,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel): cache_position=cache_position, ) + torch.cuda.synchronize() + t = datetime.datetime.now() + e = (t - s).total_seconds() + idx = 0 + if idx not in timing: + timing[idx] = {"name": "GemmaForCausalLM: outputs = self.model()", "timing": 0.0} + timing[idx]["timing"] += e + hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float()