diff --git a/Cargo.lock b/Cargo.lock index ccc15f20e..7b58aaf64 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -126,9 +126,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.82" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" +checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" [[package]] name = "arboard" @@ -3517,9 +3517,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.36" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] diff --git a/Cargo.toml b/Cargo.toml index 063aef636..ae1025d44 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ pretty_assertions = "1.4" proc-macro2 = "1.0.79" protobuf = "3.3" protobuf-codegen = "3.3" -quote = "1.0.36" +quote = "1.0.33" percent-encoding = "2.3.1" r2d2 = "0.8.10" r2d2_sqlite = { version = "0.23.0" } diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index c5bc4b0ab..b13f5486c 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -118,7 +118,7 @@ represent the corresponding Burn Op. | [NegativeLogLikelihoodLoss][110] | ❌ | ❌ | | [NonMaxSuppression][112] | ❌ | ❌ | | [NonZero][113] | ❌ | ❌ | -| [Not][114] | ❌ | ✅ | +| [Not][114] | ✅ | ✅ | | [OneHot][115] | ❌ | ✅ | | [Optional][116] | ❌ | ❌ | | [OptionalGetElement][117] | ❌ | ❌ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 1a215a62a..6bee747ac 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -10,6 +10,7 @@ fn main() { .input("tests/add/add.onnx") .input("tests/avg_pool2d/avg_pool2d.onnx") .input("tests/batch_norm/batch_norm.onnx") + .input("tests/cast/cast.onnx") .input("tests/clip/clip_opset16.onnx") .input("tests/clip/clip_opset7.onnx") .input("tests/concat/concat.onnx") @@ -32,6 +33,7 @@ fn main() { .input("tests/maxpool2d/maxpool2d.onnx") .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") + .input("tests/not/not.onnx") .input("tests/recip/recip.onnx") .input("tests/relu/relu.onnx") .input("tests/leaky_relu/leaky_relu.onnx") diff --git a/crates/burn-import/onnx-tests/tests/cast/cast.onnx b/crates/burn-import/onnx-tests/tests/cast/cast.onnx new file mode 100644 index 000000000..35f3f2042 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/cast/cast.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/cast/cast.py b/crates/burn-import/onnx-tests/tests/cast/cast.py new file mode 100755 index 000000000..fb3940d1e --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/cast/cast.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/cast/cast.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward( + self, + x_bool, + x_int, + x_float, + x_scalar, + ): + # NOTE: we clone same-type casts for int and bool, otherwise the exporter would + # link other type casts to the output of the bool cast, leading to additional casts + return ( + x_bool.clone().bool(), + x_bool.int(), + x_bool.float(), + x_int.bool(), + x_int.clone().int(), + x_int.float(), + x_float.bool(), + x_float.int(), + x_float.float(), + x_scalar.int(), + ) + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "cast.onnx" + test_bool = torch.ones((2, 1), device=device, dtype=torch.bool) + test_int = torch.ones((2, 1), device=device, dtype=torch.int) + test_float = torch.ones((2, 1), device=device, dtype=torch.float) + test_scalar = torch.ones(1, device=device, dtype=torch.float).squeeze() + test_input = (test_bool, test_int, test_float, test_scalar) + + # NOTE: torch exports logical_not with a cast node even if the input is already bool + # https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L2204-L2207 + torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16) + + print(f"Finished exporting model to {onnx_name}") + + # Output some test data for use in the test + print(f"Test input data: {test_input}") + output = model.forward(*test_input) + print(f"Test output data: {output}") + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/not/not.onnx b/crates/burn-import/onnx-tests/tests/not/not.onnx new file mode 100644 index 000000000..c3d9539b0 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/not/not.onnx @@ -0,0 +1,19 @@ +pytorch2.1.2: +6 + onnx::Cast_0/Cast_output_0/Cast"Cast* +to  + +/Cast_output_02/Not"Not +main_graphZ& + onnx::Cast_0 +  + + + +b +2 +  + + + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/not/not.py b/crates/burn-import/onnx-tests/tests/not/not.py new file mode 100755 index 000000000..d5bcc2c53 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/not/not.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/not/not.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + return torch.logical_not(x) + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "not.onnx" + test_input = torch.tensor([[[[True, False, True, False]]]], device=device) + + # NOTE: torch exports logical_not with a cast node even if the input is already bool + # https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L2204-L2207 + torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16) + + print(f"Finished exporting model to {onnx_name}") + + # Output some test data for use in the test + print(f"Test input data: {test_input}") + output = model.forward(test_input) + print(f"Test output data: {output}") + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 813e5058e..45c01b9d4 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -17,6 +17,7 @@ include_models!( add, avg_pool2d, batch_norm, + cast, clip_opset16, clip_opset7, concat, @@ -40,6 +41,7 @@ include_models!( maxpool2d, mul, neg, + not, recip, reduce_mean, relu, @@ -64,7 +66,7 @@ mod tests { use super::*; - use burn::tensor::{Data, Int, Shape, Tensor}; + use burn::tensor::{Bool, Data, Int, Shape, Tensor}; use float_cmp::ApproxEq; @@ -854,6 +856,22 @@ mod tests { assert_eq!(output2, expected2); } + #[test] + fn not() { + let device = Default::default(); + let model: not::Model = not::Model::new(&device); + + let input = Tensor::::from_bool( + Data::from([[[[true, false, true, false]]]]), + &device, + ); + + let output = model.forward(input).to_data(); + let expected = Data::from([[[[false, true, false, true]]]]); + + assert_eq!(output, expected); + } + #[test] fn test_model_creation_with_a_default_device() { let device = Default::default(); @@ -908,4 +926,52 @@ mod tests { let output = model.forward(input); assert_eq!(output.shape(), expected_shape); } + + #[test] + fn cast() { + let device = Default::default(); + let model: cast::Model = cast::Model::new(&device); + + let input_bool = + Tensor::::from_bool(Data::from([[true], [true]]), &device); + let input_int = Tensor::::from_ints([[1], [1]], &device); + let input_float = Tensor::::from_floats([[1.], [1.]], &device); + let input_scalar = 1f32; + + let ( + output1, + output2, + output3, + output4, + output5, + output6, + output7, + output8, + output9, + output_scalar, + ) = model.forward( + input_bool.clone(), + input_int.clone(), + input_float.clone(), + input_scalar, + ); + let expected_bool = input_bool.to_data(); + let expected_int = input_int.to_data(); + let expected_float = input_float.to_data(); + let expected_scalar = 1; + + assert_eq!(output1.to_data(), expected_bool); + assert_eq!(output2.to_data(), expected_int); + output3.to_data().assert_approx_eq(&expected_float, 4); + + assert_eq!(output4.to_data(), expected_bool); + assert_eq!(output5.to_data(), expected_int); + output6.to_data().assert_approx_eq(&expected_float, 4); + + assert_eq!(output7.to_data(), expected_bool); + assert_eq!(output8.to_data(), expected_int); + output9.to_data().assert_approx_eq(&expected_float, 4); + + assert_eq!(output_scalar, expected_scalar); + } } diff --git a/crates/burn-import/src/burn/node/unary.rs b/crates/burn-import/src/burn/node/unary.rs index 924c99d5f..d20438314 100644 --- a/crates/burn-import/src/burn/node/unary.rs +++ b/crates/burn-import/src/burn/node/unary.rs @@ -1,5 +1,5 @@ use super::{Node, NodeCodegen}; -use crate::burn::{BurnImports, Scope, ToTokens, Type}; +use crate::burn::{BurnImports, Scope, TensorKind, ToTokens, Type}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; @@ -20,7 +20,8 @@ pub struct UnaryNode { /// Type of unary node. #[derive(Clone)] pub enum UnaryNodeKind { - Cast, + // Input and output tensor types (required for codegen imports) + Cast(Option, Option), Cos, Erf, Exp, @@ -29,6 +30,7 @@ pub enum UnaryNodeKind { Log, LogSoftmax, Neg, + Not, ReduceMean, Reciprocal, LeakyRelu, @@ -44,7 +46,7 @@ pub enum UnaryNodeKind { impl UnaryNodeKind { pub fn as_str(&self) -> &str { match self { - Self::Cast => "cast", + Self::Cast(..) => "cast", Self::Cos => "cos", Self::Erf => "erf", Self::Exp => "exp", @@ -53,6 +55,7 @@ impl UnaryNodeKind { Self::Log => "log", Self::LogSoftmax => "log_softmax", Self::Neg => "neg", + Self::Not => "not", Self::ReduceMean => "reduce_mean", Self::Reciprocal => "reciprocal", Self::LeakyRelu => "leaky_relu", @@ -120,6 +123,17 @@ impl NodeCodegen for UnaryNode { UnaryNodeKind::Neg => { imports.register("core::ops::Neg"); } + UnaryNodeKind::Not => { + imports.register("burn::tensor::Bool"); + } + UnaryNodeKind::Cast(Some(input_kind), Some(output_kind)) => { + if input_kind == TensorKind::Bool || output_kind == TensorKind::Bool { + imports.register("burn::tensor::Bool"); + } + if input_kind == TensorKind::Int || output_kind == TensorKind::Int { + imports.register("burn::tensor::Int"); + } + } _ => {} } } @@ -217,42 +231,61 @@ impl UnaryNode { Self::new(input, output, UnaryNodeKind::Neg, Rc::new(function)) } + pub(crate) fn not(input: Type, output: Type) -> Self { + // Not ONNX operator is constrained to bool tensors, so no need to check the type. + let function = move |input| quote! { #input.bool_not()}; + Self::new(input, output, UnaryNodeKind::Not, Rc::new(function)) + } + /// Casts the input to the output type. - /// - /// Currently this function only supports the following conversions: - /// 1) scalar -> scalar - /// - /// TODO: Implement the following conversions: - /// 2) tensor int -> tensor float - /// 3) tensor float -> tensor int - /// 4) tensor -> scalar - /// 5) scalar -> tensor pub(crate) fn cast(input: Type, output: Type) -> Self { match (input.clone(), output.clone()) { (Type::Scalar(input_scalar), Type::Scalar(output_scalar)) => { if input_scalar.kind == output_scalar.kind { // If the input and output types are the same, we don't need to cast. - Self::new(input, output, UnaryNodeKind::Cast, Rc::new(|input| input)) + Self::new( + input, + output, + UnaryNodeKind::Cast(None, None), + Rc::new(|input| input), + ) } else { // If the input and output types are different, we need to cast. let ty = output_scalar.ty(); Self::new( input, output, - UnaryNodeKind::Cast, + UnaryNodeKind::Cast(None, None), Rc::new(move |input| quote! { #input as #ty }), ) } } - (Type::Tensor(_input_tensor), Type::Tensor(_output_tensor)) => { - // TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023) - // TODO: If the input is scalar and the output type is a tensor, - // we should generate another code block. (@antimora 8/4/2023) - // Tensor::from_data(Data::from([#input]).convert()).unsqueeze(); - todo!() - } + (Type::Tensor(input_tensor), Type::Tensor(output_tensor)) => { + if input_tensor.kind == output_tensor.kind { + // If the input and output types are the same, we don't need to cast. + Self::new( + input, + output, + UnaryNodeKind::Cast(Some(input_tensor.kind), Some(output_tensor.kind)), + Rc::new(|input| input), + ) + } else { + // If the input and output types are different, we need to cast. + let function = match output_tensor.kind { + TensorKind::Bool => move |input| quote! { #input.bool()}, + TensorKind::Int => move |input| quote! { #input.int()}, + TensorKind::Float => move |input| quote! { #input.float()}, + }; - _ => panic!("output must be a tensor"), + Self::new( + input, + output, + UnaryNodeKind::Cast(Some(input_tensor.kind), Some(output_tensor.kind)), + Rc::new(function), + ) + } + } + _ => panic!("output must be a tensor or scalar"), } } @@ -553,6 +586,51 @@ mod tests { vec!["scalar1".to_string()], vec!["scalar2".to_string()], ); + one_node_graph( + UnaryNode::cast( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_int("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.int(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + one_node_graph( + UnaryNode::cast( + Type::Tensor(TensorType::new_int("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.float(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + one_node_graph( + UnaryNode::cast( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_bool("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.bool(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); } #[test] @@ -687,4 +765,23 @@ mod tests { vec!["tensor2".to_string()], ); } + + #[test] + fn test_unary_codegen_not() { + one_node_graph( + UnaryNode::not( + Type::Tensor(TensorType::new_bool("tensor1", 4)), + Type::Tensor(TensorType::new_bool("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.bool_not(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } } diff --git a/crates/burn-import/src/burn/ty.rs b/crates/burn-import/src/burn/ty.rs index 963ee81ab..82ee06246 100644 --- a/crates/burn-import/src/burn/ty.rs +++ b/crates/burn-import/src/burn/ty.rs @@ -13,7 +13,7 @@ pub struct TensorType { pub shape: Option>, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TensorKind { Int, Float, diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index b7dd0762d..8e5c4f1d6 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -38,6 +38,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::MaxPool2d => same_as_input(node), NodeType::Mul => same_as_input(node), NodeType::Neg => same_as_input(node), + NodeType::Not => same_as_input(node), NodeType::Reciprocal => same_as_input(node), NodeType::ReduceMean => reduce_mean_update_outputs(node), NodeType::Relu => same_as_input(node), @@ -135,6 +136,7 @@ fn cast_update_outputs(node: &mut Node) { if node.inputs.len() != 1 { panic!("Cast: multiple inputs are not supported"); } + let input = &mut node.inputs[0]; let output = &mut node.outputs[0]; // Extract cast type and update the output tensor @@ -145,6 +147,7 @@ fn cast_update_outputs(node: &mut Node) { DataType::INT32 => ElementType::Int32, DataType::INT64 => ElementType::Int64, DataType::DOUBLE => ElementType::Float64, + DataType::BOOL => ElementType::Bool, _ => panic!("Cast: unsupported type"), }, _ => panic!("'to' attribute must be an Int64"), @@ -152,19 +155,25 @@ fn cast_update_outputs(node: &mut Node) { None => panic!("Constant node must have a value attribute"), }; - match output.ty.clone() { + match input.ty.clone() { ArgType::Tensor(tensor) => { if tensor.dim == 0 { // treat 0-dim tensor as scalar output.ty = ArgType::Scalar(elem_type); + input.ty = ArgType::Scalar(tensor.elem_type); } else { - todo!("Cast: support casting from different tensor types"); + // Cast input and output are the same shape, but possibly different types + output.ty = ArgType::Tensor(TensorType { + elem_type, + dim: tensor.dim, + shape: tensor.shape.clone(), + }); } } ArgType::Scalar(_scalar) => { output.ty = ArgType::Scalar(elem_type); } - _ => panic!("Cast: only scalar input is valid"), + _ => panic!("Cast: only scalar and tensor inputs are valid"), } } diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 59180cb8b..de1e37d85 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -237,6 +237,7 @@ impl OnnxGraph { NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), NodeType::MatMul => graph.register(Self::matmul_conversion(node)), NodeType::Neg => graph.register(Self::neg_conversion(node)), + NodeType::Not => graph.register(Self::not_conversion(node)), NodeType::Linear => graph.register(Self::linear_conversion::(node)), NodeType::BatchNormalization => { graph.register(Self::batch_norm_conversion::(node)) @@ -697,6 +698,13 @@ impl OnnxGraph { let output = node.outputs.first().unwrap().to_type(); UnaryNode::neg(input, output) } + + fn not_conversion(node: Node) -> UnaryNode { + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); + UnaryNode::not(input, output) + } + fn pow_conversion(node: Node) -> BinaryNode { let lhs = node.inputs.first().unwrap().to_type(); let rhs = node.inputs.get(1).unwrap().to_type(); diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index 3dcdbcc27..41c59221b 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT OR Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -anyhow = "1.0.82" +anyhow = "1.0.81" clap = { version = "4.5.4", features = ["derive"] } derive_more = { version = "0.99.17", features = ["display"], default-features = false } env_logger = "0.11.3"