diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index b5cfd5519..a9517c04c 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -191,7 +191,7 @@ represent the corresponding Burn Op. | [Tanh][182] | ✅ | ✅ | | [TfIdfVectorizer][183] | ❌ | ❌ | | [ThresholdedRelu][184] | ❌ | ❌ | -| [Tile][185] | ❌ | ✅ | +| [Tile][185] | ✅ | ✅ | | [TopK][186] | ❌ | ✅ | | [Transpose][187] | ✅ | ✅ | | [Trilu][188] | ❌ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index a41f3df3c..457019318 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -93,6 +93,7 @@ fn main() { .input("tests/sum/sum.onnx") .input("tests/sum/sum_int.onnx") .input("tests/tanh/tanh.onnx") + .input("tests/tile/tile.onnx") .input("tests/transpose/transpose.onnx") .input("tests/unsqueeze/unsqueeze.onnx") .input("tests/unsqueeze/unsqueeze_opset11.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index a108aaf4f..6bdaaaab1 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -102,6 +102,7 @@ include_models!( sum, sum_int, tanh, + tile, transpose, unsqueeze, unsqueeze_opset11, @@ -1712,6 +1713,23 @@ mod tests { output.to_data().assert_eq(&expected, true); } + #[test] + fn tile() { + let device = Default::default(); + let model: tile::Model = tile::Model::new(&device); + + let input = Tensor::::from_floats([[1., 2.], [3., 4.]], &device); + let output = model.forward(input).to_data(); + let expected = TensorData::from([ + [1.0f32, 2.0f32, 1.0f32, 2.0f32], + [3.0f32, 4.0f32, 3.0f32, 4.0f32], + [1.0f32, 2.0f32, 1.0f32, 2.0f32], + [3.0f32, 4.0f32, 3.0f32, 4.0f32], + ]); + + output.assert_eq(&expected, true); + } + #[test] fn unsqueeze() { let device = Default::default(); diff --git a/crates/burn-import/onnx-tests/tests/tile/tile.onnx b/crates/burn-import/onnx-tests/tests/tile/tile.onnx new file mode 100644 index 000000000..a2162746e Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/tile/tile.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/tile/tile.py b/crates/burn-import/onnx-tests/tests/tile/tile.py new file mode 100644 index 000000000..5067bdb80 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/tile/tile.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +import onnx +import onnx.helper +import onnx.checker + + +def build_model(): + # Define the input tensor as a graph input + input_tensor = onnx.helper.make_tensor_value_info( + name="input_tensor", + elem_type=onnx.TensorProto.FLOAT, + shape=[2, 2] + ) + + output_tensor = onnx.helper.make_tensor_value_info( + name="output_tensor", + elem_type=onnx.TensorProto.FLOAT, + shape=[4, 4] + ) + + # Define the shape tensor for tiling as an initializer + shape_tensor = onnx.helper.make_tensor( + name="shape_tensor", + data_type=onnx.TensorProto.INT64, + dims=[2], + vals=[2, 2] + ) + # Create the Tile node + tile_node = onnx.helper.make_node( + "Tile", + inputs=["input_tensor", "shape_tensor"], + outputs=["output_tensor"] + ) + + # Build the graph + graph = onnx.helper.make_graph( + nodes=[tile_node], + name="main_graph", + inputs=[input_tensor], + outputs=[output_tensor], + initializer=[shape_tensor] + ) + + # Build the model + model = onnx.helper.make_model( + graph, + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 16)] + ) + + return model + + +def main(): + onnx_model = build_model() + + onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + + file_name = "tile.onnx" + onnx.save(onnx_model, file_name) + onnx.checker.check_model(onnx_model) + print(f"ONNX model saved as {file_name}") + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 46e1d5e1a..a1c9103b4 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -11,7 +11,7 @@ use super::{ max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, - squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -113,6 +113,7 @@ pub enum Node { Slice(SliceNode), Squeeze(SqueezeNode), Sum(SumNode), + Tile(TileNode), Unary(UnaryNode), Unsqueeze(UnsqueezeNode), Where(WhereNode), @@ -160,6 +161,7 @@ macro_rules! match_all { Node::Slice(node) => $func(node), Node::Squeeze(node) => $func(node), Node::Sum(node) => $func(node), + Node::Tile(node) => $func(node), Node::Unary(node) => $func(node), Node::Unsqueeze(node) => $func(node), Node::Where(node) => $func(node), @@ -215,6 +217,7 @@ impl Node { Node::Slice(_) => "slice", Node::Squeeze(_) => "squeeze", Node::Sum(_) => "add", + Node::Tile(_) => "tile", Node::Unary(unary) => unary.kind.as_str(), Node::Unsqueeze(_) => "unsqueeze", Node::Where(_) => "where", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 875e3e5af..ee294ddfd 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -36,6 +36,7 @@ pub(crate) mod resize; pub(crate) mod slice; pub(crate) mod squeeze; pub(crate) mod sum; +pub(crate) mod tile; pub(crate) mod unary; pub(crate) mod unsqueeze; pub(crate) use base::*; diff --git a/crates/burn-import/src/burn/node/tile.rs b/crates/burn-import/src/burn/node/tile.rs new file mode 100644 index 000000000..cf56f7ad2 --- /dev/null +++ b/crates/burn-import/src/burn/node/tile.rs @@ -0,0 +1,97 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; +use burn::config::Config; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Config, Debug)] +pub struct TileConfig { + pub repeats: Vec, +} + +#[derive(Debug, Clone, new)] +pub struct TileNode { + pub input: TensorType, + pub output: TensorType, + pub config: TileConfig, +} + +impl NodeCodegen for TileNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + + let repeats = self.config.repeats.iter().map(|r| r.to_tokens()); + + quote! { + let #output = #input.repeat(&[#(#repeats),*]); + } + } + + fn into_node(self) -> Node { + Node::Tile(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{test::assert_tokens, tile::TileConfig, tile::TileNode}, + TensorType, + }; + + #[test] + fn test_codegen_tile() { + let mut graph = BurnGraph::::default(); + let config = TileConfig::new(vec![2, 3, 4]); + graph.register(TileNode::new( + TensorType::new_float("input", 3), + TensorType::new_float("output", 3), + config, + )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.repeat(&[2, 3, 4]); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 586493ebe..d701150c8 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -7,7 +7,7 @@ use burn::nn::{ PaddingConfig2d, PaddingConfig3d, }; -use crate::burn::node::pad::PadConfig; +use crate::burn::node::{pad::PadConfig, tile::TileConfig}; use onnx_ir::ir::{ArgType, AttributeValue, Data, Node}; /// Create a Conv1dConfig from the attributes of the node @@ -745,6 +745,26 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { ) } +/// Create a TileConfig from the attributes of the node +pub fn tile_config(node: &Node) -> TileConfig { + let repeat = node + .inputs + .get(1) + .map(|input| { + if let Some(data) = &input.value { + data.clone() + .into_i64s() + .iter() + .map(|&x| x as usize) + .collect() + } else { + vec![] + } + }) + .unwrap_or_default(); + TileConfig::new(repeat) +} + /// Create a PadConfig from the attributes of the node pub fn pad_config(node: &Node) -> PadConfig { fn get_pads(node: &Node) -> Vec { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 39e9a428c..18d2981e8 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -50,6 +50,7 @@ use crate::{ slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, + tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }, @@ -66,7 +67,8 @@ use super::op_configuration::{ hard_sigmoid_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config, - shape_config, slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config, + shape_config, slice_config, softmax_config, squeeze_config, tile_config, transpose_config, + unsqueeze_config, }; use onnx_ir::{ convert_constant_value, @@ -335,6 +337,7 @@ impl ParsedOnnxGraph { NodeType::Sign => graph.register(Self::sign_conversion(node)), NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), NodeType::RandomUniform => graph.register(Self::random_uniform_conversion(node)), + NodeType::Tile => graph.register(Self::tile_conversion(node)), NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)), NodeType::ConstantOfShape => { graph.register(Self::constant_of_shape_conversion(node)) @@ -1167,6 +1170,14 @@ impl ParsedOnnxGraph { SqueezeNode::new(input, output, axes) } + + fn tile_conversion(node: Node) -> TileNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + let config = tile_config(&node); + + TileNode::new(input, output, config) + } } /// Extract data from node states and convert it to `TensorData`.