55 lines
1.2 KiB
Python
55 lines
1.2 KiB
Python
import copy
|
|
import os
|
|
|
|
import torch
|
|
import deepspeed
|
|
|
|
|
|
local_rank = int(os.getenv('LOCAL_RANK', '0'))
|
|
world_size = int(os.getenv('WORLD_SIZE', '1'))
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
self.ml = torch.nn.ModuleList()
|
|
for _ in range(4000):
|
|
self.ml.append(torch.nn.Linear(500, 500))
|
|
|
|
def forward(self, batch):
|
|
for i, l in enumerate(self.ml):
|
|
# print(f"{i}: {l.weight.device}")
|
|
batch = l(batch)
|
|
|
|
return batch
|
|
|
|
|
|
class DummyDataset(torch.utils.data.Dataset):
|
|
def __init__(self):
|
|
self.batch = torch.rand(500, 500)
|
|
|
|
def __getitem__(self, idx):
|
|
return copy.deepcopy(self.batch)
|
|
|
|
def __len__(self):
|
|
return 1000
|
|
|
|
dd = DummyDataset()
|
|
dl = torch.utils.data.DataLoader(dd)
|
|
example = next(iter(dl)).to(f"cuda:{local_rank}")
|
|
|
|
model = Model()
|
|
model = model.to(f"cuda:{local_rank}")
|
|
|
|
model = deepspeed.init_inference(
|
|
model,
|
|
mp_size=world_size,
|
|
checkpoint=None,
|
|
replace_method=None,
|
|
#replace_method="auto"
|
|
)
|
|
|
|
out = model(example)
|
|
#if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
|
# print(out)
|