mirror of https://github.com/tracel-ai/burn.git
Add Rank0 variant to AdaptorRecordV1 and AdaptorRecordItemV1 (#1442)
This commit is contained in:
parent
bc39e4c7a1
commit
80aac1dde4
|
@ -8,6 +8,9 @@ use serde::{Deserialize, Serialize};
|
|||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
|
||||
pub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {
|
||||
/// Rank 0.
|
||||
Rank0(O::State<0>),
|
||||
|
||||
/// Rank 1.
|
||||
Rank1(O::State<1>),
|
||||
|
||||
|
@ -36,6 +39,7 @@ pub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {
|
|||
impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
AdaptorRecordV1::Rank0(record) => AdaptorRecordV1::Rank0(record.clone()),
|
||||
AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()),
|
||||
AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()),
|
||||
AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()),
|
||||
|
@ -52,6 +56,9 @@ impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
|
|||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(bound = "")]
|
||||
pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
|
||||
/// Rank 0.
|
||||
Rank0(<O::State<0> as Record<B>>::Item<S>),
|
||||
|
||||
/// Rank 1.
|
||||
Rank1(<O::State<1> as Record<B>>::Item<S>),
|
||||
|
||||
|
@ -93,6 +100,7 @@ where
|
|||
/// Panics if the state dimension is not supported.
|
||||
pub fn into_state<const D: usize>(self) -> O::State<D> {
|
||||
let boxed_state: Box<dyn Any> = match self {
|
||||
AdaptorRecordV1::Rank0(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank1(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank2(s) => Box::new(s),
|
||||
AdaptorRecordV1::Rank3(s) => Box::new(s),
|
||||
|
@ -121,6 +129,7 @@ where
|
|||
let state: Box<dyn Any> = Box::new(state);
|
||||
|
||||
match D {
|
||||
0 => AdaptorRecordV1::Rank0(*state.downcast().unwrap()),
|
||||
1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()),
|
||||
2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()),
|
||||
3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()),
|
||||
|
@ -143,6 +152,7 @@ where
|
|||
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
|
||||
match self {
|
||||
AdaptorRecordV1::Rank0(record) => AdaptorRecordItemV1::Rank0(record.into_item()),
|
||||
AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()),
|
||||
AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()),
|
||||
AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()),
|
||||
|
@ -156,6 +166,9 @@ where
|
|||
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
|
||||
match item {
|
||||
AdaptorRecordItemV1::Rank0(item) => {
|
||||
AdaptorRecordV1::Rank0(<O::State<0> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
AdaptorRecordItemV1::Rank1(item) => {
|
||||
AdaptorRecordV1::Rank1(<O::State<1> as Record<B>>::from_item(item, device))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue