Implementing Monte Carlo Tree Search Algorithm

This commit is contained in:
flyly 2022-04-12 18:47:11 +08:00
parent 7f97d633f1
commit fd34b5da2a
3 changed files with 434 additions and 5 deletions

View File

@ -26,9 +26,11 @@ import logging
try:
from .dao.gsql_execute import GSqlExecute
from .dao.execute_factory import ExecuteFactory
from .mcts import MCTS
except ImportError:
from dao.gsql_execute import GSqlExecute
from dao.execute_factory import ExecuteFactory
from mcts import MCTS
ENABLE_MULTI_NODE = False
SAMPLE_NUM = 5
@ -192,9 +194,12 @@ class IndexAdvisor:
self.workload_used_index))
if DRIVER:
self.db.close_conn()
opt_config = greedy_determine_opt_config(self.workload_info[0], atomic_config_total,
candidate_indexes, self.index_cost_total[0])
if MAX_INDEX_STORAGE:
opt_config = MCTS(self.workload_info[0], atomic_config_total, candidate_indexes,
MAX_INDEX_STORAGE, MAX_INDEX_NUM)
else:
opt_config = greedy_determine_opt_config(self.workload_info[0], atomic_config_total,
candidate_indexes, self.index_cost_total[0])
self.retain_lower_cost_index(candidate_indexes)
if len(opt_config) == 0:
print("No optimal indexes generated!")
@ -943,7 +948,7 @@ def check_parameter(args):
raise argparse.ArgumentTypeError("%s is an invalid positive int value" %
args.max_index_num)
if args.max_index_storage is not None and args.max_index_storage <= 0:
raise argparse.ArgumentTypeError("%s is an invalid positive int value" %
raise argparse.ArgumentTypeError("%s is an invalid positive float value" %
args.max_index_storage)
JSON_TYPE = args.json
MAX_INDEX_NUM = args.max_index_num
@ -971,7 +976,7 @@ def main(argv):
arg_parser.add_argument(
"--max_index_num", help="Maximum number of suggested indexes", type=int)
arg_parser.add_argument("--max_index_storage",
help="Maximum storage of suggested indexes/MB", type=int)
help="Maximum storage of suggested indexes/MB", type=float)
arg_parser.add_argument("--multi_iter_mode", action='store_true',
help="Whether to use multi-iteration algorithm", default=False)
arg_parser.add_argument("--multi_node", action='store_true',

View File

@ -0,0 +1,397 @@
import sys
import math
import random
import copy
STORAGE_THRESHOLD = 0
AVAILABLE_CHOICES = None
ATOMIC_CHOICES = None
WORKLOAD_INFO = None
MAX_INDEX_NUM = 0
def is_same_index(index, compared_index):
return index.table == compared_index.table and \
index.columns == compared_index.columns and \
index.index_type == compared_index.index_type
def atomic_config_is_valid(atomic_config, config):
# if candidate indexes contains all atomic index of current config1, then record it
for atomic_index in atomic_config:
is_exist = False
for index in config:
if is_same_index(index, atomic_index):
index.storage = atomic_index.storage
is_exist = True
break
if not is_exist:
return False
return True
def find_subsets_num(choice):
atomic_subsets_num = []
for pos, atomic in enumerate(ATOMIC_CHOICES):
if not atomic or len(atomic) > len(choice):
continue
# find valid atomic index
if atomic_config_is_valid(atomic, choice):
atomic_subsets_num.append(pos)
# find the same atomic index as the candidate index
if len(atomic) == 1 and (is_same_index(choice[-1], atomic[0])):
choice[-1].atomic_pos = pos
return atomic_subsets_num
def find_best_benefit(choice):
atomic_subsets_num = find_subsets_num(choice)
total_benefit = 0
for ind, obj in enumerate(WORKLOAD_INFO):
# calculate the best benefit for the current sql
max_benefit = 0
for pos in atomic_subsets_num:
if (obj.cost_list[0] - obj.cost_list[pos]) > max_benefit:
max_benefit = obj.cost_list[0] - obj.cost_list[pos]
total_benefit += max_benefit
return total_benefit
def get_diff(available_choices, choices):
except_choices = copy.copy(available_choices)
for i in available_choices:
for j in choices:
if is_same_index(i, j):
except_choices.remove(i)
return except_choices
class State(object):
"""
The game state of the Monte Carlo tree search,
the state data recorded under a certain Node node,
including the current game score, the current number of game rounds,
and the execution record from the beginning to the current.
It is necessary to realize whether the current state has reached the end of the game state,
and support the operation of randomly fetching from the Action collection.
"""
def __init__(self):
self.current_storage = 0.0
self.current_benefit = 0.0
# record the sum of choices up to the current state
self.accumulation_choices = []
# record available choices of current state
self.available_choices = []
self.displayable_choices = []
def get_available_choices(self):
return self.available_choices
def set_available_choices(self, choices):
self.available_choices = choices
def get_current_storage(self):
return self.current_storage
def set_current_storage(self, value):
self.current_storage = value
def get_current_benefit(self):
return self.current_benefit
def set_current_benefit(self, value):
self.current_benefit = value
def get_accumulation_choices(self):
return self.accumulation_choices
def set_accumulation_choices(self, choices):
self.accumulation_choices = choices
def is_terminal(self):
# the current node is a leaf node
return len(self.accumulation_choices) == MAX_INDEX_NUM
def compute_benefit(self):
return self.current_benefit
def get_next_state_with_random_choice(self):
# ensure that the choices taken are not repeated
if not self.available_choices:
return None
random_choice = random.choice([choice for choice in self.available_choices])
self.available_choices.remove(random_choice)
choice = copy.copy(self.accumulation_choices)
choice.append(random_choice)
benefit = find_best_benefit(choice)
# if current choice not satisfy restrictions, then continue get next choice
if benefit <= self.current_benefit or \
self.current_storage + random_choice.storage > STORAGE_THRESHOLD:
return self.get_next_state_with_random_choice()
next_state = State()
# initialize the properties of the new state
next_state.set_accumulation_choices(choice)
next_state.set_current_benefit(benefit)
next_state.set_current_storage(self.current_storage + random_choice.storage)
next_state.set_available_choices(get_diff(AVAILABLE_CHOICES, choice))
return next_state
def __repr__(self):
self.displayable_choices = ['{}: {}'.format(choice.table, choice.columns)
for choice in self.accumulation_choices]
return "reward: {}, storage :{}, choices: {}".format(
self.current_benefit, self.current_storage, self.displayable_choices)
class Node(object):
"""
The Node of the Monte Carlo tree search tree contains the parent node and
current point information,
which is used to calculate the traversal times and quality value of the UCB,
and the State of the Node selected by the game.
"""
def __init__(self):
self.visit_number = 0
self.quality = 0.0
self.parent = None
self.children = []
self.state = None
def get_parent(self):
return self.parent
def set_parent(self, parent):
self.parent = parent
def get_children(self):
return self.children
def expand_child(self, node):
node.set_parent(self)
self.children.append(node)
def set_state(self, state):
self.state = state
def get_state(self):
return self.state
def get_visit_number(self):
return self.visit_number
def set_visit_number(self, number):
self.visit_number = number
def update_visit_number(self):
self.visit_number += 1
def get_quality_value(self):
return self.quality
def set_quality_value(self, value):
self.quality = value
def update_quality_value(self, reward):
self.quality += reward
def is_all_expand(self):
return len(self.children) == \
len(AVAILABLE_CHOICES) - len(self.get_state().get_accumulation_choices())
def __repr__(self):
return "Node: {}, Q/N: {}/{}, State: {}".format(
hash(self), self.quality, self.visit_number, self.state)
def tree_policy(node):
"""
In the Selection and Expansion stages of Monte Carlo tree search,
the node that needs to be searched (such as the root node) is passed in,
and the best node that needs to be expanded is returned
according to the exploration/exploitation algorithm.
Note that if the node is a leaf node, it will be returned directly.
The basic strategy is to first find the child nodes that have not been selected at present,
and select them randomly if there are more than one. If both are selected,
find the one with the largest UCB value that has weighed exploration/exploitation,
and randomly select if the UCB values are equal.
"""
# check if the current node is leaf node
while node and not node.get_state().is_terminal():
if node.is_all_expand():
node = best_child(node, True)
else:
# return the new sub node
sub_node = expand(node)
# when there is no node that satisfies the condition in the remaining nodes,
# this state is empty
if sub_node.get_state():
return sub_node
# return the leaf node
return node
def default_policy(node):
"""
In the Simulation stage of Monte Carlo tree search, input a node that needs to be expanded,
create a new node after random operation, and return the reward of the new node.
Note that the input node should not be a child node,
and there are unexecuted Actions that can be expendable.
The basic strategy is to choose the Action at random.
"""
# get the state of the game
current_state = copy.deepcopy(node.get_state())
# run until the game over
while not current_state.is_terminal():
# pick one random action to play and get next state
next_state = current_state.get_next_state_with_random_choice()
if not next_state:
break
current_state = next_state
final_state_reward = current_state.compute_benefit()
return final_state_reward
def expand(node):
"""
Enter a node, expand a new node on the node, use the random method to execute the Action,
and return the new node. Note that it is necessary to ensure that the newly
added nodes are different from other node Action
"""
new_state = node.get_state().get_next_state_with_random_choice()
sub_node = Node()
sub_node.set_state(new_state)
node.expand_child(sub_node)
return sub_node
def best_child(node, is_exploration):
"""
Using the UCB algorithm,
select the child node with the highest score after weighing the exploration and exploitation.
Note that if it is the prediction stage,
the current Q-value score with the highest score is directly selected.
"""
best_score = -sys.maxsize
best_sub_node = None
# travel all sub nodes to find the best one
for sub_node in node.get_children():
# The children nodes of the node contains the children node whose state is empty,
# this kind of node comes from the node that does not meet the conditions.
if not sub_node.get_state():
continue
# ignore exploration for inference
if is_exploration:
C = 1 / math.sqrt(2.0)
else:
C = 0.0
# UCB = quality / times + C * sqrt(2 * ln(total_times) / times)
left = sub_node.get_quality_value() / sub_node.get_visit_number()
right = 2.0 * math.log(node.get_visit_number()) / sub_node.get_visit_number()
score = left + C * math.sqrt(right)
# get the maximum score, while filtering nodes that do not meet the space constraints and
# nodes that have no revenue
if score > best_score \
and sub_node.get_state().get_current_storage() <= STORAGE_THRESHOLD \
and sub_node.get_state().get_current_benefit() > 0:
best_sub_node = sub_node
best_score = score
return best_sub_node
def backpropagate(node, reward):
"""
In the Backpropagation stage of Monte Carlo tree search,
input the node that needs to be expended and the reward of the newly executed Action,
feed it back to the expend node and all upstream nodes,
and update the corresponding data.
"""
# update util the root node
while node is not None:
# update the visit number
node.update_visit_number()
# update the quality value
node.update_quality_value(reward)
# change the node to the parent node
node = node.parent
def monte_carlo_tree_search(node):
"""
Implement the Monte Carlo tree search algorithm, pass in a root node,
expand new nodes and update data according to the
tree structure that has been explored before in a limited time,
and then return as long as the child node with the highest exploitation.
When making predictions,
you only need to select the node with the largest exploitation according to the Q value,
and find the next optimal node.
"""
computation_budget = len(AVAILABLE_CHOICES) * 3
# run as much as possible under the computation budget
for i in range(computation_budget):
# 1. find the best node to expand
expand_node = tree_policy(node)
if not expand_node:
# when it is None, it means that all nodes are added but no nodes meet the space limit
break
# 2. random get next action and get reward
reward = default_policy(expand_node)
# 3. update all passing nodes with reward
backpropagate(expand_node, reward)
# get the best next node
best_next_node = best_child(node, False)
return best_next_node
def MCTS(workload_info, atomic_choices, available_choices, storage_threshold, max_index_num):
global ATOMIC_CHOICES, STORAGE_THRESHOLD, WORKLOAD_INFO, AVAILABLE_CHOICES, MAX_INDEX_NUM
WORKLOAD_INFO = workload_info
AVAILABLE_CHOICES = available_choices
ATOMIC_CHOICES = atomic_choices
STORAGE_THRESHOLD = storage_threshold
MAX_INDEX_NUM = max_index_num if max_index_num else len(available_choices)
# create the initialized state and initialized node
init_state = State()
choices = copy.copy(available_choices)
init_state.set_available_choices(choices)
init_node = Node()
init_node.set_state(init_state)
current_node = init_node
opt_config = []
# set the rounds to play
for i in range(len(AVAILABLE_CHOICES)):
if current_node:
current_node = monte_carlo_tree_search(current_node)
if current_node:
opt_config = current_node.state.accumulation_choices
else:
break
return opt_config

View File

@ -27,6 +27,7 @@ from collections.abc import Iterable
from collections import defaultdict
import index_advisor_workload as iaw
import mcts
def hash_any(obj):
@ -227,6 +228,32 @@ select * from student_range_part1 where credit=1;
class IndexAdvisorTester(unittest.TestCase):
def test_mcts(self):
storage_threshold = 12
index1 = iaw.IndexItem('public.a', 'col1', index_type='global')
index2 = iaw.IndexItem('public.b', 'col1', index_type='global')
index3 = iaw.IndexItem('public.c', 'col1', index_type='global')
index4 = iaw.IndexItem('public.d', 'col1', index_type='global')
atomic_index1 = iaw.IndexItem('public.a', 'col1', index_type='global')
atomic_index2 = iaw.IndexItem('public.b', 'col1', index_type='global')
atomic_index3 = iaw.IndexItem('public.c', 'col1', index_type='global')
atomic_index4 = iaw.IndexItem('public.d', 'col1', index_type='global')
atomic_index1.storage = 10
atomic_index2.storage = 4
atomic_index3.storage = 7
available_choices = [index1, index2, index3, index4]
atomic_choices = [[], [atomic_index2], [atomic_index1], [atomic_index3],
[atomic_index2, atomic_index3], [atomic_index4]]
query = iaw.QueryItem('select * from gia_01', 1)
query.cost_list = [10, 7, 5, 9, 4, 11]
workload_info = [query]
results = mcts.MCTS(workload_info, atomic_choices, available_choices, storage_threshold, 2)
self.assertLessEqual([index1.atomic_pos, index2.atomic_pos, index3.atomic_pos], [2, 1, 3])
self.assertSetEqual({results[0].table, results[1].table}, {'public.b', 'public.c'})
def test_get_indexable_columns(self):
tables = 'table1 table2 table2 table3 table3 table3'.split()
columns = 'col1,col2 col2 col3 col1,col2 col2,col3 col2,col5'.split()