commplying with isort
This commit is contained in:
parent
db2a3b2e01
commit
5c8e5b3709
|
@ -17,13 +17,13 @@ For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceC
|
|||
as a standard :class:`~transformers.BertForSequenceClassification`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from emmental.modules import MagnitudeBinarizer, TopKBinarizer, ThresholdBinarizer
|
||||
from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
def main(args):
|
||||
|
@ -40,13 +40,13 @@ def main(args):
|
|||
for name, tensor in model.items():
|
||||
if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
|
||||
pruned_model[name] = tensor
|
||||
print(f"Pruned layer {name}")
|
||||
print(f"Copied layer {name}")
|
||||
elif "classifier" in name or "qa_output" in name:
|
||||
pruned_model[name] = tensor
|
||||
print(f"Pruned layer {name}")
|
||||
print(f"Copied layer {name}")
|
||||
elif "bias" in name:
|
||||
pruned_model[name] = tensor
|
||||
print(f"Pruned layer {name}")
|
||||
print(f"Copied layer {name}")
|
||||
else:
|
||||
if pruning_method == "magnitude":
|
||||
mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold)
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
Count remaining (non-zero) weights in the encoder (i.e. the transformer layers).
|
||||
Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %.
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from emmental.modules import TopKBinarizer, ThresholdBinarizer
|
||||
from emmental.modules import ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
def main(args):
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
from .modules import *
|
||||
|
||||
from .configuration_bert_masked import MaskedBertConfig
|
||||
|
||||
from .modeling_bert_masked import (
|
||||
MaskedBertModel,
|
||||
MaskedBertForMultipleChoice,
|
||||
MaskedBertForQuestionAnswering,
|
||||
MaskedBertForSequenceClassification,
|
||||
MaskedBertForTokenClassification,
|
||||
MaskedBertForMultipleChoice,
|
||||
MaskedBertModel,
|
||||
)
|
||||
from .modules import *
|
||||
|
|
|
@ -19,8 +19,9 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init`
|
|||
|
||||
import logging
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -26,13 +26,16 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from emmental import MaskedBertConfig, MaskedLinear
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from transformers.modeling_bert import (
|
||||
ACT2FN,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BertLayerNorm,
|
||||
load_tf_weights_in_bert,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_bert import load_tf_weights_in_bert, ACT2FN, BertLayerNorm
|
||||
|
||||
from emmental import MaskedLinear
|
||||
from emmental import MaskedBertConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer
|
||||
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
from .masked_nn import MaskedLinear
|
||||
|
|
|
@ -19,14 +19,14 @@ the weight matrix to prune a portion of the weights.
|
|||
The pruned weight matrix is then multiplied against the inputs (and if necessary, the bias is added).
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import init
|
||||
|
||||
import math
|
||||
|
||||
from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer
|
||||
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||
|
||||
|
||||
class MaskedLinear(nn.Linear):
|
||||
|
|
|
@ -24,12 +24,13 @@ import random
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from emmental import MaskedBertConfig, MaskedBertForSequenceClassification
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
|
@ -43,7 +44,6 @@ from transformers import glue_convert_examples_to_features as convert_examples_t
|
|||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import glue_processors as processors
|
||||
|
||||
from emmental import MaskedBertConfig, MaskedBertForSequenceClassification
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
|
|
@ -25,12 +25,13 @@ import timeit
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from emmental import MaskedBertConfig, MaskedBertForQuestionAnswering
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
|
@ -48,8 +49,6 @@ from transformers.data.metrics.squad_metrics import (
|
|||
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
|
||||
|
||||
|
||||
from emmental import MaskedBertConfig, MaskedBertForQuestionAnswering
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
|
|
Loading…
Reference in New Issue