Add Rank0 variant to AdaptorRecordV1 and AdaptorRecordItemV1 (#1442)

This commit is contained in:
carrotflakes 2024-03-13 02:08:20 +09:00 committed by GitHub
parent bc39e4c7a1
commit 80aac1dde4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 0 deletions

View File

@ -8,6 +8,9 @@ use serde::{Deserialize, Serialize};
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
pub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {
/// Rank 0.
Rank0(O::State<0>),
/// Rank 1.
Rank1(O::State<1>),
@ -36,6 +39,7 @@ pub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {
impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
fn clone(&self) -> Self {
match self {
AdaptorRecordV1::Rank0(record) => AdaptorRecordV1::Rank0(record.clone()),
AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()),
AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()),
AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()),
@ -52,6 +56,9 @@ impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
/// Rank 0.
Rank0(<O::State<0> as Record<B>>::Item<S>),
/// Rank 1.
Rank1(<O::State<1> as Record<B>>::Item<S>),
@ -93,6 +100,7 @@ where
/// Panics if the state dimension is not supported.
pub fn into_state<const D: usize>(self) -> O::State<D> {
let boxed_state: Box<dyn Any> = match self {
AdaptorRecordV1::Rank0(s) => Box::new(s),
AdaptorRecordV1::Rank1(s) => Box::new(s),
AdaptorRecordV1::Rank2(s) => Box::new(s),
AdaptorRecordV1::Rank3(s) => Box::new(s),
@ -121,6 +129,7 @@ where
let state: Box<dyn Any> = Box::new(state);
match D {
0 => AdaptorRecordV1::Rank0(*state.downcast().unwrap()),
1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()),
2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()),
3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()),
@ -143,6 +152,7 @@ where
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
match self {
AdaptorRecordV1::Rank0(record) => AdaptorRecordItemV1::Rank0(record.into_item()),
AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()),
AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()),
AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()),
@ -156,6 +166,9 @@ where
fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
match item {
AdaptorRecordItemV1::Rank0(item) => {
AdaptorRecordV1::Rank0(<O::State<0> as Record<B>>::from_item(item, device))
}
AdaptorRecordItemV1::Rank1(item) => {
AdaptorRecordV1::Rank1(<O::State<1> as Record<B>>::from_item(item, device))
}