mirror of https://gitee.com/anolis/sysom.git
98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
from typing import Callable
|
|
import paramiko
|
|
from io import StringIO
|
|
from paramiko.client import SSHClient, AutoAddPolicy
|
|
from paramiko.rsakey import RSAKey
|
|
|
|
DEFAULT_CONNENT_TIMEOUT = 5 # 默认ssh链接超时时间 5s
|
|
DEFAULT_NODE_USER = 'root' # 默认节点用户名 root
|
|
|
|
|
|
class SSH:
|
|
"""A SSH client used to run command in remote node
|
|
|
|
args:
|
|
hostname(str): Host name
|
|
|
|
Keyword Args:
|
|
username(str): User name, default 'root'
|
|
port(str): SSH communicate port, default 22
|
|
connect_timeout(int): Connection timeout duration, default 5s
|
|
password(str)
|
|
"""
|
|
|
|
# key_pair cached the key pair generated by initialization stage
|
|
_key_pair = {}
|
|
_private_key_getter: Callable[[], str] = None
|
|
_public_key_getter: Callable[[], str] = None
|
|
|
|
def __init__(self, hostname: str, **kwargs) -> None:
|
|
self.connect_args = {
|
|
'hostname': hostname,
|
|
'username': kwargs.get('username', DEFAULT_NODE_USER),
|
|
'port': kwargs.get('port', 22),
|
|
'timeout': kwargs.get('timeout', DEFAULT_CONNENT_TIMEOUT),
|
|
}
|
|
if 'password' in kwargs and kwargs['password'] is not None:
|
|
self.connect_args['password'] = kwargs.get('password')
|
|
else:
|
|
if SSH._private_key_getter is None:
|
|
raise Exception("_private_key_getter not set")
|
|
self.connect_args['pkey'] = RSAKey.from_private_key(
|
|
StringIO(SSH._private_key_getter())
|
|
)
|
|
|
|
self._client: SSHClient = self.client()
|
|
|
|
def client(self):
|
|
try:
|
|
client = SSHClient()
|
|
client.set_missing_host_key_policy(AutoAddPolicy)
|
|
client.connect(**self.connect_args)
|
|
return client
|
|
except paramiko.AuthenticationException:
|
|
raise Exception('authorization fail, password or pkey error!')
|
|
except:
|
|
raise Exception('authorization fail!')
|
|
|
|
@classmethod
|
|
def set_private_key_getter(cls, private_key_getter: Callable[[], str]):
|
|
cls._private_key_getter = private_key_getter
|
|
|
|
@classmethod
|
|
def set_public_key_getter(cls, public_key_getter: Callable[[], str]):
|
|
cls._public_key_getter = public_key_getter
|
|
|
|
def run_command(self, command):
|
|
if self._client:
|
|
ssh_session = self._client.get_transport().open_session()
|
|
ssh_session.set_combine_stderr(True)
|
|
ssh_session.exec_command(command)
|
|
stdout = ssh_session.makefile("rb", -1)
|
|
statue = ssh_session.recv_exit_status()
|
|
output = stdout.read().decode()
|
|
return statue, output
|
|
else:
|
|
raise Exception('No client!')
|
|
|
|
def add_public_key(self):
|
|
if self._public_key_getter is None:
|
|
raise Exception("_public_key_getter not set")
|
|
public_key = SSH._public_key_getter()
|
|
command = f'mkdir -p -m 700 ~/.ssh && \
|
|
echo {public_key!r} >> ~/.ssh/authorized_keys && \
|
|
chmod 600 ~/.ssh/authorized_keys'
|
|
statue, _ = self.run_command(command)
|
|
if statue != 0:
|
|
raise Exception('add public key faild!')
|
|
|
|
@staticmethod
|
|
def validate_ssh_host(ip: str, password: str, port: int = 22, username: str = 'root'):
|
|
try:
|
|
ssh = SSH(hostname=ip, password=password,
|
|
port=port, username=username, timeout=2)
|
|
ssh.add_public_key()
|
|
return True, 'authorization success'
|
|
except Exception as e:
|
|
return False, f'error: {e}'
|