[Flax] Fix flax pt equivalence tests (#12154)
* fix_torch_device_generate_test * remove @ * upload
This commit is contained in:
parent
d438eee030
commit
007be9e402
|
@ -181,7 +181,7 @@ class FlaxModelTesterMixin:
|
|||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
|
@ -192,10 +192,7 @@ class FlaxModelTesterMixin:
|
|||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
if not isinstance(
|
||||
fx_output_loaded, tuple
|
||||
): # TODO(Patrick, Daniel) - let's discard use_cache for now
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
|
@ -229,7 +226,7 @@ class FlaxModelTesterMixin:
|
|||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
|
@ -242,8 +239,7 @@ class FlaxModelTesterMixin:
|
|||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
if not isinstance(fx_output, tuple): # TODO(Patrick, Daniel) - let's discard use_cache for now
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
|
Loading…
Reference in New Issue