openfold/scripts/deepspeed_inference_test.py

55 lines
1.2 KiB
Python

import copy
import os
import torch
import deepspeed
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.ml = torch.nn.ModuleList()
for _ in range(4000):
self.ml.append(torch.nn.Linear(500, 500))
def forward(self, batch):
for i, l in enumerate(self.ml):
# print(f"{i}: {l.weight.device}")
batch = l(batch)
return batch
class DummyDataset(torch.utils.data.Dataset):
def __init__(self):
self.batch = torch.rand(500, 500)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
dd = DummyDataset()
dl = torch.utils.data.DataLoader(dd)
example = next(iter(dl)).to(f"cuda:{local_rank}")
model = Model()
model = model.to(f"cuda:{local_rank}")
model = deepspeed.init_inference(
model,
mp_size=world_size,
checkpoint=None,
replace_method=None,
#replace_method="auto"
)
out = model(example)
#if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# print(out)