diff --git a/examples/movement-pruning/Saving_PruneBERT.ipynb b/examples/movement-pruning/Saving_PruneBERT.ipynb index 01fcd3257c..b9ce4bb892 100644 --- a/examples/movement-pruning/Saving_PruneBERT.ipynb +++ b/examples/movement-pruning/Saving_PruneBERT.ipynb @@ -18,7 +18,9 @@ "\n", "We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical)!\n", "\n", - "" + "\n", + "\n", + "*Note: this notebook is compatible with `torch>=1.5.0` If you are using, `torch==1.4.0`, please refer to [this previous version of the notebook](https://github.com/huggingface/transformers/commit/b11386e158e86e62d4041eabd86d044cd1695737).*" ] }, { @@ -67,10 +69,7 @@ "source": [ "# Load fine-pruned model and quantize the model\n", "\n", - "model_path = \"serialization_dir/bert-base-uncased/92/squad/l1\"\n", - "model_name = \"bertarized_l1_with_distil_0._0.1_1_2_l1_1100._3e-5_1e-2_sigmoied_threshold_constant_0._10_epochs\"\n", - "\n", - "model = BertForQuestionAnswering.from_pretrained(os.path.join(model_path, model_name))\n", + "model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n", "model.to('cpu')\n", "\n", "quantized_model = torch.quantization.quantize_dynamic(\n", @@ -196,7 +195,7 @@ "\n", "elementary_qtz_st = {}\n", "for name, param in qtz_st.items():\n", - " if param.is_quantized:\n", + " if \"dtype\" not in name and param.is_quantized:\n", " print(\"Decompose quantization for\", name)\n", " # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n", " scale = param.q_scale() # torch.tensor(1,) - float32\n", @@ -221,6 +220,17 @@ { "cell_type": "code", "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n", + "str_2_dtype = {\"qint8\": torch.qint8}\n", + "dtype_2_str = {torch.qint8: \"qint8\"}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "metadata": { "scrolled": true }, @@ -245,7 +255,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -266,9 +276,10 @@ "Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n", "Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n", "Skip bert.pooler.dense._packed_params.bias\n", + "Skip bert.pooler.dense._packed_params.dtype\n", "\n", - "Encoder Size (MB) - Dense: 340.25\n", - "Encoder Size (MB) - Sparse & Quantized: 11.27\n" + "Encoder Size (MB) - Dense: 340.26\n", + "Encoder Size (MB) - Sparse & Quantized: 11.28\n" ] } ], @@ -300,10 +311,14 @@ "\n", " elif type(param) == float or type(param) == int or type(param) == tuple:\n", " # float - tensor _packed_params.weight.scale\n", - " # int - tensor_packed_params.weight.zero_point\n", + " # int - tensor _packed_params.weight.zero_point\n", " # tuple - tensor _packed_params.weight.shape\n", " hf.attrs[name] = param\n", "\n", + " elif type(param) == torch.dtype:\n", + " # dtype - tensor _packed_params.dtype\n", + " hf.attrs[name] = dtype_2_str[param]\n", + " \n", " else:\n", " hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n", "\n", @@ -319,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -327,7 +342,7 @@ "output_type": "stream", "text": [ "\n", - "Size (MB): 99.39\n" + "Size (MB): 99.41\n" ] } ], @@ -363,10 +378,15 @@ " # tuple - tensor _packed_params.weight.shape\n", " hf.attrs[name] = param\n", "\n", + " elif type(param) == torch.dtype:\n", + " # dtype - tensor _packed_params.dtype\n", + " hf.attrs[name] = dtype_2_str[param]\n", + " \n", " else:\n", " hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n", "\n", "\n", + "\n", "with open('dbg/metadata.json', 'w') as f:\n", " f.write(json.dumps(qtz_st._metadata)) \n", "\n", @@ -383,7 +403,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -406,6 +426,8 @@ " attr_param = int(attr_param)\n", " else:\n", " attr_param = torch.tensor(attr_param)\n", + " elif \".dtype\" in attr_name:\n", + " attr_param = str_2_dtype[attr_param]\n", " reconstructed_elementary_qtz_st[attr_name] = attr_param\n", " # print(f\"Unpack {attr_name}\")\n", " \n", @@ -428,7 +450,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -451,7 +473,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -487,7 +509,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -517,7 +539,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -526,7 +548,7 @@ "" ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -553,7 +575,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ {