168 lines
5.6 KiB
Python
168 lines
5.6 KiB
Python
import os
|
|
from collections import deque
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
# ------------
|
|
# Data loading
|
|
# ------------
|
|
|
|
|
|
class CNNDMDataset(Dataset):
|
|
"""Abstracts the dataset used to train seq2seq models.
|
|
|
|
The class will process the documents that are located in the specified
|
|
folder. The preprocessing will work on any document that is reasonably
|
|
formatted. On the CNN/DailyMail dataset it will extract both the story
|
|
and the summary.
|
|
|
|
CNN/Daily News:
|
|
|
|
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
|
stored in different files; the summary appears at the end of the story as
|
|
sentences that are prefixed by the special `@highlight` line. To process
|
|
the data, untar both datasets in the same folder, and pass the path to this
|
|
folder as the "data_dir argument. The formatting code was inspired by [2].
|
|
|
|
[1] https://cs.nyu.edu/~kcho/
|
|
[2] https://github.com/abisee/cnn-dailymail/
|
|
"""
|
|
|
|
def __init__(self, path="", prefix="train"):
|
|
"""We initialize the class by listing all the documents to summarize.
|
|
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
|
|
"""
|
|
assert os.path.isdir(path)
|
|
|
|
self.documents = []
|
|
story_filenames_list = os.listdir(path)
|
|
for story_filename in story_filenames_list:
|
|
if "summary" in story_filename:
|
|
continue
|
|
path_to_story = os.path.join(path, story_filename)
|
|
if not os.path.isfile(path_to_story):
|
|
continue
|
|
self.documents.append(path_to_story)
|
|
|
|
def __len__(self):
|
|
"""Returns the number of documents."""
|
|
return len(self.documents)
|
|
|
|
def __getitem__(self, idx):
|
|
document_path = self.documents[idx]
|
|
document_name = document_path.split("/")[-1]
|
|
with open(document_path, encoding="utf-8") as source:
|
|
raw_story = source.read()
|
|
story_lines, summary_lines = process_story(raw_story)
|
|
return document_name, story_lines, summary_lines
|
|
|
|
|
|
def process_story(raw_story):
|
|
"""Extract the story and summary from a story file.
|
|
|
|
Arguments:
|
|
raw_story (str): content of the story file as an utf-8 encoded string.
|
|
|
|
Raises:
|
|
IndexError: If the story is empty or contains no highlights.
|
|
"""
|
|
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
|
|
|
# for some unknown reason some lines miss a period, add it
|
|
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
|
|
|
# gather article lines
|
|
story_lines = []
|
|
lines = deque(nonempty_lines)
|
|
while True:
|
|
try:
|
|
element = lines.popleft()
|
|
if element.startswith("@highlight"):
|
|
break
|
|
story_lines.append(element)
|
|
except IndexError:
|
|
# if "@highlight" is absent from the file we pop
|
|
# all elements until there is None, raising an exception.
|
|
return story_lines, []
|
|
|
|
# gather summary lines
|
|
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
|
|
|
return story_lines, summary_lines
|
|
|
|
|
|
def _add_missing_period(line):
|
|
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', "\u2019", "\u2019", ")"]
|
|
if line.startswith("@highlight"):
|
|
return line
|
|
if line[-1] in END_TOKENS:
|
|
return line
|
|
return line + "."
|
|
|
|
|
|
# --------------------------
|
|
# Encoding and preprocessing
|
|
# --------------------------
|
|
|
|
|
|
def truncate_or_pad(sequence, block_size, pad_token_id):
|
|
"""Adapt the source and target sequences' lengths to the block size.
|
|
If the sequence is shorter we append padding token to the right of the sequence.
|
|
"""
|
|
if len(sequence) > block_size:
|
|
return sequence[:block_size]
|
|
else:
|
|
sequence.extend([pad_token_id] * (block_size - len(sequence)))
|
|
return sequence
|
|
|
|
|
|
def build_mask(sequence, pad_token_id):
|
|
"""Builds the mask. The attention mechanism will only attend to positions
|
|
with value 1."""
|
|
mask = torch.ones_like(sequence)
|
|
idx_pad_tokens = sequence == pad_token_id
|
|
mask[idx_pad_tokens] = 0
|
|
return mask
|
|
|
|
|
|
def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
|
"""Encode the story and summary lines, and join them
|
|
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
|
sentences.
|
|
"""
|
|
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
|
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
|
|
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
|
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
|
|
|
|
return story_token_ids, summary_token_ids
|
|
|
|
|
|
def compute_token_type_ids(batch, separator_token_id):
|
|
"""Segment embeddings as described in [1]
|
|
|
|
The values {0,1} were found in the repository [2].
|
|
|
|
Attributes:
|
|
batch: torch.Tensor, size [batch_size, block_size]
|
|
Batch of input.
|
|
separator_token_id: int
|
|
The value of the token that separates the segments.
|
|
|
|
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
|
arXiv preprint arXiv:1908.08345 (2019).
|
|
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
|
|
"""
|
|
batch_embeddings = []
|
|
for sequence in batch:
|
|
sentence_num = -1
|
|
embeddings = []
|
|
for s in sequence:
|
|
if s == separator_token_id:
|
|
sentence_num += 1
|
|
embeddings.append(sentence_num % 2)
|
|
batch_embeddings.append(embeddings)
|
|
return torch.tensor(batch_embeddings)
|