Fix pytorch image classification example (#14883)
* Update example * Remove skip in tests
This commit is contained in:
parent
7df4b90c76
commit
1045a36c1f
|
@ -279,12 +279,14 @@ def main():
|
|||
|
||||
def train_transforms(example_batch):
|
||||
"""Apply _train_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
||||
example_batch["pixel_values"] = [
|
||||
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
|
||||
]
|
||||
return example_batch
|
||||
|
||||
def val_transforms(example_batch):
|
||||
"""Apply _val_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
||||
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
|
||||
return example_batch
|
||||
|
||||
if training_args.do_train:
|
||||
|
|
|
@ -19,7 +19,6 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
@ -409,7 +408,6 @@ class ExamplesTests(TestCasePlus):
|
|||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
|
||||
@unittest.skip("Fix me Nate!")
|
||||
def test_run_image_classification(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
|
Loading…
Reference in New Issue