pruning should keep on device
This commit is contained in:
parent
e4b46d86ce
commit
7f00a36e27
|
@ -80,7 +80,7 @@ def prune_linear_layer(layer, index, dim=0):
|
|||
b = layer.bias[index].clone().detach()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None)
|
||||
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
||||
new_layer.weight.requires_grad = False
|
||||
new_layer.weight.copy_(W.contiguous())
|
||||
new_layer.weight.requires_grad = True
|
||||
|
|
|
@ -55,7 +55,7 @@ def prune_conv1d_layer(layer, index, dim=1):
|
|||
b = layer.bias[index].clone().detach()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = Conv1D(new_size[1], new_size[0])
|
||||
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
|
||||
new_layer.weight.requires_grad = False
|
||||
new_layer.weight.copy_(W.contiguous())
|
||||
new_layer.weight.requires_grad = True
|
||||
|
|
Loading…
Reference in New Issue