Fix make fix-copies with type annotations (#13586)
This commit is contained in:
parent
cec1c63642
commit
88dbbfb2d6
|
@ -109,6 +109,10 @@ def _compute_mask_indices(
|
||||||
# scatter indices to mask
|
# scatter indices to mask
|
||||||
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
|
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# make sure padded input ids cannot be masked
|
||||||
|
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
|
||||||
|
|
||||||
return spec_aug_mask
|
return spec_aug_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -258,7 +258,7 @@ def _compute_mask_indices(
|
||||||
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
|
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
|
||||||
)
|
)
|
||||||
|
|
||||||
return tf.cast(spec_aug_mask, tf.float32)
|
return spec_aug_mask
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
|
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
|
||||||
|
|
|
@ -508,9 +508,9 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
||||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif inputs["input_ids"] is not None:
|
elif inputs["input_ids"] is not None:
|
||||||
input_shape = shape_list(tensor=inputs["input_ids"])
|
input_shape = shape_list(inputs["input_ids"])
|
||||||
elif inputs["inputs_embeds"] is not None:
|
elif inputs["inputs_embeds"] is not None:
|
||||||
input_shape = shape_list(tensor=inputs["inputs_embeds"])[:-1]
|
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ LOCALIZED_READMES = {
|
||||||
|
|
||||||
|
|
||||||
def _should_continue(line, indent):
|
def _should_continue(line, indent):
|
||||||
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None
|
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
||||||
|
|
||||||
|
|
||||||
def find_code_in_transformers(object_name):
|
def find_code_in_transformers(object_name):
|
||||||
|
|
Loading…
Reference in New Issue