fix multi-gpu squad loss
This commit is contained in:
parent
955cee33a5
commit
2f4765d3ed
|
@ -1130,11 +1130,11 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-05T12:14:39.425771Z",
|
||||
"start_time": "2018-11-05T12:14:39.302205Z"
|
||||
"end_time": "2018-11-05T12:43:46.437506Z",
|
||||
"start_time": "2018-11-05T12:43:46.396007Z"
|
||||
},
|
||||
"code_folding": []
|
||||
},
|
||||
|
@ -1144,13 +1144,13 @@
|
|||
"all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)\n",
|
||||
"all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)\n",
|
||||
"all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)\n",
|
||||
"all_start_positions = torch.tensor([f.start_position for f in eval_features], dtype=torch.long)\n",
|
||||
"all_end_positions = torch.tensor([f.end_position for f in eval_features], dtype=torch.long)\n",
|
||||
"all_start_positions = torch.tensor([[f.start_position] for f in eval_features], dtype=torch.long)\n",
|
||||
"all_end_positions = torch.tensor([[f.end_position] for f in eval_features], dtype=torch.long)\n",
|
||||
"\n",
|
||||
"eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,\n",
|
||||
" all_start_positions, all_end_positions)\n",
|
||||
"eval_sampler = SequentialSampler(eval_data)\n",
|
||||
"eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)\n",
|
||||
"eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=2)\n",
|
||||
"\n",
|
||||
"model.eval()\n",
|
||||
"None"
|
||||
|
@ -1158,11 +1158,11 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-05T12:14:39.463511Z",
|
||||
"start_time": "2018-11-05T12:14:39.427506Z"
|
||||
"end_time": "2018-11-05T12:43:47.166396Z",
|
||||
"start_time": "2018-11-05T12:43:47.132531Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
|
@ -1170,16 +1170,16 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[torch.Size([1, 384]), torch.Size([1, 384]), torch.Size([1, 384]), torch.Size([1]), torch.Size([1])]\n"
|
||||
"[torch.Size([2, 384]), torch.Size([2, 384]), torch.Size([2, 384]), torch.Size([2, 1]), torch.Size([2, 1])]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([1])"
|
||||
"torch.Size([2, 1])"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -1193,11 +1193,11 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-05T12:14:40.503188Z",
|
||||
"start_time": "2018-11-05T12:14:39.465446Z"
|
||||
"end_time": "2018-11-05T12:43:52.988437Z",
|
||||
"start_time": "2018-11-05T12:43:50.683929Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
|
@ -1205,7 +1205,24 @@
|
|||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating: 0%| | 0/270 [00:00<?, ?it/s]\n"
|
||||
"Evaluating: 0%| | 0/135 [00:00<?, ?it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "RuntimeError",
|
||||
"evalue": "multi-target not supported at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/THNN/generic/ClassNLLCriterion.c:21",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-29-c9a3d92f4c7e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mend_positions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mend_positions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mtotal_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mstart_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_logits\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegment_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_positions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0munique_id\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"unique_ids\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 477\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 478\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/Documents/Thomas/Code/HF/BERT/pytorch-pretrained-BERT/modeling.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, token_type_ids, attention_mask, start_positions, end_positions)\u001b[0m\n\u001b[1;32m 457\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstart_positions\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mend_positions\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 458\u001b[0m \u001b[0mloss_fct\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCrossEntropyLoss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 459\u001b[0;31m \u001b[0mstart_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_fct\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_positions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 460\u001b[0m \u001b[0mend_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_fct\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mend_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_positions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0mtotal_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mstart_loss\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mend_loss\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 477\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 478\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m 860\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 861\u001b[0m return F.cross_entropy(input, target, weight=self.weight,\n\u001b[0;32m--> 862\u001b[0;31m ignore_index=self.ignore_index, reduction=self.reduction)\n\u001b[0m\u001b[1;32m 863\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 864\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mcross_entropy\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m 1548\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msize_average\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1549\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_get_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1550\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1551\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1552\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mnll_loss\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m 1405\u001b[0m .format(input.size(0), target.size(0)))\n\u001b[1;32m 1406\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1407\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_enum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1408\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mdim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1409\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnll_loss2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_enum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m: multi-target not supported at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/THNN/generic/ClassNLLCriterion.c:21"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -455,6 +455,8 @@ class BertForQuestionAnswering(nn.Module):
|
|||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
start_positions = start_positions.squeeze(-1) # If we are on multi-GPU, split add a dimension
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
|
|
Loading…
Reference in New Issue