CLI: Don't check the model head when there is no model head (#18733)

This commit is contained in:
Joao Gante 2022-08-23 15:38:59 +01:00 committed by GitHub
parent 438698085c
commit 6faf283288
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 2 deletions

View File

@ -286,7 +286,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k}
hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k}
max_crossload_output_diff = max(output_differences.values())
if len(output_differences) == 0 and architectures is not None:
raise ValueError(
f"Something went wrong -- the config file has architectures ({architectures}), but no model head"
" output was found. All outputs start with 'hidden'"
)
max_crossload_output_diff = max(output_differences.values()) if output_differences else 0.0
max_crossload_hidden_diff = max(hidden_differences.values())
if max_crossload_output_diff > MAX_ERROR or max_crossload_hidden_diff > self._max_hidden_error:
raise ValueError(
@ -310,7 +315,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs)
output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k}
hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k}
max_conversion_output_diff = max(output_differences.values())
if len(output_differences) == 0 and architectures is not None:
raise ValueError(
f"Something went wrong -- the config file has architectures ({architectures}), but no model head"
" output was found. All outputs start with 'hidden'"
)
max_conversion_output_diff = max(output_differences.values()) if output_differences else 0.0
max_conversion_hidden_diff = max(hidden_differences.values())
if max_conversion_output_diff > MAX_ERROR or max_conversion_hidden_diff > self._max_hidden_error:
raise ValueError(