From 2c67f9bc46071a90a350829727a1183dbabfd892 Mon Sep 17 00:00:00 2001 From: abejgonzalez Date: Sun, 10 Apr 2022 17:57:31 -0700 Subject: [PATCH] awstools typing + small organization --- .github/scripts/requirements.txt | 2 + deploy/awstools/awstools.py | 310 +++++++++++++++++-------------- 2 files changed, 170 insertions(+), 142 deletions(-) diff --git a/.github/scripts/requirements.txt b/.github/scripts/requirements.txt index a554bc4f..188f0523 100644 --- a/.github/scripts/requirements.txt +++ b/.github/scripts/requirements.txt @@ -3,3 +3,5 @@ boto3==1.20.21 pytz pyyaml requests +mypy_boto3_ec2==1.21.9 +mypy_boto3_s3==1.21.0 diff --git a/deploy/awstools/awstools.py b/deploy/awstools/awstools.py index a2470957..d30d61ad 100755 --- a/deploy/awstools/awstools.py +++ b/deploy/awstools/awstools.py @@ -9,12 +9,19 @@ import os from datetime import datetime, timedelta import time import sys +import json import boto3 import botocore from botocore import exceptions from fabric.api import local, hide, settings # type: ignore +# imports needed for python type checking +from typing import Any, Dict, Optional, List, Sequence, cast +from mypy_boto3_ec2.service_resource import Instance as EC2InstanceResource +from mypy_boto3_ec2.type_defs import FilterTypeDef +from mypy_boto3_s3.literals import BucketLocationConstraintType + # setup basic config for logging if __name__ == '__main__': logging.basicConfig() @@ -36,7 +43,7 @@ def depaginated_boto_query(client, operation, operation_params, return_key): return_values_all += page[return_key] return return_values_all -def valid_aws_configure_creds(): +def valid_aws_configure_creds() -> bool: """ See if aws configure has been run. Returns False if aws configure needs to be run, else True. @@ -55,7 +62,76 @@ def valid_aws_configure_creds(): return False return True -def aws_resource_names(): +def get_localhost_instance_info(url_ext: str) -> Optional[str]: + """ Obtain latest instance info from instance metadata service. See + https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html + for more info on what can be accessed. + + Args: + url_ext: Part of URL after 169.254.169.254/latest + + Returns: + Data obtained in string form or None + """ + res = None + # This takes multiple minutes without a timeout from the CI container. In + # practise it should resolve nearly instantly on an initialized EC2 instance. + curl_connection_timeout = 10 + with settings(ok_ret_codes=[0,28]), hide('everything'): + res = local(f"curl -s --connect-timeout {curl_connection_timeout} http://169.254.169.254/{url_ext}", capture=True) + rootLogger.debug(res.stdout) + rootLogger.debug(res.stderr) + + if res.return_code == 28: + return None + else: + return res.stdout + +def get_localhost_instance_id() -> Optional[str]: + """Get current manager instance id, if applicable. + + Returns: + A ``str`` of the instance id or ``None`` + """ + + return get_localhost_instance_info("latest/meta-data/instance-id") + +def get_localhost_tags() -> Dict[str, Any]: + """Get current manager tags. + + Returns: + A ``dict`` of tags (name -> value). Empty if no tags found or can't access the inst id. + """ + instanceid = get_localhost_instance_id() + rootLogger.debug(instanceid) + + resptags: Dict[str, Any] = {} + + if instanceid: + # Look up this instance's ID, if we do not have permission to describe tags, use the default dictionary + client = boto3.client('ec2') + try: + operation_params = { + 'Filters': [ + { + 'Name': 'resource-id', + 'Values': [ + instanceid, + ] + }, + ] + } + resp_pairs = depaginated_boto_query(client, 'describe_tags', operation_params, 'Tags') + except client.exceptions.ClientError: + return resptags + + for pair in resp_pairs: + resptags[pair['Key']] = pair['Value'] + rootLogger.debug(resptags) + + return resptags + +def aws_resource_names() -> Dict[str, Any]: """ Get names for various aws resources the manager relies on. For example: vpcname, securitygroupname, keyname, etc. @@ -85,62 +161,25 @@ def aws_resource_names(): 'runfarmprefix': None, } - resp = None - res = None - # This takes multiple minutes without a timeout from the CI container. In - # practise it should resolve nearly instantly on an initialized EC2 instance. - curl_connection_timeout = 10 - with settings(warn_only=True), hide('everything'): - res = local("""curl -s --connect-timeout {} http://169.254.169.254/latest/meta-data/instance-id""".format(curl_connection_timeout), capture=True) + resptags = get_localhost_tags() + if resptags: + in_tutorial_mode = 'firesim-tutorial-username' in resptags.keys() + if not in_tutorial_mode: + return base_dict - # Use the default dictionary if we're not on an EC2 instance (e.g., when a - # manager is launched from CI; during demos) - if res.return_code != 0: - return base_dict - - - # Look up this instance's ID, if we do not have permission to describe tags, use the default dictionary - client = boto3.client('ec2') - try: - instanceid = res.stdout - rootLogger.debug(instanceid) - - operation_params = { - 'Filters': [ - { - 'Name': 'resource-id', - 'Values': [ - instanceid, - ] - }, - ] - } - resp_pairs = depaginated_boto_query(client, 'describe_tags', operation_params, 'Tags') - except client.exceptions.ClientError: - return base_dict - - resptags = {} - for pair in resp_pairs: - resptags[pair['Key']] = pair['Value'] - rootLogger.debug(resptags) - - in_tutorial_mode = 'firesim-tutorial-username' in resptags.keys() - if not in_tutorial_mode: - return base_dict - - # at this point, assume we are in tutorial mode and get all tags we need - base_dict['tutorial_mode'] = True - base_dict['vpcname'] = resptags['firesim-tutorial-username'] - base_dict['securitygroupname'] = resptags['firesim-tutorial-username'] - base_dict['keyname'] = resptags['firesim-tutorial-username'] - base_dict['s3bucketname'] = resptags['firesim-tutorial-username'] - base_dict['snsname'] = resptags['firesim-tutorial-username'] - base_dict['runfarmprefix'] = resptags['firesim-tutorial-username'] + # at this point, assume we are in tutorial mode and get all tags we need + base_dict['tutorial_mode'] = True + base_dict['vpcname'] = resptags['firesim-tutorial-username'] + base_dict['securitygroupname'] = resptags['firesim-tutorial-username'] + base_dict['keyname'] = resptags['firesim-tutorial-username'] + base_dict['s3bucketname'] = resptags['firesim-tutorial-username'] + base_dict['snsname'] = resptags['firesim-tutorial-username'] + base_dict['runfarmprefix'] = resptags['firesim-tutorial-username'] return base_dict # AMIs are region specific -def get_f1_ami_id(): +def get_f1_ami_id() -> str: """ Get the AWS F1 Developer AMI by looking up the image name -- should be region independent. """ client = boto3.client('ec2') @@ -148,7 +187,7 @@ def get_f1_ami_id(): assert len(response['Images']) == 1 return response['Images'][0]['ImageId'] -def get_aws_userid(): +def get_aws_userid() -> str: """ Get the user's IAM ID to intelligently create a bucket name when doing managerinit. The previous method to do this was: @@ -158,10 +197,13 @@ def get_aws_userid(): But it seems that by default many accounts do not have permission to run this, so instead we get it from instance metadata. """ - res = local("""curl -s http://169.254.169.254/latest/dynamic/instance-identity/document | grep -oP '(?<="accountId" : ")[^"]*(?=")'""", capture=True) - return res.stdout.lower() + info = get_localhost_instance_info("dynamic/instance-identity/document") + if info is not None: + return json.loads(info)['accountId'].lower() + else: + assert False, "Unable to obtain accountId from instance metadata" -def construct_instance_market_options(instancemarket, spotinterruptionbehavior, spotmaxprice): +def construct_instance_market_options(instancemarket: str, spotinterruptionbehavior: str, spotmaxprice: str) -> Dict[str, Any]: """ construct the dictionary necessary to configure instance market selection (on-demand vs spot) See: @@ -169,7 +211,7 @@ def construct_instance_market_options(instancemarket, spotinterruptionbehavior, and https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_InstanceMarketOptionsRequest.html """ - instmarkoptions = dict() + instmarkoptions: Dict[str, Any] = dict() if instancemarket == "spot": instmarkoptions['MarketType'] = "spot" instmarkoptions['SpotOptions'] = dict() @@ -187,9 +229,9 @@ def construct_instance_market_options(instancemarket, spotinterruptionbehavior, else: assert False, "INVALID INSTANCE MARKET TYPE." -def launch_instances(instancetype, count, instancemarket, spotinterruptionbehavior, spotmaxprice, blockdevices=None, - tags=None, randomsubnet=False, user_data_file=None, timeout=timedelta(), always_expand=True, ami_id=None): - """ Launch `count` instances of type `instancetype` +def launch_instances(instancetype: str, count: int, instancemarket: str, spotinterruptionbehavior: str, spotmaxprice: str, blockdevices: Optional[List[Dict[str, Any]]] = None, + tags: Optional[Dict[str, Any]] = None, randomsubnet: bool = False, user_data_file: Optional[str] = None, timeout: timedelta = timedelta(), always_expand: bool = True, ami_id: Optional[str] = None) -> List[EC2InstanceResource]: + """Launch `count` instances of type `instancetype` Using `instancemarket`, `spotinterruptionbehavior` and `spotmaxprice` to define instance market conditions (see also: construct_market_conditions) @@ -197,46 +239,31 @@ def launch_instances(instancetype, count, instancemarket, spotinterruptionbehavi This will launch instances in avail zone 0, then once capacity runs out, zone 1, then zone 2, etc. The ordering of availablility zones can be randomized by passing`randomsubnet=True` - Parameters - ---------- - instancetype : str - String acceptable by `boto3.ec2.create_instances()` `InstanceType` parameter - count : int - The number of instances to launch - instancemarket - spotinterruptionbehavior - spotmaxprice - blockdevices - tags : dict of tag names to string values, default=None - Dict of tags - randomsubnet : bool, default=False - If true, subnets will be chosen randomly instead of starting from 0 and proceeding incrementally. - user_data_file : str, default=None - Path to readable file. Contents of file are passed as `UserData` to AWS - timeout : datetime.timedelta, default=timedelta() (immediate timeout after attempting all subnets) - `timedelta` object representing how long we should continue to try asking for instances - always_expand : bool, default=True - When true, create `count` instances, regardless of whether any already exist. When False, only - create instances until there are `count` total instances that match `tags` and `instancetype` - If `tags` are not passed, `always_expand` must be `True` or `ValueError` is thrown. - ami_id : Optional[str], default=None - Override AMI ID to use for launching instances. `None` results in the default AMI ID specified by - `awstools.get_f1_ami_id()`. + Args: + instancetype: String acceptable by `boto3.ec2.create_instances()` `InstanceType` parameter + count: The number of instances to launch + instancemarket + spotinterruptionbehavior + spotmaxprice + blockdevices + tags: dict of tag names to string values + randomsubnet: If true, subnets will be chosen randomly instead of starting from 0 and proceeding incrementally. + user_data_file: Path to readable file. Contents of file are passed as `UserData` to AWS + timeout: `timedelta` object representing how long we should continue to try asking for instances + always_expand: When true, create `count` instances, regardless of whether any already exist. When False, only + create instances until there are `count` total instances that match `tags` and `instancetype` + If `tags` are not passed, `always_expand` must be `True` or `ValueError` is thrown. + ami_id: Override AMI ID to use for launching instances. `None` results in the default AMI ID specified by + `awstools.get_f1_ami_id()`. - Return type - ----------- - list(boto.ec2.Instance) - - Returns - ------- - list of instance resources. If `always_expand` is True, this list contains only the instances created in this - call. When `always_expand` is False, it contains all instances matching `tags` whether created in this call or not + Returns: + List of instance resources. If `always_expand` is True, this list contains only the instances created in this + call. When `always_expand` is False, it contains all instances matching `tags` whether created in this call or not """ if tags is None and not always_expand: raise ValueError("always_expand=False requires tags to be given") - aws_resource_names_dict = aws_resource_names() keyname = aws_resource_names_dict['keyname'] securitygroupname = aws_resource_names_dict['securitygroupname'] @@ -245,7 +272,7 @@ def launch_instances(instancetype, count, instancemarket, spotinterruptionbehavi ec2 = boto3.resource('ec2') client = boto3.client('ec2') - vpcfilter = [{'Name':'tag:Name', 'Values': [vpcname]}] + vpcfilter: Sequence[FilterTypeDef] = [{'Name':'tag:Name', 'Values': [vpcname]}] # docs show 'NextToken' / 'MaxResults' which suggests pagination, but # the boto3 source says collections handle pagination automatically, # so assume this is fine @@ -346,16 +373,17 @@ def launch_instances(instancetype, count, instancemarket, spotinterruptionbehavi rootLogger.info("Continuing to request remaining {}, {} instances".format(count - len(instances), instancetype)) return instances -def run_block_device_dict(): +def run_block_device_dict() -> List[Dict[str, Any]]: return [ { 'DeviceName': '/dev/sda1', 'Ebs': { 'VolumeSize': 300, 'VolumeType': 'gp2' } } ] -def run_tag_dict(): +def run_tag_dict() -> Dict[str, Any]: return { 'fsimcluster': "defaultcluster" } -def run_filters_list_dict(): +def run_filters_list_dict() -> List[Dict[str, Any]]: return [ { 'Name': 'tag:fsimcluster', 'Values': [ "defaultcluster" ] } ] -def launch_run_instances(instancetype, count, fsimclustertag, instancemarket, spotinterruptionbehavior, spotmaxprice, timeout, always_expand): + +def launch_run_instances(instancetype: str, count: int, fsimclustertag: str, instancemarket: str, spotinterruptionbehavior: str, spotmaxprice: str, timeout: timedelta, always_expand: bool) -> List[EC2InstanceResource]: return launch_instances(instancetype, count, instancemarket, spotinterruptionbehavior, spotmaxprice, timeout=timeout, always_expand=always_expand, blockdevices=[ { @@ -368,7 +396,7 @@ def launch_run_instances(instancetype, count, fsimclustertag, instancemarket, sp ], tags={ 'fsimcluster': fsimclustertag }) -def get_instances_with_filter(filters, allowed_states=['pending', 'running', 'shutting-down', 'stopping', 'stopped']): +def get_instances_with_filter(filters: List[Dict[str, Any]], allowed_states: List[str] = ['pending', 'running', 'shutting-down', 'stopping', 'stopped']) -> List[EC2InstanceResource]: """ Produces a list of instances based on a set of provided filters """ ec2_client = boto3.client('ec2') operation_params = { @@ -385,20 +413,19 @@ def get_instances_with_filter(filters, allowed_states=['pending', 'running', 'sh instances.extend(res['Instances']) return instances -def get_run_instances_by_tag_type(fsimclustertag, instancetype): +def get_run_instances_by_tag_type(fsimclustertag: str, instancetype: str) -> List[EC2InstanceResource]: """ return list of instances that match fsimclustertag and instance type """ return get_instances_by_tag_type( tags={'fsimcluster': fsimclustertag}, instancetype=instancetype ) -def get_instances_by_tag_type(tags, instancetype): +def get_instances_by_tag_type(tags: Dict[str, Any], instancetype: str) -> List[EC2InstanceResource]: """ return list of instances that match all tags and instance type """ res = boto3.resource('ec2') # see note above. collections automatically handle pagination - instances = res.instances.filter( - Filters = [ + filters = [ { 'Name': 'instance-type', 'Values': [ @@ -411,44 +438,42 @@ def get_instances_by_tag_type(tags, instancetype): 'running', ] }, - ] - + - [ + ] + [ { - 'Name': 'tag:{}'.format(k), + 'Name': f'tag:{k}', 'Values': [ v, ] } for k, v in tags.items() - ] - ) - return instances + ] + instances = res.instances.filter(Filters = filters) # type: ignore + return list(instances) -def get_private_ips_for_instances(instances): +def get_private_ips_for_instances(instances: List[EC2InstanceResource]) -> List[str]: """" Take list of instances (as returned by create_instances), return private IPs. """ return [instance.private_ip_address for instance in instances] -def get_instance_ids_for_instances(instances): +def get_instance_ids_for_instances(instances: List[EC2InstanceResource]) -> List[str]: """" Take list of instances (as returned by create_instances), return instance ids. """ return [instance.id for instance in instances] -def instances_sorted_by_avail_ip(instances): +def instances_sorted_by_avail_ip(instances: List[EC2InstanceResource]) -> List[EC2InstanceResource]: """ This returns a list of instance objects, first sorted by their private ip, then sorted by availability zone. """ ips = get_private_ips_for_instances(instances) ips_to_instances = zip(ips, instances) insts = sorted(ips_to_instances, key=lambda x: x[0]) - insts = [x[1] for x in insts] - return sorted(insts, key=lambda x: x.placement['AvailabilityZone']) + ip_sorted_insts = [x[1] for x in insts] + return sorted(ip_sorted_insts, key=lambda x: x.placement['AvailabilityZone']) -def instance_privateip_lookup_table(instances): +def instance_privateip_lookup_table(instances: List[EC2InstanceResource]) -> Dict[str, EC2InstanceResource]: """ Given a list of instances, construct a lookup table that goes from privateip -> instance obj """ ips = get_private_ips_for_instances(instances) ips_to_instances = zip(ips, instances) return { ip: instance for (ip, instance) in ips_to_instances } -def wait_on_instance_launches(instances, message=""): +def wait_on_instance_launches(instances: List[EC2InstanceResource], message: str = "") -> None: """ Take a list of instances (as returned by create_instances), wait until instance is running. """ rootLogger.info("Waiting for instance boots: " + str(len(instances)) + " " + message) @@ -456,13 +481,13 @@ def wait_on_instance_launches(instances, message=""): instance.wait_until_running() rootLogger.info(str(instance.id) + " booted!") -def terminate_instances(instanceids, dryrun=True): +def terminate_instances(instanceids: List[str], dryrun: bool = True) -> None: """ Terminate instances when given a list of instance ids. for safety, this supplies dryrun=True by default. """ client = boto3.client('ec2') client.terminate_instances(InstanceIds=instanceids, DryRun=dryrun) -def auto_create_bucket(userbucketname): +def auto_create_bucket(userbucketname: str) -> None: """ Check if the user-specified s3 bucket is available. If we get a NoSuchBucket exception, create the bucket for the user. If we get any other exception, exit. @@ -484,7 +509,8 @@ def auto_create_bucket(userbucketname): # create the bucket for the user and setup directory structure rootLogger.info("Creating s3 bucket for you named: " + userbucketname) my_session = boto3.session.Session() - my_region = my_session.region_name + my_region: BucketLocationConstraintType + my_region = my_session.region_name # type: ignore # yes, this is unfortunately the correct way of handling this. # you cannot pass 'us-east-1' as a location constraint because @@ -500,12 +526,12 @@ def auto_create_bucket(userbucketname): # now, setup directory structure resp = s3cli.put_object( Bucket = userbucketname, - Body = '', + Body = b'', Key = 'dcp/' ) resp2 = s3cli.put_object( Bucket = userbucketname, - Body = '', + Body = b'', Key = 'logs/' ) @@ -516,7 +542,7 @@ def auto_create_bucket(userbucketname): rootLogger.critical(repr(exc)) assert False -def get_snsname_arn(): +def get_snsname_arn() -> Optional[str]: """ If the Topic doesn't exist create it, send catch exceptions while creating. Or if it exists get arn """ client = boto3.client('sns') @@ -538,7 +564,7 @@ def get_snsname_arn(): return response['TopicArn'] -def subscribe_to_firesim_topic(email): +def subscribe_to_firesim_topic(email: str) -> None: """ Subscribe a user to their FireSim SNS topic for notifications. """ client = boto3.client('sns') @@ -564,7 +590,7 @@ receive any notifications until you click the confirmation link.""".format(email rootLogger.warning(err) -def send_firesim_notification(subject, body): +def send_firesim_notification(subject: str, body: str) -> None: client = boto3.client('sns') arn = get_snsname_arn() @@ -585,7 +611,7 @@ def send_firesim_notification(subject, body): rootLogger.warning("Unknown exception is encountered while trying publish notifications") rootLogger.warning(err) -def main(args): +def main(args: List[str]) -> int: import argparse import yaml parser = argparse.ArgumentParser(description="Launch/terminate instances", formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -601,27 +627,27 @@ def main(args): parser.add_argument("--filters", type=yaml.safe_load, default=run_filters_list_dict(), help="List of dicts used to filter instances. Used by \'terminate\'.") parser.add_argument("--user_data_file", default=None, help="File path to use as user data (run on initialization). Used by \'launch\'.") parser.add_argument("--ami_id", default=get_f1_ami_id(), help="Override AMI ID used for launch. Defaults to \'awstools.get_f1_ami_id()\'. Used by \'launch\'.") - args = parser.parse_args(args) + parsed_args = parser.parse_args(args) - if args.command == "launch": + if parsed_args.command == "launch": insts = launch_instances( - args.inst_type, - args.inst_amt, - args.market, - args.int_behavior, - args.spot_max_price, - args.block_devices, - args.tags, - args.random_subnet, - args.user_data_file, - args.ami_id) + parsed_args.inst_type, + parsed_args.inst_amt, + parsed_args.market, + parsed_args.int_behavior, + parsed_args.spot_max_price, + parsed_args.block_devices, + parsed_args.tags, + parsed_args.random_subnet, + parsed_args.user_data_file, + parsed_args.ami_id) instids = get_instance_ids_for_instances(insts) print("Instance IDs: {}".format(instids)) wait_on_instance_launches(insts) print("Launched instance IPs: {}".format(get_private_ips_for_instances(insts))) else: # "terminate" - insts = get_instances_with_filter(args.filters) - instids = [ inst['InstanceId'] for inst in insts ] + insts = get_instances_with_filter(parsed_args.filters) + instids = [ inst.instance_id for inst in insts ] terminate_instances(instids, False) print("Terminated instance IDs: {}".format(instids)) return 0