Source code for motorway.contrib.amazon_kinesis.ramps

import random
from Queue import Queue
import json
from threading import Thread, Lock, Semaphore
import time
import uuid
import datetime
import logging
import boto3
from motorway.messages import Message
from motorway.ramp import Ramp
from boto3.dynamodb.conditions import Attr
shard_election_logger = logging.getLogger("motorway.contrib.amazon_kinesis.shard_election")

logger = logging.getLogger(__name__)


class NoItemsReturned(Exception):
    pass


[docs]class KinesisRamp(Ramp): stream_name = None heartbeat_timeout = 30 # Wait 10 seconds for a heartbeat update, or kill it MAX_UNCOMPLETED_ITEMS = 3000 GET_RECORDS_LIMIT = 1000 def __init__(self, shard_threads_enabled=True, **kwargs): super(KinesisRamp, self).__init__(**kwargs) self.conn = boto3.client(**self.connection_parameters('kinesis')) assert self.stream_name, "Please define attribute stream_name on your KinesisRamp" control_table_name = self.get_control_table_name() self.worker_id = str(uuid.uuid4()) self.semaphore = Semaphore() self.uncompleted_ids = {} self.dynamodb_client = boto3.client(**self.connection_parameters('dynamodb')) if shard_threads_enabled: self.dynamodb = boto3.resource(**self.connection_parameters('dynamodb')) try: self.dynamodb_client.describe_table(TableName=control_table_name) except self.dynamodb_client.exceptions.ResourceNotFoundException: self.dynamodb_client.create_table( TableName=control_table_name, KeySchema=[ { 'AttributeName': 'shard_id', 'KeyType': 'HASH' }, ], ProvisionedThroughput={ 'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10 }, AttributeDefinitions=[ { 'AttributeName': 'shard_id', 'AttributeType': 'S' }, ], ) self.dynamodb_client.get_waiter('table_exists').wait(TableName=control_table_name) self.control_table = self.dynamodb.Table(control_table_name) shards = self.conn.describe_stream(StreamName=self.stream_name)['StreamDescription']['Shards'] random.shuffle(shards) # Start the threads in random order, in case of bulk restart threads = [] self.insertion_queue = Queue() for i, shard in enumerate(shards): self.uncompleted_ids[shard['ShardId']] = set() t = Thread(target=self.process_shard, name="%s-%s" % (self.__class__.__name__, i), args=(shard['ShardId'], )) threads.append(t) t.start() def get_control_table_name(self): return 'pipeline-control-%s' % self.stream_name
[docs] def claim_shard(self, shard_id): """ Atomically update the shard in DynamoDB :param shard_id: :return: bool """ shard_election_logger.info("Claiming shard %s" % shard_id) try: control_record = self.control_table.get_item(Key={'shard_id': shard_id})['Item'] except KeyError: raise NoItemsReturned() control_record['worker_id'] = self.worker_id control_record['heartbeat'] = 0 try: self.control_table.put_item(Item=control_record, ConditionExpression=Attr('shard_id').eq(shard_id) & Attr('checkpoint').eq(control_record['checkpoint']) # ensure that the record was not changed between the get and put. ) except self.dynamodb_client.exceptions.ConditionalCheckFailedException: # Someone else edited the record shard_election_logger.debug("Failed to claim %s to %s" % (shard_id, self.worker_id)) return False return True
[docs] def can_claim_shard(self, shard_id): """ Determine whether or not a given shard can be claimed because of 1) It's currently not being processed by another process 2) It's unevenly balanced between the consuming nodes/workers :param shard_id: :return: bool """ # ===================================== # | C L A I M S T A L E S H A R D S | # ===================================== worker_heartbeats = {} # Store all heartbeats so we can compare them easily to track changes control_record = None shards = self.control_table.scan()['Items'] for shard in shards: if shard['shard_id'] == shard_id: control_record = dict(shard) worker_heartbeats[shard['worker_id']] = shard['heartbeat'] if control_record is None: raise NoItemsReturned() worker_heartbeats[self.worker_id] = 0 time.sleep(self.heartbeat_timeout) updated_control_record = self.control_table.get_item(Key={'shard_id': shard_id})['Item'] # Continue sleeping if heartbeat or worker id has changed if control_record['heartbeat'] == updated_control_record['heartbeat'] and control_record['worker_id'] == updated_control_record['worker_id']: # if both the heartbeat and the worker_id is the same shard_election_logger.debug("Shard %s - heartbeat and worker id remained unchanged for defined time, taking over" % shard_id) return True elif updated_control_record['worker_id'] != control_record['worker_id']: shard_election_logger.debug("Shard %s - Worker id changed to %s, continue sleeping" % (shard_id, updated_control_record['worker_id'])) else: shard_election_logger.debug("Shard %s - Heartbeat changed, continue sleeping" % shard_id) # ===================== # | B A L A N C I N G | # ===================== # Balance, if possible active_workers = { self.worker_id: True } # re-fetch the shards and compare the heartbeat shards = self.control_table.scan()['Items'] for shard in shards: # Update active worker cache if shard['worker_id'] in worker_heartbeats and worker_heartbeats[shard['worker_id']] == shard['heartbeat']: active_workers[shard['worker_id']] = False else: active_workers[shard['worker_id']] = True number_of_active_workers = sum([1 for is_active in active_workers.values() if is_active]) number_of_shards = len(shards) optimal_number_of_shards_per_worker = number_of_shards / number_of_active_workers workers = set([shard['worker_id'] for shard in shards]) shards_per_worker = {worker: sum([1 for shard in shards if shard['worker_id'] == worker]) for worker in workers} for shard in shards: if shard['shard_id'] == shard_id: if ( # Check if the shards current worker has too many, or if the worker has no workers, then take # the shard if the current worker has more than one shard! shards_per_worker.get(shard['worker_id'], 0) > optimal_number_of_shards_per_worker or ( shards_per_worker.get(self.worker_id, 0) == 0 and shards_per_worker.get(shard['worker_id'], 0) > 1 ) ) and ( # Only get shards for balancing purposes, if we have too little shards_per_worker.get(self.worker_id, 0) < optimal_number_of_shards_per_worker ): shard_election_logger.debug("Taking over %s from %s" % (shard_id, shard['worker_id'])) return True return False
[docs] def process_shard(self, shard_id): """ Every shard (at startup) has an active thread that runs this function to either consume or wait to be ready to consume data from a shard :param shard_id: :return: """ while True: try: # try to claim the shard try: # Continuously try to claim until broken out of while True: # First check if we can claim outside the semaphore in parallel on all shards if self.can_claim_shard(shard_id): # If we can claim it, try again with a semaphore ensuring we fully check the entire # table of workers/shards before we take the final decision with self.semaphore: if self.can_claim_shard(shard_id): if self.claim_shard(shard_id): break # Wait a bit until we check if it's available to claim again time.sleep(random.randrange(2, 15)) except NoItemsReturned: # no record for this shard found, nobody ever claimed the shard yet, so claim it self.control_table.put_item(Item={ 'shard_id': shard_id, 'checkpoint': 0, 'worker_id': self.worker_id, 'heartbeat': 0, }) # get initial iterator control_record = self.control_table.get_item(Key={'shard_id': shard_id})['Item'] if control_record['checkpoint'] > 0: # if we have a checkpoint, start from the checkpoint iterator = self.conn.get_shard_iterator( StreamName=self.stream_name, ShardId=shard_id, ShardIteratorType="AT_SEQUENCE_NUMBER", StartingSequenceNumber=str(control_record['checkpoint']) )['ShardIterator'] else: # we have no checkpoint stored, start from the latest item in Kinesis iterator = self.conn.get_shard_iterator( StreamName=self.stream_name, ShardId=shard_id, ShardIteratorType="LATEST", )['ShardIterator'] cloudwatch = boto3.client(**self.connection_parameters('cloudwatch')) current_minute = lambda: datetime.datetime.now().minute minute = None latest_item = None while True: control_record = self.control_table.get_item(Key={'shard_id': shard_id})['Item'] # always retrieve this at the top of the loop current_checkpoint = control_record['checkpoint'] current_heartbeat = control_record['heartbeat'] # if the shard was claimed by another worker, break out of the loop if not control_record['worker_id'] == self.worker_id: shard_election_logger.info("Lost shard %s, going back to standby" % shard_id) break # update the heartbeat and the checkpoint control_record['heartbeat'] += 1 if len(self.uncompleted_ids[shard_id]): # Get the "youngest" uncompleted sequence number control_record['checkpoint'] = min(self.uncompleted_ids[shard_id]) elif latest_item: # or the latest item we yielded control_record['checkpoint'] = latest_item self.control_table.put_item(Item=control_record, ConditionExpression=Attr('shard_id').eq(shard_id) & Attr('checkpoint').eq(current_checkpoint) & Attr('worker_id').eq(self.worker_id) & Attr('heartbeat').eq(current_heartbeat) # Will fail if someone else modified it - ConditionalCheckFailedException ) if len(self.uncompleted_ids[shard_id]) < self.MAX_UNCOMPLETED_ITEMS: # get records from Kinesis, using the previously created iterator result = self.conn.get_records(ShardIterator=iterator, Limit=self.GET_RECORDS_LIMIT) # insert the records into the queue, and use the provided iterator for the next loop for record in result['Records']: self.uncompleted_ids[shard_id].add(record['SequenceNumber']) latest_item = record['SequenceNumber'] self.insertion_queue.put(record) iterator = result['NextShardIterator'] else: logger.debug("Pausing, too many uncompleted items (%s/%s)" % (len(self.uncompleted_ids[shard_id]), self.MAX_UNCOMPLETED_ITEMS)) # we have too many uncompleted items, so back off for a while # however, the iterator needs to be updated, because it expires after a while # use the latest record we added to the queue as the starting point next_iterator_number = latest_item if latest_item else str(control_record['checkpoint']) iterator = self.conn.get_shard_iterator( StreamName=self.stream_name, ShardId=shard_id, ShardIteratorType="AT_SEQUENCE_NUMBER", StartingSequenceNumber=next_iterator_number )['ShardIterator'] # get just one item to update the MillisBehindLatest below result = self.conn.get_records(ShardIterator=iterator, Limit=1) # Push metrics to CloudWatch delay = result['MillisBehindLatest'] if minute != current_minute(): # push once per minute to CloudWatch. minute = current_minute() cloudwatch.put_metric_data(Namespace='Motorway/Kinesis', MetricData=[{'MetricName': 'MillisecondsBehind', 'Dimensions': [{ 'Name': 'Stream', 'Value': self.stream_name }, { 'Name': 'Shard', 'Value': shard_id }], 'Value': delay, 'Unit': 'Milliseconds' }]) # recommended pause between fetches from AWS time.sleep(1) except self.dynamodb_client.exceptions.ConditionalCheckFailedException as e: logger.warning(e) pass # we're no longer worker for this shard except (self.dynamodb_client.exceptions.ProvisionedThroughputExceededException, self.dynamodb_client.exceptions.ProvisionedThroughputExceededException, self.conn.exceptions.LimitExceededException, self.conn.exceptions.ProvisionedThroughputExceededException) as e: logger.warning(e) time.sleep(random.randrange(5, self.heartbeat_timeout/2)) # back off for a while
def connection_parameters(self, service_name): return { 'region_name': 'eu-west-1', 'service_name': service_name, # Add this or use ENV VARS # 'aws_access_key_id': '', # 'aws_secret_access_key': '' }
[docs] def next(self): msg = self.insertion_queue.get() try: yield Message(msg['SequenceNumber'], json.loads(msg['Data']), grouping_value=msg['PartitionKey']) except ValueError as e: logger.exception(e)
[docs] def success(self, _id): for uncompleted_ids in self.uncompleted_ids.values(): if _id in uncompleted_ids: uncompleted_ids.remove(_id)