xref: /haiku/src/tests/kits/net/service/testserver.py (revision 21258e2674226d6aa732321b6f8494841895af5f)
1#
2# Copyright 2020 Haiku, Inc. All rights reserved.
3# Distributed under the terms of the MIT License.
4#
5# Authors:
6#  Kyle Ambroff-Kao, kyle@ambroffkao.com
7#
8
9"""
10HTTP(S) server used for integration testing of ServicesKit.
11
12This service receives HTTP requests and just echos them back in the response.
13
14This is intentionally not using any fancy frameworks or libraries so as to not
15require any dependencies, and also to allow for adding endpoints to replicate
16behavior of other servers in the future.
17"""
18
19import abc
20import base64
21import gzip
22import hashlib
23import http.server
24import io
25import optparse
26import os
27import re
28import socket
29import ssl
30import subprocess
31import sys
32import tempfile
33import zlib
34
35
36MULTIPART_FORM_BOUNDARY_RE = re.compile(
37    r'^multipart/form-data; boundary=(----------------------------\d+)$')
38AUTH_PATH_RE = re.compile(
39    r'^/auth/(?P<strategy>(basic|digest))'
40    '/(?P<username>[a-z0-9]+)/(?P<password>[a-z0-9]+)',
41    re.IGNORECASE)
42
43
44class RequestHandler(http.server.BaseHTTPRequestHandler):
45    """
46    Any GET or POST request just gets echoed back to the sender. If the path
47    ends with a numeric component like "/404" or "/500", then that value will
48    be set as the status code in the response.
49
50    Note that this isn't meant to replicate expected functionality exactly.
51    Rather than implementing all of these status codes as expected per RFC,
52    such as having an empty response body for 201 response, only the
53    functionality that is required to handle requests from HttpTests is
54    implemented.
55
56    There can also be endpoints here that are intentionally non-compliant in
57    order to exercize the HTTP client's behavior when a server is badly
58    behaved.
59    """
60    def do_GET(self, write_response=True):
61        authorized, extra_headers = self._authorize()
62        if not authorized:
63            return
64
65        encoding, response_body = self._build_response_body()
66
67        self.send_response(
68            extract_desired_status_code_from_path(self.path, 200))
69        self.send_header('Content-Type', 'text/plain')
70        self.send_header('Content-Length', str(len(response_body)))
71        if encoding:
72            self.send_header('Content-Encoding', encoding)
73        for header_name, header_value in extra_headers:
74            self.send_header(header_name, header_value)
75        self.end_headers()
76
77        if write_response:
78            self.wfile.write(response_body)
79
80    def do_HEAD(self):
81        self.do_GET(False)
82
83    def do_POST(self):
84        authorized, extra_headers = self._authorize()
85        if not authorized:
86            return
87
88        encoding, response_body = self._build_response_body()
89        self.send_response(
90            extract_desired_status_code_from_path(self.path, 200))
91        self.send_header('Content-Type', 'text/plain')
92        self.send_header('Content-Length', str(len(response_body)))
93        if encoding:
94            self.send_header('Content-Encoding', encoding)
95        for header_name, header_value in extra_headers:
96            self.send_header(header_name, header_value)
97
98        self.end_headers()
99        self.wfile.write(response_body)
100
101    def do_DELETE(self):
102        self._not_supported()
103
104    def do_PATCH(self):
105        self._not_supported()
106
107    def do_OPTIONS(self):
108        self._not_supported()
109
110    def send_response(self, code, message=None):
111        self.log_request(code)
112        self.send_response_only(code, message)
113        self.send_header('Server', 'Test HTTP Server for Haiku')
114        self.send_header('Date', 'Sun, 09 Feb 2020 19:32:42 GMT')
115
116    def _build_response_body(self):
117        # The post-body may be multi-part/form-data, in which case the client
118        # will have generated some random identifier to identify the boundary.
119        # If that's the case, we'll replace it here in order to allow the test
120        # client to validate the response data without needing to predict the
121        # boundary identifier. This makes the response body deterministic even
122        # though the boundary will change with every request, and lets the
123        # tests in HttpTests hard-code the entire expected response body for
124        # validation.
125        boundary_id_value = None
126
127        supported_encodings = [
128            e.strip()
129            for e in self.headers.get('Accept-Encoding', '').split(',')
130            if e.strip()]
131        if 'gzip' in supported_encodings:
132            encoding = 'gzip'
133            output_stream = GzipResponseBodyBuilder()
134        elif 'deflate' in supported_encodings:
135            encoding = 'deflate'
136            output_stream = DeflateResponseBodyBuilder()
137        else:
138            encoding = None
139            output_stream = RawResponseBodyBuilder()
140
141        output_stream.write(
142            'Path: {}\r\n\r\n'.format(self.path).encode('utf-8'))
143        output_stream.write(b'Headers:\r\n')
144        output_stream.write(b'--------\r\n')
145        for header in self.headers:
146            for header_value in self.headers.get_all(header):
147                if header in ('Host', 'Referer', 'X-Forwarded-For'):
148                    # The server port can change between runs which will change
149                    # the size and contents of the response body. To make tests
150                    # that verify the contents of the response body easier the
151                    # server port will be stripped from these headers when
152                    # echoed to the response body.
153                    header_value = re.sub(r':[0-9]+', ':PORT', header_value)
154
155                    # The scheme will also be in this header value, and we want
156                    # to return the same reguardless of whether http:// or
157                    # https:// was used.
158                    header_value = re.sub(
159                        r'https?://',
160                        'SCHEME://',
161                        header_value)
162                if header == 'Content-Type':
163                    match = MULTIPART_FORM_BOUNDARY_RE.match(
164                        self.headers.get('Content-Type', 'text/plain'))
165                    if match is not None:
166                        boundary_id_value = match.group(1)
167                        header_value = header_value.replace(
168                            boundary_id_value,
169                            '<<BOUNDARY-ID>>')
170                output_stream.write(
171                    '{}: {}\r\n'.format(header, header_value).encode('utf-8'))
172
173        content_length = int(self.headers.get('Content-Length', 0))
174        if content_length > 0:
175            output_stream.write(b'\r\n')
176            output_stream.write(b'Request body:\r\n')
177            output_stream.write(b'-------------\r\n')
178
179            body_bytes = self.rfile.read(content_length).decode('utf-8')
180            if boundary_id_value:
181                body_bytes = body_bytes.replace(
182                    boundary_id_value, '<<BOUNDARY-ID>>')
183
184            output_stream.write(body_bytes.encode('utf-8'))
185            output_stream.write(b'\r\n')
186
187        return encoding, output_stream.get_bytes()
188
189    def _not_supported(self):
190        self.send_response(405, '{} not supported'.format(self.command))
191        self.end_headers()
192        self.wfile.write(
193            '{} not supported\r\n'.format(self.command).encode('utf-8'))
194
195    def _authorize(self):
196        """
197        Authorizes the request. If True is returned that means that the
198        request was not authorized and the 4xx response has been send to the
199        client.
200        """
201        # We only authorize paths like
202        # /auth/<strategy>/<expected-username>/<expected-password>
203        match = AUTH_PATH_RE.match(self.path)
204        if match is None:
205            return True, []
206
207        strategy = match.group('strategy')
208        expected_username = match.group('username')
209        expected_password = match.group('password')
210
211        if strategy == 'basic':
212            return self._handle_basic_auth(
213                expected_username,
214                expected_password)
215        elif strategy == 'digest':
216            return self._handle_digest_auth(
217                expected_username,
218                expected_password)
219        else:
220            raise NotImplementedError(
221                'Unimplemented authorization strategy ' + strategy)
222
223    def _handle_basic_auth(self, expected_username, expected_password):
224        authorization = self.headers.get('Authorization', None)
225        auth_type = None
226        encoded_credentials = None
227        username = None
228        password = None
229
230        if authorization:
231            auth_type, encoded_credentials = authorization.split()
232
233        if encoded_credentials is not None:
234            decoded = base64.decodebytes(encoded_credentials.encode('utf-8'))
235            username, password = decoded.decode('utf-8').split(':')
236
237        if authorization is None or auth_type != 'Basic' \
238                or encoded_credentials is None \
239                or username != expected_username \
240                or password != expected_password:
241            self.send_response(401, 'Not authorized')
242            self.send_header('Www-Authenticate', 'Basic realm="Fake Realm"')
243            self.end_headers()
244            return False, []
245
246        return True, [('Www-Authenticate', 'Basic realm="Fake Realm"')]
247
248    def _handle_digest_auth(self, expected_username, expected_password):
249        """
250        Implement enough of the digest auth RFC to make tests pass.
251        """
252        # Note: These values will always be the same because we want the
253        # response to be deterministic for testing purposes.
254        NONCE = 'f3a95f20879dd891a5544bf96a3e5518'
255        OPAQUE = 'f0bb55f1221a51b6d38117c331611799'
256
257        extra_headers = []
258        authorization = self.headers.get('Authorization', None)
259        credentials = None
260        auth_type = None
261        if authorization is not None:
262            auth_type, fields = authorization.split(maxsplit=1)
263            if auth_type == 'Digest':
264                credentials = parse_kv_pair_header(fields)
265
266        expected_response_hash = None
267        if credentials:
268            expected_response_hash = compute_digest_challenge_response_hash(
269                self.command,
270                self.path,
271                '',
272                credentials,
273                expected_password)
274
275        if authorization is None or credentials is None \
276                or auth_type != 'Digest' \
277                or expected_response_hash != credentials.get('response'):
278            self.send_response(401, 'Not authorized')
279            self.send_header(
280                'Www-Authenticate',
281                'Digest realm="user@shredder",'
282                ' nonce="{}",'
283                ' qop="auth",'
284                ' opaque={},'
285                ' algorithm=MD5,'
286                ' stale=FALSE'.format(NONCE, OPAQUE))
287            self.send_header('Set-Cookie', 'stale_after=never; Path=/')
288            self.send_header('Set-Cookie', 'fake=fake_value; Path=/')
289            self.end_headers()
290            return False, extra_headers
291
292        return True, extra_headers
293
294
295class ResponseBodyBuilder(object):
296    __meta__ = abc.ABCMeta
297
298    @abc.abstractmethod
299    def write(self, bytes):
300        raise NotImplementedError()
301
302    @abc.abstractmethod
303    def get_bytes(self):
304        raise NotImplementedError()
305
306
307class RawResponseBodyBuilder(ResponseBodyBuilder):
308    def __init__(self):
309        self.buf = io.BytesIO()
310
311    def write(self, bytes):
312        self.buf.write(bytes)
313
314    def get_bytes(self):
315        return self.buf.getvalue()
316
317
318class GzipResponseBodyBuilder(ResponseBodyBuilder):
319    def __init__(self):
320        self.buf = io.BytesIO()
321        self.compressor = gzip.GzipFile(
322            mode='wb',
323            compresslevel=4,
324            fileobj=self.buf)
325
326    def write(self, bytes):
327        self.compressor.write(bytes)
328
329    def get_bytes(self):
330        self.compressor.close()
331        return self.buf.getvalue()
332
333
334class DeflateResponseBodyBuilder(ResponseBodyBuilder):
335    def __init__(self):
336        self.raw = RawResponseBodyBuilder()
337
338    def write(self, bytes):
339        self.raw.write(bytes)
340
341    def get_bytes(self):
342        return zlib.compress(self.raw.get_bytes())
343
344
345def extract_desired_status_code_from_path(path, default=200):
346    status_code = default
347    path_parts = os.path.split(path)
348    try:
349        status_code = int(path_parts[-1])
350    except ValueError:
351        pass
352    return status_code
353
354
355def generate_self_signed_tls_cert(common_name, cert_path, key_path):
356    subprocess.check_call([
357        'openssl',
358        'req',
359        '-x509',
360        '-nodes',
361        '-subj', '/CN={}'.format(common_name),
362        '-newkey', 'rsa:4096',
363        '-keyout', key_path,
364        '-out', cert_path,
365        '-days', '1'
366    ])
367
368
369def compute_digest_challenge_response_hash(
370        request_method,
371        request_uri,
372        request_body,
373        credentials,
374        expected_password):
375    """
376    Compute hash as defined by RFC2069, although this isn't an attempt to be
377    perfect, just enough for basic integration tests in HttpTests to work.
378
379    :param credentials: Map of values parsed from the Authorization header
380                        from the client.
381    :param expected_password: The known correct password of the user
382                              attempting to authenticate.
383    :return: None if a hash cannot be produced, otherwise the hash as defined
384             by RFC2069.
385    """
386    algorithm = credentials.get('algorithm')
387    if algorithm == 'MD5':
388        hashfunc = hashlib.md5
389    elif algorithm == 'SHA-256':
390        hashfunc = hashlib.sha256
391    elif algorithm == 'SHA-512':
392        hashfunc = hashlib.sha512
393    else:
394        return None
395
396    realm = credentials.get('realm')
397    username = credentials.get('username')
398
399    ha1 = hashfunc(':'.join([
400        username,
401        realm,
402        expected_password]).encode('utf-8')).hexdigest()
403
404    qop = credentials.get('qop')
405    if qop is None or qop == 'auth':
406        ha2 = hashfunc(':'.join([
407            request_method,
408            request_uri]).encode('utf-8')).hexdigest()
409    elif qop == 'auth-int':
410        ha2 = hashfunc(':'.join([
411            request_method,
412            request_uri,
413            request_body]).encode('utf-8')).hexdigest()
414    else:
415        ha2 = None
416
417    if ha1 is None or ha2 is None:
418        return None
419
420    if qop is None:
421        return hashfunc(':'.join([
422            ha1,
423            credentials.get('nonce', ''),
424            ha2]).encode('utf-8')).hexdigest()
425    elif qop == 'auth' or qop == 'auth-int':
426        hash_components = [
427            ha1,
428            credentials.get('nonce', ''),
429            credentials.get('nc', ''),
430            credentials.get('cnonce', ''),
431            qop,
432            ha2]
433        return hashfunc(':'.join(hash_components).encode('utf-8')).hexdigest()
434
435
436def parse_kv_pair_header(header_value, sep=','):
437    d = {}
438    for kvpair in header_value.split(sep):
439        key, value = kvpair.strip().split('=')
440        d[key.strip()] = value.strip().strip('"')
441    return d
442
443
444def main():
445    options = parse_args(sys.argv)
446
447    bind_addr = (
448        options.bind_addr,
449        0 if options.port is None else options.port)
450
451    server = http.server.HTTPServer(
452        bind_addr,
453        RequestHandler,
454        bind_and_activate=False)
455    if options.port is None:
456        server.server_port = server.socket.getsockname()[1]
457    else:
458        server.server_port = options.port
459
460    if options.server_socket_fd:
461        server.socket = socket.fromfd(
462            options.server_socket_fd,
463            socket.AF_INET,
464            socket.SOCK_STREAM)
465
466    def run_server():
467        if not options.server_socket_fd:
468            server.server_bind()
469            server.server_activate()
470        print(
471            'Test server listening on port',
472            server.server_port,
473            file=sys.stderr)
474        server.serve_forever(0.01)
475
476    try:
477        if options.use_tls:
478            with tempfile.TemporaryDirectory() as temp_cert_dir:
479                common_name = options.bind_addr + ':' + str(options.port)
480                cert_file = os.path.join(temp_cert_dir, 'cert.pem')
481                key_file = os.path.join(temp_cert_dir, 'key.pem')
482                generate_self_signed_tls_cert(
483                    common_name,
484                    cert_file,
485                    key_file)
486                server.socket = ssl.wrap_socket(
487                    server.socket,
488                    certfile=cert_file,
489                    keyfile=key_file,
490                    server_side=True,
491                    do_handshake_on_connect=False)
492            run_server()
493        else:
494            run_server()
495    except KeyboardInterrupt:
496        server.server_close()
497
498
499def parse_args(argv):
500    parser = optparse.OptionParser(
501        usage='Usage: %prog [OPTIONS]',
502        description=__doc__)
503    parser.add_option(
504        '--bind-addr',
505        default='127.0.0.1',
506        dest='bind_addr',
507        help='By default only bind to loopback')
508    parser.add_option(
509        '--use-tls',
510        dest='use_tls',
511        default=False,
512        action='store_true',
513        help='If set, a self-signed TLS certificate, key and CA will be'
514        ' generated for testing purposes.')
515    parser.add_option(
516        '--port',
517        dest='port',
518        default=None,
519        type='int',
520        help='If not specified a random port will be used.')
521    parser.add_option(
522        "--fd",
523        dest='server_socket_fd',
524        default=None,
525        type='int',
526        help='A socket FD to use for accept() instead of binding a new one.')
527    options, args = parser.parse_args(argv)
528    if len(args) > 1:
529        parser.error('Unexpected arguments: {}'.format(', '.join(args[1:])))
530    return options
531
532
533if __name__ == '__main__':
534    main()
535