29 lines
904 B
Python
29 lines
904 B
Python
import random
|
|
import time
|
|
import datetime
|
|
import sys
|
|
import numpy as np
|
|
import jittor as jt
|
|
|
|
class ReplayBuffer:
|
|
def __init__(self, max_size=50):
|
|
assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
|
|
self.max_size = max_size
|
|
self.data = []
|
|
|
|
def push_and_pop(self, data):
|
|
to_return = []
|
|
|
|
for i in range(data.size(0)):
|
|
element = data[i:i+1]
|
|
if len(self.data) < self.max_size:
|
|
self.data.append(element)
|
|
to_return.append(element)
|
|
else:
|
|
if random.uniform(0, 1) > 0.5:
|
|
i = random.randint(0, self.max_size - 1)
|
|
to_return.append(self.data[i].clone())
|
|
self.data[i] = element
|
|
else:
|
|
to_return.append(element)
|
|
return jt.contrib.concat(to_return, dim=0) |