[FX] Symbolic trace for Bloom (#18356)

* Bloom model can now be traced

* Bloom traced model can be torch scripted and serialized

* Bloom can be traced with variable keyword arguments

* Enable XLNet support

* Disable XLNet for now
This commit is contained in:
Michael Benayoun 2022-07-29 16:12:27 +02:00 committed by GitHub
parent 1763770bd9
commit 4e2f4a92dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 21 deletions

View File

@ -244,7 +244,7 @@ class BloomAttention(nn.Module):
new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
# new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1))
# fused_qkv = fused_qkv.transpose(1, 0)
fused_qkv = fused_qkv.reshape(*new_tensor_shape)
fused_qkv = fused_qkv.reshape(new_tensor_shape)
# fused_qkv = fused_qkv.permute(0, 2, 1, 3)
return torch.split(fused_qkv, self.head_dim, -1)
@ -306,7 +306,7 @@ class BloomAttention(nn.Module):
attn_weights = (attention_scores * self.layer_number) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
attention_probs = attention_probs * (~attention_mask.bool())
attention_probs = attention_probs * (~attention_mask.to(torch.bool))
# [batch_size, num_heads, q_length, k_length]
attention_probs = self.attention_dropout(attention_probs)
@ -314,7 +314,7 @@ class BloomAttention(nn.Module):
attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, k_length]
attention_probs_reshaped = attention_probs.view(*matmul_result.shape)
attention_probs_reshaped = attention_probs.view(matmul_result.shape)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(

View File

@ -98,6 +98,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"bert",
"blenderbot",
"blenderbot-small",
"bloom",
"clip",
"deberta",
"deberta-v2",
@ -127,8 +128,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"trocr",
"vit",
"xglm",
# "xlnet",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "xlnet",
]
_REGULAR_SUPPORTED_MODELS = []
@ -562,10 +562,8 @@ class HFProxy(Proxy):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
def __contains__(self, key):
# To handle cases such as :
# `"some_key" in kwargs`
if self.node.op == "placeholder":
return False
if hasattr(self, "_metadata") and self._metadata is not None:
return key in self._metadata
return super().__contains__(key)
@ -905,6 +903,9 @@ class HFTracer(Tracer):
inputs.update(self._generate_dummy_input(root, input_name, shape))
concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()}
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
concrete_metas[f"**{param.name}"] = {}
self.meta_args = concrete_metas
self.patched_torch_methods = {
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
@ -933,18 +934,15 @@ class HFTracer(Tracer):
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
to_visit = [node]
to_delete = collections.OrderedDict()
while to_visit:
n = to_visit.pop(0)
to_delete[n] = None
to_visit += list(n.users.keys())
self.graph.erase_node(node)
for user in reversed(to_delete.keys()):
self.graph.erase_node(user)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.

View File

@ -320,7 +320,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
)
all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
fx_compatible = False
fx_compatible = True
test_missing_keys = False
test_pruning = False
test_torchscript = True # torch.autograd functions seems to be not supported