cleaning - change ' to " (black requirements)
This commit is contained in:
parent
ebba9e929d
commit
e83d9f1c1d
|
@ -114,17 +114,17 @@ class LmSeqsDataset(Dataset):
|
|||
"""
|
||||
Remove sequences with a (too) high level of unknown tokens.
|
||||
"""
|
||||
if 'unk_token' not in self.params.special_tok_ids:
|
||||
if "unk_token" not in self.params.special_tok_ids:
|
||||
return
|
||||
else:
|
||||
unk_token_id = self.params.special_tok_ids['unk_token']
|
||||
unk_token_id = self.params.special_tok_ids["unk_token"]
|
||||
init_size = len(self)
|
||||
unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids])
|
||||
indices = (unk_occs/self.lengths) < 0.5
|
||||
indices = (unk_occs / self.lengths) < 0.5
|
||||
self.token_ids = self.token_ids[indices]
|
||||
self.lengths = self.lengths[indices]
|
||||
new_size = len(self)
|
||||
logger.info(f'Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).')
|
||||
logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).")
|
||||
|
||||
def print_statistics(self):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue