scripts: add some basic filter
This commit is contained in:
parent
9afb08ceaf
commit
b1ab4381ef
|
@ -1,4 +1,5 @@
|
|||
import itertools
|
||||
import inspect
|
||||
from enum import Enum, unique
|
||||
|
||||
@unique
|
||||
|
@ -27,10 +28,11 @@ class HitState(Enum):
|
|||
HIT = 1
|
||||
|
||||
class ClientDir:
|
||||
def __init__(self, tl_state):
|
||||
def __init__(self, tl_state, hit_state):
|
||||
self.tl_state = tl_state
|
||||
self.hit_state = hit_state
|
||||
def __str__(self):
|
||||
return (f"{self.tl_state}")
|
||||
return (f"{self.tl_state} {self.hit_state}")
|
||||
|
||||
class SelfDir:
|
||||
def __init__(self, tl_state, dirty_state, hit_state, client_tl_states):
|
||||
|
@ -111,20 +113,61 @@ client_fields = []
|
|||
for i in range(0, NUM_CLIENTS):
|
||||
self_fields.append(list(TLState))
|
||||
client_fields.append(list(TLState))
|
||||
client_fields.append(list(HitState))
|
||||
|
||||
def get_all_states():
|
||||
def get_all_states(log = False):
|
||||
all_states = []
|
||||
for x in itertools.product(*self_fields):
|
||||
self_state, dirty, hit, block_state = x[:4]
|
||||
self_clients = x[4:]
|
||||
self_dir = SelfDir(self_state, dirty, hit, self_clients)
|
||||
for c in itertools.product(*client_fields):
|
||||
client_dirs = [ ClientDir(c[i]) for i in range(len(c)) ]
|
||||
client_dirs = [ ClientDir(c[i], c[i+1]) for i in range(0, len(c), 2) ]
|
||||
new_state = DirState(self_dir, client_dirs, block_state)
|
||||
print(new_state)
|
||||
if log:
|
||||
print(new_state)
|
||||
all_states.append(new_state)
|
||||
print(f"states: {len(all_states)}")
|
||||
if log:
|
||||
print(f"states: {len(all_states)}")
|
||||
return all_states
|
||||
|
||||
get_all_states()
|
||||
all_states = get_all_states()
|
||||
print(f"all states: {len(all_states)}")
|
||||
|
||||
def invalid_filter(s):
|
||||
if s.self_dir.tl_state != TLState.INVALID and s.block_state.self_block == Block.NULL:
|
||||
return False
|
||||
for c, b in zip(s.client_dirs, s.block_state.client_blocks):
|
||||
if c.tl_state != TLState.INVALID and b == Block.NULL:
|
||||
return False
|
||||
return True
|
||||
|
||||
def hit_filter(s):
|
||||
if s.self_dir.hit_state == HitState.HIT:
|
||||
if s.block_state.self_block == Block.NULL:
|
||||
return False
|
||||
if s.self_dir.tl_state == TLState.INVALID:
|
||||
return False
|
||||
for c, b in zip(s.client_dirs, s.block_state.client_blocks):
|
||||
if c.hit_state == HitState.HIT:
|
||||
if b == Block.NULL:
|
||||
return False
|
||||
if c.tl_state == TLState.INVALID:
|
||||
return False
|
||||
return True
|
||||
|
||||
def retrieve_name(var):
|
||||
for fi in reversed(inspect.stack()):
|
||||
names = [var_name for var_name, var_val in fi.frame.f_locals.items() if var_val is var]
|
||||
if len(names) > 0:
|
||||
return names[0]
|
||||
|
||||
filters = [
|
||||
invalid_filter,
|
||||
hit_filter
|
||||
]
|
||||
|
||||
for f in filters:
|
||||
all_states = list(filter(f, all_states))
|
||||
print(f"filter: {retrieve_name(f)} states: {len(all_states)}")
|
||||
|
||||
|
|
Loading…
Reference in New Issue