#
# RuCTFe 2014 Exploit Helpers
# by patrikf
#

FLAG_SERVER_TIMEOUT = 5
TARGET_CONNECT_TIMEOUT = 5
TARGET_READ_TIMEOUT = 5
ATTACK_THREADS = 20
ATTACK_NHOSTS = 330
TICK_TIME = 120
FLAG_THREADS = 2

import argparse
import logging
import re
import requests
import socket
import sys
import threading
import time
import traceback
import sys

import queue

HUGE_TIMEOUT = 3600 # workaround for Queue bug

def indent(amount, string):
    prefix = ' ' * amount
    return '\n'.join([prefix + line for line in string.split('\n')])


class ExploitBase(object):
    FLAG_REGEX = '^\w{31}=$'

    class Adapter(logging.LoggerAdapter):
        def process(self, msg, kwargs):
            return '[%s] %s' % (self.extra['host'], msg), kwargs

    def __init__(self, host, port, flag_service=None, **kwargs):
        self.host = host
        self.port = port
        self.logger = logging.getLogger('Exploit')
        self.logger = self.Adapter(self.logger, extra={'host': host, 'port': port})
        self.flag_pattern = re.compile(self.FLAG_REGEX)
        self.flag_service = flag_service

    def submit_flag(self, flag):
        '''Submit a flag to the flag server.'''
        if not self.flag_pattern.match(flag):
            raise Exception('doesn\'t look like a flag: %s' % (repr(flag)))
        if self.flag_service:
            self.flag_service.submit_flag(flag)
        else:
            self.logger.info('flag: %s (not submitted in test mode)', flag)

    def run_catch(self):
        self.logger.debug('starting attack')
        try:
            status = self.run()
        except socket.timeout:
            self.logger.warn('socket timeout')
            status = 'timeout'
        except socket.error as e:
            self.logger.error('socket error (%s.%s): %s',
                              type(e).__module__,
                              type(e).__name__,
                              e.strerror)
            status = 'error'
        except requests.exceptions.RequestException as e:
            self.logger.error('request exception: %s', str(e))
            status = 'error'
        except Exception as e:
            self.logger.exception('caught exception: %s.%s', type(e).__module__, type(e).__name__)
            status = 'error'

        return status

    def run(self):
        '''To be implemented by subclasses.'''
        pass


class ExploitBaseTcp(ExploitBase):
    def run(self):
        try:
            self.connect()
        except socket.timeout:
            self.logger.info('connect timeout')
            status = 'down'
        except ConnectionRefusedError:
            self.logger.info('connection refused')
            status = 'down'
        else:
            try:
                self.exploit()
            finally:
                self.close()
            status = 'success'

        return status

    def connect(self):
        '''Establish a connection to the target.'''
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.socket.settimeout(TARGET_CONNECT_TIMEOUT)
        self.socket.connect((self.host, self.port))
        self.logger.debug('connected')
        self.socket.settimeout(TARGET_READ_TIMEOUT)

    def close(self):
        '''Close the target connection.'''
        self.socket.close()

    def recv_fix(self, length):
        '''Receive fixed-length data.'''
        buf = self.socket.recv(length, socket.MSG_WAITALL)
        if len(buf) != length:
            raise Exception('preliminary eof')
        return buf

    def recv_until(self, magic_str):
        '''Receive data until a magic string (e.g. prompt) is encountered.'''
        buf = bytes()
        if not isinstance(magic_str, bytes):
            magic_str = magic_str.encode('ascii')
        matchpos = 0
        while matchpos < len(magic_str):
            c = self.recv_fix(1)
            buf += c
            if c[0] == magic_str[matchpos]:
                matchpos += 1
            else:
                matchpos = 0
        return buf

    def recv_all(self):
        '''Receive data until eof.'''
        buf = bytes()
        while True:
            chunk = self.socket.recv(1024, socket.MSG_WAITALL)
            buf += chunk
            if len(chunk) == 0:
                break
        return buf

    def send(buf):
        '''Send data.'''
        self.socket.sendall(buf)

    def exploit(self):
        '''To be implemented by subclasses.'''
        pass


class RateLimit(object):
    def __init__(self, delta=1, initial=False):
        self.last_update = None
        self.delta = delta
        self.initial = initial

    def step(self):
        if not self.last_update:
            self.last_update = time.time()
            return self.initial
        if time.time() > self.last_update + self.delta:
            self.last_update = time.time()
            return True
        else:
            return False


class FlagService(object):
    class Thread(threading.Thread):
        def __init__(self, service, thread_no):
            threading.Thread.__init__(self)
            self.setDaemon(True)
            self.service = service
            self.thread_no = thread_no

        def progress(self):
            if self.thread_no == 0:
                self.service.progress()

        def submit_flag(self, flag):
            self.service.logger.debug('submitting flag: %s', flag)
            service_name = 'ctfutil-%s' % (sys.argv[0])
            try:
                r = requests.get('http://10.0.1.10/submit.php',
                                 params={'flag': flag,
                                         'service': service_name},
                                 timeout=FLAG_SERVER_TIMEOUT,
                                 stream=False)
                if r.status_code == 200:
                    self.service.record_submitted()
                else:
                    self.service.logger.error('error %d', r.status_code)
            except requests.exceptions.RequestException as e:
                self.service.logger.error('caught exception: %s', str(e))

        def run(self):
            while True:
                while True:
                    self.progress()
                    try:
                        item = self.service.flag_queue.get(timeout=1)
                        break
                    except queue.Empty:
                        pass

                if item is None:
                    break

                self.submit_flag(item)
                self.service.flag_queue.task_done()

            self.service.flag_queue.task_done()

    def __init__(self):
        self.logger = logging.getLogger('FlagService')
        self.flag_queue = queue.Queue(500)
        self.threads = [self.Thread(self, i) for i in range(FLAG_THREADS)]
        self.periodic_update = RateLimit(1, True)
        self.lock = threading.Lock()
        self.flags_submitted = 0

    def start(self):
        for thread in self.threads:
            thread.start()

    def submit_flag(self, flag):
        self.flag_queue.put(flag, HUGE_TIMEOUT)

    def record_submitted(self):
        with self.lock:
            self.flags_submitted += 1

    def progress(self):
        if not self.periodic_update.step():
            return
        with self.lock:
            self.logger.info(
                '%d flags submitted, %d pending' % (
                    self.flags_submitted,
                    self.flag_queue.qsize()))
            self.flags_submitted = 0


class ParallelAttack(object):
    class Thread(threading.Thread):
        def __init__(self, attack):
            threading.Thread.__init__(self)
            self.setDaemon(True)
            self.attack = attack

        def run(self):
            while True:
                item = self.attack.queue.get()
                if item is None:
                    break
                kwargs = self.attack.kwargs.copy()
                kwargs.update(item)
                exploit = self.attack.exploit_class(**kwargs)
                status = exploit.run_catch()
                self.attack.record_status(item, status)
                self.attack.queue.task_done()

            self.attack.queue.task_done()

    def __init__(self, exploit_class, nthreads, subservice, **kwargs):
        self.logger = logging.getLogger('ParallelAttack')
        self.queue = queue.Queue()
        self.flag_service = FlagService()
        self.exploit_class = exploit_class
        self.nthreads = nthreads
        self.subservice = subservice

        self.kwargs = kwargs
        self.kwargs['flag_service'] = self.flag_service

        self.lock = threading.Lock()
        self.hoststatus = {}
        self.statushist = {}

    def wait_tick(self):
        self.logger.info('waiting for next tick')
        time.sleep(20)

    def record_status(self, workitem, status):
        host = workitem['host']
        with self.lock:
            oldstatus = self.hoststatus.get(host)
            self.hoststatus[host] = status

            if oldstatus:
                self.statushist[oldstatus] -= 1
            self.statushist[status] = self.statushist.get(status, 0) + 1

    def status_summary(self):
        with self.lock:
            hist = self.statushist
        order = ['success', 'timeout', 'down', 'error', 'pending']
        for key in hist.keys():
            if key not in order:
                order.append(key)
        elems = ["%d %s" % (hist.get(key, 0), key) for key in order]
        self.logger.info('status: %s', ', '.join(elems))

    def wait_for_queue(self):
        rl = RateLimit(1, True)
        while True:
            if rl.step():
                self.status_summary()
            with self.lock:
                n_pending = self.statushist['pending']
            if n_pending == 0:
                break
            else:
                time.sleep(0.1)
        self.status_summary()

    def run(self):
        self.logger.info('starting parallel attack on %d.{0..%d}.attack with ' +
                         '%d threads', self.subservice, ATTACK_NHOSTS-1, self.nthreads)

        self.flag_service.start()

        self.threads = []
        for i in range(self.nthreads):
            thread = self.Thread(self)
            self.threads.append(thread)
            thread.start()

        try:
            while True:
                self.round_start = time.time()

                for i in range(0, ATTACK_NHOSTS):
                    item = {'host': '%d.%d.attack' % (self.subservice, i)}
                    self.record_status(item, 'pending')
                    self.queue.put(item, timeout=HUGE_TIMEOUT)

                self.wait_for_queue()
                self.wait_tick()

        except KeyboardInterrupt:
            self.logger.warn('interrupted, shutting down')

        for i in range(len(self.threads)):
            self.queue.put(None, timeout=HUGE_TIMEOUT)

        self.queue.join()

        for thread in self.threads:
            thread.join()


class Formatter(logging.Formatter):
    def format(self, record):
        return logging.Formatter.format(self, record)

    def formatException(self, exc_info):
        exc_str = '\n'.join([' '*4 + line for line in traceback.format_exception(*exc_info)])
        return exc_str


class AttackTool(object):
    def __init__(self, exploit_class, **kwargs):
        self.exploit_class = exploit_class
        self.logger = logging.getLogger('AttackTool')
        self.kwargs = kwargs

    def init_logging(self):
        streamHandler = logging.StreamHandler()
        streamHandler.setFormatter(Formatter(
    #        '%(asctime)s [%(levelname)-5s %(filename)s:%(lineno)-3d] %(message)s',
            '%(asctime)s %(levelname)-7s %(name)-14s  %(message)s',
            '%H:%M:%S'))
        rootLogger = logging.getLogger()
        rootLogger.addHandler(streamHandler)
        rootLogger.setLevel(logging.INFO)

        logging.getLogger('urllib3.connectionpool').setLevel(logging.WARN)

    def parse_args(self):
        parser = argparse.ArgumentParser(usage='%(prog)s [options...] [-a|-t HOST]')
        parser.add_argument('-p', '--port', metavar='PORT', type=int, required=True)
        parser.add_argument('-v', '--verbose',
                            help='increase logging level to DEBUG',
                            action='store_true', dest='debug')
        parser.add_argument('--threads', metavar='N', type=int,
                            default=ATTACK_THREADS,
                            help='number of threads for --attack-all')
        g = parser.add_mutually_exclusive_group(required=True)
        g.add_argument('-a', '--attack-all',
                       help='attack all hosts',
                       action='store_true', dest='attack_all')
        g.add_argument('-t', '--test', metavar='HOST', type=str,
                       help='attack a single host',
                       action='append', dest='targets')

        return parser.parse_args()

    def run(self):
        self.init_logging()

        opts = self.parse_args()

        self.kwargs['port'] = opts.port

        if opts.debug:
            rootLogger = logging.getLogger()
            rootLogger.setLevel(logging.DEBUG)

        if opts.attack_all:
            attack = ParallelAttack(exploit_class=self.exploit_class,
                                    nthreads=opts.threads,
                                    **self.kwargs)
            attack.run()
        else:
            for target in opts.targets:
                self.logger.info('test-attacking %s', target)
                kwargs = self.kwargs.copy()
                kwargs['host'] = target
                exploit = self.exploit_class(**kwargs)
                status = exploit.run()
                self.logger.info('status: %s', status)

