forked from TensorLayer/tensorlayer3
747 lines
27 KiB
Python
747 lines
27 KiB
Python
#! /usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import pickle
|
|
import sys
|
|
import time
|
|
from datetime import datetime
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
import gridfs
|
|
import pymongo
|
|
from tensorlayer import logging
|
|
from tensorlayer.files import (
|
|
assign_weights, del_folder, exists_or_mkdir, load_hdf5_to_weights, save_weights_to_hdf5, static_graph2net
|
|
)
|
|
|
|
|
|
class TensorHub(object):
|
|
"""It is a MongoDB based manager that help you to manage data, network architecture, parameters and logging.
|
|
|
|
Parameters
|
|
-------------
|
|
ip : str
|
|
Localhost or IP address.
|
|
port : int
|
|
Port number.
|
|
dbname : str
|
|
Database name.
|
|
username : str or None
|
|
User name, set to None if you do not need authentication.
|
|
password : str
|
|
Password.
|
|
project_name : str or None
|
|
Experiment key for this entire project, similar with the repository name of Github.
|
|
|
|
Attributes
|
|
------------
|
|
ip, port, dbname and other input parameters : see above
|
|
See above.
|
|
project_name : str
|
|
The given project name, if no given, set to the script name.
|
|
db : mongodb client
|
|
See ``pymongo.MongoClient``.
|
|
"""
|
|
|
|
# @deprecated_alias(db_name='dbname', user_name='username', end_support_version=2.1)
|
|
def __init__(
|
|
self, ip='localhost', port=27017, dbname='dbname', username='None', password='password', project_name=None
|
|
):
|
|
self.ip = ip
|
|
self.port = port
|
|
self.dbname = dbname
|
|
self.username = username
|
|
|
|
print("[Database] Initializing ...")
|
|
# connect mongodb
|
|
client = pymongo.MongoClient(ip, port)
|
|
self.db = client[dbname]
|
|
if username is None:
|
|
print(username, password)
|
|
self.db.authenticate(username, password)
|
|
else:
|
|
print("[Database] No username given, it works if authentication is not required")
|
|
if project_name is None:
|
|
self.project_name = sys.argv[0].split('.')[0]
|
|
print("[Database] No project_name given, use {}".format(self.project_name))
|
|
else:
|
|
self.project_name = project_name
|
|
|
|
# define file system (Buckets)
|
|
self.dataset_fs = gridfs.GridFS(self.db, collection="datasetFilesystem")
|
|
self.model_fs = gridfs.GridFS(self.db, collection="modelfs")
|
|
# self.params_fs = gridfs.GridFS(self.db, collection="parametersFilesystem")
|
|
# self.architecture_fs = gridfs.GridFS(self.db, collection="architectureFilesystem")
|
|
|
|
print("[Database] Connected ")
|
|
_s = "[Database] Info:\n"
|
|
_s += " ip : {}\n".format(self.ip)
|
|
_s += " port : {}\n".format(self.port)
|
|
_s += " dbname : {}\n".format(self.dbname)
|
|
_s += " username : {}\n".format(self.username)
|
|
_s += " password : {}\n".format("*******")
|
|
_s += " project_name : {}\n".format(self.project_name)
|
|
self._s = _s
|
|
print(self._s)
|
|
|
|
def __str__(self):
|
|
"""Print information of databset."""
|
|
return self._s
|
|
|
|
def _fill_project_info(self, args):
|
|
"""Fill in project_name for all studies, architectures and parameters."""
|
|
return args.update({'project_name': self.project_name})
|
|
|
|
@staticmethod
|
|
def _serialization(ps):
|
|
"""Serialize data."""
|
|
return pickle.dumps(ps, protocol=pickle.HIGHEST_PROTOCOL) # protocol=2)
|
|
# with open('_temp.pkl', 'wb') as file:
|
|
# return pickle.dump(ps, file, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
@staticmethod
|
|
def _deserialization(ps):
|
|
"""Deseralize data."""
|
|
return pickle.loads(ps)
|
|
|
|
# =========================== MODELS ================================
|
|
def save_model(self, network=None, model_name='model', **kwargs):
|
|
"""Save model architecture and parameters into database, timestamp will be added automatically.
|
|
|
|
Parameters
|
|
----------
|
|
network : TensorLayer Model
|
|
TensorLayer Model instance.
|
|
model_name : str
|
|
The name/key of model.
|
|
kwargs : other events
|
|
Other events, such as name, accuracy, loss, step number and etc (optinal).
|
|
|
|
Examples
|
|
---------
|
|
Save model architecture and parameters into database.
|
|
>>> db.save_model(net, accuracy=0.8, loss=2.3, name='second_model')
|
|
|
|
Load one model with parameters from database (run this in other script)
|
|
>>> net = db.find_top_model(accuracy=0.8, loss=2.3)
|
|
|
|
Find and load the latest model.
|
|
>>> net = db.find_top_model(sort=[("time", pymongo.DESCENDING)])
|
|
>>> net = db.find_top_model(sort=[("time", -1)])
|
|
|
|
Find and load the oldest model.
|
|
>>> net = db.find_top_model(sort=[("time", pymongo.ASCENDING)])
|
|
>>> net = db.find_top_model(sort=[("time", 1)])
|
|
|
|
Get model information
|
|
>>> net._accuracy
|
|
... 0.8
|
|
|
|
Returns
|
|
---------
|
|
boolean : True for success, False for fail.
|
|
"""
|
|
kwargs.update({'model_name': model_name})
|
|
self._fill_project_info(kwargs) # put project_name into kwargs
|
|
|
|
# params = network.get_all_params()
|
|
params = network.all_weights
|
|
|
|
s = time.time()
|
|
|
|
# kwargs.update({'architecture': network.all_graphs, 'time': datetime.utcnow()})
|
|
kwargs.update({'architecture': network.config, 'time': datetime.utcnow()})
|
|
|
|
try:
|
|
params_id = self.model_fs.put(self._serialization(params))
|
|
kwargs.update({'params_id': params_id, 'time': datetime.utcnow()})
|
|
self.db.Model.insert_one(kwargs)
|
|
print("[Database] Save model: SUCCESS, took: {}s".format(round(time.time() - s, 2)))
|
|
return True
|
|
except Exception as e:
|
|
exc_type, exc_obj, exc_tb = sys.exc_info()
|
|
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
|
logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
|
|
print("[Database] Save model: FAIL")
|
|
return False
|
|
|
|
def find_top_model(self, sort=None, model_name='model', **kwargs):
|
|
"""Finds and returns a model architecture and its parameters from the database which matches the requirement.
|
|
|
|
Parameters
|
|
----------
|
|
sort : List of tuple
|
|
PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
|
|
model_name : str or None
|
|
The name/key of model.
|
|
kwargs : other events
|
|
Other events, such as name, accuracy, loss, step number and etc (optinal).
|
|
|
|
Examples
|
|
---------
|
|
- see ``save_model``.
|
|
|
|
Returns
|
|
---------
|
|
network : TensorLayer Model
|
|
Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``.
|
|
"""
|
|
# print(kwargs) # {}
|
|
kwargs.update({'model_name': model_name})
|
|
self._fill_project_info(kwargs)
|
|
|
|
s = time.time()
|
|
|
|
d = self.db.Model.find_one(filter=kwargs, sort=sort)
|
|
|
|
# _temp_file_name = '_find_one_model_ztemp_file'
|
|
if d is not None:
|
|
params_id = d['params_id']
|
|
graphs = d['architecture']
|
|
_datetime = d['time']
|
|
# exists_or_mkdir(_temp_file_name, False)
|
|
# with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file:
|
|
# pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL)
|
|
else:
|
|
print("[Database] FAIL! Cannot find model: {}".format(kwargs))
|
|
return False
|
|
try:
|
|
params = self._deserialization(self.model_fs.get(params_id).read())
|
|
# TODO : restore model and load weights
|
|
network = static_graph2net(graphs)
|
|
assign_weights(weights=params, network=network)
|
|
# np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params)
|
|
#
|
|
# network = load_graph_and_params(name=_temp_file_name, sess=sess)
|
|
# del_folder(_temp_file_name)
|
|
|
|
pc = self.db.Model.find(kwargs)
|
|
print(
|
|
"[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".format(
|
|
kwargs, sort, _datetime, round(time.time() - s, 2)
|
|
)
|
|
)
|
|
|
|
# FIXME : not sure what's this for
|
|
# put all informations of model into the TL layer
|
|
# for key in d:
|
|
# network.__dict__.update({"_%s" % key: d[key]})
|
|
|
|
# check whether more parameters match the requirement
|
|
params_id_list = pc.distinct('params_id')
|
|
n_params = len(params_id_list)
|
|
if n_params != 1:
|
|
print(" Note that there are {} models match the kwargs".format(n_params))
|
|
return network
|
|
except Exception as e:
|
|
exc_type, exc_obj, exc_tb = sys.exc_info()
|
|
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
|
logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
|
|
return False
|
|
|
|
def delete_model(self, **kwargs):
|
|
"""Delete model.
|
|
|
|
Parameters
|
|
-----------
|
|
kwargs : logging information
|
|
Find items to delete, leave it empty to delete all log.
|
|
"""
|
|
self._fill_project_info(kwargs)
|
|
self.db.Model.delete_many(kwargs)
|
|
logging.info("[Database] Delete Model SUCCESS")
|
|
|
|
# =========================== DATASET ===============================
|
|
def save_dataset(self, dataset=None, dataset_name=None, **kwargs):
|
|
"""Saves one dataset into database, timestamp will be added automatically.
|
|
|
|
Parameters
|
|
----------
|
|
dataset : any type
|
|
The dataset you want to store.
|
|
dataset_name : str
|
|
The name of dataset.
|
|
kwargs : other events
|
|
Other events, such as description, author and etc (optinal).
|
|
|
|
Examples
|
|
----------
|
|
Save dataset
|
|
>>> db.save_dataset([X_train, y_train, X_test, y_test], 'mnist', description='this is a tutorial')
|
|
|
|
Get dataset
|
|
>>> dataset = db.find_top_dataset('mnist')
|
|
|
|
Returns
|
|
---------
|
|
boolean : Return True if save success, otherwise, return False.
|
|
"""
|
|
self._fill_project_info(kwargs)
|
|
if dataset_name is None:
|
|
raise Exception("dataset_name is None, please give a dataset name")
|
|
kwargs.update({'dataset_name': dataset_name})
|
|
|
|
s = time.time()
|
|
try:
|
|
dataset_id = self.dataset_fs.put(self._serialization(dataset))
|
|
kwargs.update({'dataset_id': dataset_id, 'time': datetime.utcnow()})
|
|
self.db.Dataset.insert_one(kwargs)
|
|
# print("[Database] Save params: {} SUCCESS, took: {}s".format(file_name, round(time.time()-s, 2)))
|
|
print("[Database] Save dataset: SUCCESS, took: {}s".format(round(time.time() - s, 2)))
|
|
return True
|
|
except Exception as e:
|
|
exc_type, exc_obj, exc_tb = sys.exc_info()
|
|
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
|
logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
|
|
print("[Database] Save dataset: FAIL")
|
|
return False
|
|
|
|
def find_top_dataset(self, dataset_name=None, sort=None, **kwargs):
|
|
"""Finds and returns a dataset from the database which matches the requirement.
|
|
|
|
Parameters
|
|
----------
|
|
dataset_name : str
|
|
The name of dataset.
|
|
sort : List of tuple
|
|
PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
|
|
kwargs : other events
|
|
Other events, such as description, author and etc (optinal).
|
|
|
|
Examples
|
|
---------
|
|
Save dataset
|
|
>>> db.save_dataset([X_train, y_train, X_test, y_test], 'mnist', description='this is a tutorial')
|
|
|
|
Get dataset
|
|
>>> dataset = db.find_top_dataset('mnist')
|
|
>>> datasets = db.find_datasets('mnist')
|
|
|
|
Returns
|
|
--------
|
|
dataset : the dataset or False
|
|
Return False if nothing found.
|
|
|
|
"""
|
|
|
|
self._fill_project_info(kwargs)
|
|
if dataset_name is None:
|
|
raise Exception("dataset_name is None, please give a dataset name")
|
|
kwargs.update({'dataset_name': dataset_name})
|
|
|
|
s = time.time()
|
|
|
|
d = self.db.Dataset.find_one(filter=kwargs, sort=sort)
|
|
|
|
if d is not None:
|
|
dataset_id = d['dataset_id']
|
|
else:
|
|
print("[Database] FAIL! Cannot find dataset: {}".format(kwargs))
|
|
return False
|
|
try:
|
|
dataset = self._deserialization(self.dataset_fs.get(dataset_id).read())
|
|
pc = self.db.Dataset.find(kwargs)
|
|
print("[Database] Find one dataset SUCCESS, {} took: {}s".format(kwargs, round(time.time() - s, 2)))
|
|
|
|
# check whether more datasets match the requirement
|
|
dataset_id_list = pc.distinct('dataset_id')
|
|
n_dataset = len(dataset_id_list)
|
|
if n_dataset != 1:
|
|
print(" Note that there are {} datasets match the requirement".format(n_dataset))
|
|
return dataset
|
|
except Exception as e:
|
|
exc_type, exc_obj, exc_tb = sys.exc_info()
|
|
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
|
logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
|
|
return False
|
|
|
|
def find_datasets(self, dataset_name=None, **kwargs):
|
|
"""Finds and returns all datasets from the database which matches the requirement.
|
|
In some case, the data in a dataset can be stored separately for better management.
|
|
|
|
Parameters
|
|
----------
|
|
dataset_name : str
|
|
The name/key of dataset.
|
|
kwargs : other events
|
|
Other events, such as description, author and etc (optional).
|
|
|
|
Returns
|
|
--------
|
|
params : the parameters, return False if nothing found.
|
|
|
|
"""
|
|
|
|
self._fill_project_info(kwargs)
|
|
if dataset_name is None:
|
|
raise Exception("dataset_name is None, please give a dataset name")
|
|
kwargs.update({'dataset_name': dataset_name})
|
|
|
|
s = time.time()
|
|
pc = self.db.Dataset.find(kwargs)
|
|
|
|
if pc is not None:
|
|
dataset_id_list = pc.distinct('dataset_id')
|
|
dataset_list = []
|
|
for dataset_id in dataset_id_list: # you may have multiple Buckets files
|
|
tmp = self.dataset_fs.get(dataset_id).read()
|
|
dataset_list.append(self._deserialization(tmp))
|
|
else:
|
|
print("[Database] FAIL! Cannot find any dataset: {}".format(kwargs))
|
|
return False
|
|
|
|
print("[Database] Find {} datasets SUCCESS, took: {}s".format(len(dataset_list), round(time.time() - s, 2)))
|
|
return dataset_list
|
|
|
|
def delete_datasets(self, **kwargs):
|
|
"""Delete datasets.
|
|
|
|
Parameters
|
|
-----------
|
|
kwargs : logging information
|
|
Find items to delete, leave it empty to delete all log.
|
|
|
|
"""
|
|
|
|
self._fill_project_info(kwargs)
|
|
self.db.Dataset.delete_many(kwargs)
|
|
logging.info("[Database] Delete Dataset SUCCESS")
|
|
|
|
# =========================== LOGGING ===============================
|
|
def save_training_log(self, **kwargs):
|
|
"""Saves the training log, timestamp will be added automatically.
|
|
|
|
Parameters
|
|
-----------
|
|
kwargs : logging information
|
|
Events, such as accuracy, loss, step number and etc.
|
|
|
|
Examples
|
|
---------
|
|
>>> db.save_training_log(accuracy=0.33, loss=0.98)
|
|
|
|
"""
|
|
|
|
self._fill_project_info(kwargs)
|
|
kwargs.update({'time': datetime.utcnow()})
|
|
_result = self.db.TrainLog.insert_one(kwargs)
|
|
_log = self._print_dict(kwargs)
|
|
logging.info("[Database] train log: " + _log)
|
|
|
|
def save_validation_log(self, **kwargs):
|
|
"""Saves the validation log, timestamp will be added automatically.
|
|
|
|
Parameters
|
|
-----------
|
|
kwargs : logging information
|
|
Events, such as accuracy, loss, step number and etc.
|
|
|
|
Examples
|
|
---------
|
|
>>> db.save_validation_log(accuracy=0.33, loss=0.98)
|
|
|
|
"""
|
|
|
|
self._fill_project_info(kwargs)
|
|
kwargs.update({'time': datetime.utcnow()})
|
|
_result = self.db.ValidLog.insert_one(kwargs)
|
|
_log = self._print_dict(kwargs)
|
|
logging.info("[Database] valid log: " + _log)
|
|
|
|
def save_testing_log(self, **kwargs):
|
|
"""Saves the testing log, timestamp will be added automatically.
|
|
|
|
Parameters
|
|
-----------
|
|
kwargs : logging information
|
|
Events, such as accuracy, loss, step number and etc.
|
|
|
|
Examples
|
|
---------
|
|
>>> db.save_testing_log(accuracy=0.33, loss=0.98)
|
|
|
|
"""
|
|
|
|
self._fill_project_info(kwargs)
|
|
kwargs.update({'time': datetime.utcnow()})
|
|
_result = self.db.TestLog.insert_one(kwargs)
|
|
_log = self._print_dict(kwargs)
|
|
logging.info("[Database] test log: " + _log)
|
|
|
|
def delete_training_log(self, **kwargs):
|
|
"""Deletes training log.
|
|
|
|
Parameters
|
|
-----------
|
|
kwargs : logging information
|
|
Find items to delete, leave it empty to delete all log.
|
|
|
|
Examples
|
|
---------
|
|
Save training log
|
|
>>> db.save_training_log(accuracy=0.33)
|
|
>>> db.save_training_log(accuracy=0.44)
|
|
|
|
Delete logs that match the requirement
|
|
>>> db.delete_training_log(accuracy=0.33)
|
|
|
|
Delete all logs
|
|
>>> db.delete_training_log()
|
|
"""
|
|
self._fill_project_info(kwargs)
|
|
self.db.TrainLog.delete_many(kwargs)
|
|
logging.info("[Database] Delete TrainLog SUCCESS")
|
|
|
|
def delete_validation_log(self, **kwargs):
|
|
"""Deletes validation log.
|
|
|
|
Parameters
|
|
-----------
|
|
kwargs : logging information
|
|
Find items to delete, leave it empty to delete all log.
|
|
|
|
Examples
|
|
---------
|
|
- see ``save_training_log``.
|
|
"""
|
|
self._fill_project_info(kwargs)
|
|
self.db.ValidLog.delete_many(kwargs)
|
|
logging.info("[Database] Delete ValidLog SUCCESS")
|
|
|
|
def delete_testing_log(self, **kwargs):
|
|
"""Deletes testing log.
|
|
|
|
Parameters
|
|
-----------
|
|
kwargs : logging information
|
|
Find items to delete, leave it empty to delete all log.
|
|
|
|
Examples
|
|
---------
|
|
- see ``save_training_log``.
|
|
"""
|
|
self._fill_project_info(kwargs)
|
|
self.db.TestLog.delete_many(kwargs)
|
|
logging.info("[Database] Delete TestLog SUCCESS")
|
|
|
|
# def find_training_logs(self, **kwargs):
|
|
# pass
|
|
#
|
|
# def find_validation_logs(self, **kwargs):
|
|
# pass
|
|
#
|
|
# def find_testing_logs(self, **kwargs):
|
|
# pass
|
|
|
|
# =========================== Task ===================================
|
|
def create_task(self, task_name=None, script=None, hyper_parameters=None, saved_result_keys=None, **kwargs):
|
|
"""Uploads a task to the database, timestamp will be added automatically.
|
|
|
|
Parameters
|
|
-----------
|
|
task_name : str
|
|
The task name.
|
|
script : str
|
|
File name of the python script.
|
|
hyper_parameters : dictionary
|
|
The hyper parameters pass into the script.
|
|
saved_result_keys : list of str
|
|
The keys of the task results to keep in the database when the task finishes.
|
|
kwargs : other parameters
|
|
Users customized parameters such as description, version number.
|
|
|
|
Examples
|
|
-----------
|
|
Uploads a task
|
|
>>> db.create_task(task_name='mnist', script='example/tutorial_mnist_simple.py', description='simple tutorial')
|
|
|
|
Finds and runs the latest task
|
|
>>> db.run_top_task(sort=[("time", pymongo.DESCENDING)])
|
|
>>> db.run_top_task(sort=[("time", -1)])
|
|
|
|
Finds and runs the oldest task
|
|
>>> db.run_top_task(sort=[("time", pymongo.ASCENDING)])
|
|
>>> db.run_top_task(sort=[("time", 1)])
|
|
|
|
"""
|
|
if not isinstance(task_name, str): # is None:
|
|
raise Exception("task_name should be string")
|
|
if not isinstance(script, str): # is None:
|
|
raise Exception("script should be string")
|
|
if hyper_parameters is None:
|
|
hyper_parameters = {}
|
|
if saved_result_keys is None:
|
|
saved_result_keys = []
|
|
|
|
self._fill_project_info(kwargs)
|
|
kwargs.update({'time': datetime.utcnow()})
|
|
kwargs.update({'hyper_parameters': hyper_parameters})
|
|
kwargs.update({'saved_result_keys': saved_result_keys})
|
|
|
|
_script = open(script, 'rb').read()
|
|
|
|
kwargs.update({'status': 'pending', 'script': _script, 'result': {}})
|
|
self.db.Task.insert_one(kwargs)
|
|
logging.info("[Database] Saved Task - task_name: {} script: {}".format(task_name, script))
|
|
|
|
def run_top_task(self, task_name=None, sort=None, **kwargs):
|
|
"""Finds and runs a pending task that in the first of the sorting list.
|
|
|
|
Parameters
|
|
-----------
|
|
task_name : str
|
|
The task name.
|
|
sort : List of tuple
|
|
PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
|
|
kwargs : other parameters
|
|
Users customized parameters such as description, version number.
|
|
|
|
Examples
|
|
---------
|
|
Monitors the database and pull tasks to run
|
|
>>> while True:
|
|
>>> print("waiting task from distributor")
|
|
>>> db.run_top_task(task_name='mnist', sort=[("time", -1)])
|
|
>>> time.sleep(1)
|
|
|
|
Returns
|
|
--------
|
|
boolean : True for success, False for fail.
|
|
"""
|
|
if not isinstance(task_name, str): # is None:
|
|
raise Exception("task_name should be string")
|
|
self._fill_project_info(kwargs)
|
|
kwargs.update({'status': 'pending'})
|
|
|
|
# find task and set status to running
|
|
task = self.db.Task.find_one_and_update(kwargs, {'$set': {'status': 'running'}}, sort=sort)
|
|
|
|
# try:
|
|
# get task info e.g. hyper parameters, python script
|
|
if task is None:
|
|
logging.info("[Database] Find Task FAIL: key: {} sort: {}".format(task_name, sort))
|
|
return False
|
|
else:
|
|
logging.info("[Database] Find Task SUCCESS: key: {} sort: {}".format(task_name, sort))
|
|
_datetime = task['time']
|
|
_script = task['script']
|
|
_id = task['_id']
|
|
_hyper_parameters = task['hyper_parameters']
|
|
_saved_result_keys = task['saved_result_keys']
|
|
logging.info(" hyper parameters:")
|
|
for key in _hyper_parameters:
|
|
globals()[key] = _hyper_parameters[key]
|
|
logging.info(" {}: {}".format(key, _hyper_parameters[key]))
|
|
# run task
|
|
s = time.time()
|
|
logging.info("[Database] Start Task: key: {} sort: {} push time: {}".format(task_name, sort, _datetime))
|
|
_script = _script.decode('utf-8')
|
|
with tf.Graph().as_default(): # # as graph: # clear all TF graphs
|
|
exec(_script, globals())
|
|
|
|
# set status to finished
|
|
_ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'finished'}})
|
|
|
|
# return results
|
|
__result = {}
|
|
for _key in _saved_result_keys:
|
|
logging.info(" result: {}={} {}".format(_key, globals()[_key], type(globals()[_key])))
|
|
__result.update({"%s" % _key: globals()[_key]})
|
|
_ = self.db.Task.find_one_and_update(
|
|
{'_id': _id}, {'$set': {
|
|
'result': __result
|
|
}}, return_document=pymongo.ReturnDocument.AFTER
|
|
)
|
|
logging.info(
|
|
"[Database] Finished Task: task_name - {} sort: {} push time: {} took: {}s".format(
|
|
task_name, sort, _datetime,
|
|
time.time() - s
|
|
)
|
|
)
|
|
return True
|
|
# except Exception as e:
|
|
# exc_type, exc_obj, exc_tb = sys.exc_info()
|
|
# fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
|
# logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
|
|
# logging.info("[Database] Fail to run task")
|
|
# # if fail, set status back to pending
|
|
# _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'pending'}})
|
|
# return False
|
|
|
|
def delete_tasks(self, **kwargs):
|
|
"""Delete tasks.
|
|
|
|
Parameters
|
|
-----------
|
|
kwargs : logging information
|
|
Find items to delete, leave it empty to delete all log.
|
|
|
|
Examples
|
|
---------
|
|
>>> db.delete_tasks()
|
|
|
|
"""
|
|
|
|
self._fill_project_info(kwargs)
|
|
self.db.Task.delete_many(kwargs)
|
|
logging.info("[Database] Delete Task SUCCESS")
|
|
|
|
def check_unfinished_task(self, task_name=None, **kwargs):
|
|
"""Finds and runs a pending task.
|
|
|
|
Parameters
|
|
-----------
|
|
task_name : str
|
|
The task name.
|
|
kwargs : other parameters
|
|
Users customized parameters such as description, version number.
|
|
|
|
Examples
|
|
---------
|
|
Wait until all tasks finish in user's local console
|
|
|
|
>>> while not db.check_unfinished_task():
|
|
>>> time.sleep(1)
|
|
>>> print("all tasks finished")
|
|
>>> sess = tf.InteractiveSession()
|
|
>>> net = db.find_top_model(sess=sess, sort=[("test_accuracy", -1)])
|
|
>>> print("the best accuracy {} is from model {}".format(net._test_accuracy, net._name))
|
|
|
|
Returns
|
|
--------
|
|
boolean : True for success, False for fail.
|
|
|
|
"""
|
|
|
|
if not isinstance(task_name, str): # is None:
|
|
raise Exception("task_name should be string")
|
|
self._fill_project_info(kwargs)
|
|
|
|
kwargs.update({'$or': [{'status': 'pending'}, {'status': 'running'}]})
|
|
|
|
# ## find task
|
|
# task = self.db.Task.find_one(kwargs)
|
|
task = self.db.Task.find(kwargs)
|
|
|
|
task_id_list = task.distinct('_id')
|
|
n_task = len(task_id_list)
|
|
|
|
if n_task == 0:
|
|
logging.info("[Database] No unfinished task - task_name: {}".format(task_name))
|
|
return False
|
|
else:
|
|
|
|
logging.info("[Database] Find {} unfinished task - task_name: {}".format(n_task, task_name))
|
|
return True
|
|
|
|
@staticmethod
|
|
def _print_dict(args):
|
|
string = ''
|
|
for key, value in args.items():
|
|
if key is not '_id':
|
|
string += str(key) + ": " + str(value) + " / "
|
|
return string
|