[FX] Symbolic trace for Bloom (#18356)

* Bloom model can now be traced

* Bloom traced model can be torch scripted and serialized

* Bloom can be traced with variable keyword arguments

* Enable XLNet support

* Disable XLNet for now
This commit is contained in:
Michael Benayoun 2022-07-29 16:12:27 +02:00 committed by GitHub
parent 1763770bd9
commit 4e2f4a92dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 21 deletions

View File

@ -244,7 +244,7 @@ class BloomAttention(nn.Module):
new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
# new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1)) # new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1))
# fused_qkv = fused_qkv.transpose(1, 0) # fused_qkv = fused_qkv.transpose(1, 0)
fused_qkv = fused_qkv.reshape(*new_tensor_shape) fused_qkv = fused_qkv.reshape(new_tensor_shape)
# fused_qkv = fused_qkv.permute(0, 2, 1, 3) # fused_qkv = fused_qkv.permute(0, 2, 1, 3)
return torch.split(fused_qkv, self.head_dim, -1) return torch.split(fused_qkv, self.head_dim, -1)
@ -306,7 +306,7 @@ class BloomAttention(nn.Module):
attn_weights = (attention_scores * self.layer_number) + attention_mask attn_weights = (attention_scores * self.layer_number) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
attention_probs = attention_probs * (~attention_mask.bool()) attention_probs = attention_probs * (~attention_mask.to(torch.bool))
# [batch_size, num_heads, q_length, k_length] # [batch_size, num_heads, q_length, k_length]
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
@ -314,7 +314,7 @@ class BloomAttention(nn.Module):
attention_probs = attention_probs * head_mask attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, k_length] # change view [batch_size x num_heads, q_length, k_length]
attention_probs_reshaped = attention_probs.view(*matmul_result.shape) attention_probs_reshaped = attention_probs.view(matmul_result.shape)
# matmul: [batch_size * num_heads, q_length, head_dim] # matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm( context_layer = torch.bmm(

View File

@ -98,6 +98,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"bert", "bert",
"blenderbot", "blenderbot",
"blenderbot-small", "blenderbot-small",
"bloom",
"clip", "clip",
"deberta", "deberta",
"deberta-v2", "deberta-v2",
@ -127,8 +128,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"trocr", "trocr",
"vit", "vit",
"xglm", "xglm",
# "xlnet", # "xlnet",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
] ]
_REGULAR_SUPPORTED_MODELS = [] _REGULAR_SUPPORTED_MODELS = []
@ -562,10 +562,8 @@ class HFProxy(Proxy):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {}) return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
def __contains__(self, key): def __contains__(self, key):
# To handle cases such as : if hasattr(self, "_metadata") and self._metadata is not None:
# `"some_key" in kwargs` return key in self._metadata
if self.node.op == "placeholder":
return False
return super().__contains__(key) return super().__contains__(key)
@ -905,6 +903,9 @@ class HFTracer(Tracer):
inputs.update(self._generate_dummy_input(root, input_name, shape)) inputs.update(self._generate_dummy_input(root, input_name, shape))
concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()} concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()}
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
concrete_metas[f"**{param.name}"] = {}
self.meta_args = concrete_metas self.meta_args = concrete_metas
self.patched_torch_methods = { self.patched_torch_methods = {
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
@ -933,18 +934,15 @@ class HFTracer(Tracer):
node.type = torch.Tensor node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed. # It is a concrete arg so it is not used and should be removed.
else: else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): to_visit = [node]
# Newer versions of torch.fx emit an assert statement to_delete = collections.OrderedDict()
# for concrete arguments; delete those before we delete while to_visit:
# the concrete arg. n = to_visit.pop(0)
to_delete = [] to_delete[n] = None
for user in node.users: to_visit += list(n.users.keys())
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node) for user in reversed(to_delete.keys()):
self.graph.erase_node(user)
# TODO: solves GraphModule creation. # TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure. # Without this, return type annotation "Tuple" is causing code execution failure.

View File

@ -320,7 +320,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
) )
all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else () all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
fx_compatible = False fx_compatible = True
test_missing_keys = False test_missing_keys = False
test_pruning = False test_pruning = False
test_torchscript = True # torch.autograd functions seems to be not supported test_torchscript = True # torch.autograd functions seems to be not supported