xref: /haiku/src/tests/kits/net/service/testserver.py (revision a127b88ecbfab58f64944c98aa47722a18e363b2)
1#
2# Copyright 2020 Haiku, Inc. All rights reserved.
3# Distributed under the terms of the MIT License.
4#
5# Authors:
6#  Kyle Ambroff-Kao, kyle@ambroffkao.com
7#
8
9"""
10HTTP(S) server used for integration testing of ServicesKit.
11
12This service receives HTTP requests and just echos them back in the response.
13
14This is intentionally not using any fancy frameworks or libraries so as to not
15require any dependencies, and also to allow for adding endpoints to replicate
16behavior of other servers in the future.
17"""
18
19import abc
20import base64
21import gzip
22import hashlib
23import http.server
24import io
25import optparse
26import os
27import re
28import socket
29import ssl
30import subprocess
31import sys
32import tempfile
33import zlib
34
35
36MULTIPART_FORM_BOUNDARY_RE = re.compile(
37    r'^multipart/form-data; boundary=(----------------------------\d+)$')
38AUTH_PATH_RE = re.compile(
39    r'^/auth/(?P<strategy>(basic|digest))'
40    '/(?P<username>[a-z0-9]+)/(?P<password>[a-z0-9]+)',
41    re.IGNORECASE)
42
43
44class RequestHandler(http.server.BaseHTTPRequestHandler):
45    """
46    Any GET or POST request just gets echoed back to the sender. If the path
47    ends with a numeric component like "/404" or "/500", then that value will
48    be set as the status code in the response.
49
50    Note that this isn't meant to replicate expected functionality exactly.
51    Rather than implementing all of these status codes as expected per RFC,
52    such as having an empty response body for 201 response, only the
53    functionality that is required to handle requests from HttpTests is
54    implemented.
55
56    There can also be endpoints here that are intentionally non-compliant in
57    order to exercize the HTTP client's behavior when a server is badly
58    behaved.
59    """
60    def do_GET(self, write_response=True):
61        authorized, extra_headers = self._authorize()
62        if not authorized:
63            return
64
65        encoding, response_body = self._build_response_body()
66
67        status_code = extract_desired_status_code_from_path(self.path, 200)
68        self.send_response(status_code)
69        if status_code >= 300 and status_code < 400:
70            self.send_header('Location', '/')
71
72        if status_code == 204:
73            write_response = False
74        else:
75            self.send_header('Content-Type', 'text/plain')
76            self.send_header('Content-Length', str(len(response_body)))
77            if encoding:
78                self.send_header('Content-Encoding', encoding)
79
80        for header_name, header_value in extra_headers:
81            self.send_header(header_name, header_value)
82        self.end_headers()
83
84        if write_response:
85            self.wfile.write(response_body)
86
87    def do_HEAD(self):
88        self.do_GET(False)
89
90    def do_POST(self):
91        authorized, extra_headers = self._authorize()
92        if not authorized:
93            return
94
95        encoding, response_body = self._build_response_body()
96        self.send_response(
97            extract_desired_status_code_from_path(self.path, 200))
98        self.send_header('Content-Type', 'text/plain')
99        self.send_header('Content-Length', str(len(response_body)))
100        if encoding:
101            self.send_header('Content-Encoding', encoding)
102        for header_name, header_value in extra_headers:
103            self.send_header(header_name, header_value)
104
105        self.end_headers()
106        self.wfile.write(response_body)
107
108    def do_DELETE(self):
109        self._not_supported()
110
111    def do_PATCH(self):
112        self._not_supported()
113
114    def do_OPTIONS(self):
115        self._not_supported()
116
117    def send_response(self, code, message=None):
118        self.log_request(code)
119        self.send_response_only(code, message)
120        self.send_header('Server', 'Test HTTP Server for Haiku')
121        self.send_header('Date', 'Sun, 09 Feb 2020 19:32:42 GMT')
122
123    def _build_response_body(self):
124        # The post-body may be multi-part/form-data, in which case the client
125        # will have generated some random identifier to identify the boundary.
126        # If that's the case, we'll replace it here in order to allow the test
127        # client to validate the response data without needing to predict the
128        # boundary identifier. This makes the response body deterministic even
129        # though the boundary will change with every request, and lets the
130        # tests in HttpTests hard-code the entire expected response body for
131        # validation.
132        boundary_id_value = None
133
134        supported_encodings = [
135            e.strip()
136            for e in self.headers.get('Accept-Encoding', '').split(',')
137            if e.strip()]
138        if 'gzip' in supported_encodings:
139            encoding = 'gzip'
140            output_stream = GzipResponseBodyBuilder()
141        elif 'deflate' in supported_encodings:
142            encoding = 'deflate'
143            output_stream = DeflateResponseBodyBuilder()
144        else:
145            encoding = None
146            output_stream = RawResponseBodyBuilder()
147
148        output_stream.write(
149            'Path: {}\r\n\r\n'.format(self.path).encode('utf-8'))
150        output_stream.write(b'Headers:\r\n')
151        output_stream.write(b'--------\r\n')
152        for header in self.headers:
153            for header_value in self.headers.get_all(header):
154                if header in ('Host', 'Referer', 'X-Forwarded-For'):
155                    # The server port can change between runs which will change
156                    # the size and contents of the response body. To make tests
157                    # that verify the contents of the response body easier the
158                    # server port will be stripped from these headers when
159                    # echoed to the response body.
160                    header_value = re.sub(r':[0-9]+', ':PORT', header_value)
161
162                    # The scheme will also be in this header value, and we want
163                    # to return the same reguardless of whether http:// or
164                    # https:// was used.
165                    header_value = re.sub(
166                        r'https?://',
167                        'SCHEME://',
168                        header_value)
169                if header == 'Content-Type':
170                    match = MULTIPART_FORM_BOUNDARY_RE.match(
171                        self.headers.get('Content-Type', 'text/plain'))
172                    if match is not None:
173                        boundary_id_value = match.group(1)
174                        header_value = header_value.replace(
175                            boundary_id_value,
176                            '<<BOUNDARY-ID>>')
177                output_stream.write(
178                    '{}: {}\r\n'.format(header, header_value).encode('utf-8'))
179
180        content_length = int(self.headers.get('Content-Length', 0))
181        if content_length > 0:
182            output_stream.write(b'\r\n')
183            output_stream.write(b'Request body:\r\n')
184            output_stream.write(b'-------------\r\n')
185
186            body_bytes = self.rfile.read(content_length).decode('utf-8')
187            if boundary_id_value:
188                body_bytes = body_bytes.replace(
189                    boundary_id_value, '<<BOUNDARY-ID>>')
190
191            output_stream.write(body_bytes.encode('utf-8'))
192            output_stream.write(b'\r\n')
193
194        return encoding, output_stream.get_bytes()
195
196    def _not_supported(self):
197        self.send_response(405, '{} not supported'.format(self.command))
198        self.end_headers()
199        self.wfile.write(
200            '{} not supported\r\n'.format(self.command).encode('utf-8'))
201
202    def _authorize(self):
203        """
204        Authorizes the request. If True is returned that means that the
205        request was not authorized and the 4xx response has been send to the
206        client.
207        """
208        # We only authorize paths like
209        # /auth/<strategy>/<expected-username>/<expected-password>
210        match = AUTH_PATH_RE.match(self.path)
211        if match is None:
212            return True, []
213
214        strategy = match.group('strategy')
215        expected_username = match.group('username')
216        expected_password = match.group('password')
217
218        if strategy == 'basic':
219            return self._handle_basic_auth(
220                expected_username,
221                expected_password)
222        elif strategy == 'digest':
223            return self._handle_digest_auth(
224                expected_username,
225                expected_password)
226        else:
227            raise NotImplementedError(
228                'Unimplemented authorization strategy ' + strategy)
229
230    def _handle_basic_auth(self, expected_username, expected_password):
231        authorization = self.headers.get('Authorization', None)
232        auth_type = None
233        encoded_credentials = None
234        username = None
235        password = None
236
237        if authorization:
238            auth_type, encoded_credentials = authorization.split()
239
240        if encoded_credentials is not None:
241            decoded = base64.decodebytes(encoded_credentials.encode('utf-8'))
242            username, password = decoded.decode('utf-8').split(':')
243
244        if authorization is None or auth_type != 'Basic' \
245                or encoded_credentials is None \
246                or username != expected_username \
247                or password != expected_password:
248            self.send_response(401, 'Not authorized')
249            self.send_header('Www-Authenticate', 'Basic realm="Fake Realm"')
250            self.end_headers()
251            return False, []
252
253        return True, [('Www-Authenticate', 'Basic realm="Fake Realm"')]
254
255    def _handle_digest_auth(self, expected_username, expected_password):
256        """
257        Implement enough of the digest auth RFC to make tests pass.
258        """
259        # Note: These values will always be the same because we want the
260        # response to be deterministic for testing purposes.
261        NONCE = 'f3a95f20879dd891a5544bf96a3e5518'
262        OPAQUE = 'f0bb55f1221a51b6d38117c331611799'
263
264        extra_headers = []
265        authorization = self.headers.get('Authorization', None)
266        credentials = None
267        auth_type = None
268        if authorization is not None:
269            auth_type, fields = authorization.split(maxsplit=1)
270            if auth_type == 'Digest':
271                credentials = parse_kv_pair_header(fields)
272
273        expected_response_hash = None
274        if credentials:
275            expected_response_hash = compute_digest_challenge_response_hash(
276                self.command,
277                self.path,
278                '',
279                credentials,
280                expected_password)
281
282        if authorization is None or credentials is None \
283                or auth_type != 'Digest' \
284                or expected_response_hash != credentials.get('response'):
285            self.send_response(401, 'Not authorized')
286            self.send_header(
287                'Www-Authenticate',
288                'Digest realm="user@shredder",'
289                ' nonce="{}",'
290                ' qop="auth",'
291                ' opaque={},'
292                ' algorithm=MD5,'
293                ' stale=FALSE'.format(NONCE, OPAQUE))
294            self.send_header('Set-Cookie', 'stale_after=never; Path=/')
295            self.send_header('Set-Cookie', 'fake=fake_value; Path=/')
296            self.end_headers()
297            return False, extra_headers
298
299        return True, extra_headers
300
301
302class ResponseBodyBuilder(object):
303    __meta__ = abc.ABCMeta
304
305    @abc.abstractmethod
306    def write(self, bytes):
307        raise NotImplementedError()
308
309    @abc.abstractmethod
310    def get_bytes(self):
311        raise NotImplementedError()
312
313
314class RawResponseBodyBuilder(ResponseBodyBuilder):
315    def __init__(self):
316        self.buf = io.BytesIO()
317
318    def write(self, bytes):
319        self.buf.write(bytes)
320
321    def get_bytes(self):
322        return self.buf.getvalue()
323
324
325class GzipResponseBodyBuilder(ResponseBodyBuilder):
326    def __init__(self):
327        self.buf = io.BytesIO()
328        self.compressor = gzip.GzipFile(
329            mode='wb',
330            compresslevel=4,
331            fileobj=self.buf)
332
333    def write(self, bytes):
334        self.compressor.write(bytes)
335
336    def get_bytes(self):
337        self.compressor.close()
338        return self.buf.getvalue()
339
340
341class DeflateResponseBodyBuilder(ResponseBodyBuilder):
342    def __init__(self):
343        self.raw = RawResponseBodyBuilder()
344
345    def write(self, bytes):
346        self.raw.write(bytes)
347
348    def get_bytes(self):
349        return zlib.compress(self.raw.get_bytes())
350
351
352def extract_desired_status_code_from_path(path, default=200):
353    status_code = default
354    path_parts = os.path.split(path)
355    try:
356        status_code = int(path_parts[-1])
357    except ValueError:
358        pass
359    return status_code
360
361
362def generate_self_signed_tls_cert(common_name, cert_path, key_path):
363    subprocess.check_call([
364        'openssl',
365        'req',
366        '-x509',
367        '-nodes',
368        '-subj', '/CN={}'.format(common_name),
369        '-newkey', 'rsa:4096',
370        '-keyout', key_path,
371        '-out', cert_path,
372        '-days', '1'
373    ])
374
375
376def compute_digest_challenge_response_hash(
377        request_method,
378        request_uri,
379        request_body,
380        credentials,
381        expected_password):
382    """
383    Compute hash as defined by RFC2069, although this isn't an attempt to be
384    perfect, just enough for basic integration tests in HttpTests to work.
385
386    :param credentials: Map of values parsed from the Authorization header
387                        from the client.
388    :param expected_password: The known correct password of the user
389                              attempting to authenticate.
390    :return: None if a hash cannot be produced, otherwise the hash as defined
391             by RFC2069.
392    """
393    algorithm = credentials.get('algorithm')
394    if algorithm == 'MD5':
395        hashfunc = hashlib.md5
396    elif algorithm == 'SHA-256':
397        hashfunc = hashlib.sha256
398    elif algorithm == 'SHA-512':
399        hashfunc = hashlib.sha512
400    else:
401        return None
402
403    realm = credentials.get('realm')
404    username = credentials.get('username')
405
406    ha1 = hashfunc(':'.join([
407        username,
408        realm,
409        expected_password]).encode('utf-8')).hexdigest()
410
411    qop = credentials.get('qop')
412    if qop is None or qop == 'auth':
413        ha2 = hashfunc(':'.join([
414            request_method,
415            request_uri]).encode('utf-8')).hexdigest()
416    elif qop == 'auth-int':
417        ha2 = hashfunc(':'.join([
418            request_method,
419            request_uri,
420            request_body]).encode('utf-8')).hexdigest()
421    else:
422        ha2 = None
423
424    if ha1 is None or ha2 is None:
425        return None
426
427    if qop is None:
428        return hashfunc(':'.join([
429            ha1,
430            credentials.get('nonce', ''),
431            ha2]).encode('utf-8')).hexdigest()
432    elif qop == 'auth' or qop == 'auth-int':
433        hash_components = [
434            ha1,
435            credentials.get('nonce', ''),
436            credentials.get('nc', ''),
437            credentials.get('cnonce', ''),
438            qop,
439            ha2]
440        return hashfunc(':'.join(hash_components).encode('utf-8')).hexdigest()
441
442
443def parse_kv_pair_header(header_value, sep=','):
444    d = {}
445    for kvpair in header_value.split(sep):
446        key, value = kvpair.strip().split('=')
447        d[key.strip()] = value.strip().strip('"')
448    return d
449
450
451def main():
452    options = parse_args(sys.argv)
453
454    bind_addr = (
455        options.bind_addr,
456        0 if options.port is None else options.port)
457
458    server = http.server.HTTPServer(
459        bind_addr,
460        RequestHandler,
461        bind_and_activate=False)
462    if options.port is None:
463        server.server_port = server.socket.getsockname()[1]
464    else:
465        server.server_port = options.port
466
467    if options.server_socket_fd:
468        server.socket = socket.fromfd(
469            options.server_socket_fd,
470            socket.AF_INET,
471            socket.SOCK_STREAM)
472
473    def run_server():
474        if not options.server_socket_fd:
475            server.server_bind()
476            server.server_activate()
477        print(
478            'Test server listening on port',
479            server.server_port,
480            file=sys.stderr)
481        server.serve_forever(0.01)
482
483    try:
484        if options.use_tls:
485            with tempfile.TemporaryDirectory() as temp_cert_dir:
486                common_name = options.bind_addr + ':' + str(options.port)
487                cert_file = os.path.join(temp_cert_dir, 'cert.pem')
488                key_file = os.path.join(temp_cert_dir, 'key.pem')
489                generate_self_signed_tls_cert(
490                    common_name,
491                    cert_file,
492                    key_file)
493                server.socket = ssl.wrap_socket(
494                    server.socket,
495                    certfile=cert_file,
496                    keyfile=key_file,
497                    server_side=True,
498                    do_handshake_on_connect=False)
499            run_server()
500        else:
501            run_server()
502    except KeyboardInterrupt:
503        server.server_close()
504
505
506def parse_args(argv):
507    parser = optparse.OptionParser(
508        usage='Usage: %prog [OPTIONS]',
509        description=__doc__)
510    parser.add_option(
511        '--bind-addr',
512        default='127.0.0.1',
513        dest='bind_addr',
514        help='By default only bind to loopback')
515    parser.add_option(
516        '--use-tls',
517        dest='use_tls',
518        default=False,
519        action='store_true',
520        help='If set, a self-signed TLS certificate, key and CA will be'
521        ' generated for testing purposes.')
522    parser.add_option(
523        '--port',
524        dest='port',
525        default=None,
526        type='int',
527        help='If not specified a random port will be used.')
528    parser.add_option(
529        "--fd",
530        dest='server_socket_fd',
531        default=None,
532        type='int',
533        help='A socket FD to use for accept() instead of binding a new one.')
534    options, args = parser.parse_args(argv)
535    if len(args) > 1:
536        parser.error('Unexpected arguments: {}'.format(', '.join(args[1:])))
537    return options
538
539
540if __name__ == '__main__':
541    main()
542