Fix pytorch image classification example (#14883)

* Update example

* Remove skip in tests
This commit is contained in:
Mario Šaško 2021-12-22 14:42:19 +01:00 committed by GitHub
parent 7df4b90c76
commit 1045a36c1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -279,12 +279,14 @@ def main():
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
example_batch["pixel_values"] = [
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
]
return example_batch
def val_transforms(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
return example_batch
if training_args.do_train:

View File

@ -19,7 +19,6 @@ import json
import logging
import os
import sys
import unittest
from unittest.mock import patch
import torch
@ -409,7 +408,6 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30)
@unittest.skip("Fix me Nate!")
def test_run_image_classification(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)