diff --git a/examples/movement-pruning/Saving_PruneBERT.ipynb b/examples/movement-pruning/Saving_PruneBERT.ipynb
index 01fcd3257c..b9ce4bb892 100644
--- a/examples/movement-pruning/Saving_PruneBERT.ipynb
+++ b/examples/movement-pruning/Saving_PruneBERT.ipynb
@@ -18,7 +18,9 @@
"\n",
"We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical)!\n",
"\n",
- ""
+ "\n",
+ "\n",
+ "*Note: this notebook is compatible with `torch>=1.5.0` If you are using, `torch==1.4.0`, please refer to [this previous version of the notebook](https://github.com/huggingface/transformers/commit/b11386e158e86e62d4041eabd86d044cd1695737).*"
]
},
{
@@ -67,10 +69,7 @@
"source": [
"# Load fine-pruned model and quantize the model\n",
"\n",
- "model_path = \"serialization_dir/bert-base-uncased/92/squad/l1\"\n",
- "model_name = \"bertarized_l1_with_distil_0._0.1_1_2_l1_1100._3e-5_1e-2_sigmoied_threshold_constant_0._10_epochs\"\n",
- "\n",
- "model = BertForQuestionAnswering.from_pretrained(os.path.join(model_path, model_name))\n",
+ "model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n",
"model.to('cpu')\n",
"\n",
"quantized_model = torch.quantization.quantize_dynamic(\n",
@@ -196,7 +195,7 @@
"\n",
"elementary_qtz_st = {}\n",
"for name, param in qtz_st.items():\n",
- " if param.is_quantized:\n",
+ " if \"dtype\" not in name and param.is_quantized:\n",
" print(\"Decompose quantization for\", name)\n",
" # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n",
" scale = param.q_scale() # torch.tensor(1,) - float32\n",
@@ -221,6 +220,17 @@
{
"cell_type": "code",
"execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n",
+ "str_2_dtype = {\"qint8\": torch.qint8}\n",
+ "dtype_2_str = {torch.qint8: \"qint8\"}\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
"metadata": {
"scrolled": true
},
@@ -245,7 +255,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -266,9 +276,10 @@
"Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n",
"Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n",
"Skip bert.pooler.dense._packed_params.bias\n",
+ "Skip bert.pooler.dense._packed_params.dtype\n",
"\n",
- "Encoder Size (MB) - Dense: 340.25\n",
- "Encoder Size (MB) - Sparse & Quantized: 11.27\n"
+ "Encoder Size (MB) - Dense: 340.26\n",
+ "Encoder Size (MB) - Sparse & Quantized: 11.28\n"
]
}
],
@@ -300,10 +311,14 @@
"\n",
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
" # float - tensor _packed_params.weight.scale\n",
- " # int - tensor_packed_params.weight.zero_point\n",
+ " # int - tensor _packed_params.weight.zero_point\n",
" # tuple - tensor _packed_params.weight.shape\n",
" hf.attrs[name] = param\n",
"\n",
+ " elif type(param) == torch.dtype:\n",
+ " # dtype - tensor _packed_params.dtype\n",
+ " hf.attrs[name] = dtype_2_str[param]\n",
+ " \n",
" else:\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
@@ -319,7 +334,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -327,7 +342,7 @@
"output_type": "stream",
"text": [
"\n",
- "Size (MB): 99.39\n"
+ "Size (MB): 99.41\n"
]
}
],
@@ -363,10 +378,15 @@
" # tuple - tensor _packed_params.weight.shape\n",
" hf.attrs[name] = param\n",
"\n",
+ " elif type(param) == torch.dtype:\n",
+ " # dtype - tensor _packed_params.dtype\n",
+ " hf.attrs[name] = dtype_2_str[param]\n",
+ " \n",
" else:\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
"\n",
+ "\n",
"with open('dbg/metadata.json', 'w') as f:\n",
" f.write(json.dumps(qtz_st._metadata)) \n",
"\n",
@@ -383,7 +403,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -406,6 +426,8 @@
" attr_param = int(attr_param)\n",
" else:\n",
" attr_param = torch.tensor(attr_param)\n",
+ " elif \".dtype\" in attr_name:\n",
+ " attr_param = str_2_dtype[attr_param]\n",
" reconstructed_elementary_qtz_st[attr_name] = attr_param\n",
" # print(f\"Unpack {attr_name}\")\n",
" \n",
@@ -428,7 +450,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -451,7 +473,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -487,7 +509,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -517,7 +539,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -526,7 +548,7 @@
""
]
},
- "execution_count": 12,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -553,7 +575,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{