From 4711db0e182427c34e08cb9b9135b7a7685d2dbc Mon Sep 17 00:00:00 2001 From: Louis Fortier-Dubois Date: Tue, 21 Nov 2023 09:13:19 -0500 Subject: [PATCH] bump candle to 0.3.1 and conv_transpose_1d (#977) --- burn-candle/Cargo.toml | 4 ++-- burn-candle/src/backend.rs | 1 + burn-candle/src/ops/module.rs | 21 ++++++++++++++++++++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/burn-candle/Cargo.toml b/burn-candle/Cargo.toml index 74c59c8a3..bad485f5f 100644 --- a/burn-candle/Cargo.toml +++ b/burn-candle/Cargo.toml @@ -18,8 +18,8 @@ derive-new = { workspace = true } burn-tensor = { path = "../burn-tensor", version = "0.11.0", default-features = false } half = { workspace = true } -# TODO remove pinned version ("=") once candle-core is updated to 0.3.1 -candle-core = { version = "=0.3.0" } +candle-core = { version = "0.3.1" } + [dev-dependencies] burn-autodiff = { path = "../burn-autodiff", version = "0.11.0", default-features = false, features = [ diff --git a/burn-candle/src/backend.rs b/burn-candle/src/backend.rs index c2bec2417..79bc1517d 100644 --- a/burn-candle/src/backend.rs +++ b/burn-candle/src/backend.rs @@ -50,6 +50,7 @@ impl From for CandleDevice { match device.location() { DeviceLocation::Cpu => CandleDevice::Cpu, DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id), + DeviceLocation::Metal => panic!("Metal unsupported"), } } } diff --git a/burn-candle/src/ops/module.rs b/burn-candle/src/ops/module.rs index 0a169277f..b0b1c2264 100644 --- a/burn-candle/src/ops/module.rs +++ b/burn-candle/src/ops/module.rs @@ -83,7 +83,26 @@ impl ModuleOps for Candle>, options: ConvTransposeOptions<1>, ) -> FloatTensor { - panic!("Candle does not support conv_transpose1d") + assert!( + options.groups == 1, + "Candle does not support groups in transposed convolutions" + ); + let conv_transpose = x + .tensor + .conv_transpose1d( + &weight.tensor, + options.padding[0], + options.padding_out[0], + options.stride[0], + options.dilation[0], + ) + .unwrap(); + CandleTensor::new(match bias { + Some(bias) => conv_transpose + .broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap()) + .unwrap(), + None => conv_transpose, + }) } fn conv_transpose2d(