mirror of https://github.com/open-mmlab/mmengine
75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
from unittest.mock import MagicMock
|
|
|
|
import torch
|
|
import torch.distributed as torch_dist
|
|
import torch.nn as nn
|
|
|
|
from mmengine.dist import all_gather
|
|
from mmengine.hooks import SyncBuffersHook
|
|
from mmengine.registry import MODELS
|
|
from mmengine.testing._internal import MultiProcessTestCase
|
|
from mmengine.testing.runner_test_case import RunnerTestCase, ToyModel
|
|
|
|
|
|
class ToyModuleWithNorm(ToyModel):
|
|
|
|
def __init__(self, data_preprocessor=None):
|
|
super().__init__(data_preprocessor=data_preprocessor)
|
|
bn = nn.BatchNorm1d(2)
|
|
self.linear1 = nn.Sequential(self.linear1, bn)
|
|
|
|
def init_weights(self):
|
|
for buffer in self.buffers():
|
|
buffer.fill_(
|
|
torch.tensor(int(os.environ['RANK']), dtype=torch.float32))
|
|
return super().init_weights()
|
|
|
|
|
|
class TestSyncBuffersHook(MultiProcessTestCase, RunnerTestCase):
|
|
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
def prepare_subprocess(self):
|
|
MODELS.register_module(module=ToyModuleWithNorm, force=True)
|
|
super(MultiProcessTestCase, self).setUp()
|
|
|
|
def test_sync_buffers_hook(self):
|
|
self.setup_dist_env()
|
|
runner = MagicMock()
|
|
runner.model = ToyModuleWithNorm()
|
|
runner.model.init_weights()
|
|
|
|
for buffer in runner.model.buffers():
|
|
buffer1, buffer2 = all_gather(buffer)
|
|
self.assertFalse(torch.allclose(buffer1, buffer2))
|
|
|
|
hook = SyncBuffersHook()
|
|
hook.after_train_epoch(runner)
|
|
|
|
for buffer in runner.model.buffers():
|
|
buffer1, buffer2 = all_gather(buffer)
|
|
self.assertTrue(torch.allclose(buffer1, buffer2))
|
|
|
|
def test_with_runner(self):
|
|
self.setup_dist_env()
|
|
cfg = self.epoch_based_cfg
|
|
cfg.model = dict(type='ToyModuleWithNorm')
|
|
cfg.launch = 'pytorch'
|
|
cfg.custom_hooks = [dict(type='SyncBuffersHook')]
|
|
runner = self.build_runner(cfg)
|
|
runner.train()
|
|
|
|
for buffer in runner.model.buffers():
|
|
buffer1, buffer2 = all_gather(buffer)
|
|
self.assertTrue(torch.allclose(buffer1, buffer2))
|
|
|
|
def setup_dist_env(self):
|
|
super().setup_dist_env()
|
|
os.environ['RANK'] = str(self.rank)
|
|
torch_dist.init_process_group(
|
|
backend='gloo', rank=self.rank, world_size=self.world_size)
|