update `mvmt-pruning/saving_prunebert` (updating torch to 1.5)

This commit is contained in:
VictorSanh 2020-06-11 19:42:45 +00:00
parent caf3746678
commit 473808da0d
1 changed files with 41 additions and 19 deletions

View File

@ -18,7 +18,9 @@
"\n",
"We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical)!\n",
"\n",
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Floptical_disk_21MB.jpg/440px-Floptical_disk_21MB.jpg\" width=\"200\">"
"<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/Floptical_disk_21MB.jpg/440px-Floptical_disk_21MB.jpg\" width=\"200\">\n",
"\n",
"*Note: this notebook is compatible with `torch>=1.5.0` If you are using, `torch==1.4.0`, please refer to [this previous version of the notebook](https://github.com/huggingface/transformers/commit/b11386e158e86e62d4041eabd86d044cd1695737).*"
]
},
{
@ -67,10 +69,7 @@
"source": [
"# Load fine-pruned model and quantize the model\n",
"\n",
"model_path = \"serialization_dir/bert-base-uncased/92/squad/l1\"\n",
"model_name = \"bertarized_l1_with_distil_0._0.1_1_2_l1_1100._3e-5_1e-2_sigmoied_threshold_constant_0._10_epochs\"\n",
"\n",
"model = BertForQuestionAnswering.from_pretrained(os.path.join(model_path, model_name))\n",
"model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n",
"model.to('cpu')\n",
"\n",
"quantized_model = torch.quantization.quantize_dynamic(\n",
@ -196,7 +195,7 @@
"\n",
"elementary_qtz_st = {}\n",
"for name, param in qtz_st.items():\n",
" if param.is_quantized:\n",
" if \"dtype\" not in name and param.is_quantized:\n",
" print(\"Decompose quantization for\", name)\n",
" # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n",
" scale = param.q_scale() # torch.tensor(1,) - float32\n",
@ -221,6 +220,17 @@
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n",
"str_2_dtype = {\"qint8\": torch.qint8}\n",
"dtype_2_str = {torch.qint8: \"qint8\"}\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
@ -245,7 +255,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
@ -266,9 +276,10 @@
"Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n",
"Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n",
"Skip bert.pooler.dense._packed_params.bias\n",
"Skip bert.pooler.dense._packed_params.dtype\n",
"\n",
"Encoder Size (MB) - Dense: 340.25\n",
"Encoder Size (MB) - Sparse & Quantized: 11.27\n"
"Encoder Size (MB) - Dense: 340.26\n",
"Encoder Size (MB) - Sparse & Quantized: 11.28\n"
]
}
],
@ -300,10 +311,14 @@
"\n",
" elif type(param) == float or type(param) == int or type(param) == tuple:\n",
" # float - tensor _packed_params.weight.scale\n",
" # int - tensor_packed_params.weight.zero_point\n",
" # int - tensor _packed_params.weight.zero_point\n",
" # tuple - tensor _packed_params.weight.shape\n",
" hf.attrs[name] = param\n",
"\n",
" elif type(param) == torch.dtype:\n",
" # dtype - tensor _packed_params.dtype\n",
" hf.attrs[name] = dtype_2_str[param]\n",
" \n",
" else:\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
@ -319,7 +334,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [
{
@ -327,7 +342,7 @@
"output_type": "stream",
"text": [
"\n",
"Size (MB): 99.39\n"
"Size (MB): 99.41\n"
]
}
],
@ -363,10 +378,15 @@
" # tuple - tensor _packed_params.weight.shape\n",
" hf.attrs[name] = param\n",
"\n",
" elif type(param) == torch.dtype:\n",
" # dtype - tensor _packed_params.dtype\n",
" hf.attrs[name] = dtype_2_str[param]\n",
" \n",
" else:\n",
" hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n",
"\n",
"\n",
"\n",
"with open('dbg/metadata.json', 'w') as f:\n",
" f.write(json.dumps(qtz_st._metadata)) \n",
"\n",
@ -383,7 +403,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@ -406,6 +426,8 @@
" attr_param = int(attr_param)\n",
" else:\n",
" attr_param = torch.tensor(attr_param)\n",
" elif \".dtype\" in attr_name:\n",
" attr_param = str_2_dtype[attr_param]\n",
" reconstructed_elementary_qtz_st[attr_name] = attr_param\n",
" # print(f\"Unpack {attr_name}\")\n",
" \n",
@ -428,7 +450,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -451,7 +473,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@ -487,7 +509,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@ -517,7 +539,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@ -526,7 +548,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@ -553,7 +575,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"outputs": [
{