CLI: Don't check the model head when there is no model head (#18733)
This commit is contained in:
parent
438698085c
commit
6faf283288
|
@ -286,7 +286,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
|
||||
output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k}
|
||||
hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k}
|
||||
max_crossload_output_diff = max(output_differences.values())
|
||||
if len(output_differences) == 0 and architectures is not None:
|
||||
raise ValueError(
|
||||
f"Something went wrong -- the config file has architectures ({architectures}), but no model head"
|
||||
" output was found. All outputs start with 'hidden'"
|
||||
)
|
||||
max_crossload_output_diff = max(output_differences.values()) if output_differences else 0.0
|
||||
max_crossload_hidden_diff = max(hidden_differences.values())
|
||||
if max_crossload_output_diff > MAX_ERROR or max_crossload_hidden_diff > self._max_hidden_error:
|
||||
raise ValueError(
|
||||
|
@ -310,7 +315,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
|||
conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs)
|
||||
output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k}
|
||||
hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k}
|
||||
max_conversion_output_diff = max(output_differences.values())
|
||||
if len(output_differences) == 0 and architectures is not None:
|
||||
raise ValueError(
|
||||
f"Something went wrong -- the config file has architectures ({architectures}), but no model head"
|
||||
" output was found. All outputs start with 'hidden'"
|
||||
)
|
||||
max_conversion_output_diff = max(output_differences.values()) if output_differences else 0.0
|
||||
max_conversion_hidden_diff = max(hidden_differences.values())
|
||||
if max_conversion_output_diff > MAX_ERROR or max_conversion_hidden_diff > self._max_hidden_error:
|
||||
raise ValueError(
|
||||
|
|
Loading…
Reference in New Issue