bump candle to 0.3.1 and conv_transpose_1d (#977)

This commit is contained in:
Louis Fortier-Dubois 2023-11-21 09:13:19 -05:00 committed by GitHub
parent cdf54d0b40
commit 4711db0e18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 3 deletions

View File

@ -18,8 +18,8 @@ derive-new = { workspace = true }
burn-tensor = { path = "../burn-tensor", version = "0.11.0", default-features = false }
half = { workspace = true }
# TODO remove pinned version ("=") once candle-core is updated to 0.3.1
candle-core = { version = "=0.3.0" }
candle-core = { version = "0.3.1" }
[dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.11.0", default-features = false, features = [

View File

@ -50,6 +50,7 @@ impl From<candle_core::Device> for CandleDevice {
match device.location() {
DeviceLocation::Cpu => CandleDevice::Cpu,
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
DeviceLocation::Metal => panic!("Metal unsupported"),
}
}
}

View File

@ -83,7 +83,26 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<Self, 3> {
panic!("Candle does not support conv_transpose1d")
assert!(
options.groups == 1,
"Candle does not support groups in transposed convolutions"
);
let conv_transpose = x
.tensor
.conv_transpose1d(
&weight.tensor,
options.padding[0],
options.padding_out[0],
options.stride[0],
options.dilation[0],
)
.unwrap();
CandleTensor::new(match bias {
Some(bias) => conv_transpose
.broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())
.unwrap(),
None => conv_transpose,
})
}
fn conv_transpose2d(