JGAN/models/cyclegan/utils.py

29 lines
904 B
Python

import random
import time
import datetime
import sys
import numpy as np
import jittor as jt
class ReplayBuffer:
def __init__(self, max_size=50):
assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
self.max_size = max_size
self.data = []
def push_and_pop(self, data):
to_return = []
for i in range(data.size(0)):
element = data[i:i+1]
if len(self.data) < self.max_size:
self.data.append(element)
to_return.append(element)
else:
if random.uniform(0, 1) > 0.5:
i = random.randint(0, self.max_size - 1)
to_return.append(self.data[i].clone())
self.data[i] = element
else:
to_return.append(element)
return jt.contrib.concat(to_return, dim=0)