Adding read_swag_examples to load the dataset.

This commit is contained in:
Grégory Châtel 2018-12-06 14:02:46 +01:00
parent 7183cded4e
commit 83fdbd6043
1 changed files with 31 additions and 22 deletions

View File

@ -14,6 +14,9 @@
# limitations under the License.
"""BERT finetuning runner."""
import pandas as pd
class SwagExample(object):
"""A single training/test example for the SWAG dataset."""
def __init__(self,
@ -53,26 +56,32 @@ class SwagExample(object):
return ', '.join(l)
if __name__ == "__main__":
e = SwagExample(
3416,
'Members of the procession walk down the street holding small horn brass instruments.',
'A drum line',
'passes by walking down the street playing their instruments.',
'has heard approaching them.',
"arrives and they're outside dancing and asleep.",
'turns the lead singer watches the performance.',
)
print(e)
def read_swag_examples(input_file, is_training):
input_df = pd.read_csv(input_file)
e = SwagExample(
3416,
'Members of the procession walk down the street holding small horn brass instruments.',
'A drum line',
'passes by walking down the street playing their instruments.',
'has heard approaching them.',
"arrives and they're outside dancing and asleep.",
'turns the lead singer watches the performance.',
0
)
print(e)
if is_training and 'label' not in input_df.columns:
raise ValueError(
"For training, the input file must contain a label column.")
examples = [
SwagExample(
swag_id = row['fold-ind'],
context_sentence = row['sent1'],
start_ending = row['sent2'],
ending_0 = row['ending0'],
ending_1 = row['ending1'],
ending_2 = row['ending2'],
ending_3 = row['ending3'],
label = row['label'] if is_training else None
) for _, row in input_df.iterrows()
]
return examples
if __name__ == "__main__":
examples = read_swag_examples('data/train.csv', True)
print(len(examples))
for example in examples[:5]:
print('###########################')
print(example)