sysom1/sysom_server/sysom_channel/app/executor.py

278 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*- #
"""
Time 2022/10/11 16:13
Author: mingfeng (SunnyQjm)
Email mfeng@linux.alibaba.com
File executor.py
Description:
"""
import os
import asyncio
import time
import json
from typing import Callable, Optional
from queue import Queue
from importlib import import_module
import asyncssh
from cec_base.event import Event
from cec_base.consumer import Consumer
from cec_base.producer import Producer, dispatch_producer
from cec_base.cec_client import MultiConsumer, CecAsyncConsumeTask, StoppableThread
from clogger import logger
from conf.settings import *
from app.crud import update_or_create_channel_params, get_channel_params_by_instance
from app.database import SessionLocal
from app import schemas
from lib.channels.base import ChannelResult, ChannelException, ChannelCode
from channel_job import JobEntry
class ChannelListener(MultiConsumer):
""" A cec-based channel listener
A cec-based channel lilster, ssed to listen to requests for channels from
other modules and output the results to cec after performing the corresponding
operation on the target node
Args:
task_process_thread_num(str): The number of threads contained in the thread
pool used to execute the task
"""
def __init__(self) -> None:
super().__init__(SYSOM_CEC_URL, custom_callback=self.on_receive_event)
self.append_group_consume_task(
SYSOM_CEC_CHANNEL_TOPIC,
SYSOM_CEC_CHANNEL_CONSUMER_GROUP,
Consumer.generate_consumer_id(),
ensure_topic_exist=True
)
self._target_topic = SYSOM_CEC_CHANNEL_RESULT_TOPIC
self._producer: Producer = dispatch_producer(SYSOM_CEC_URL)
# Define opt table
self._opt_table = {
'init': self._do_init_channel,
"cmd": self._do_run_command
}
# 执行任务的线程池数量
self._task_process_thread: Optional[StoppableThread] = None
self._task_queue: Queue = Queue(maxsize=1000)
def _get_channel(self, channel_type):
"""
根据要执行的命令,动态引入一个 Channel 的实现用于执行命令
"""
try:
return import_module(f'lib.channels.{channel_type}').Channel
except Exception as e:
raise Exception(f'No channels available => {str(e)}')
def _delivery(self, topic: str, value: dict):
self._producer.produce(topic, value)
self._producer.flush()
async def _perform_opt(self, opt_func: Callable[[str, dict], ChannelResult],
default_channel: str, task: dict) -> ChannelResult:
"""
Use the specified channel to perform operations on the remote
node and return the results.
"""
async def _try_another_channel(result: ChannelResult):
channels_path = os.path.join(BASE_DIR, 'lib', 'channels')
packages = [dir.replace('.py', '') for dir in os.listdir(
channels_path) if not dir.startswith('__')]
packages.remove('base')
packages.remove(default_channel)
err = None
for _, pkg in enumerate(packages):
try:
result = await opt_func(pkg, task)
err = None
break
except Exception as exc:
logger.error(str(exc))
err = exc
return result, err
result, err = ChannelResult(code=1), None
try:
result = await opt_func(default_channel, task)
if result.code != 0:
result, inner_err = await _try_another_channel(result)
if inner_err is not None:
err = inner_err
except Exception as exc:
logger.error(str(exc))
err = exc
result, inner_err = await _try_another_channel(result)
if inner_err is not None:
err = inner_err
if err is not None:
raise err
return result
async def _do_run_command(self, channel_type: str, task: dict) -> ChannelResult:
"""cmd opt"""
def on_data_received(data: str, data_type: asyncssh.DataType):
echo = task.get("echo", {})
bind_result_topic = task.get("bind_result_topic", None)
if bind_result_topic is not None:
self._delivery(bind_result_topic, {
"code": 100,
"err_msg": "",
"echo": echo,
"result": data
})
self._producer.flush()
params = task.get("params", {}).copy()
instance = params.get("instance", "")
timeout = params.pop(JobEntry.CHANNEL_PARAMS_TIMEOUT, None)
auto_retry = params.pop(JobEntry.CHANNEL_PARAMS_AUTO_RETRY, False)
no_need_sudo = params.pop(JobEntry.CHANNEL_PARAMS_NO_NEED_SUDO, False)
return_as_stream = params.pop(
JobEntry.CHANNEL_PARAMS_RETURN_AS_STREAM, False)
# Get params from sys_channel_params table if instance exists
with SessionLocal() as db:
channel_params_item = get_channel_params_by_instance(
db, instance=instance)
if channel_params_item is not None:
params = {
**params, **json.loads(channel_params_item.params)
}
res = await self._get_channel(channel_type)(**params).run_command_auto_retry_async(
timeout=timeout,
auto_retry=auto_retry,
no_need_sudo=no_need_sudo,
on_data_received=on_data_received if return_as_stream else None
)
return res
async def _do_init_channel(self, channel_type: str, task: dict) -> ChannelResult:
"""init opt"""
params = task.get("params", {}).copy()
instance = params.get("instance", "")
timeout = params.pop(JobEntry.CHANNEL_PARAMS_TIMEOUT, None)
auto_retry = params.pop(JobEntry.CHANNEL_PARAMS_AUTO_RETRY, False)
res1 = await self._get_channel(channel_type).initial_async(
**params, timeout=timeout, auto_retry=auto_retry
)
res = res1
if "password" in params and res1.code != 0:
# Use password failed, try to use key
params_without_password = params.copy()
params_without_password.pop("password", "")
res = await self._get_channel(channel_type).initial_async(
**params_without_password, timeout=timeout, auto_retry=auto_retry
)
# Save params after init channel success
if res.code == 0:
with SessionLocal() as db:
params.pop("password", "")
update_or_create_channel_params(db, schemas.ChannelParams(
instance=instance,
params=json.dumps(params)
))
return res
async def _process_each_task(self, event: Event, cecConsumeTask: CecAsyncConsumeTask):
"""
处理每个单独的任务
"""
task = event.value
result = {
"code": 0,
"err_msg": "",
"echo": task.get("echo", {}),
"result": ""
}
bind_result_topic = task.get("bind_result_topic", None)
try:
opt_type = task.get("type", "cmd")
channel_type = task.get("channel", "ssh")
params = task.get("params", {})
if opt_type not in self._opt_table:
result["code"] = 1
result["err_msg"] = f"Not support opt: {opt_type}"
else:
channel_result = await self._perform_opt(
self._opt_table[opt_type],
channel_type,
task=task
)
result["code"] = channel_result.code
if channel_result.code != 0:
result["err_msg"] = channel_result.err_msg
if channel_result.err_msg == "" and channel_result.result != "":
result["err_msg"] = channel_result.result
result["result"] = channel_result.result
except ChannelException as ce:
logger.exception(ce)
result["code"] = ce.code
result["err_msg"] = ce.message
result["result"] = ce.summary
except Exception as e:
logger.exception(e)
result["code"] = ChannelCode.SERVER_ERROR.value
result["err_msg"] = str(e)
result["result"] = "Channel Server Error"
finally:
# 执行消息确认
res = cecConsumeTask.ack(event)
# 将任务执行的结果写入到事件中心,供 Task 模块获取
self._delivery(self._target_topic, result)
# 如果显示指定了反馈topic则往该topic也发送一份
if (bind_result_topic):
self._delivery(bind_result_topic, result)
def on_receive_event(self, event: Event, task: CecAsyncConsumeTask):
self._task_queue.put(
self._process_each_task(event, task)
)
def _process_task(self):
def _get_task_from_queue():
_tasks = []
while not self._task_queue.empty():
_task = self._task_queue.get_nowait()
if _task:
_tasks.append(_task)
else:
break
return _tasks
tasks = _get_task_from_queue()
loop = asyncio.new_event_loop()
while not self._task_process_thread.stopped():
if len(tasks) == 0:
time.sleep(0.1)
tasks = _get_task_from_queue()
continue
finished, unfinished = loop.run_until_complete(
asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED, timeout=0.5)
)
for task in finished:
if task.exception() is not None:
logger.error(str(task.exception()))
else:
pass
tasks = _get_task_from_queue()
if unfinished is not None:
tasks += list(unfinished)
def start(self):
super().start()
if self._task_process_thread is not None \
and not self._task_process_thread.stopped() \
and self._task_process_thread.is_alive():
return
self._task_process_thread = StoppableThread(target=self._process_task)
self._task_process_thread.setDaemon(True)
self._task_process_thread.start()