mirror of https://github.com/tracel-ai/burn.git
Add matmul ONNX op support (#1638)
* Mul onnx op already supported * Add matmul onnx op checks and tests * Add missing eq derives * Change supscript symbol * Remove dead code * Add support for matmul broadcast * No more broadcasting restrictions * Add results comment for mm, mv and vm
This commit is contained in:
parent
2a721a9d0c
commit
7705fd9c25
|
@ -99,7 +99,7 @@ represent the corresponding Burn Op.
|
|||
| [LpPool][91] | ❌ | ❌ |
|
||||
| [LRN][92] | ❌ | ❌ |
|
||||
| [LSTM][93] | ❌ | ✅ |
|
||||
| [MatMul][94] | ❌ | ✅ |
|
||||
| [MatMul][94] | ✅ | ✅ |
|
||||
| [MatMulInteger][95] | ❌ | ✅ |
|
||||
| [Max][96] | ❌ | ✅ |
|
||||
| [MaxPool1d][97] | ❌ | ✅ |
|
||||
|
@ -112,7 +112,7 @@ represent the corresponding Burn Op.
|
|||
| [Min][104] | ❌ | ✅ |
|
||||
| [Mish][105] | ❌ | ❌ |
|
||||
| [Mod][106] | ❌ | ❌ |
|
||||
| [Mul][107] | ❌ | ✅ |
|
||||
| [Mul][107] | ✅ | ✅ |
|
||||
| [Multinomial][108] | ❌ | ❌ |
|
||||
| [Neg][109] | ✅ | ✅ |
|
||||
| [NegativeLogLikelihoodLoss][110] | ❌ | ❌ |
|
||||
|
|
|
@ -30,6 +30,7 @@ fn main() {
|
|||
.input("tests/linear/linear.onnx")
|
||||
.input("tests/log_softmax/log_softmax.onnx")
|
||||
.input("tests/log/log.onnx")
|
||||
.input("tests/matmul/matmul.onnx")
|
||||
.input("tests/maxpool2d/maxpool2d.onnx")
|
||||
.input("tests/mul/mul.onnx")
|
||||
.input("tests/neg/neg.onnx")
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
pytorch2.1.2:ť
|
||||
4
|
||||
onnx::MatMul_0
|
||||
onnx::MatMul_14/MatMul"MatMul
|
||||
6
|
||||
onnx::MatMul_2
|
||||
onnx::MatMul_35 /MatMul_1"MatMul
|
||||
6
|
||||
onnx::MatMul_3
|
||||
onnx::MatMul_26 /MatMul_2"MatMul
|
||||
main_graphZ(
|
||||
onnx::MatMul_0
|
||||
|
||||
|
||||
|
||||
|
||||
Z(
|
||||
onnx::MatMul_1
|
||||
|
||||
|
||||
|
||||
|
||||
Z(
|
||||
onnx::MatMul_2
|
||||
|
||||
|
||||
|
||||
|
||||
Z
|
||||
onnx::MatMul_3
|
||||
|
||||
|
||||
b
|
||||
4
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
5
|
||||
|
||||
|
||||
|
||||
b
|
||||
6
|
||||
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,43 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/matmul/matmul.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, a, b, c, d):
|
||||
return torch.matmul(a, b), torch.matmul(c, d), torch.matmul(d, c)
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "matmul.onnx"
|
||||
a = torch.arange(24, dtype=torch.float, device=device).reshape(1, 2, 3, 4)
|
||||
b = torch.arange(16, dtype=torch.float, device=device).reshape(1, 2, 4, 2)
|
||||
c = torch.arange(96, dtype=torch.float, device=device).reshape(2, 3, 4, 4)
|
||||
d = torch.arange(4, dtype=torch.float, device=device)
|
||||
test_input = (a, b, c, d)
|
||||
|
||||
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)
|
||||
|
||||
print(f"Finished exporting model to {onnx_name}")
|
||||
|
||||
# Output some test data for use in the test
|
||||
print(f"Test input data: {test_input}")
|
||||
output = model.forward(*test_input)
|
||||
print(f"Test output data: {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -40,6 +40,7 @@ include_models!(
|
|||
linear,
|
||||
log_softmax,
|
||||
log,
|
||||
matmul,
|
||||
maxpool2d,
|
||||
mul,
|
||||
neg,
|
||||
|
@ -135,6 +136,7 @@ mod tests {
|
|||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
|
||||
// Initialize the model with weights (loaded from the exported file)
|
||||
|
@ -166,6 +168,61 @@ mod tests {
|
|||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul() {
|
||||
// Initialize the model with weights (loaded from the exported file)
|
||||
let model: matmul::Model<Backend> = matmul::Model::default();
|
||||
|
||||
let device = Default::default();
|
||||
let a = Tensor::<Backend, 1, Int>::arange(0..24, &device)
|
||||
.reshape([1, 2, 3, 4])
|
||||
.float();
|
||||
let b = Tensor::<Backend, 1, Int>::arange(0..16, &device)
|
||||
.reshape([1, 2, 4, 2])
|
||||
.float();
|
||||
let c = Tensor::<Backend, 1, Int>::arange(0..96, &device)
|
||||
.reshape([2, 3, 4, 4])
|
||||
.float();
|
||||
let d = Tensor::<Backend, 1, Int>::arange(0..4, &device).float();
|
||||
|
||||
let (output_mm, output_mv, output_vm) = model.forward(a, b, c, d);
|
||||
// matrix-matrix `a @ b`
|
||||
let expected_mm = Data::from([[
|
||||
[[28., 34.], [76., 98.], [124., 162.]],
|
||||
[[604., 658.], [780., 850.], [956., 1042.]],
|
||||
]]);
|
||||
// matrix-vector `c @ d` where the lhs vector is expanded and broadcasted to the correct dims
|
||||
let expected_mv = Data::from([
|
||||
[
|
||||
[14., 38., 62., 86.],
|
||||
[110., 134., 158., 182.],
|
||||
[206., 230., 254., 278.],
|
||||
],
|
||||
[
|
||||
[302., 326., 350., 374.],
|
||||
[398., 422., 446., 470.],
|
||||
[494., 518., 542., 566.],
|
||||
],
|
||||
]);
|
||||
// vector-matrix `d @ c` where the rhs vector is expanded and broadcasted to the correct dims
|
||||
let expected_vm = Data::from([
|
||||
[
|
||||
[56., 62., 68., 74.],
|
||||
[152., 158., 164., 170.],
|
||||
[248., 254., 260., 266.],
|
||||
],
|
||||
[
|
||||
[344., 350., 356., 362.],
|
||||
[440., 446., 452., 458.],
|
||||
[536., 542., 548., 554.],
|
||||
],
|
||||
]);
|
||||
|
||||
assert_eq!(output_mm.to_data(), expected_mm);
|
||||
assert_eq!(output_vm.to_data(), expected_vm);
|
||||
assert_eq!(output_mv.to_data(), expected_mv);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concat_tensors() {
|
||||
// Initialize the model
|
||||
|
|
|
@ -1,16 +1,27 @@
|
|||
use core::cmp::Ordering;
|
||||
|
||||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{Scope, TensorType, Type};
|
||||
use crate::burn::{Scope, TensorKind, TensorType, ToTokens, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
#[derive(Debug, Clone, new)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatmulNode {
|
||||
pub lhs: TensorType,
|
||||
pub rhs: TensorType,
|
||||
pub output: TensorType,
|
||||
}
|
||||
|
||||
impl MatmulNode {
|
||||
pub fn new(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
|
||||
if lhs.kind != TensorKind::Float {
|
||||
panic!("MatMul is only implemented for float tensors");
|
||||
}
|
||||
Self { lhs, rhs, output }
|
||||
}
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for MatmulNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
|
@ -28,8 +39,49 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for MatmulNode {
|
|||
let rhs = scope.tensor_use_owned(&self.rhs, node_position);
|
||||
let output = &self.output.name;
|
||||
|
||||
let lhs_dim = self.lhs.dim;
|
||||
let rhs_dim = self.rhs.dim;
|
||||
|
||||
// Support broadcasting for missing dimensions
|
||||
match lhs_dim.cmp(&rhs_dim) {
|
||||
Ordering::Greater => {
|
||||
// Alternate unsqueeze(0) -> unsqueeze(-1) -> unsqueeze(0) -> ...
|
||||
let axes = (0..lhs_dim - rhs_dim)
|
||||
.map(|i| if i % 2 == 0 { 0 } else { -1 })
|
||||
.collect::<Vec<i64>>();
|
||||
let axes = axes.to_tokens();
|
||||
|
||||
if rhs_dim == 1 {
|
||||
// Matrix-vector product: squeeze(-1)
|
||||
let squeeze_dim = lhs_dim - 1;
|
||||
quote! {
|
||||
let #output = #lhs.matmul(#rhs.unsqueeze_dims(&#axes)).squeeze(#squeeze_dim);
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
let #output = #lhs.matmul(#rhs.unsqueeze_dims(&#axes));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ordering::Less => {
|
||||
// Always unsqueeze(0)
|
||||
let axes = [0i64].repeat(rhs_dim - lhs_dim).to_tokens();
|
||||
|
||||
if lhs_dim == 1 {
|
||||
// Vector-matrix product: squeeze(-2)
|
||||
let squeeze_dim = rhs_dim - 2;
|
||||
quote! {
|
||||
let #output = #lhs.unsqueeze_dims(&#axes).matmul(#rhs).squeeze(#squeeze_dim);
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
let #output = #lhs.unsqueeze_dims(&#axes).matmul(#rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ordering::Equal => quote! {
|
||||
let #output = #lhs.matmul(#rhs);
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,7 +103,7 @@ mod tests {
|
|||
};
|
||||
|
||||
#[test]
|
||||
fn test_codegen_two_nodes() {
|
||||
fn test_codegen_matmul() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(MatmulNode::new(
|
||||
|
@ -99,4 +151,104 @@ mod tests {
|
|||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_matmul_matrix_vector() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(MatmulNode::new(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 1),
|
||||
TensorType::new_float("tensor3", 3),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
vec!["tensor3".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
tensor1: Tensor<B, 4>,
|
||||
tensor2: Tensor<B, 1>
|
||||
) -> Tensor<B, 3> {
|
||||
let tensor3 = tensor1.matmul(tensor2.unsqueeze_dims(&[0, -1, 0])).squeeze(3usize);
|
||||
|
||||
tensor3
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_matmul_vector_matrix() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(MatmulNode::new(
|
||||
TensorType::new_float("tensor1", 1),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
TensorType::new_float("tensor3", 3),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
vec!["tensor3".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
tensor1: Tensor<B, 1>,
|
||||
tensor2: Tensor<B, 4>
|
||||
) -> Tensor<B, 3> {
|
||||
let tensor3 = tensor1.unsqueeze_dims(&[0, 0, 0]).matmul(tensor2).squeeze(2usize);
|
||||
|
||||
tensor3
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use core::cmp::max;
|
||||
use core::panic;
|
||||
|
||||
use protobuf::Enum;
|
||||
|
@ -35,6 +36,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
|||
NodeType::Linear => linear_update_outputs(node),
|
||||
NodeType::Log => same_as_input(node),
|
||||
NodeType::LogSoftmax => same_as_input(node),
|
||||
NodeType::MatMul => matmul_update_outputs(node),
|
||||
NodeType::MaxPool2d => same_as_input(node),
|
||||
NodeType::Mul => same_as_input(node),
|
||||
NodeType::Neg => same_as_input(node),
|
||||
|
@ -442,6 +444,28 @@ fn conv_transpose2d_update_outputs(node: &mut Node) {
|
|||
}
|
||||
}
|
||||
|
||||
fn matmul_update_outputs(node: &mut Node) {
|
||||
// NOTE: matmul only supported for float tensors
|
||||
match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
|
||||
(ArgType::Tensor(a), ArgType::Tensor(b)) => {
|
||||
// With broadcasting support, output dim has to be computed based on the inputs
|
||||
let mut out_dim = max(a.dim, b.dim);
|
||||
|
||||
// Matrix-vector or vector-matrix product
|
||||
if (a.dim >= 2 && b.dim == 1) || (a.dim == 1 && b.dim >= 2) {
|
||||
out_dim -= 1;
|
||||
}
|
||||
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: a.elem_type.clone(),
|
||||
dim: out_dim,
|
||||
shape: a.shape.clone(),
|
||||
});
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Infers the shape of a ReduceMax node and replaces the shape of the output tensor.
|
||||
fn reduce_max_update_outputs(node: &mut Node) {
|
||||
if node.inputs.len() != 1 {
|
||||
|
|
Loading…
Reference in New Issue