Add matmul ONNX op support (#1638)

* Mul onnx op already supported

* Add matmul onnx op checks and tests

* Add missing eq derives

* Change supscript symbol

* Remove dead code

* Add support for matmul broadcast

* No more broadcasting restrictions

* Add results comment for mm, mv and vm
This commit is contained in:
Guillaume Lagrange 2024-04-18 09:20:31 -04:00 committed by GitHub
parent 2a721a9d0c
commit 7705fd9c25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 333 additions and 7 deletions

View File

@ -99,7 +99,7 @@ represent the corresponding Burn Op.
| [LpPool][91] | ❌ | ❌ |
| [LRN][92] | ❌ | ❌ |
| [LSTM][93] | ❌ | ✅ |
| [MatMul][94] | | ✅ |
| [MatMul][94] | | ✅ |
| [MatMulInteger][95] | ❌ | ✅ |
| [Max][96] | ❌ | ✅ |
| [MaxPool1d][97] | ❌ | ✅ |
@ -112,7 +112,7 @@ represent the corresponding Burn Op.
| [Min][104] | ❌ | ✅ |
| [Mish][105] | ❌ | ❌ |
| [Mod][106] | ❌ | ❌ |
| [Mul][107] | | ✅ |
| [Mul][107] | | ✅ |
| [Multinomial][108] | ❌ | ❌ |
| [Neg][109] | ✅ | ✅ |
| [NegativeLogLikelihoodLoss][110] | ❌ | ❌ |

View File

@ -30,6 +30,7 @@ fn main() {
.input("tests/linear/linear.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/log/log.onnx")
.input("tests/matmul/matmul.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")

View File

@ -0,0 +1,49 @@
pytorch2.1.2:ť
4
onnx::MatMul_0
onnx::MatMul_14/MatMul"MatMul
6
onnx::MatMul_2
onnx::MatMul_35 /MatMul_1"MatMul
6
onnx::MatMul_3
onnx::MatMul_26 /MatMul_2"MatMul
main_graphZ(
onnx::MatMul_0




Z(
onnx::MatMul_1




Z(
onnx::MatMul_2




Z
onnx::MatMul_3

b
4




b
5



b
6



B

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/matmul/matmul.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, a, b, c, d):
return torch.matmul(a, b), torch.matmul(c, d), torch.matmul(d, c)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "matmul.onnx"
a = torch.arange(24, dtype=torch.float, device=device).reshape(1, 2, 3, 4)
b = torch.arange(16, dtype=torch.float, device=device).reshape(1, 2, 4, 2)
c = torch.arange(96, dtype=torch.float, device=device).reshape(2, 3, 4, 4)
d = torch.arange(4, dtype=torch.float, device=device)
test_input = (a, b, c, d)
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)
print(f"Finished exporting model to {onnx_name}")
# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(*test_input)
print(f"Test output data: {output}")
if __name__ == "__main__":
main()

View File

@ -40,6 +40,7 @@ include_models!(
linear,
log_softmax,
log,
matmul,
maxpool2d,
mul,
neg,
@ -135,6 +136,7 @@ mod tests {
assert_eq!(output.to_data(), expected);
}
#[test]
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
// Initialize the model with weights (loaded from the exported file)
@ -166,6 +168,61 @@ mod tests {
assert_eq!(output.to_data(), expected);
}
#[test]
fn matmul() {
// Initialize the model with weights (loaded from the exported file)
let model: matmul::Model<Backend> = matmul::Model::default();
let device = Default::default();
let a = Tensor::<Backend, 1, Int>::arange(0..24, &device)
.reshape([1, 2, 3, 4])
.float();
let b = Tensor::<Backend, 1, Int>::arange(0..16, &device)
.reshape([1, 2, 4, 2])
.float();
let c = Tensor::<Backend, 1, Int>::arange(0..96, &device)
.reshape([2, 3, 4, 4])
.float();
let d = Tensor::<Backend, 1, Int>::arange(0..4, &device).float();
let (output_mm, output_mv, output_vm) = model.forward(a, b, c, d);
// matrix-matrix `a @ b`
let expected_mm = Data::from([[
[[28., 34.], [76., 98.], [124., 162.]],
[[604., 658.], [780., 850.], [956., 1042.]],
]]);
// matrix-vector `c @ d` where the lhs vector is expanded and broadcasted to the correct dims
let expected_mv = Data::from([
[
[14., 38., 62., 86.],
[110., 134., 158., 182.],
[206., 230., 254., 278.],
],
[
[302., 326., 350., 374.],
[398., 422., 446., 470.],
[494., 518., 542., 566.],
],
]);
// vector-matrix `d @ c` where the rhs vector is expanded and broadcasted to the correct dims
let expected_vm = Data::from([
[
[56., 62., 68., 74.],
[152., 158., 164., 170.],
[248., 254., 260., 266.],
],
[
[344., 350., 356., 362.],
[440., 446., 452., 458.],
[536., 542., 548., 554.],
],
]);
assert_eq!(output_mm.to_data(), expected_mm);
assert_eq!(output_vm.to_data(), expected_vm);
assert_eq!(output_mv.to_data(), expected_mv);
}
#[test]
fn concat_tensors() {
// Initialize the model

View File

@ -1,16 +1,27 @@
use core::cmp::Ordering;
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use crate::burn::{Scope, TensorKind, TensorType, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
#[derive(Debug, Clone, new)]
#[derive(Debug, Clone)]
pub struct MatmulNode {
pub lhs: TensorType,
pub rhs: TensorType,
pub output: TensorType,
}
impl MatmulNode {
pub fn new(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
if lhs.kind != TensorKind::Float {
panic!("MatMul is only implemented for float tensors");
}
Self { lhs, rhs, output }
}
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for MatmulNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
@ -28,8 +39,49 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MatmulNode {
let rhs = scope.tensor_use_owned(&self.rhs, node_position);
let output = &self.output.name;
let lhs_dim = self.lhs.dim;
let rhs_dim = self.rhs.dim;
// Support broadcasting for missing dimensions
match lhs_dim.cmp(&rhs_dim) {
Ordering::Greater => {
// Alternate unsqueeze(0) -> unsqueeze(-1) -> unsqueeze(0) -> ...
let axes = (0..lhs_dim - rhs_dim)
.map(|i| if i % 2 == 0 { 0 } else { -1 })
.collect::<Vec<i64>>();
let axes = axes.to_tokens();
if rhs_dim == 1 {
// Matrix-vector product: squeeze(-1)
let squeeze_dim = lhs_dim - 1;
quote! {
let #output = #lhs.matmul(#rhs.unsqueeze_dims(&#axes)).squeeze(#squeeze_dim);
}
} else {
quote! {
let #output = #lhs.matmul(#rhs.unsqueeze_dims(&#axes));
}
}
}
Ordering::Less => {
// Always unsqueeze(0)
let axes = [0i64].repeat(rhs_dim - lhs_dim).to_tokens();
if lhs_dim == 1 {
// Vector-matrix product: squeeze(-2)
let squeeze_dim = rhs_dim - 2;
quote! {
let #output = #lhs.unsqueeze_dims(&#axes).matmul(#rhs).squeeze(#squeeze_dim);
}
} else {
quote! {
let #output = #lhs.unsqueeze_dims(&#axes).matmul(#rhs);
}
}
}
Ordering::Equal => quote! {
let #output = #lhs.matmul(#rhs);
},
}
}
@ -51,7 +103,7 @@ mod tests {
};
#[test]
fn test_codegen_two_nodes() {
fn test_codegen_matmul() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(MatmulNode::new(
@ -99,4 +151,104 @@ mod tests {
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_matmul_matrix_vector() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(MatmulNode::new(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 1),
TensorType::new_float("tensor3", 3),
));
graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor2: Tensor<B, 1>
) -> Tensor<B, 3> {
let tensor3 = tensor1.matmul(tensor2.unsqueeze_dims(&[0, -1, 0])).squeeze(3usize);
tensor3
}
}
};
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_matmul_vector_matrix() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(MatmulNode::new(
TensorType::new_float("tensor1", 1),
TensorType::new_float("tensor2", 4),
TensorType::new_float("tensor3", 3),
));
graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 1>,
tensor2: Tensor<B, 4>
) -> Tensor<B, 3> {
let tensor3 = tensor1.unsqueeze_dims(&[0, 0, 0]).matmul(tensor2).squeeze(2usize);
tensor3
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -1,3 +1,4 @@
use core::cmp::max;
use core::panic;
use protobuf::Enum;
@ -35,6 +36,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Linear => linear_update_outputs(node),
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
NodeType::MatMul => matmul_update_outputs(node),
NodeType::MaxPool2d => same_as_input(node),
NodeType::Mul => same_as_input(node),
NodeType::Neg => same_as_input(node),
@ -442,6 +444,28 @@ fn conv_transpose2d_update_outputs(node: &mut Node) {
}
}
fn matmul_update_outputs(node: &mut Node) {
// NOTE: matmul only supported for float tensors
match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
(ArgType::Tensor(a), ArgType::Tensor(b)) => {
// With broadcasting support, output dim has to be computed based on the inputs
let mut out_dim = max(a.dim, b.dim);
// Matrix-vector or vector-matrix product
if (a.dim >= 2 && b.dim == 1) || (a.dim == 1 && b.dim >= 2) {
out_dim -= 1;
}
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: a.elem_type.clone(),
dim: out_dim,
shape: a.shape.clone(),
});
}
_ => panic!("Only tensor input is valid"),
}
}
/// Infers the shape of a ReduceMax node and replaces the shape of the output tensor.
fn reduce_max_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {