Fix a bug for `CallbackHandler.callback_list` (#8052)

* Fix callback_list

* Add test

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Fix test

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
This commit is contained in:
Harutaka Kawamura 2020-10-27 23:37:04 +09:00 committed by GitHub
parent 8e28c327fc
commit 7bff0af0a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 1 deletions

View File

@ -325,7 +325,7 @@ class CallbackHandler(TrainerCallback):
@property
def callback_list(self):
return "\n".join(self.callbacks)
return "\n".join(cb.__class__.__name__ for cb in self.callbacks)
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_init_end", args, state, control)

View File

@ -221,3 +221,10 @@ class TrainerCallbackTest(unittest.TestCase):
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
# warning should be emitted for duplicated callbacks
with unittest.mock.patch("transformers.trainer_callback.logger.warn") as warn_mock:
trainer = self.get_trainer(
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
)
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]