Fix DataCollatorForWholeWordMask again (#8397)
This commit is contained in:
parent
610730998f
commit
4a53e8e9e4
|
@ -206,6 +206,10 @@ def _collate_batch(examples, tokenizer):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def tolist(x: Union[List[Any], torch.Tensor]):
|
||||||
|
return x.tolist() if isinstance(x, torch.Tensor) else x
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForLanguageModeling:
|
class DataCollatorForLanguageModeling:
|
||||||
"""
|
"""
|
||||||
|
@ -320,13 +324,13 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||||
mask_labels = []
|
mask_labels = []
|
||||||
for e in examples:
|
for e in examples:
|
||||||
ref_tokens = []
|
ref_tokens = []
|
||||||
for id in e["input_ids"].tolist():
|
for id in tolist(e["input_ids"]):
|
||||||
token = self.tokenizer._convert_id_to_token(id)
|
token = self.tokenizer._convert_id_to_token(id)
|
||||||
ref_tokens.append(token)
|
ref_tokens.append(token)
|
||||||
|
|
||||||
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
||||||
if "chinese_ref" in e:
|
if "chinese_ref" in e:
|
||||||
ref_pos = e["chinese_ref"].tolist()
|
ref_pos = tolist(e["chinese_ref"])
|
||||||
len_seq = e["input_ids"].size(0)
|
len_seq = e["input_ids"].size(0)
|
||||||
for i in range(len_seq):
|
for i in range(len_seq):
|
||||||
if i in ref_pos:
|
if i in ref_pos:
|
||||||
|
|
Loading…
Reference in New Issue