ONNX Tile operation (#2092)

* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error

* implementing tile onnx file

* temp

* working implementation and test

* working e2e test

* adding new supported onnx operation to the md file
This commit is contained in:
mepatrick73 2024-08-07 17:43:59 -04:00 committed by GitHub
parent 6b61ad5a61
commit d770b1f470
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 222 additions and 4 deletions

View File

@ -191,7 +191,7 @@ represent the corresponding Burn Op.
| [Tanh][182] | ✅ | ✅ |
| [TfIdfVectorizer][183] | ❌ | ❌ |
| [ThresholdedRelu][184] | ❌ | ❌ |
| [Tile][185] | | ✅ |
| [Tile][185] | | ✅ |
| [TopK][186] | ❌ | ✅ |
| [Transpose][187] | ✅ | ✅ |
| [Trilu][188] | ❌ | ✅ |

View File

@ -93,6 +93,7 @@ fn main() {
.input("tests/sum/sum.onnx")
.input("tests/sum/sum_int.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/tile/tile.onnx")
.input("tests/transpose/transpose.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")

View File

@ -102,6 +102,7 @@ include_models!(
sum,
sum_int,
tanh,
tile,
transpose,
unsqueeze,
unsqueeze_opset11,
@ -1712,6 +1713,23 @@ mod tests {
output.to_data().assert_eq(&expected, true);
}
#[test]
fn tile() {
let device = Default::default();
let model: tile::Model<Backend> = tile::Model::new(&device);
let input = Tensor::<Backend, 2>::from_floats([[1., 2.], [3., 4.]], &device);
let output = model.forward(input).to_data();
let expected = TensorData::from([
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
]);
output.assert_eq(&expected, true);
}
#[test]
fn unsqueeze() {
let device = Default::default();

Binary file not shown.

View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
import onnx
import onnx.helper
import onnx.checker
def build_model():
# Define the input tensor as a graph input
input_tensor = onnx.helper.make_tensor_value_info(
name="input_tensor",
elem_type=onnx.TensorProto.FLOAT,
shape=[2, 2]
)
output_tensor = onnx.helper.make_tensor_value_info(
name="output_tensor",
elem_type=onnx.TensorProto.FLOAT,
shape=[4, 4]
)
# Define the shape tensor for tiling as an initializer
shape_tensor = onnx.helper.make_tensor(
name="shape_tensor",
data_type=onnx.TensorProto.INT64,
dims=[2],
vals=[2, 2]
)
# Create the Tile node
tile_node = onnx.helper.make_node(
"Tile",
inputs=["input_tensor", "shape_tensor"],
outputs=["output_tensor"]
)
# Build the graph
graph = onnx.helper.make_graph(
nodes=[tile_node],
name="main_graph",
inputs=[input_tensor],
outputs=[output_tensor],
initializer=[shape_tensor]
)
# Build the model
model = onnx.helper.make_model(
graph,
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)]
)
return model
def main():
onnx_model = build_model()
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
file_name = "tile.onnx"
onnx.save(onnx_model, file_name)
onnx.checker.check_model(onnx_model)
print(f"ONNX model saved as {file_name}")
if __name__ == "__main__":
main()

View File

@ -11,7 +11,7 @@ use super::{
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
@ -113,6 +113,7 @@ pub enum Node<PS: PrecisionSettings> {
Slice(SliceNode),
Squeeze(SqueezeNode),
Sum(SumNode),
Tile(TileNode),
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
@ -160,6 +161,7 @@ macro_rules! match_all {
Node::Slice(node) => $func(node),
Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node),
Node::Tile(node) => $func(node),
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
@ -215,6 +217,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Slice(_) => "slice",
Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add",
Node::Tile(_) => "tile",
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",

View File

@ -36,6 +36,7 @@ pub(crate) mod resize;
pub(crate) mod slice;
pub(crate) mod squeeze;
pub(crate) mod sum;
pub(crate) mod tile;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
pub(crate) use base::*;

View File

@ -0,0 +1,97 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
#[derive(Config, Debug)]
pub struct TileConfig {
pub repeats: Vec<usize>,
}
#[derive(Debug, Clone, new)]
pub struct TileNode {
pub input: TensorType,
pub output: TensorType,
pub config: TileConfig,
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for TileNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let repeats = self.config.repeats.iter().map(|r| r.to_tokens());
quote! {
let #output = #input.repeat(&[#(#repeats),*]);
}
}
fn into_node(self) -> Node<PS> {
Node::Tile(self)
}
}
#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{test::assert_tokens, tile::TileConfig, tile::TileNode},
TensorType,
};
#[test]
fn test_codegen_tile() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = TileConfig::new(vec![2, 3, 4]);
graph.register(TileNode::new(
TensorType::new_float("input", 3),
TensorType::new_float("output", 3),
config,
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output = input.repeat(&[2, 3, 4]);
output
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -7,7 +7,7 @@ use burn::nn::{
PaddingConfig2d, PaddingConfig3d,
};
use crate::burn::node::pad::PadConfig;
use crate::burn::node::{pad::PadConfig, tile::TileConfig};
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
/// Create a Conv1dConfig from the attributes of the node
@ -745,6 +745,26 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) {
)
}
/// Create a TileConfig from the attributes of the node
pub fn tile_config(node: &Node) -> TileConfig {
let repeat = node
.inputs
.get(1)
.map(|input| {
if let Some(data) = &input.value {
data.clone()
.into_i64s()
.iter()
.map(|&x| x as usize)
.collect()
} else {
vec![]
}
})
.unwrap_or_default();
TileConfig::new(repeat)
}
/// Create a PadConfig from the attributes of the node
pub fn pad_config(node: &Node) -> PadConfig {
fn get_pads(node: &Node) -> Vec<usize> {

View File

@ -50,6 +50,7 @@ use crate::{
slice::SliceNode,
squeeze::SqueezeNode,
sum::SumNode,
tile::TileNode,
unary::UnaryNode,
unsqueeze::UnsqueezeNode,
},
@ -66,7 +67,8 @@ use super::op_configuration::{
hard_sigmoid_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config,
max_pool1d_config, max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config,
reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config,
shape_config, slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config,
shape_config, slice_config, softmax_config, squeeze_config, tile_config, transpose_config,
unsqueeze_config,
};
use onnx_ir::{
convert_constant_value,
@ -335,6 +337,7 @@ impl ParsedOnnxGraph {
NodeType::Sign => graph.register(Self::sign_conversion(node)),
NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)),
NodeType::RandomUniform => graph.register(Self::random_uniform_conversion(node)),
NodeType::Tile => graph.register(Self::tile_conversion(node)),
NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)),
NodeType::ConstantOfShape => {
graph.register(Self::constant_of_shape_conversion(node))
@ -1167,6 +1170,14 @@ impl ParsedOnnxGraph {
SqueezeNode::new(input, output, axes)
}
fn tile_conversion(node: Node) -> TileNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = tile_config(&node);
TileNode::new(input, output, config)
}
}
/// Extract data from node states and convert it to `TensorData`.