This commit is contained in:
ydshieh 2024-05-25 09:26:22 +02:00
parent 0ae789e043
commit 0ba60a3df9
2 changed files with 149 additions and 2 deletions

View File

@ -481,6 +481,7 @@ class EosTokenCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
self.eos_token_id = self.eos_token_id.to(input_ids.device)
if input_ids.device.type == "mps":
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = (
@ -492,7 +493,7 @@ class EosTokenCriteria(StoppingCriteria):
.squeeze()
)
else:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
return is_done

View File

@ -16,6 +16,7 @@
import copy
import inspect
import json
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@ -2451,6 +2452,13 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
import datetime
from collections import OrderedDict
timing = OrderedDict()
torch.cuda.synchronize()
s_gen = datetime.datetime.now()
s = datetime.datetime.now()
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
@ -2516,10 +2524,65 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = -1
if idx not in timing:
timing[idx] = {"name": "before while loop", "timing": 0.0}
timing[idx]["timing"] += e
torch.cuda.synchronize()
s_while = datetime.datetime.now()
step = 0
while True:
step += 1
if step > 4095:
break
torch.cuda.synchronize()
s = datetime.datetime.now()
not_stop = self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device)
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 0
if idx not in timing:
timing[idx] = {"name": "_has_unfinished_sequences", "timing": 0.0}
timing[idx]["timing"] += e
torch.cuda.synchronize()
s = datetime.datetime.now()
if not not_stop:
break
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 0.1
if idx not in timing:
timing[idx] = {"name": "if not not_stop", "timing": 0.0}
timing[idx]["timing"] += e
torch.cuda.synchronize()
s = datetime.datetime.now()
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 1
if idx not in timing:
timing[idx] = {"name": "prepare_inputs_for_generation", "timing": 0.0}
timing[idx]["timing"] += e
torch.cuda.synchronize()
s = datetime.datetime.now()
# forward pass to get next token
outputs = self(
**model_inputs,
@ -2528,6 +2591,17 @@ class GenerationMixin:
output_hidden_states=output_hidden_states,
)
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 2
if idx not in timing:
timing[idx] = {"name": "model forward", "timing": 0.0}
timing[idx]["timing"] += e
torch.cuda.synchronize()
s = datetime.datetime.now()
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
@ -2536,6 +2610,17 @@ class GenerationMixin:
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 3
if idx not in timing:
timing[idx] = {"name": "logits_processor", "timing": 0.0}
timing[idx]["timing"] += e
torch.cuda.synchronize()
s = datetime.datetime.now()
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
@ -2556,6 +2641,17 @@ class GenerationMixin:
else (outputs.hidden_states,)
)
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 4
if idx not in timing:
timing[idx] = {"name": "prepare outputs", "timing": 0.0}
timing[idx]["timing"] += e
torch.cuda.synchronize()
s = datetime.datetime.now()
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
@ -2569,15 +2665,65 @@ class GenerationMixin:
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 5
if idx not in timing:
timing[idx] = {"name": "next_tokens", "timing": 0.0}
timing[idx]["timing"] += e
torch.cuda.synchronize()
s = datetime.datetime.now()
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 6
if idx not in timing:
timing[idx] = {"name": "_update_model_kwargs_for_generation", "timing": 0.0}
timing[idx]["timing"] += e
torch.cuda.synchronize()
s = datetime.datetime.now()
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 7
if idx not in timing:
timing[idx] = {"name": "stopping_criteria", "timing": 0.0}
timing[idx]["timing"] += e
torch.cuda.synchronize()
s = datetime.datetime.now()
this_peer_finished = unfinished_sequences.max() == 0
torch.cuda.synchronize()
t = datetime.datetime.now()
e = (t - s).total_seconds()
idx = 8
if idx not in timing:
timing[idx] = {"name": "final part in while", "timing": 0.0}
timing[idx]["timing"] += e
e1 = (t - s_gen).total_seconds()
e2 = (t - s_while).total_seconds()
print(f"generation time: {e1}")
print(f"while time: {e2}")
print(json.dumps(timing, indent=4))
breakpoint()
if streamer is not None:
streamer.end()