Enable fx tracing for Mistral (#30209)
* tracing for mistral * typo * fix copies
This commit is contained in:
parent
98717cb341
commit
304c6a1e0d
|
@ -868,9 +868,6 @@ class MixtralSparseMoeBlock(nn.Module):
|
|||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
if top_x.shape[0] == 0:
|
||||
continue
|
||||
|
||||
# Index the correct hidden states and compute the expert hidden state for
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
|
|
|
@ -840,9 +840,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
if top_x.shape[0] == 0:
|
||||
continue
|
||||
|
||||
# Index the correct hidden states and compute the expert hidden state for
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
|
|
|
@ -141,12 +141,16 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
|||
"marian",
|
||||
"mbart",
|
||||
"megatron-bert",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
"mobilebert",
|
||||
"mt5",
|
||||
"nezha",
|
||||
"opt",
|
||||
"pegasus",
|
||||
"plbart",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"resnet",
|
||||
"roberta",
|
||||
"segformer",
|
||||
|
@ -758,6 +762,7 @@ class HFTracer(Tracer):
|
|||
"tensor",
|
||||
"clamp",
|
||||
"finfo",
|
||||
"tril",
|
||||
]
|
||||
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
||||
|
||||
|
|
|
@ -303,6 +303,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = True
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
|
|
|
@ -302,6 +302,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = True
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
|
|
|
@ -313,6 +313,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = True
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
|
|
|
@ -342,6 +342,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = True
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
|
|
Loading…
Reference in New Issue