Adding read_swag_examples to load the dataset.
This commit is contained in:
parent
7183cded4e
commit
83fdbd6043
|
@ -14,6 +14,9 @@
|
|||
# limitations under the License.
|
||||
"""BERT finetuning runner."""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class SwagExample(object):
|
||||
"""A single training/test example for the SWAG dataset."""
|
||||
def __init__(self,
|
||||
|
@ -53,26 +56,32 @@ class SwagExample(object):
|
|||
|
||||
return ', '.join(l)
|
||||
|
||||
if __name__ == "__main__":
|
||||
e = SwagExample(
|
||||
3416,
|
||||
'Members of the procession walk down the street holding small horn brass instruments.',
|
||||
'A drum line',
|
||||
'passes by walking down the street playing their instruments.',
|
||||
'has heard approaching them.',
|
||||
"arrives and they're outside dancing and asleep.",
|
||||
'turns the lead singer watches the performance.',
|
||||
)
|
||||
print(e)
|
||||
def read_swag_examples(input_file, is_training):
|
||||
input_df = pd.read_csv(input_file)
|
||||
|
||||
e = SwagExample(
|
||||
3416,
|
||||
'Members of the procession walk down the street holding small horn brass instruments.',
|
||||
'A drum line',
|
||||
'passes by walking down the street playing their instruments.',
|
||||
'has heard approaching them.',
|
||||
"arrives and they're outside dancing and asleep.",
|
||||
'turns the lead singer watches the performance.',
|
||||
0
|
||||
)
|
||||
print(e)
|
||||
if is_training and 'label' not in input_df.columns:
|
||||
raise ValueError(
|
||||
"For training, the input file must contain a label column.")
|
||||
|
||||
examples = [
|
||||
SwagExample(
|
||||
swag_id = row['fold-ind'],
|
||||
context_sentence = row['sent1'],
|
||||
start_ending = row['sent2'],
|
||||
ending_0 = row['ending0'],
|
||||
ending_1 = row['ending1'],
|
||||
ending_2 = row['ending2'],
|
||||
ending_3 = row['ending3'],
|
||||
label = row['label'] if is_training else None
|
||||
) for _, row in input_df.iterrows()
|
||||
]
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
examples = read_swag_examples('data/train.csv', True)
|
||||
print(len(examples))
|
||||
for example in examples[:5]:
|
||||
print('###########################')
|
||||
print(example)
|
||||
|
|
Loading…
Reference in New Issue