sync
This commit is contained in:
parent
0ba60a3df9
commit
8da2af59ca
|
@ -2722,6 +2722,11 @@ class GenerationMixin:
|
|||
print(f"generation time: {e1}")
|
||||
print(f"while time: {e2}")
|
||||
print(json.dumps(timing, indent=4))
|
||||
|
||||
import transformers
|
||||
o = transformers.models.gemma.modeling_gemma.timing
|
||||
print(json.dumps(o, indent=4))
|
||||
|
||||
breakpoint()
|
||||
|
||||
if streamer is not None:
|
||||
|
|
|
@ -15,6 +15,11 @@
|
|||
# limitations under the License.
|
||||
""" PyTorch Gemma model."""
|
||||
|
||||
import json
|
||||
import datetime
|
||||
|
||||
timing = {}
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
@ -1119,6 +1124,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
torch.cuda.synchronize()
|
||||
s = datetime.datetime.now()
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
|
@ -1133,6 +1141,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t = datetime.datetime.now()
|
||||
e = (t - s).total_seconds()
|
||||
idx = 0
|
||||
if idx not in timing:
|
||||
timing[idx] = {"name": "GemmaForCausalLM: outputs = self.model()", "timing": 0.0}
|
||||
timing[idx]["timing"] += e
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
|
Loading…
Reference in New Issue