[FX] Symbolic trace for Bloom (#18356)
* Bloom model can now be traced * Bloom traced model can be torch scripted and serialized * Bloom can be traced with variable keyword arguments * Enable XLNet support * Disable XLNet for now
This commit is contained in:
parent
1763770bd9
commit
4e2f4a92dd
|
@ -244,7 +244,7 @@ class BloomAttention(nn.Module):
|
|||
new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
|
||||
# new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1))
|
||||
# fused_qkv = fused_qkv.transpose(1, 0)
|
||||
fused_qkv = fused_qkv.reshape(*new_tensor_shape)
|
||||
fused_qkv = fused_qkv.reshape(new_tensor_shape)
|
||||
# fused_qkv = fused_qkv.permute(0, 2, 1, 3)
|
||||
return torch.split(fused_qkv, self.head_dim, -1)
|
||||
|
||||
|
@ -306,7 +306,7 @@ class BloomAttention(nn.Module):
|
|||
attn_weights = (attention_scores * self.layer_number) + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
attention_probs = attention_probs * (~attention_mask.bool())
|
||||
attention_probs = attention_probs * (~attention_mask.to(torch.bool))
|
||||
# [batch_size, num_heads, q_length, k_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
|
@ -314,7 +314,7 @@ class BloomAttention(nn.Module):
|
|||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size x num_heads, q_length, k_length]
|
||||
attention_probs_reshaped = attention_probs.view(*matmul_result.shape)
|
||||
attention_probs_reshaped = attention_probs.view(matmul_result.shape)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = torch.bmm(
|
||||
|
|
|
@ -98,6 +98,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||
"bert",
|
||||
"blenderbot",
|
||||
"blenderbot-small",
|
||||
"bloom",
|
||||
"clip",
|
||||
"deberta",
|
||||
"deberta-v2",
|
||||
|
@ -128,7 +129,6 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||
"vit",
|
||||
"xglm",
|
||||
# "xlnet",
|
||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||
]
|
||||
|
||||
_REGULAR_SUPPORTED_MODELS = []
|
||||
|
@ -562,10 +562,8 @@ class HFProxy(Proxy):
|
|||
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
||||
|
||||
def __contains__(self, key):
|
||||
# To handle cases such as :
|
||||
# `"some_key" in kwargs`
|
||||
if self.node.op == "placeholder":
|
||||
return False
|
||||
if hasattr(self, "_metadata") and self._metadata is not None:
|
||||
return key in self._metadata
|
||||
return super().__contains__(key)
|
||||
|
||||
|
||||
|
@ -905,6 +903,9 @@ class HFTracer(Tracer):
|
|||
inputs.update(self._generate_dummy_input(root, input_name, shape))
|
||||
|
||||
concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()}
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
|
||||
concrete_metas[f"**{param.name}"] = {}
|
||||
self.meta_args = concrete_metas
|
||||
self.patched_torch_methods = {
|
||||
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
|
||||
|
@ -933,18 +934,15 @@ class HFTracer(Tracer):
|
|||
node.type = torch.Tensor
|
||||
# It is a concrete arg so it is not used and should be removed.
|
||||
else:
|
||||
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
||||
# Newer versions of torch.fx emit an assert statement
|
||||
# for concrete arguments; delete those before we delete
|
||||
# the concrete arg.
|
||||
to_delete = []
|
||||
for user in node.users:
|
||||
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
||||
to_delete.append(user)
|
||||
for user in to_delete:
|
||||
self.graph.erase_node(user)
|
||||
to_visit = [node]
|
||||
to_delete = collections.OrderedDict()
|
||||
while to_visit:
|
||||
n = to_visit.pop(0)
|
||||
to_delete[n] = None
|
||||
to_visit += list(n.users.keys())
|
||||
|
||||
self.graph.erase_node(node)
|
||||
for user in reversed(to_delete.keys()):
|
||||
self.graph.erase_node(user)
|
||||
|
||||
# TODO: solves GraphModule creation.
|
||||
# Without this, return type annotation "Tuple" is causing code execution failure.
|
||||
|
|
|
@ -320,7 +320,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
|||
)
|
||||
|
||||
all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
fx_compatible = True
|
||||
test_missing_keys = False
|
||||
test_pruning = False
|
||||
test_torchscript = True # torch.autograd functions seems to be not supported
|
||||
|
|
Loading…
Reference in New Issue