sysom1/sysom_server/sysom_vul/apps/vul/async_fetch.py

201 lines
6.0 KiB
Python

'''
@File: async_fetch.py
@Time: 2022-12-20 14:54:18
@Author: DM
@Desc: 异步获取VUl数据
'''
import json
import asyncio
import requests
from typing import List, Dict, Union
from clogger import logger
from django.db import connection
from django.conf import settings
from aiohttp import ClientSession
from asyncio import AbstractEventLoop
from aiohttp.client_exceptions import ContentTypeError
from .models import VulAddrModel
from lib.utils import FetchHttpBase
class FetchVulData:
CONCURRENCY = 7
def __init__(self,
loop: AbstractEventLoop = None,
instance: VulAddrModel = None,
session: ClientSession = None,
cve_data_path: List[str] = None,
total_page: int = 98
) -> None:
self.session = session
self.loop = loop
self._instance = instance
self.total_page = total_page
self._cve_data_path = cve_data_path
self._tasks: List[asyncio.Task] = []
self.semaphore = asyncio.Semaphore(self.CONCURRENCY)
self._results: List[Dict] = []
@property
def request_session(self) -> ClientSession:
if self.session is None:
self.session = ClientSession()
return self.session
async def fetch(self, kwargs) -> Union[Dict, str]:
"""
获取请求数据
"""
result = None
try:
async with self.semaphore:
try:
async with self.request_session.request(**kwargs) as response:
result = await response.json()
except ContentTypeError:
res = await response.text()
result = f'被检测为Spider: {res}'
finally:
# await self._session_close()
...
return result
def result_callback(self, future: asyncio.Task) -> None:
res = future.result()
if isinstance(res, str):
logger.error(res)
return
if not self._cve_data_path:
self._results.extend(res)
else:
if len(self._cve_data_path) >= 1:
for key in self._cve_data_path:
res = res.get(key)
self._results.extend(res)
async def _session_close(self):
await self.session.close()
async def start(self) -> None:
for i in range(1, self.total_page+1):
kwargs = self.make_request_params(instance=self._instance, page_num=i)
task: asyncio.Task = self.loop.create_task(
self.fetch(kwargs))
task.add_done_callback(self.result_callback)
self._tasks.append(task)
await asyncio.gather(*self._tasks)
await self._session_close()
@property
def results(self):
for res in self._results:
yield res
@staticmethod
def make_request_params(instance: VulAddrModel, page_num: int = 1) -> Dict:
"""
结构化请求的参数
{
'method': instance.get_method_display(),
'url': instance.url,
'headers': json.loads(instance.headers),
'data': json.loads(instance.body),
'params': json.loads(instance.params),
'auth': auth
}
返回参数对象
"""
kwargs = dict()
kwargs['url'] = instance.url
kwargs['method'] = instance.get_method_display()
kwargs['headers'] = json.loads(instance.headers)
kwargs['data'] = json.loads(instance.body)
params = json.loads(instance.params)
params['page_num'] = page_num
kwargs['params'] = params
if instance.authorization_type.lower() == "basic":
authorization_body = json.loads(instance.authorization_body)
auth = (authorization_body.get('username'),
authorization_body.get('password'))
kwargs['auth'] = auth
return kwargs
@classmethod
def _get_page_total_num(cls, kwargs) -> Union[bool, int]:
"""向漏洞库请求数据
Args:
kwargs (_type_): 创建的结构化参数
Returns:
Union[bool, int]: 如果请求成功,返回请求数据的从页;失败返回False
"""
response = requests.request(**kwargs)
if response.status_code == 200:
result = response.json()
return result['data']['total_page']
else:
return False
@classmethod
def _update_vul_data_status(cls, instance: VulAddrModel, status: int):
try:
instance.status = status
instance.save()
finally:
connection.close()
@classmethod
def run(cls, instance: VulAddrModel, cve_data_path:List[str], loop=None):
"""实例化一个event loop
@instance VulAddrModel对象 (必填参数)
@cve_data_path 解析cve数据的结构体 []
: loop 默认值为None, 不传递可以new_event_loop
return []
"""
kwargs = cls.make_request_params(instance=instance)
page_total_num = cls._get_page_total_num(kwargs)
if not page_total_num:
logger.error(f'总页码数获取失败, 参数: {kwargs}')
cls._update_vul_data_status(instance, 1)
raise Exception('总页码数获取失败, 参数: {kwargs}')
cls._update_vul_data_status(instance, 0)
_loop = loop or asyncio.new_event_loop()
asyncio.set_event_loop(_loop)
spider = cls(
loop=_loop,
instance=instance,
cve_data_path=cve_data_path,
total_page=page_total_num
)
spider.loop.run_until_complete(spider.start())
spider.loop.close()
return spider.results
class FetchHost(FetchHttpBase):
SERVICE_NAME = settings.HOST_SERVICE_NAME
SERVICE_URI = settings.SERVICE_URL
SERVICE_PORT = settings.HOST_SERVICE_PORT
API_VERSION = settings.HOST_API_VERSION_
@classmethod
def get_host_list(cls) -> list:
return cls.initializtion().fetch('get', '')