]> git.somenet.org - ctf/pub/exploit_framework.git/blob - ctfutil.py
GITOLITE.txt
[ctf/pub/exploit_framework.git] / ctfutil.py
1 #
2 # RuCTFe 2014 Exploit Helpers
3 # by patrikf
4 #
5
6 FLAG_SERVER_TIMEOUT = 5
7 TARGET_CONNECT_TIMEOUT = 5
8 TARGET_READ_TIMEOUT = 5
9 ATTACK_THREADS = 20
10 ATTACK_NHOSTS = 330
11 TICK_TIME = 120
12 FLAG_THREADS = 2
13
14 import argparse
15 import logging
16 import re
17 import requests
18 import socket
19 import sys
20 import threading
21 import time
22 import traceback
23 import sys
24
25 import queue
26
27 HUGE_TIMEOUT = 3600 # workaround for Queue bug
28
29 def indent(amount, string):
30     prefix = ' ' * amount
31     return '\n'.join([prefix + line for line in string.split('\n')])
32
33
34 class ExploitBase(object):
35     FLAG_REGEX = '^\w{31}=$'
36
37     class Adapter(logging.LoggerAdapter):
38         def process(self, msg, kwargs):
39             return '[%s] %s' % (self.extra['host'], msg), kwargs
40
41     def __init__(self, host, port, flag_service=None, **kwargs):
42         self.host = host
43         self.port = port
44         self.logger = logging.getLogger('Exploit')
45         self.logger = self.Adapter(self.logger, extra={'host': host, 'port': port})
46         self.flag_pattern = re.compile(self.FLAG_REGEX)
47         self.flag_service = flag_service
48
49     def submit_flag(self, flag):
50         '''Submit a flag to the flag server.'''
51         if not self.flag_pattern.match(flag):
52             raise Exception('doesn\'t look like a flag: %s' % (repr(flag)))
53         if self.flag_service:
54             self.flag_service.submit_flag(flag)
55         else:
56             self.logger.info('flag: %s (not submitted in test mode)', flag)
57
58     def run_catch(self):
59         self.logger.debug('starting attack')
60         try:
61             status = self.run()
62         except socket.timeout:
63             self.logger.warn('socket timeout')
64             status = 'timeout'
65         except socket.error as e:
66             self.logger.error('socket error (%s.%s): %s',
67                               type(e).__module__,
68                               type(e).__name__,
69                               e.strerror)
70             status = 'error'
71         except requests.exceptions.RequestException as e:
72             self.logger.error('request exception: %s', str(e))
73             status = 'error'
74         except Exception as e:
75             self.logger.exception('caught exception: %s.%s', type(e).__module__, type(e).__name__)
76             status = 'error'
77
78         return status
79
80     def run(self):
81         '''To be implemented by subclasses.'''
82         pass
83
84
85 class ExploitBaseTcp(ExploitBase):
86     def run(self):
87         try:
88             self.connect()
89         except socket.timeout:
90             self.logger.info('connect timeout')
91             status = 'down'
92         except ConnectionRefusedError:
93             self.logger.info('connection refused')
94             status = 'down'
95         else:
96             try:
97                 self.exploit()
98             finally:
99                 self.close()
100             status = 'success'
101
102         return status
103
104     def connect(self):
105         '''Establish a connection to the target.'''
106         self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
107         self.socket.settimeout(TARGET_CONNECT_TIMEOUT)
108         self.socket.connect((self.host, self.port))
109         self.logger.debug('connected')
110         self.socket.settimeout(TARGET_READ_TIMEOUT)
111
112     def close(self):
113         '''Close the target connection.'''
114         self.socket.close()
115
116     def recv_fix(self, length):
117         '''Receive fixed-length data.'''
118         buf = self.socket.recv(length, socket.MSG_WAITALL)
119         if len(buf) != length:
120             raise Exception('preliminary eof')
121         return buf
122
123     def recv_until(self, magic_str):
124         '''Receive data until a magic string (e.g. prompt) is encountered.'''
125         buf = bytes()
126         if not isinstance(magic_str, bytes):
127             magic_str = magic_str.encode('ascii')
128         matchpos = 0
129         while matchpos < len(magic_str):
130             c = self.recv_fix(1)
131             buf += c
132             if c[0] == magic_str[matchpos]:
133                 matchpos += 1
134             else:
135                 matchpos = 0
136         return buf
137
138     def recv_all(self):
139         '''Receive data until eof.'''
140         buf = bytes()
141         while True:
142             chunk = self.socket.recv(1024, socket.MSG_WAITALL)
143             buf += chunk
144             if len(chunk) == 0:
145                 break
146         return buf
147
148     def send(buf):
149         '''Send data.'''
150         self.socket.sendall(buf)
151
152     def exploit(self):
153         '''To be implemented by subclasses.'''
154         pass
155
156
157 class RateLimit(object):
158     def __init__(self, delta=1, initial=False):
159         self.last_update = None
160         self.delta = delta
161         self.initial = initial
162
163     def step(self):
164         if not self.last_update:
165             self.last_update = time.time()
166             return self.initial
167         if time.time() > self.last_update + self.delta:
168             self.last_update = time.time()
169             return True
170         else:
171             return False
172
173
174 class FlagService(object):
175     class Thread(threading.Thread):
176         def __init__(self, service, thread_no):
177             threading.Thread.__init__(self)
178             self.setDaemon(True)
179             self.service = service
180             self.thread_no = thread_no
181
182         def progress(self):
183             if self.thread_no == 0:
184                 self.service.progress()
185
186         def submit_flag(self, flag):
187             self.service.logger.debug('submitting flag: %s', flag)
188             service_name = 'ctfutil-%s' % (sys.argv[0])
189             try:
190                 r = requests.get('http://10.0.1.10/submit.php',
191                                  params={'flag': flag,
192                                          'service': service_name},
193                                  timeout=FLAG_SERVER_TIMEOUT,
194                                  stream=False)
195                 if r.status_code == 200:
196                     self.service.record_submitted()
197                 else:
198                     self.service.logger.error('error %d', r.status_code)
199             except requests.exceptions.RequestException as e:
200                 self.service.logger.error('caught exception: %s', str(e))
201
202         def run(self):
203             while True:
204                 while True:
205                     self.progress()
206                     try:
207                         item = self.service.flag_queue.get(timeout=1)
208                         break
209                     except queue.Empty:
210                         pass
211
212                 if item is None:
213                     break
214
215                 self.submit_flag(item)
216                 self.service.flag_queue.task_done()
217
218             self.service.flag_queue.task_done()
219
220     def __init__(self):
221         self.logger = logging.getLogger('FlagService')
222         self.flag_queue = queue.Queue(500)
223         self.threads = [self.Thread(self, i) for i in range(FLAG_THREADS)]
224         self.periodic_update = RateLimit(1, True)
225         self.lock = threading.Lock()
226         self.flags_submitted = 0
227
228     def start(self):
229         for thread in self.threads:
230             thread.start()
231
232     def submit_flag(self, flag):
233         self.flag_queue.put(flag, HUGE_TIMEOUT)
234
235     def record_submitted(self):
236         with self.lock:
237             self.flags_submitted += 1
238
239     def progress(self):
240         if not self.periodic_update.step():
241             return
242         with self.lock:
243             self.logger.info(
244                 '%d flags submitted, %d pending' % (
245                     self.flags_submitted,
246                     self.flag_queue.qsize()))
247             self.flags_submitted = 0
248
249
250 class ParallelAttack(object):
251     class Thread(threading.Thread):
252         def __init__(self, attack):
253             threading.Thread.__init__(self)
254             self.setDaemon(True)
255             self.attack = attack
256
257         def run(self):
258             while True:
259                 item = self.attack.queue.get()
260                 if item is None:
261                     break
262                 kwargs = self.attack.kwargs.copy()
263                 kwargs.update(item)
264                 exploit = self.attack.exploit_class(**kwargs)
265                 status = exploit.run_catch()
266                 self.attack.record_status(item, status)
267                 self.attack.queue.task_done()
268
269             self.attack.queue.task_done()
270
271     def __init__(self, exploit_class, nthreads, subservice, **kwargs):
272         self.logger = logging.getLogger('ParallelAttack')
273         self.queue = queue.Queue()
274         self.flag_service = FlagService()
275         self.exploit_class = exploit_class
276         self.nthreads = nthreads
277         self.subservice = subservice
278
279         self.kwargs = kwargs
280         self.kwargs['flag_service'] = self.flag_service
281
282         self.lock = threading.Lock()
283         self.hoststatus = {}
284         self.statushist = {}
285
286     def wait_tick(self):
287         self.logger.info('waiting for next tick')
288         time.sleep(20)
289
290     def record_status(self, workitem, status):
291         host = workitem['host']
292         with self.lock:
293             oldstatus = self.hoststatus.get(host)
294             self.hoststatus[host] = status
295
296             if oldstatus:
297                 self.statushist[oldstatus] -= 1
298             self.statushist[status] = self.statushist.get(status, 0) + 1
299
300     def status_summary(self):
301         with self.lock:
302             hist = self.statushist
303         order = ['success', 'timeout', 'down', 'error', 'pending']
304         for key in hist.keys():
305             if key not in order:
306                 order.append(key)
307         elems = ["%d %s" % (hist.get(key, 0), key) for key in order]
308         self.logger.info('status: %s', ', '.join(elems))
309
310     def wait_for_queue(self):
311         rl = RateLimit(1, True)
312         while True:
313             if rl.step():
314                 self.status_summary()
315             with self.lock:
316                 n_pending = self.statushist['pending']
317             if n_pending == 0:
318                 break
319             else:
320                 time.sleep(0.1)
321         self.status_summary()
322
323     def run(self):
324         self.logger.info('starting parallel attack on %d.{0..%d}.attack with ' +
325                          '%d threads', self.subservice, ATTACK_NHOSTS-1, self.nthreads)
326
327         self.flag_service.start()
328
329         self.threads = []
330         for i in range(self.nthreads):
331             thread = self.Thread(self)
332             self.threads.append(thread)
333             thread.start()
334
335         try:
336             while True:
337                 self.round_start = time.time()
338
339                 for i in range(0, ATTACK_NHOSTS):
340                     item = {'host': '%d.%d.attack' % (self.subservice, i)}
341                     self.record_status(item, 'pending')
342                     self.queue.put(item, timeout=HUGE_TIMEOUT)
343
344                 self.wait_for_queue()
345                 self.wait_tick()
346
347         except KeyboardInterrupt:
348             self.logger.warn('interrupted, shutting down')
349
350         for i in range(len(self.threads)):
351             self.queue.put(None, timeout=HUGE_TIMEOUT)
352
353         self.queue.join()
354
355         for thread in self.threads:
356             thread.join()
357
358
359 class Formatter(logging.Formatter):
360     def format(self, record):
361         return logging.Formatter.format(self, record)
362
363     def formatException(self, exc_info):
364         exc_str = '\n'.join([' '*4 + line for line in traceback.format_exception(*exc_info)])
365         return exc_str
366
367
368 class AttackTool(object):
369     def __init__(self, exploit_class, **kwargs):
370         self.exploit_class = exploit_class
371         self.logger = logging.getLogger('AttackTool')
372         self.kwargs = kwargs
373
374     def init_logging(self):
375         streamHandler = logging.StreamHandler()
376         streamHandler.setFormatter(Formatter(
377     #        '%(asctime)s [%(levelname)-5s %(filename)s:%(lineno)-3d] %(message)s',
378             '%(asctime)s %(levelname)-7s %(name)-14s  %(message)s',
379             '%H:%M:%S'))
380         rootLogger = logging.getLogger()
381         rootLogger.addHandler(streamHandler)
382         rootLogger.setLevel(logging.INFO)
383
384         logging.getLogger('urllib3.connectionpool').setLevel(logging.WARN)
385
386     def parse_args(self):
387         parser = argparse.ArgumentParser(usage='%(prog)s [options...] [-a|-t HOST]')
388         parser.add_argument('-p', '--port', metavar='PORT', type=int, required=True)
389         parser.add_argument('-v', '--verbose',
390                             help='increase logging level to DEBUG',
391                             action='store_true', dest='debug')
392         parser.add_argument('--threads', metavar='N', type=int,
393                             default=ATTACK_THREADS,
394                             help='number of threads for --attack-all')
395         g = parser.add_mutually_exclusive_group(required=True)
396         g.add_argument('-a', '--attack-all',
397                        help='attack all hosts',
398                        action='store_true', dest='attack_all')
399         g.add_argument('-t', '--test', metavar='HOST', type=str,
400                        help='attack a single host',
401                        action='append', dest='targets')
402
403         return parser.parse_args()
404
405     def run(self):
406         self.init_logging()
407
408         opts = self.parse_args()
409
410         self.kwargs['port'] = opts.port
411
412         if opts.debug:
413             rootLogger = logging.getLogger()
414             rootLogger.setLevel(logging.DEBUG)
415
416         if opts.attack_all:
417             attack = ParallelAttack(exploit_class=self.exploit_class,
418                                     nthreads=opts.threads,
419                                     **self.kwargs)
420             attack.run()
421         else:
422             for target in opts.targets:
423                 self.logger.info('test-attacking %s', target)
424                 kwargs = self.kwargs.copy()
425                 kwargs['host'] = target
426                 exploit = self.exploit_class(**kwargs)
427                 status = exploit.run()
428                 self.logger.info('status: %s', status)
429