awstools typing + small organization

This commit is contained in:
abejgonzalez 2022-04-10 17:57:31 -07:00 committed by Abraham Gonzalez
parent 4926b8ea45
commit 2c67f9bc46
2 changed files with 170 additions and 142 deletions

View File

@ -3,3 +3,5 @@ boto3==1.20.21
pytz
pyyaml
requests
mypy_boto3_ec2==1.21.9
mypy_boto3_s3==1.21.0

View File

@ -9,12 +9,19 @@ import os
from datetime import datetime, timedelta
import time
import sys
import json
import boto3
import botocore
from botocore import exceptions
from fabric.api import local, hide, settings # type: ignore
# imports needed for python type checking
from typing import Any, Dict, Optional, List, Sequence, cast
from mypy_boto3_ec2.service_resource import Instance as EC2InstanceResource
from mypy_boto3_ec2.type_defs import FilterTypeDef
from mypy_boto3_s3.literals import BucketLocationConstraintType
# setup basic config for logging
if __name__ == '__main__':
logging.basicConfig()
@ -36,7 +43,7 @@ def depaginated_boto_query(client, operation, operation_params, return_key):
return_values_all += page[return_key]
return return_values_all
def valid_aws_configure_creds():
def valid_aws_configure_creds() -> bool:
""" See if aws configure has been run. Returns False if aws configure
needs to be run, else True.
@ -55,7 +62,76 @@ def valid_aws_configure_creds():
return False
return True
def aws_resource_names():
def get_localhost_instance_info(url_ext: str) -> Optional[str]:
""" Obtain latest instance info from instance metadata service. See
https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html
for more info on what can be accessed.
Args:
url_ext: Part of URL after 169.254.169.254/latest
Returns:
Data obtained in string form or None
"""
res = None
# This takes multiple minutes without a timeout from the CI container. In
# practise it should resolve nearly instantly on an initialized EC2 instance.
curl_connection_timeout = 10
with settings(ok_ret_codes=[0,28]), hide('everything'):
res = local(f"curl -s --connect-timeout {curl_connection_timeout} http://169.254.169.254/{url_ext}", capture=True)
rootLogger.debug(res.stdout)
rootLogger.debug(res.stderr)
if res.return_code == 28:
return None
else:
return res.stdout
def get_localhost_instance_id() -> Optional[str]:
"""Get current manager instance id, if applicable.
Returns:
A ``str`` of the instance id or ``None``
"""
return get_localhost_instance_info("latest/meta-data/instance-id")
def get_localhost_tags() -> Dict[str, Any]:
"""Get current manager tags.
Returns:
A ``dict`` of tags (name -> value). Empty if no tags found or can't access the inst id.
"""
instanceid = get_localhost_instance_id()
rootLogger.debug(instanceid)
resptags: Dict[str, Any] = {}
if instanceid:
# Look up this instance's ID, if we do not have permission to describe tags, use the default dictionary
client = boto3.client('ec2')
try:
operation_params = {
'Filters': [
{
'Name': 'resource-id',
'Values': [
instanceid,
]
},
]
}
resp_pairs = depaginated_boto_query(client, 'describe_tags', operation_params, 'Tags')
except client.exceptions.ClientError:
return resptags
for pair in resp_pairs:
resptags[pair['Key']] = pair['Value']
rootLogger.debug(resptags)
return resptags
def aws_resource_names() -> Dict[str, Any]:
""" Get names for various aws resources the manager relies on. For example:
vpcname, securitygroupname, keyname, etc.
@ -85,62 +161,25 @@ def aws_resource_names():
'runfarmprefix': None,
}
resp = None
res = None
# This takes multiple minutes without a timeout from the CI container. In
# practise it should resolve nearly instantly on an initialized EC2 instance.
curl_connection_timeout = 10
with settings(warn_only=True), hide('everything'):
res = local("""curl -s --connect-timeout {} http://169.254.169.254/latest/meta-data/instance-id""".format(curl_connection_timeout), capture=True)
resptags = get_localhost_tags()
if resptags:
in_tutorial_mode = 'firesim-tutorial-username' in resptags.keys()
if not in_tutorial_mode:
return base_dict
# Use the default dictionary if we're not on an EC2 instance (e.g., when a
# manager is launched from CI; during demos)
if res.return_code != 0:
return base_dict
# Look up this instance's ID, if we do not have permission to describe tags, use the default dictionary
client = boto3.client('ec2')
try:
instanceid = res.stdout
rootLogger.debug(instanceid)
operation_params = {
'Filters': [
{
'Name': 'resource-id',
'Values': [
instanceid,
]
},
]
}
resp_pairs = depaginated_boto_query(client, 'describe_tags', operation_params, 'Tags')
except client.exceptions.ClientError:
return base_dict
resptags = {}
for pair in resp_pairs:
resptags[pair['Key']] = pair['Value']
rootLogger.debug(resptags)
in_tutorial_mode = 'firesim-tutorial-username' in resptags.keys()
if not in_tutorial_mode:
return base_dict
# at this point, assume we are in tutorial mode and get all tags we need
base_dict['tutorial_mode'] = True
base_dict['vpcname'] = resptags['firesim-tutorial-username']
base_dict['securitygroupname'] = resptags['firesim-tutorial-username']
base_dict['keyname'] = resptags['firesim-tutorial-username']
base_dict['s3bucketname'] = resptags['firesim-tutorial-username']
base_dict['snsname'] = resptags['firesim-tutorial-username']
base_dict['runfarmprefix'] = resptags['firesim-tutorial-username']
# at this point, assume we are in tutorial mode and get all tags we need
base_dict['tutorial_mode'] = True
base_dict['vpcname'] = resptags['firesim-tutorial-username']
base_dict['securitygroupname'] = resptags['firesim-tutorial-username']
base_dict['keyname'] = resptags['firesim-tutorial-username']
base_dict['s3bucketname'] = resptags['firesim-tutorial-username']
base_dict['snsname'] = resptags['firesim-tutorial-username']
base_dict['runfarmprefix'] = resptags['firesim-tutorial-username']
return base_dict
# AMIs are region specific
def get_f1_ami_id():
def get_f1_ami_id() -> str:
""" Get the AWS F1 Developer AMI by looking up the image name -- should be region independent.
"""
client = boto3.client('ec2')
@ -148,7 +187,7 @@ def get_f1_ami_id():
assert len(response['Images']) == 1
return response['Images'][0]['ImageId']
def get_aws_userid():
def get_aws_userid() -> str:
""" Get the user's IAM ID to intelligently create a bucket name when doing managerinit.
The previous method to do this was:
@ -158,10 +197,13 @@ def get_aws_userid():
But it seems that by default many accounts do not have permission to run this,
so instead we get it from instance metadata.
"""
res = local("""curl -s http://169.254.169.254/latest/dynamic/instance-identity/document | grep -oP '(?<="accountId" : ")[^"]*(?=")'""", capture=True)
return res.stdout.lower()
info = get_localhost_instance_info("dynamic/instance-identity/document")
if info is not None:
return json.loads(info)['accountId'].lower()
else:
assert False, "Unable to obtain accountId from instance metadata"
def construct_instance_market_options(instancemarket, spotinterruptionbehavior, spotmaxprice):
def construct_instance_market_options(instancemarket: str, spotinterruptionbehavior: str, spotmaxprice: str) -> Dict[str, Any]:
""" construct the dictionary necessary to configure instance market selection
(on-demand vs spot)
See:
@ -169,7 +211,7 @@ def construct_instance_market_options(instancemarket, spotinterruptionbehavior,
and
https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_InstanceMarketOptionsRequest.html
"""
instmarkoptions = dict()
instmarkoptions: Dict[str, Any] = dict()
if instancemarket == "spot":
instmarkoptions['MarketType'] = "spot"
instmarkoptions['SpotOptions'] = dict()
@ -187,9 +229,9 @@ def construct_instance_market_options(instancemarket, spotinterruptionbehavior,
else:
assert False, "INVALID INSTANCE MARKET TYPE."
def launch_instances(instancetype, count, instancemarket, spotinterruptionbehavior, spotmaxprice, blockdevices=None,
tags=None, randomsubnet=False, user_data_file=None, timeout=timedelta(), always_expand=True, ami_id=None):
""" Launch `count` instances of type `instancetype`
def launch_instances(instancetype: str, count: int, instancemarket: str, spotinterruptionbehavior: str, spotmaxprice: str, blockdevices: Optional[List[Dict[str, Any]]] = None,
tags: Optional[Dict[str, Any]] = None, randomsubnet: bool = False, user_data_file: Optional[str] = None, timeout: timedelta = timedelta(), always_expand: bool = True, ami_id: Optional[str] = None) -> List[EC2InstanceResource]:
"""Launch `count` instances of type `instancetype`
Using `instancemarket`, `spotinterruptionbehavior` and `spotmaxprice` to define instance market conditions
(see also: construct_market_conditions)
@ -197,46 +239,31 @@ def launch_instances(instancetype, count, instancemarket, spotinterruptionbehavi
This will launch instances in avail zone 0, then once capacity runs out, zone 1, then zone 2, etc.
The ordering of availablility zones can be randomized by passing`randomsubnet=True`
Parameters
----------
instancetype : str
String acceptable by `boto3.ec2.create_instances()` `InstanceType` parameter
count : int
The number of instances to launch
instancemarket
spotinterruptionbehavior
spotmaxprice
blockdevices
tags : dict of tag names to string values, default=None
Dict of tags
randomsubnet : bool, default=False
If true, subnets will be chosen randomly instead of starting from 0 and proceeding incrementally.
user_data_file : str, default=None
Path to readable file. Contents of file are passed as `UserData` to AWS
timeout : datetime.timedelta, default=timedelta() (immediate timeout after attempting all subnets)
`timedelta` object representing how long we should continue to try asking for instances
always_expand : bool, default=True
When true, create `count` instances, regardless of whether any already exist. When False, only
create instances until there are `count` total instances that match `tags` and `instancetype`
If `tags` are not passed, `always_expand` must be `True` or `ValueError` is thrown.
ami_id : Optional[str], default=None
Override AMI ID to use for launching instances. `None` results in the default AMI ID specified by
`awstools.get_f1_ami_id()`.
Args:
instancetype: String acceptable by `boto3.ec2.create_instances()` `InstanceType` parameter
count: The number of instances to launch
instancemarket
spotinterruptionbehavior
spotmaxprice
blockdevices
tags: dict of tag names to string values
randomsubnet: If true, subnets will be chosen randomly instead of starting from 0 and proceeding incrementally.
user_data_file: Path to readable file. Contents of file are passed as `UserData` to AWS
timeout: `timedelta` object representing how long we should continue to try asking for instances
always_expand: When true, create `count` instances, regardless of whether any already exist. When False, only
create instances until there are `count` total instances that match `tags` and `instancetype`
If `tags` are not passed, `always_expand` must be `True` or `ValueError` is thrown.
ami_id: Override AMI ID to use for launching instances. `None` results in the default AMI ID specified by
`awstools.get_f1_ami_id()`.
Return type
-----------
list(boto.ec2.Instance)
Returns
-------
list of instance resources. If `always_expand` is True, this list contains only the instances created in this
call. When `always_expand` is False, it contains all instances matching `tags` whether created in this call or not
Returns:
List of instance resources. If `always_expand` is True, this list contains only the instances created in this
call. When `always_expand` is False, it contains all instances matching `tags` whether created in this call or not
"""
if tags is None and not always_expand:
raise ValueError("always_expand=False requires tags to be given")
aws_resource_names_dict = aws_resource_names()
keyname = aws_resource_names_dict['keyname']
securitygroupname = aws_resource_names_dict['securitygroupname']
@ -245,7 +272,7 @@ def launch_instances(instancetype, count, instancemarket, spotinterruptionbehavi
ec2 = boto3.resource('ec2')
client = boto3.client('ec2')
vpcfilter = [{'Name':'tag:Name', 'Values': [vpcname]}]
vpcfilter: Sequence[FilterTypeDef] = [{'Name':'tag:Name', 'Values': [vpcname]}]
# docs show 'NextToken' / 'MaxResults' which suggests pagination, but
# the boto3 source says collections handle pagination automatically,
# so assume this is fine
@ -346,16 +373,17 @@ def launch_instances(instancetype, count, instancemarket, spotinterruptionbehavi
rootLogger.info("Continuing to request remaining {}, {} instances".format(count - len(instances), instancetype))
return instances
def run_block_device_dict():
def run_block_device_dict() -> List[Dict[str, Any]]:
return [ { 'DeviceName': '/dev/sda1', 'Ebs': { 'VolumeSize': 300, 'VolumeType': 'gp2' } } ]
def run_tag_dict():
def run_tag_dict() -> Dict[str, Any]:
return { 'fsimcluster': "defaultcluster" }
def run_filters_list_dict():
def run_filters_list_dict() -> List[Dict[str, Any]]:
return [ { 'Name': 'tag:fsimcluster', 'Values': [ "defaultcluster" ] } ]
def launch_run_instances(instancetype, count, fsimclustertag, instancemarket, spotinterruptionbehavior, spotmaxprice, timeout, always_expand):
def launch_run_instances(instancetype: str, count: int, fsimclustertag: str, instancemarket: str, spotinterruptionbehavior: str, spotmaxprice: str, timeout: timedelta, always_expand: bool) -> List[EC2InstanceResource]:
return launch_instances(instancetype, count, instancemarket, spotinterruptionbehavior, spotmaxprice, timeout=timeout, always_expand=always_expand,
blockdevices=[
{
@ -368,7 +396,7 @@ def launch_run_instances(instancetype, count, fsimclustertag, instancemarket, sp
],
tags={ 'fsimcluster': fsimclustertag })
def get_instances_with_filter(filters, allowed_states=['pending', 'running', 'shutting-down', 'stopping', 'stopped']):
def get_instances_with_filter(filters: List[Dict[str, Any]], allowed_states: List[str] = ['pending', 'running', 'shutting-down', 'stopping', 'stopped']) -> List[EC2InstanceResource]:
""" Produces a list of instances based on a set of provided filters """
ec2_client = boto3.client('ec2')
operation_params = {
@ -385,20 +413,19 @@ def get_instances_with_filter(filters, allowed_states=['pending', 'running', 'sh
instances.extend(res['Instances'])
return instances
def get_run_instances_by_tag_type(fsimclustertag, instancetype):
def get_run_instances_by_tag_type(fsimclustertag: str, instancetype: str) -> List[EC2InstanceResource]:
""" return list of instances that match fsimclustertag and instance type """
return get_instances_by_tag_type(
tags={'fsimcluster': fsimclustertag},
instancetype=instancetype
)
def get_instances_by_tag_type(tags, instancetype):
def get_instances_by_tag_type(tags: Dict[str, Any], instancetype: str) -> List[EC2InstanceResource]:
""" return list of instances that match all tags and instance type """
res = boto3.resource('ec2')
# see note above. collections automatically handle pagination
instances = res.instances.filter(
Filters = [
filters = [
{
'Name': 'instance-type',
'Values': [
@ -411,44 +438,42 @@ def get_instances_by_tag_type(tags, instancetype):
'running',
]
},
]
+
[
] + [
{
'Name': 'tag:{}'.format(k),
'Name': f'tag:{k}',
'Values': [
v,
]
} for k, v in tags.items()
]
)
return instances
]
instances = res.instances.filter(Filters = filters) # type: ignore
return list(instances)
def get_private_ips_for_instances(instances):
def get_private_ips_for_instances(instances: List[EC2InstanceResource]) -> List[str]:
"""" Take list of instances (as returned by create_instances), return private IPs. """
return [instance.private_ip_address for instance in instances]
def get_instance_ids_for_instances(instances):
def get_instance_ids_for_instances(instances: List[EC2InstanceResource]) -> List[str]:
"""" Take list of instances (as returned by create_instances), return instance ids. """
return [instance.id for instance in instances]
def instances_sorted_by_avail_ip(instances):
def instances_sorted_by_avail_ip(instances: List[EC2InstanceResource]) -> List[EC2InstanceResource]:
""" This returns a list of instance objects, first sorted by their private ip,
then sorted by availability zone. """
ips = get_private_ips_for_instances(instances)
ips_to_instances = zip(ips, instances)
insts = sorted(ips_to_instances, key=lambda x: x[0])
insts = [x[1] for x in insts]
return sorted(insts, key=lambda x: x.placement['AvailabilityZone'])
ip_sorted_insts = [x[1] for x in insts]
return sorted(ip_sorted_insts, key=lambda x: x.placement['AvailabilityZone'])
def instance_privateip_lookup_table(instances):
def instance_privateip_lookup_table(instances: List[EC2InstanceResource]) -> Dict[str, EC2InstanceResource]:
""" Given a list of instances, construct a lookup table that goes from
privateip -> instance obj """
ips = get_private_ips_for_instances(instances)
ips_to_instances = zip(ips, instances)
return { ip: instance for (ip, instance) in ips_to_instances }
def wait_on_instance_launches(instances, message=""):
def wait_on_instance_launches(instances: List[EC2InstanceResource], message: str = "") -> None:
""" Take a list of instances (as returned by create_instances), wait until
instance is running. """
rootLogger.info("Waiting for instance boots: " + str(len(instances)) + " " + message)
@ -456,13 +481,13 @@ def wait_on_instance_launches(instances, message=""):
instance.wait_until_running()
rootLogger.info(str(instance.id) + " booted!")
def terminate_instances(instanceids, dryrun=True):
def terminate_instances(instanceids: List[str], dryrun: bool = True) -> None:
""" Terminate instances when given a list of instance ids. for safety,
this supplies dryrun=True by default. """
client = boto3.client('ec2')
client.terminate_instances(InstanceIds=instanceids, DryRun=dryrun)
def auto_create_bucket(userbucketname):
def auto_create_bucket(userbucketname: str) -> None:
""" Check if the user-specified s3 bucket is available.
If we get a NoSuchBucket exception, create the bucket for the user.
If we get any other exception, exit.
@ -484,7 +509,8 @@ def auto_create_bucket(userbucketname):
# create the bucket for the user and setup directory structure
rootLogger.info("Creating s3 bucket for you named: " + userbucketname)
my_session = boto3.session.Session()
my_region = my_session.region_name
my_region: BucketLocationConstraintType
my_region = my_session.region_name # type: ignore
# yes, this is unfortunately the correct way of handling this.
# you cannot pass 'us-east-1' as a location constraint because
@ -500,12 +526,12 @@ def auto_create_bucket(userbucketname):
# now, setup directory structure
resp = s3cli.put_object(
Bucket = userbucketname,
Body = '',
Body = b'',
Key = 'dcp/'
)
resp2 = s3cli.put_object(
Bucket = userbucketname,
Body = '',
Body = b'',
Key = 'logs/'
)
@ -516,7 +542,7 @@ def auto_create_bucket(userbucketname):
rootLogger.critical(repr(exc))
assert False
def get_snsname_arn():
def get_snsname_arn() -> Optional[str]:
""" If the Topic doesn't exist create it, send catch exceptions while creating. Or if it exists get arn """
client = boto3.client('sns')
@ -538,7 +564,7 @@ def get_snsname_arn():
return response['TopicArn']
def subscribe_to_firesim_topic(email):
def subscribe_to_firesim_topic(email: str) -> None:
""" Subscribe a user to their FireSim SNS topic for notifications. """
client = boto3.client('sns')
@ -564,7 +590,7 @@ receive any notifications until you click the confirmation link.""".format(email
rootLogger.warning(err)
def send_firesim_notification(subject, body):
def send_firesim_notification(subject: str, body: str) -> None:
client = boto3.client('sns')
arn = get_snsname_arn()
@ -585,7 +611,7 @@ def send_firesim_notification(subject, body):
rootLogger.warning("Unknown exception is encountered while trying publish notifications")
rootLogger.warning(err)
def main(args):
def main(args: List[str]) -> int:
import argparse
import yaml
parser = argparse.ArgumentParser(description="Launch/terminate instances", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@ -601,27 +627,27 @@ def main(args):
parser.add_argument("--filters", type=yaml.safe_load, default=run_filters_list_dict(), help="List of dicts used to filter instances. Used by \'terminate\'.")
parser.add_argument("--user_data_file", default=None, help="File path to use as user data (run on initialization). Used by \'launch\'.")
parser.add_argument("--ami_id", default=get_f1_ami_id(), help="Override AMI ID used for launch. Defaults to \'awstools.get_f1_ami_id()\'. Used by \'launch\'.")
args = parser.parse_args(args)
parsed_args = parser.parse_args(args)
if args.command == "launch":
if parsed_args.command == "launch":
insts = launch_instances(
args.inst_type,
args.inst_amt,
args.market,
args.int_behavior,
args.spot_max_price,
args.block_devices,
args.tags,
args.random_subnet,
args.user_data_file,
args.ami_id)
parsed_args.inst_type,
parsed_args.inst_amt,
parsed_args.market,
parsed_args.int_behavior,
parsed_args.spot_max_price,
parsed_args.block_devices,
parsed_args.tags,
parsed_args.random_subnet,
parsed_args.user_data_file,
parsed_args.ami_id)
instids = get_instance_ids_for_instances(insts)
print("Instance IDs: {}".format(instids))
wait_on_instance_launches(insts)
print("Launched instance IPs: {}".format(get_private_ips_for_instances(insts)))
else: # "terminate"
insts = get_instances_with_filter(args.filters)
instids = [ inst['InstanceId'] for inst in insts ]
insts = get_instances_with_filter(parsed_args.filters)
instids = [ inst.instance_id for inst in insts ]
terminate_instances(instids, False)
print("Terminated instance IDs: {}".format(instids))
return 0