diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 59169ce90..34a64cc75 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -99,7 +99,7 @@ represent the corresponding Burn Op. | [LpPool][91] | ❌ | ❌ | | [LRN][92] | ❌ | ❌ | | [LSTM][93] | ❌ | ✅ | -| [MatMul][94] | ❌ | ✅ | +| [MatMul][94] | ✅ | ✅ | | [MatMulInteger][95] | ❌ | ✅ | | [Max][96] | ❌ | ✅ | | [MaxPool1d][97] | ❌ | ✅ | @@ -112,7 +112,7 @@ represent the corresponding Burn Op. | [Min][104] | ❌ | ✅ | | [Mish][105] | ❌ | ❌ | | [Mod][106] | ❌ | ❌ | -| [Mul][107] | ❌ | ✅ | +| [Mul][107] | ✅ | ✅ | | [Multinomial][108] | ❌ | ❌ | | [Neg][109] | ✅ | ✅ | | [NegativeLogLikelihoodLoss][110] | ❌ | ❌ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index a53376c2c..89d2cd06c 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -30,6 +30,7 @@ fn main() { .input("tests/linear/linear.onnx") .input("tests/log_softmax/log_softmax.onnx") .input("tests/log/log.onnx") + .input("tests/matmul/matmul.onnx") .input("tests/maxpool2d/maxpool2d.onnx") .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") diff --git a/crates/burn-import/onnx-tests/tests/matmul/matmul.onnx b/crates/burn-import/onnx-tests/tests/matmul/matmul.onnx new file mode 100644 index 000000000..29b4ddb4d --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/matmul/matmul.onnx @@ -0,0 +1,49 @@ +pytorch2.1.2: +4 +onnx::MatMul_0 +onnx::MatMul_14/MatMul"MatMul +6 +onnx::MatMul_2 +onnx::MatMul_35 /MatMul_1"MatMul +6 +onnx::MatMul_3 +onnx::MatMul_26 /MatMul_2"MatMul +main_graphZ( +onnx::MatMul_0 + + + + +Z( +onnx::MatMul_1 + + + + +Z( +onnx::MatMul_2 + + + + +Z +onnx::MatMul_3 + + +b +4 + + + + +b +5 + + + +b +6 + + + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/matmul/matmul.py b/crates/burn-import/onnx-tests/tests/matmul/matmul.py new file mode 100755 index 000000000..93928df43 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/matmul/matmul.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/matmul/matmul.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, a, b, c, d): + return torch.matmul(a, b), torch.matmul(c, d), torch.matmul(d, c) + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "matmul.onnx" + a = torch.arange(24, dtype=torch.float, device=device).reshape(1, 2, 3, 4) + b = torch.arange(16, dtype=torch.float, device=device).reshape(1, 2, 4, 2) + c = torch.arange(96, dtype=torch.float, device=device).reshape(2, 3, 4, 4) + d = torch.arange(4, dtype=torch.float, device=device) + test_input = (a, b, c, d) + + torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16) + + print(f"Finished exporting model to {onnx_name}") + + # Output some test data for use in the test + print(f"Test input data: {test_input}") + output = model.forward(*test_input) + print(f"Test output data: {output}") + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index ecd8209fd..ad504f406 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -40,6 +40,7 @@ include_models!( linear, log_softmax, log, + matmul, maxpool2d, mul, neg, @@ -135,6 +136,7 @@ mod tests { assert_eq!(output.to_data(), expected); } + #[test] fn mul_scalar_with_tensor_and_tensor_with_tensor() { // Initialize the model with weights (loaded from the exported file) @@ -166,6 +168,61 @@ mod tests { assert_eq!(output.to_data(), expected); } + #[test] + fn matmul() { + // Initialize the model with weights (loaded from the exported file) + let model: matmul::Model = matmul::Model::default(); + + let device = Default::default(); + let a = Tensor::::arange(0..24, &device) + .reshape([1, 2, 3, 4]) + .float(); + let b = Tensor::::arange(0..16, &device) + .reshape([1, 2, 4, 2]) + .float(); + let c = Tensor::::arange(0..96, &device) + .reshape([2, 3, 4, 4]) + .float(); + let d = Tensor::::arange(0..4, &device).float(); + + let (output_mm, output_mv, output_vm) = model.forward(a, b, c, d); + // matrix-matrix `a @ b` + let expected_mm = Data::from([[ + [[28., 34.], [76., 98.], [124., 162.]], + [[604., 658.], [780., 850.], [956., 1042.]], + ]]); + // matrix-vector `c @ d` where the lhs vector is expanded and broadcasted to the correct dims + let expected_mv = Data::from([ + [ + [14., 38., 62., 86.], + [110., 134., 158., 182.], + [206., 230., 254., 278.], + ], + [ + [302., 326., 350., 374.], + [398., 422., 446., 470.], + [494., 518., 542., 566.], + ], + ]); + // vector-matrix `d @ c` where the rhs vector is expanded and broadcasted to the correct dims + let expected_vm = Data::from([ + [ + [56., 62., 68., 74.], + [152., 158., 164., 170.], + [248., 254., 260., 266.], + ], + [ + [344., 350., 356., 362.], + [440., 446., 452., 458.], + [536., 542., 548., 554.], + ], + ]); + + assert_eq!(output_mm.to_data(), expected_mm); + assert_eq!(output_vm.to_data(), expected_vm); + assert_eq!(output_mv.to_data(), expected_mv); + } + #[test] fn concat_tensors() { // Initialize the model diff --git a/crates/burn-import/src/burn/node/matmul.rs b/crates/burn-import/src/burn/node/matmul.rs index 53daebdc6..351f14374 100644 --- a/crates/burn-import/src/burn/node/matmul.rs +++ b/crates/burn-import/src/burn/node/matmul.rs @@ -1,16 +1,27 @@ +use core::cmp::Ordering; + use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, Type}; +use crate::burn::{Scope, TensorKind, TensorType, ToTokens, Type}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; -#[derive(Debug, Clone, new)] +#[derive(Debug, Clone)] pub struct MatmulNode { pub lhs: TensorType, pub rhs: TensorType, pub output: TensorType, } +impl MatmulNode { + pub fn new(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self { + if lhs.kind != TensorKind::Float { + panic!("MatMul is only implemented for float tensors"); + } + Self { lhs, rhs, output } + } +} + impl NodeCodegen for MatmulNode { fn output_types(&self) -> Vec { vec![Type::Tensor(self.output.clone())] @@ -28,8 +39,49 @@ impl NodeCodegen for MatmulNode { let rhs = scope.tensor_use_owned(&self.rhs, node_position); let output = &self.output.name; - quote! { - let #output = #lhs.matmul(#rhs); + let lhs_dim = self.lhs.dim; + let rhs_dim = self.rhs.dim; + + // Support broadcasting for missing dimensions + match lhs_dim.cmp(&rhs_dim) { + Ordering::Greater => { + // Alternate unsqueeze(0) -> unsqueeze(-1) -> unsqueeze(0) -> ... + let axes = (0..lhs_dim - rhs_dim) + .map(|i| if i % 2 == 0 { 0 } else { -1 }) + .collect::>(); + let axes = axes.to_tokens(); + + if rhs_dim == 1 { + // Matrix-vector product: squeeze(-1) + let squeeze_dim = lhs_dim - 1; + quote! { + let #output = #lhs.matmul(#rhs.unsqueeze_dims(&#axes)).squeeze(#squeeze_dim); + } + } else { + quote! { + let #output = #lhs.matmul(#rhs.unsqueeze_dims(&#axes)); + } + } + } + Ordering::Less => { + // Always unsqueeze(0) + let axes = [0i64].repeat(rhs_dim - lhs_dim).to_tokens(); + + if lhs_dim == 1 { + // Vector-matrix product: squeeze(-2) + let squeeze_dim = rhs_dim - 2; + quote! { + let #output = #lhs.unsqueeze_dims(&#axes).matmul(#rhs).squeeze(#squeeze_dim); + } + } else { + quote! { + let #output = #lhs.unsqueeze_dims(&#axes).matmul(#rhs); + } + } + } + Ordering::Equal => quote! { + let #output = #lhs.matmul(#rhs); + }, } } @@ -51,7 +103,7 @@ mod tests { }; #[test] - fn test_codegen_two_nodes() { + fn test_codegen_matmul() { let mut graph = BurnGraph::::default(); graph.register(MatmulNode::new( @@ -99,4 +151,104 @@ mod tests { assert_tokens(graph.codegen(), expected); } + + #[test] + fn test_codegen_matmul_matrix_vector() { + let mut graph = BurnGraph::::default(); + + graph.register(MatmulNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 1), + TensorType::new_float("tensor3", 3), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor + ) -> Tensor { + let tensor3 = tensor1.matmul(tensor2.unsqueeze_dims(&[0, -1, 0])).squeeze(3usize); + + tensor3 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_matmul_vector_matrix() { + let mut graph = BurnGraph::::default(); + + graph.register(MatmulNode::new( + TensorType::new_float("tensor1", 1), + TensorType::new_float("tensor2", 4), + TensorType::new_float("tensor3", 3), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor + ) -> Tensor { + let tensor3 = tensor1.unsqueeze_dims(&[0, 0, 0]).matmul(tensor2).squeeze(2usize); + + tensor3 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } } diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 303e6583e..d52d1346b 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -1,3 +1,4 @@ +use core::cmp::max; use core::panic; use protobuf::Enum; @@ -35,6 +36,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Linear => linear_update_outputs(node), NodeType::Log => same_as_input(node), NodeType::LogSoftmax => same_as_input(node), + NodeType::MatMul => matmul_update_outputs(node), NodeType::MaxPool2d => same_as_input(node), NodeType::Mul => same_as_input(node), NodeType::Neg => same_as_input(node), @@ -442,6 +444,28 @@ fn conv_transpose2d_update_outputs(node: &mut Node) { } } +fn matmul_update_outputs(node: &mut Node) { + // NOTE: matmul only supported for float tensors + match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) { + (ArgType::Tensor(a), ArgType::Tensor(b)) => { + // With broadcasting support, output dim has to be computed based on the inputs + let mut out_dim = max(a.dim, b.dim); + + // Matrix-vector or vector-matrix product + if (a.dim >= 2 && b.dim == 1) || (a.dim == 1 && b.dim >= 2) { + out_dim -= 1; + } + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: a.elem_type.clone(), + dim: out_dim, + shape: a.shape.clone(), + }); + } + _ => panic!("Only tensor input is valid"), + } +} + /// Infers the shape of a ReduceMax node and replaces the shape of the output tensor. fn reduce_max_update_outputs(node: &mut Node) { if node.inputs.len() != 1 {