scripts: add some basic filter

This commit is contained in:
LinJiawei 2021-09-07 15:08:27 +08:00
parent 9afb08ceaf
commit b1ab4381ef
1 changed files with 50 additions and 7 deletions

View File

@ -1,4 +1,5 @@
import itertools
import inspect
from enum import Enum, unique
@unique
@ -27,10 +28,11 @@ class HitState(Enum):
HIT = 1
class ClientDir:
def __init__(self, tl_state):
def __init__(self, tl_state, hit_state):
self.tl_state = tl_state
self.hit_state = hit_state
def __str__(self):
return (f"{self.tl_state}")
return (f"{self.tl_state} {self.hit_state}")
class SelfDir:
def __init__(self, tl_state, dirty_state, hit_state, client_tl_states):
@ -111,20 +113,61 @@ client_fields = []
for i in range(0, NUM_CLIENTS):
self_fields.append(list(TLState))
client_fields.append(list(TLState))
client_fields.append(list(HitState))
def get_all_states():
def get_all_states(log = False):
all_states = []
for x in itertools.product(*self_fields):
self_state, dirty, hit, block_state = x[:4]
self_clients = x[4:]
self_dir = SelfDir(self_state, dirty, hit, self_clients)
for c in itertools.product(*client_fields):
client_dirs = [ ClientDir(c[i]) for i in range(len(c)) ]
client_dirs = [ ClientDir(c[i], c[i+1]) for i in range(0, len(c), 2) ]
new_state = DirState(self_dir, client_dirs, block_state)
print(new_state)
if log:
print(new_state)
all_states.append(new_state)
print(f"states: {len(all_states)}")
if log:
print(f"states: {len(all_states)}")
return all_states
get_all_states()
all_states = get_all_states()
print(f"all states: {len(all_states)}")
def invalid_filter(s):
if s.self_dir.tl_state != TLState.INVALID and s.block_state.self_block == Block.NULL:
return False
for c, b in zip(s.client_dirs, s.block_state.client_blocks):
if c.tl_state != TLState.INVALID and b == Block.NULL:
return False
return True
def hit_filter(s):
if s.self_dir.hit_state == HitState.HIT:
if s.block_state.self_block == Block.NULL:
return False
if s.self_dir.tl_state == TLState.INVALID:
return False
for c, b in zip(s.client_dirs, s.block_state.client_blocks):
if c.hit_state == HitState.HIT:
if b == Block.NULL:
return False
if c.tl_state == TLState.INVALID:
return False
return True
def retrieve_name(var):
for fi in reversed(inspect.stack()):
names = [var_name for var_name, var_val in fi.frame.f_locals.items() if var_val is var]
if len(names) > 0:
return names[0]
filters = [
invalid_filter,
hit_filter
]
for f in filters:
all_states = list(filter(f, all_states))
print(f"filter: {retrieve_name(f)} states: {len(all_states)}")