163 lines
5.4 KiB
Python
163 lines
5.4 KiB
Python
import logging
|
|
import os
|
|
from typing import List, TextIO, Union
|
|
|
|
from conllu import parse_incr
|
|
from utils_ner import InputExample, Split, TokenClassificationTask
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NER(TokenClassificationTask):
|
|
def __init__(self, label_idx=-1):
|
|
# in NER datasets, the last column is usually reserved for NER label
|
|
self.label_idx = label_idx
|
|
|
|
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]:
|
|
if isinstance(mode, Split):
|
|
mode = mode.value
|
|
file_path = os.path.join(data_dir, f"{mode}.txt")
|
|
guid_index = 1
|
|
examples = []
|
|
with open(file_path, encoding="utf-8") as f:
|
|
words = []
|
|
labels = []
|
|
for line in f:
|
|
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
|
if words:
|
|
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
|
guid_index += 1
|
|
words = []
|
|
labels = []
|
|
else:
|
|
splits = line.split(" ")
|
|
words.append(splits[0])
|
|
if len(splits) > 1:
|
|
labels.append(splits[self.label_idx].replace("\n", ""))
|
|
else:
|
|
# Examples could have no label for mode = "test"
|
|
labels.append("O")
|
|
if words:
|
|
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
|
return examples
|
|
|
|
def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: List):
|
|
example_id = 0
|
|
for line in test_input_reader:
|
|
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
|
writer.write(line)
|
|
if not preds_list[example_id]:
|
|
example_id += 1
|
|
elif preds_list[example_id]:
|
|
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
|
|
writer.write(output_line)
|
|
else:
|
|
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
|
|
|
|
def get_labels(self, path: str) -> List[str]:
|
|
if path:
|
|
with open(path, "r") as f:
|
|
labels = f.read().splitlines()
|
|
if "O" not in labels:
|
|
labels = ["O"] + labels
|
|
return labels
|
|
else:
|
|
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
|
|
|
|
|
class Chunk(NER):
|
|
def __init__(self):
|
|
# in CONLL2003 dataset chunk column is second-to-last
|
|
super().__init__(label_idx=-2)
|
|
|
|
def get_labels(self, path: str) -> List[str]:
|
|
if path:
|
|
with open(path, "r") as f:
|
|
labels = f.read().splitlines()
|
|
if "O" not in labels:
|
|
labels = ["O"] + labels
|
|
return labels
|
|
else:
|
|
return [
|
|
"O",
|
|
"B-ADVP",
|
|
"B-INTJ",
|
|
"B-LST",
|
|
"B-PRT",
|
|
"B-NP",
|
|
"B-SBAR",
|
|
"B-VP",
|
|
"B-ADJP",
|
|
"B-CONJP",
|
|
"B-PP",
|
|
"I-ADVP",
|
|
"I-INTJ",
|
|
"I-LST",
|
|
"I-PRT",
|
|
"I-NP",
|
|
"I-SBAR",
|
|
"I-VP",
|
|
"I-ADJP",
|
|
"I-CONJP",
|
|
"I-PP",
|
|
]
|
|
|
|
|
|
class POS(TokenClassificationTask):
|
|
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]:
|
|
if isinstance(mode, Split):
|
|
mode = mode.value
|
|
file_path = os.path.join(data_dir, f"{mode}.txt")
|
|
guid_index = 1
|
|
examples = []
|
|
|
|
with open(file_path, encoding="utf-8") as f:
|
|
for sentence in parse_incr(f):
|
|
words = []
|
|
labels = []
|
|
for token in sentence:
|
|
words.append(token["form"])
|
|
labels.append(token["upos"])
|
|
assert len(words) == len(labels)
|
|
if words:
|
|
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
|
guid_index += 1
|
|
return examples
|
|
|
|
def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: List):
|
|
example_id = 0
|
|
for sentence in parse_incr(test_input_reader):
|
|
s_p = preds_list[example_id]
|
|
out = ""
|
|
for token in sentence:
|
|
out += f'{token["form"]} ({token["upos"]}|{s_p.pop(0)}) '
|
|
out += "\n"
|
|
writer.write(out)
|
|
example_id += 1
|
|
|
|
def get_labels(self, path: str) -> List[str]:
|
|
if path:
|
|
with open(path, "r") as f:
|
|
return f.read().splitlines()
|
|
else:
|
|
return [
|
|
"ADJ",
|
|
"ADP",
|
|
"ADV",
|
|
"AUX",
|
|
"CCONJ",
|
|
"DET",
|
|
"INTJ",
|
|
"NOUN",
|
|
"NUM",
|
|
"PART",
|
|
"PRON",
|
|
"PROPN",
|
|
"PUNCT",
|
|
"SCONJ",
|
|
"SYM",
|
|
"VERB",
|
|
"X",
|
|
]
|