This commit is contained in:
Nathaniel Simard 2024-09-10 12:13:48 -04:00 committed by GitHub
parent 17050db57e
commit d3fbdeaa48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 6 deletions

View File

@ -31,7 +31,7 @@ env:
# Note: It is not possible to define env vars in composite actions.
# To work around this issue we use inputs and define all the env vars here.
RUST_PREVIOUS_VERSION: 1.79.0
RUST_PREVIOUS_VERSION: 1.80.0
# Cargo
CARGO_TERM_COLOR: "always"

View File

@ -41,11 +41,10 @@ fn conv2d_kernel<F: Float>(
let in_channels = weight.shape(1);
let kernel_size_0 = kernel_size_0_unroll.unwrap_or_else(|| weight.shape(2));
let kernel_size_0 = weight.shape(2);
let kernel_size_1 = kernel_size_1_unroll.unwrap_or_else(|| weight.shape(3));
let unroll_1 = kernel_size_1_unroll.is_some();
let b = ABSOLUTE_POS / output.stride(0) % output.shape(0);
let oc = ABSOLUTE_POS / output.stride(1) % output.shape(1);
let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2);
@ -130,7 +129,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
let kernel_1_unroll = if kernel_1 > 8 {
None
} else {
Some(kernel_1.into())
Some(kernel_1 as u32)
};
let out_0 = calculate_conv_output_size(
@ -188,7 +187,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
ScalarArg::new(options.padding[1] as u32),
ScalarArg::new(options.groups as u32),
),
Some(kernel_1 as u32),
kernel_1_unroll,
);
output

View File

@ -9,7 +9,7 @@ name = "burn"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn"
version.workspace = true
rust-version = "1.79"
rust-version = "1.80"
[features]
default = ["burn-core/default", "burn-train?/default", "std"]