mirror of https://github.com/tracel-ai/burn.git
Fusion mix precision (#2247)
This commit is contained in:
parent
fc311323d9
commit
a567c6e888
|
@ -612,6 +612,7 @@ dependencies = [
|
|||
"burn-common",
|
||||
"burn-tensor",
|
||||
"derive-new",
|
||||
"half",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"serde",
|
||||
|
|
|
@ -23,6 +23,7 @@ derive-new = {workspace = true }
|
|||
spin = { workspace = true }
|
||||
log = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
half = { workspace = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
|
|
|
@ -62,7 +62,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::Random(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Random(desc.clone()),
|
||||
),
|
||||
RandomOps::<B, D>::new(desc, device.clone()),
|
||||
);
|
||||
|
||||
|
@ -92,7 +95,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
let desc = out.to_description_out();
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Zeros(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Zeros(desc.clone()),
|
||||
),
|
||||
ZerosOps::<B, D>::new(desc, device.clone()),
|
||||
);
|
||||
|
||||
|
@ -122,7 +128,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
let desc = out.to_description_out();
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Ones(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Ones(desc.clone()),
|
||||
),
|
||||
OnesOps::<B, D>::new(desc, device.clone()),
|
||||
);
|
||||
|
||||
|
@ -158,7 +167,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
let desc = (out.to_description_out(), fill_value.elem::<f32>());
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Full(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Full(desc.clone()),
|
||||
),
|
||||
FullOps::<B, D>::new(desc.0, desc.1, device.clone()),
|
||||
);
|
||||
|
||||
|
@ -226,7 +238,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::IntoInt(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::IntoInt(desc.clone()),
|
||||
),
|
||||
IntoIntOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -267,7 +282,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Add(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Add(desc.clone()),
|
||||
),
|
||||
AddOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -292,9 +310,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::AddScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::AddScalar(desc.clone()),
|
||||
),
|
||||
AddOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -334,7 +353,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Clamp(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Clamp(desc.clone()),
|
||||
),
|
||||
ClampOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -361,7 +383,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Sub(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Sub(desc.clone()),
|
||||
),
|
||||
SubOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -386,9 +411,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::SubScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::SubScalar(desc.clone()),
|
||||
),
|
||||
SubOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -415,7 +441,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Mul(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Mul(desc.clone()),
|
||||
),
|
||||
MulOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -440,9 +469,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::MulScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MulScalar(desc.clone()),
|
||||
),
|
||||
MulOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -469,7 +499,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Div(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Div(desc.clone()),
|
||||
),
|
||||
DivOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -494,9 +527,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::DivScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::DivScalar(desc.clone()),
|
||||
),
|
||||
DivOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -521,9 +555,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::RemScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::RemScalar(desc.clone()),
|
||||
),
|
||||
ModOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -554,7 +589,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::Float(FloatOperationDescription::Matmul(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Matmul(desc.clone()),
|
||||
),
|
||||
MatmulOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -680,7 +718,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Gather(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Gather(desc.clone()),
|
||||
),
|
||||
GatherOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -729,7 +770,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2, stream_3],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Scatter(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Scatter(desc.clone()),
|
||||
),
|
||||
ScatterOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -773,7 +817,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Select(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Select(desc.clone()),
|
||||
),
|
||||
SelectOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -821,9 +868,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2, stream_3],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::SelectAssign(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::SelectAssign(desc.clone()),
|
||||
),
|
||||
SelectAssignOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -966,9 +1014,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2, stream_3],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::MaskWhere(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MaskWhere(desc.clone()),
|
||||
),
|
||||
MaskWhereOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1011,7 +1060,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::MaskFill(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MaskFill(desc.clone()),
|
||||
),
|
||||
MaskFillOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1062,9 +1114,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::EqualElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::EqualElem(desc.clone()),
|
||||
),
|
||||
EqualElemOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1090,7 +1143,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Greater(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Greater(desc.clone()),
|
||||
),
|
||||
GreaterOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1115,9 +1171,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::GreaterElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::GreaterElem(desc.clone()),
|
||||
),
|
||||
GreaterElemOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1143,9 +1200,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::GreaterEqual(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::GreaterEqual(desc.clone()),
|
||||
),
|
||||
GreaterEqualOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1170,9 +1228,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::GreaterEqualElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::GreaterEqualElem(desc.clone()),
|
||||
),
|
||||
GreaterEqualElemOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1198,7 +1257,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Lower(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Lower(desc.clone()),
|
||||
),
|
||||
LowerOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1223,9 +1285,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::LowerElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::LowerElem(desc.clone()),
|
||||
),
|
||||
LowerElemOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1251,9 +1314,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::LowerEqual(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::LowerEqual(desc.clone()),
|
||||
),
|
||||
LowerEqualOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1278,9 +1342,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::LowerEqualElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::LowerEqualElem(desc.clone()),
|
||||
),
|
||||
LowerEqualElemOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1301,7 +1366,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Sum(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Sum(desc.clone()),
|
||||
),
|
||||
SumOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1328,7 +1396,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::SumDim(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::SumDim(desc.clone()),
|
||||
),
|
||||
SumDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1349,7 +1420,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Mean(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Mean(desc.clone()),
|
||||
),
|
||||
MeanOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1376,7 +1450,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::MeanDim(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MeanDim(desc.clone()),
|
||||
),
|
||||
MeanDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1397,7 +1474,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::Exp(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Exp(desc.clone()),
|
||||
),
|
||||
ExpOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1418,7 +1498,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::Log(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Log(desc.clone()),
|
||||
),
|
||||
LogOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1439,7 +1522,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::Log1p(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Log1p(desc.clone()),
|
||||
),
|
||||
Log1pOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1464,7 +1550,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::PowfScalar(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::PowfScalar(desc.clone()),
|
||||
),
|
||||
PowfOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1485,7 +1574,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::Sqrt(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Sqrt(desc.clone()),
|
||||
),
|
||||
SqrtOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1506,7 +1598,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Abs(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Abs(desc.clone()),
|
||||
),
|
||||
AbsOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1527,7 +1622,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::Cos(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Cos(desc.clone()),
|
||||
),
|
||||
CosOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1548,7 +1646,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::Sin(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Sin(desc.clone()),
|
||||
),
|
||||
SinOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1569,7 +1670,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::Tanh(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Tanh(desc.clone()),
|
||||
),
|
||||
TanhOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1589,7 +1693,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::Recip(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Recip(desc.clone()),
|
||||
),
|
||||
Recip::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1610,7 +1717,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::Float(FloatOperationDescription::Erf(desc.clone())),
|
||||
OperationDescription::Float(
|
||||
FloatElem::<Self>::dtype(),
|
||||
FloatOperationDescription::Erf(desc.clone()),
|
||||
),
|
||||
TanhOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1689,7 +1799,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::ArgMax(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::ArgMax(desc.clone()),
|
||||
),
|
||||
ArgMaxOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1759,7 +1872,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::ArgMin(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::ArgMin(desc.clone()),
|
||||
),
|
||||
ArgMinOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1780,7 +1896,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Max(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Max(desc.clone()),
|
||||
),
|
||||
MaxOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1807,7 +1926,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::MaxDim(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MaxDim(desc.clone()),
|
||||
),
|
||||
MaxDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1849,9 +1971,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::MaxDimWithIndices(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MaxDimWithIndices(desc.clone()),
|
||||
),
|
||||
MaxDimWithIndicesOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1872,7 +1995,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Min(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Min(desc.clone()),
|
||||
),
|
||||
MinOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1899,7 +2025,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::MinDim(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MinDim(desc.clone()),
|
||||
),
|
||||
MinDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1941,9 +2070,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::MinDimWithIndices(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MinDimWithIndices(desc.clone()),
|
||||
),
|
||||
MinDimWithIndicesOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1970,7 +2100,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Powf(desc.clone())),
|
||||
OperationDescription::NumericFloat(
|
||||
FloatElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Powf(desc.clone()),
|
||||
),
|
||||
PowOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
|
|
@ -255,7 +255,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2, stream_3],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MaskWhere(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MaskWhere(desc.clone()),
|
||||
),
|
||||
MaskWhereOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -298,7 +301,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MaskFill(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MaskFill(desc.clone()),
|
||||
),
|
||||
MaskFillOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -340,7 +346,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Gather(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Gather(desc.clone()),
|
||||
),
|
||||
GatherOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -387,7 +396,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2, stream_3],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Scatter(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Scatter(desc.clone()),
|
||||
),
|
||||
ScatterOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -431,7 +443,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Select(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Select(desc.clone()),
|
||||
),
|
||||
SelectOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -478,9 +493,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2, stream_3],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::SelectAssign(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::SelectAssign(desc.clone()),
|
||||
),
|
||||
SelectAssignOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -580,7 +596,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::EqualElem(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::EqualElem(desc.clone()),
|
||||
),
|
||||
EqualElemOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -606,7 +625,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Greater(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Greater(desc.clone()),
|
||||
),
|
||||
GreaterOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -631,9 +653,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::GreaterElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::GreaterElem(desc.clone()),
|
||||
),
|
||||
GreaterElemOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -659,9 +682,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::GreaterEqual(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::GreaterEqual(desc.clone()),
|
||||
),
|
||||
GreaterEqualOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -686,9 +710,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::GreaterEqualElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::GreaterEqualElem(desc.clone()),
|
||||
),
|
||||
GreaterEqualElemOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -714,7 +739,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Lower(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Lower(desc.clone()),
|
||||
),
|
||||
LowerOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -739,7 +767,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::LowerElem(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::LowerElem(desc.clone()),
|
||||
),
|
||||
LowerElemOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -765,7 +796,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::LowerEqual(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::LowerEqual(desc.clone()),
|
||||
),
|
||||
LowerEqualOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -790,9 +824,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::LowerEqualElem(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::LowerEqualElem(desc.clone()),
|
||||
),
|
||||
LowerEqualElemOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -819,7 +854,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::Add(desc.clone())),
|
||||
repr::OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Add(desc.clone()),
|
||||
),
|
||||
AddOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -844,9 +882,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::AddScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
repr::OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::AddScalar(desc.clone()),
|
||||
),
|
||||
AddOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -873,7 +912,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::Sub(desc.clone())),
|
||||
repr::OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Sub(desc.clone()),
|
||||
),
|
||||
SubOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -898,9 +940,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::SubScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
repr::OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::SubScalar(desc.clone()),
|
||||
),
|
||||
SubOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -927,7 +970,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::Mul(desc.clone())),
|
||||
repr::OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Mul(desc.clone()),
|
||||
),
|
||||
MulOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -952,9 +998,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::MulScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
repr::OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MulScalar(desc.clone()),
|
||||
),
|
||||
MulOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -981,7 +1028,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream_1, stream_2],
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::Div(desc.clone())),
|
||||
repr::OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Div(desc.clone()),
|
||||
),
|
||||
DivOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1006,9 +1056,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::DivScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
repr::OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::DivScalar(desc.clone()),
|
||||
),
|
||||
DivOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1033,9 +1084,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
repr::OperationDescription::NumericInt(NumericOperationDescription::RemScalar(
|
||||
desc.clone(),
|
||||
)),
|
||||
repr::OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::RemScalar(desc.clone()),
|
||||
),
|
||||
ModOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1064,7 +1116,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
let desc = out.to_description_out();
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Zeros(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Zeros(desc.clone()),
|
||||
),
|
||||
ZerosOps::<B, D>::new(desc, device.clone()),
|
||||
);
|
||||
|
||||
|
@ -1094,7 +1149,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
let desc = out.to_description_out();
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Ones(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Ones(desc.clone()),
|
||||
),
|
||||
OnesOps::<B, D>::new(desc, device.clone()),
|
||||
);
|
||||
|
||||
|
@ -1115,7 +1173,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Sum(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Sum(desc.clone()),
|
||||
),
|
||||
SumOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1139,7 +1200,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::SumDim(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::SumDim(desc.clone()),
|
||||
),
|
||||
SumDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1160,7 +1224,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Prod(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Prod(desc.clone()),
|
||||
),
|
||||
ProdOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1184,7 +1251,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::ProdDim(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::ProdDim(desc.clone()),
|
||||
),
|
||||
ProdDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1205,7 +1275,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Mean(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Mean(desc.clone()),
|
||||
),
|
||||
MeanOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1229,7 +1302,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MeanDim(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MeanDim(desc.clone()),
|
||||
),
|
||||
MeanDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1253,7 +1329,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::ArgMax(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::ArgMax(desc.clone()),
|
||||
),
|
||||
ArgMaxOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1277,7 +1356,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::ArgMin(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::ArgMin(desc.clone()),
|
||||
),
|
||||
ArgMinOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1316,7 +1398,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Clamp(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Clamp(desc.clone()),
|
||||
),
|
||||
ClampOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1337,7 +1422,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Abs(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Abs(desc.clone()),
|
||||
),
|
||||
AbsOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1433,7 +1521,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Max(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Max(desc.clone()),
|
||||
),
|
||||
MaxOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1457,7 +1548,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MaxDim(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MaxDim(desc.clone()),
|
||||
),
|
||||
MaxDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1498,9 +1592,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MaxDimWithIndices(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MaxDimWithIndices(desc.clone()),
|
||||
),
|
||||
MaxDimWithIndicesOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1521,7 +1616,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::Min(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::Min(desc.clone()),
|
||||
),
|
||||
MinOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1545,7 +1643,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MinDim(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MinDim(desc.clone()),
|
||||
),
|
||||
MinDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1586,9 +1687,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::MinDimWithIndices(
|
||||
desc.clone(),
|
||||
)),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::MinDimWithIndices(desc.clone()),
|
||||
),
|
||||
MinDimWithIndicesOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
|
@ -1626,7 +1728,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
client.register(
|
||||
vec![stream],
|
||||
OperationDescription::NumericInt(NumericOperationDescription::IntRandom(desc.clone())),
|
||||
OperationDescription::NumericInt(
|
||||
IntElem::<Self>::dtype(),
|
||||
NumericOperationDescription::IntRandom(desc.clone()),
|
||||
),
|
||||
IntRandomOps::<B, D>::new(desc, device.clone()),
|
||||
);
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use burn_tensor::{repr::*, Element, ElementConversion};
|
||||
use burn_tensor::{repr::*, DType, Element, ElementConversion};
|
||||
use half::{bf16, f16};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// The context contains the relative graph tensor mapping so that a relative tensor id can be
|
||||
|
@ -13,8 +14,12 @@ pub struct Context<'a, H> {
|
|||
pub tensors: &'a HashMap<TensorId, TensorDescription>,
|
||||
/// Handle container to retrieve tensors based on their description.
|
||||
pub handles: &'a mut HandleContainer<H>,
|
||||
/// Float scalars found in the graph in the order they appeared.
|
||||
pub scalar_floats: &'a Vec<f32>,
|
||||
/// F32 scalars found in the graph in the order they appeared.
|
||||
pub scalar_f32: &'a Vec<f32>,
|
||||
/// F16 scalars found in the graph in the order they appeared.
|
||||
pub scalar_f16: &'a Vec<f16>,
|
||||
/// BF16 scalars found in the graph in the order they appeared.
|
||||
pub scalar_bf16: &'a Vec<bf16>,
|
||||
/// Int scalars found in the graph in the order they appeared.
|
||||
pub scalar_ints: &'a Vec<i32>,
|
||||
}
|
||||
|
@ -26,7 +31,9 @@ pub(crate) struct OperationConverter {
|
|||
/// Only useful to create new shape ID.
|
||||
/// You should use tensor descriptions to retrieve the proper shape.
|
||||
shapes_global2relative: HashMap<usize, usize>,
|
||||
scalar_floats: Vec<f32>,
|
||||
scalar_f32: Vec<f32>,
|
||||
scalar_f16: Vec<f16>,
|
||||
scalar_bf16: Vec<bf16>,
|
||||
scalar_ints: Vec<i32>,
|
||||
}
|
||||
|
||||
|
@ -45,7 +52,9 @@ impl OperationConverter {
|
|||
Context {
|
||||
handles,
|
||||
tensors: &self.tensors_relative2global,
|
||||
scalar_floats: &self.scalar_floats,
|
||||
scalar_f32: &self.scalar_f32,
|
||||
scalar_f16: &self.scalar_f16,
|
||||
scalar_bf16: &self.scalar_bf16,
|
||||
scalar_ints: &self.scalar_ints,
|
||||
}
|
||||
}
|
||||
|
@ -54,12 +63,20 @@ impl OperationConverter {
|
|||
self.tensors_relative2global.clear();
|
||||
self.tensors_global2relative.clear();
|
||||
self.shapes_global2relative.clear();
|
||||
self.scalar_floats.clear();
|
||||
self.scalar_f32.clear();
|
||||
self.scalar_f16.clear();
|
||||
self.scalar_bf16.clear();
|
||||
self.scalar_ints.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn relative_float<E: Element>(&mut self, elem: &E) -> E {
|
||||
self.scalar_floats.push(elem.elem());
|
||||
pub(crate) fn relative_float<E: Element>(&mut self, elem: &E, dtype: &DType) -> E {
|
||||
match dtype {
|
||||
burn_tensor::DType::F32 => self.scalar_f32.push(elem.elem()),
|
||||
burn_tensor::DType::F16 => self.scalar_f16.push(elem.elem()),
|
||||
burn_tensor::DType::BF16 => self.scalar_bf16.push(elem.elem()),
|
||||
_ => todo!("Unsupported"),
|
||||
}
|
||||
|
||||
// We return 0 so that the id from a scalar operation is the same no matter its scalar
|
||||
// value.
|
||||
0.elem()
|
||||
|
@ -85,19 +102,24 @@ impl RelativeOps for OperationDescription {
|
|||
OperationDescription::BaseBool(ops) => {
|
||||
OperationDescription::BaseBool(ops.to_relative(converter))
|
||||
}
|
||||
OperationDescription::NumericFloat(ops) => OperationDescription::NumericFloat(
|
||||
ops.to_relative(converter, |converter, e| converter.relative_float(e)),
|
||||
OperationDescription::NumericFloat(dtype, ops) => OperationDescription::NumericFloat(
|
||||
*dtype,
|
||||
ops.to_relative(converter, |converter, e| converter.relative_float(e, dtype)),
|
||||
),
|
||||
OperationDescription::NumericInt(ops) => OperationDescription::NumericInt(
|
||||
OperationDescription::NumericInt(dtype, ops) => OperationDescription::NumericInt(
|
||||
*dtype,
|
||||
ops.to_relative(converter, |converter, e| converter.relative_int(e)),
|
||||
),
|
||||
OperationDescription::Bool(ops) => {
|
||||
OperationDescription::Bool(ops.to_relative(converter))
|
||||
}
|
||||
OperationDescription::Int(ops) => OperationDescription::Int(ops.to_relative(converter)),
|
||||
OperationDescription::Float(ops) => {
|
||||
OperationDescription::Float(ops.to_relative(converter))
|
||||
}
|
||||
OperationDescription::Float(dtype, ops) => OperationDescription::Float(
|
||||
*dtype,
|
||||
RelativeOpsScalar::<f32>::to_relative(ops, converter, |converter, e| {
|
||||
converter.relative_float(e, dtype)
|
||||
}),
|
||||
),
|
||||
OperationDescription::Module(ops) => {
|
||||
OperationDescription::Module(ops.to_relative(converter))
|
||||
}
|
||||
|
@ -342,8 +364,11 @@ impl RelativeOps for ModuleOperationDescription {
|
|||
}
|
||||
}
|
||||
|
||||
impl RelativeOps for FloatOperationDescription {
|
||||
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
|
||||
impl RelativeOpsScalar<f32> for FloatOperationDescription {
|
||||
fn to_relative<F>(&self, converter: &mut OperationConverter, local_elem: F) -> Self
|
||||
where
|
||||
F: Fn(&mut OperationConverter, &f32) -> f32,
|
||||
{
|
||||
match self {
|
||||
FloatOperationDescription::Exp(desc) => {
|
||||
FloatOperationDescription::Exp(UnaryOperationDescription {
|
||||
|
@ -372,7 +397,7 @@ impl RelativeOps for FloatOperationDescription {
|
|||
FloatOperationDescription::PowfScalar(desc) => {
|
||||
FloatOperationDescription::PowfScalar(ScalarOperationDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: converter.relative_float(&desc.rhs),
|
||||
rhs: local_elem(converter, &desc.rhs.elem()),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
|
|
|
@ -549,10 +549,10 @@ mod tests {
|
|||
// Out node.
|
||||
self.new_empty_node(out_id);
|
||||
|
||||
self.operations
|
||||
.push(OperationDescription::Float(FloatOperationDescription::Log(
|
||||
self.unary_description(),
|
||||
)));
|
||||
self.operations.push(OperationDescription::Float(
|
||||
DType::F32,
|
||||
FloatOperationDescription::Log(self.unary_description()),
|
||||
));
|
||||
}
|
||||
|
||||
fn new_empty_node(&mut self, id: u64) {
|
||||
|
|
|
@ -520,8 +520,9 @@ impl<'i> StreamSegment<TestOptimization> for TestSegment<'i> {
|
|||
|
||||
/// Just a simple operation.
|
||||
fn operation_1() -> OperationDescription {
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Add(
|
||||
BinaryOperationDescription {
|
||||
OperationDescription::NumericFloat(
|
||||
DType::F32,
|
||||
NumericOperationDescription::Add(BinaryOperationDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
|
@ -540,14 +541,15 @@ fn operation_1() -> OperationDescription {
|
|||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
},
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Just a simple operation.
|
||||
fn operation_2() -> OperationDescription {
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::AddScalar(
|
||||
ScalarOperationDescription {
|
||||
OperationDescription::NumericFloat(
|
||||
DType::F32,
|
||||
NumericOperationDescription::AddScalar(ScalarOperationDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
|
@ -561,24 +563,27 @@ fn operation_2() -> OperationDescription {
|
|||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
},
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Just a simple operation.
|
||||
fn operation_3() -> OperationDescription {
|
||||
OperationDescription::Float(FloatOperationDescription::Log(UnaryOperationDescription {
|
||||
input: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
out: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
}))
|
||||
OperationDescription::Float(
|
||||
DType::F32,
|
||||
FloatOperationDescription::Log(UnaryOperationDescription {
|
||||
input: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::ReadOnly,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
out: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -218,8 +218,9 @@ mod tests {
|
|||
}
|
||||
|
||||
fn ops_1() -> OperationDescription {
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Add(
|
||||
BinaryOperationDescription {
|
||||
OperationDescription::NumericFloat(
|
||||
DType::F32,
|
||||
NumericOperationDescription::Add(BinaryOperationDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
|
@ -238,13 +239,14 @@ mod tests {
|
|||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
},
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn ops_2() -> OperationDescription {
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::AddScalar(
|
||||
ScalarOperationDescription {
|
||||
OperationDescription::NumericFloat(
|
||||
DType::F32,
|
||||
NumericOperationDescription::AddScalar(ScalarOperationDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
|
@ -258,13 +260,14 @@ mod tests {
|
|||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
},
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn ops_3() -> OperationDescription {
|
||||
OperationDescription::NumericFloat(NumericOperationDescription::Sub(
|
||||
BinaryOperationDescription {
|
||||
OperationDescription::NumericFloat(
|
||||
DType::F32,
|
||||
NumericOperationDescription::Sub(BinaryOperationDescription {
|
||||
lhs: TensorDescription {
|
||||
id: TensorId::new(0),
|
||||
shape: vec![32, 32],
|
||||
|
@ -283,7 +286,7 @@ mod tests {
|
|||
status: TensorStatus::NotInit,
|
||||
dtype: DType::F32,
|
||||
},
|
||||
},
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,19 +46,19 @@ impl<R: JitRuntime> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuild
|
|||
return;
|
||||
}
|
||||
}
|
||||
OperationDescription::Float(ops) => {
|
||||
OperationDescription::Float(_dtype, ops) => {
|
||||
if !self.register_float(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationDescription::NumericFloat(ops) => {
|
||||
OperationDescription::NumericFloat(_dtype, ops) => {
|
||||
if !self.register_numeric::<f32>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
}
|
||||
}
|
||||
OperationDescription::NumericInt(ops) => {
|
||||
OperationDescription::NumericInt(_dtype, ops) => {
|
||||
if !self.register_numeric::<i32>(ops) {
|
||||
self.status = OptimizationStatus::Closed;
|
||||
return;
|
||||
|
@ -390,8 +390,9 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
|
|||
return false;
|
||||
}
|
||||
|
||||
let elem = desc.lhs.dtype.into();
|
||||
let lhs = self.builder.input(&desc.lhs, Variable::AbsolutePos);
|
||||
let rhs = self.builder.scalar(&desc.rhs, desc.lhs.dtype.into());
|
||||
let rhs = self.builder.scalar(&desc.rhs, elem);
|
||||
let out = self.builder.output(&desc.out, Variable::AbsolutePos);
|
||||
|
||||
self.builder.register_operation(func(lhs, rhs, out));
|
||||
|
|
|
@ -148,7 +148,13 @@ impl<R: JitRuntime> FusionKernel<R> {
|
|||
let info_size = (num_tensors * rank * 2) + 1;
|
||||
|
||||
let mut num_handles = num_tensors + 1;
|
||||
if running_info.scalars.num_float > 0 {
|
||||
if running_info.scalars.num_f32 > 0 {
|
||||
num_handles += 1;
|
||||
}
|
||||
if running_info.scalars.num_f16 > 0 {
|
||||
num_handles += 1;
|
||||
}
|
||||
if running_info.scalars.num_bf16 > 0 {
|
||||
num_handles += 1;
|
||||
}
|
||||
if running_info.scalars.num_int > 0 {
|
||||
|
@ -216,14 +222,20 @@ impl<R: JitRuntime> FusionKernel<R> {
|
|||
bindings.push(client.create(bytemuck::cast_slice(&info)).binding());
|
||||
|
||||
// Finally we finish with the named bindings.
|
||||
if running_info.scalars.num_float > 0 {
|
||||
bindings.push(
|
||||
client
|
||||
.create(bytemuck::cast_slice(
|
||||
&context.scalar_floats[0..running_info.scalars.num_float],
|
||||
))
|
||||
.binding(),
|
||||
);
|
||||
if running_info.scalars.num_f32 > 0 {
|
||||
let bytes = bytemuck::cast_slice(&context.scalar_f32[0..running_info.scalars.num_f32]);
|
||||
bindings.push(client.create(bytes).binding());
|
||||
}
|
||||
|
||||
if running_info.scalars.num_f16 > 0 {
|
||||
let bytes = bytemuck::cast_slice(&context.scalar_f16[0..running_info.scalars.num_f16]);
|
||||
bindings.push(client.create(bytes).binding());
|
||||
}
|
||||
|
||||
if running_info.scalars.num_bf16 > 0 {
|
||||
let bytes =
|
||||
bytemuck::cast_slice(&context.scalar_bf16[0..running_info.scalars.num_bf16]);
|
||||
bindings.push(client.create(bytes).binding());
|
||||
}
|
||||
|
||||
if running_info.scalars.num_int > 0 {
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Default, Clone, Serialize, Deserialize)]
|
||||
pub struct Scalars {
|
||||
pub(crate) num_float: usize,
|
||||
pub(crate) num_f32: usize,
|
||||
pub(crate) num_f16: usize,
|
||||
pub(crate) num_bf16: usize,
|
||||
pub(crate) num_int: usize,
|
||||
pub(crate) num_uint: usize,
|
||||
pub(crate) num_bool: usize,
|
||||
|
|
|
@ -103,13 +103,33 @@ impl TraceBuilder {
|
|||
/// Create a variable from an input [scalar](Element).
|
||||
pub fn scalar<E: Element>(&mut self, _value: &E, elem_type: Elem) -> Variable {
|
||||
match elem_type {
|
||||
Elem::Float(_) => {
|
||||
let var = self
|
||||
.scope
|
||||
.read_scalar(self.scalars.num_float as u16, elem_type);
|
||||
self.scalars.num_float += 1;
|
||||
var
|
||||
}
|
||||
Elem::Float(kind) => match kind {
|
||||
cubecl::ir::FloatKind::F16 => {
|
||||
let var = self
|
||||
.scope
|
||||
.read_scalar(self.scalars.num_f16 as u16, elem_type);
|
||||
|
||||
self.scalars.num_f16 += 1;
|
||||
var
|
||||
}
|
||||
cubecl::ir::FloatKind::F32 => {
|
||||
let var = self
|
||||
.scope
|
||||
.read_scalar(self.scalars.num_f32 as u16, elem_type);
|
||||
|
||||
self.scalars.num_f32 += 1;
|
||||
var
|
||||
}
|
||||
cubecl::ir::FloatKind::BF16 => {
|
||||
let var = self
|
||||
.scope
|
||||
.read_scalar(self.scalars.num_bf16 as u16, elem_type);
|
||||
|
||||
self.scalars.num_bf16 += 1;
|
||||
var
|
||||
}
|
||||
cubecl::ir::FloatKind::F64 => todo!(),
|
||||
},
|
||||
Elem::Int(_) => {
|
||||
let var = self
|
||||
.scope
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::Scalars;
|
||||
use burn_tensor::repr::TensorDescription;
|
||||
use cubecl::{
|
||||
ir::{Elem, FloatKind, IntKind, Item, Scope, Variable, Visibility},
|
||||
ir::{Elem, IntKind, Item, Scope, Variable, Visibility},
|
||||
InputInfo, KernelExpansion, OutputInfo,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -64,10 +64,23 @@ impl Trace {
|
|||
.collect::<Vec<_>>();
|
||||
|
||||
// NOTE: we might want to pass a struct including all inputs/outputs metadata instead of 3 arrays
|
||||
if self.scalars.num_float > 0 {
|
||||
if self.scalars.num_f32 > 0 {
|
||||
inputs.push(InputInfo::Scalar {
|
||||
elem: Elem::Float(FloatKind::F32),
|
||||
size: self.scalars.num_float,
|
||||
elem: Elem::Float(cubecl::ir::FloatKind::F32),
|
||||
size: self.scalars.num_f32,
|
||||
})
|
||||
}
|
||||
if self.scalars.num_f16 > 0 {
|
||||
inputs.push(InputInfo::Scalar {
|
||||
elem: Elem::Float(cubecl::ir::FloatKind::F16),
|
||||
size: self.scalars.num_f16,
|
||||
})
|
||||
}
|
||||
|
||||
if self.scalars.num_bf16 > 0 {
|
||||
inputs.push(InputInfo::Scalar {
|
||||
elem: Elem::Float(cubecl::ir::FloatKind::BF16),
|
||||
size: self.scalars.num_bf16,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ use std::ops::Range;
|
|||
use crate::{
|
||||
ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions},
|
||||
repr::tensor::TensorDescription,
|
||||
Distribution, Element,
|
||||
DType, Distribution, Element,
|
||||
};
|
||||
|
||||
/// Describe all tensor operations possible.
|
||||
|
@ -17,15 +17,15 @@ pub enum OperationDescription {
|
|||
/// Basic operation on a bool tensor.
|
||||
BaseBool(BaseOperationDescription),
|
||||
/// Numeric operation on a float tensor.
|
||||
NumericFloat(NumericOperationDescription<f32>),
|
||||
NumericFloat(DType, NumericOperationDescription<f32>),
|
||||
/// Numeric operation on an int tensor.
|
||||
NumericInt(NumericOperationDescription<i32>),
|
||||
NumericInt(DType, NumericOperationDescription<i32>),
|
||||
/// Operation specific to a bool tensor.
|
||||
Bool(BoolOperationDescription),
|
||||
/// Operation specific to an int tensor.
|
||||
Int(IntOperationDescription),
|
||||
/// Operation specific to a float tensor.
|
||||
Float(FloatOperationDescription),
|
||||
Float(DType, FloatOperationDescription),
|
||||
/// Module operation.
|
||||
Module(ModuleOperationDescription),
|
||||
}
|
||||
|
@ -1149,11 +1149,11 @@ impl OperationDescription {
|
|||
OperationDescription::BaseFloat(ops) => ops.nodes(),
|
||||
OperationDescription::BaseInt(ops) => ops.nodes(),
|
||||
OperationDescription::BaseBool(ops) => ops.nodes(),
|
||||
OperationDescription::NumericFloat(ops) => ops.nodes(),
|
||||
OperationDescription::NumericInt(ops) => ops.nodes(),
|
||||
OperationDescription::NumericFloat(_dtype, ops) => ops.nodes(),
|
||||
OperationDescription::NumericInt(_dtype, ops) => ops.nodes(),
|
||||
OperationDescription::Bool(ops) => ops.nodes(),
|
||||
OperationDescription::Int(ops) => ops.nodes(),
|
||||
OperationDescription::Float(ops) => ops.nodes(),
|
||||
OperationDescription::Float(_dtype, ops) => ops.nodes(),
|
||||
OperationDescription::Module(ops) => ops.nodes(),
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue