pruning should keep on device

This commit is contained in:
thomwolf 2019-06-19 22:23:12 +02:00
parent e4b46d86ce
commit 7f00a36e27
2 changed files with 2 additions and 2 deletions

View File

@ -80,7 +80,7 @@ def prune_linear_layer(layer, index, dim=0):
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True

View File

@ -55,7 +55,7 @@ def prune_conv1d_layer(layer, index, dim=1):
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = Conv1D(new_size[1], new_size[0])
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True