1# 2# Copyright 2020 Haiku, Inc. All rights reserved. 3# Distributed under the terms of the MIT License. 4# 5# Authors: 6# Kyle Ambroff-Kao, kyle@ambroffkao.com 7# 8 9""" 10Transparent HTTP proxy. 11""" 12 13import http.client 14import http.server 15import optparse 16import socket 17import sys 18import urllib.parse 19 20 21class RequestHandler(http.server.BaseHTTPRequestHandler): 22 """ 23 Implement the basic requirements for a transparent HTTP proxy as defined 24 by RFC 7230. Enough of the functionality is implemented to support the 25 integration tests in HttpTest that use the HTTP proxy feature. 26 27 There are many error conditions and failure modes which are not handled. 28 Those cases can be added as the test suite expands to handle more error 29 cases. 30 """ 31 def __init__(self, *args, **kwargs): 32 # This is used to hold on to persistent connections to the downstream 33 # servers. This maps downstream_host:port => HTTPConnection 34 # 35 # This implementation is not thread safe, but that's OK we only have 36 # a single thread anyway. 37 self._connections = {} 38 39 super(RequestHandler, self).__init__(*args, **kwargs) 40 41 def _proxy_request(self): 42 # Extract the downstream server from the request path. 43 # 44 # Note that no attempt is made to prevent message forwarding loops 45 # here. This doesn't need to be a complete proxy implementation, just 46 # enough of one for integration tests. RFC 7230 section 5.7 says if 47 # this were a complete implementation, it would have to make sure that 48 # the target system was not this process to avoid a loop. 49 target = urllib.parse.urlparse(self.path) 50 51 # If Connection: close wasn't used, then we may still have a connection 52 # to this downstream server handy. 53 conn = self._connections.get(target.netloc, None) 54 if conn is None: 55 conn = http.client.HTTPConnection(target.netloc) 56 57 # Collect headers from client which will be sent to the downstream 58 # server. 59 client_headers = {} 60 for header_name in self.headers: 61 if header_name in ('Host', 'Content-Length'): 62 continue 63 for header_value in self.headers.get_all(header_name): 64 client_headers[header_name] = header_value 65 66 # Compute X-Forwarded-For header 67 client_address = '{}:{}'.format(*self.client_address) 68 x_forwarded_for_header = self.headers.get('X-Forwarded-For', None) 69 if x_forwarded_for_header is None: 70 client_headers['X-Forwarded-For'] = client_address 71 else: 72 client_headers['X-Forwarded-For'] = \ 73 x_forwarded_for_header + ', ' + client_address 74 75 # Read the request body from client. 76 request_body_length = int(self.headers.get('Content-Length', '0')) 77 request_body = self.rfile.read(request_body_length) 78 79 # Send the request to the downstream server 80 if target.query: 81 target_path = target.path + '?' + target.query 82 else: 83 target_path = target.path 84 conn.request(self.command, target_path, request_body, client_headers) 85 response = conn.getresponse() 86 87 # Echo the response to the client. 88 self.send_response_only(response.status, response.reason) 89 for header_name, header_value in response.headers.items(): 90 self.send_header(header_name, header_value) 91 self.end_headers() 92 93 # Read the response body from upstream and write it to downstream, if 94 # there is a response body at all. 95 response_content_length = \ 96 int(response.headers.get('Content-Length', '0')) 97 if response_content_length > 0: 98 self.wfile.write(response.read(response_content_length)) 99 100 # Cleanup, possibly hang on to persistent connection to target 101 # server. 102 connection_header_value = self.headers.get('Connection', None) 103 if response.will_close or connection_header_value == 'close': 104 conn.close() 105 self.close_connection = True 106 else: 107 # Hang on to this connection for future requests. This isn't 108 # really bulletproof but it's good enough for integration tests. 109 self._connections[target.netloc] = conn 110 111 self.log_message( 112 'Proxied request from %s to %s', 113 client_address, 114 self.path) 115 116 def do_GET(self): 117 self._proxy_request() 118 119 def do_HEAD(self): 120 self._proxy_request() 121 122 def do_POST(self): 123 self._proxy_request() 124 125 def do_PUT(self): 126 self._proxy_request() 127 128 def do_DELETE(self): 129 self._proxy_request() 130 131 def do_PATCH(self): 132 self._proxy_request() 133 134 def do_OPTIONS(self): 135 self._proxy_request() 136 137 138def main(): 139 options = parse_args(sys.argv) 140 141 bind_addr = ( 142 options.bind_addr, 143 0 if options.port is None else options.port) 144 145 server = http.server.HTTPServer( 146 bind_addr, 147 RequestHandler, 148 bind_and_activate=False) 149 if options.port is None: 150 server.server_port = server.socket.getsockname()[1] 151 else: 152 server.server_port = options.port 153 154 if options.server_socket_fd: 155 server.socket = socket.fromfd( 156 options.server_socket_fd, 157 socket.AF_INET, 158 socket.SOCK_STREAM) 159 else: 160 server.server_bind() 161 server.server_activate() 162 163 print( 164 'Transparent HTTP proxy listening on port', 165 server.server_port, 166 file=sys.stderr) 167 try: 168 server.serve_forever(0.01) 169 except KeyboardInterrupt: 170 server.server_close() 171 172 173def parse_args(argv): 174 parser = optparse.OptionParser( 175 usage='Usage: %prog [OPTIONS]', 176 description=__doc__) 177 parser.add_option( 178 '--bind-addr', 179 default='127.0.0.1', 180 dest='bind_addr', 181 help='By default only bind to loopback') 182 parser.add_option( 183 '--port', 184 dest='port', 185 default=None, 186 type='int', 187 help='If not specified a random port will be used.') 188 parser.add_option( 189 "--fd", 190 dest='server_socket_fd', 191 default=None, 192 type='int', 193 help='A socket FD to use for accept() instead of binding a new one.') 194 options, args = parser.parse_args(argv) 195 if len(args) > 1: 196 parser.error('Unexpected arguments: {}'.format(', '.join(args[1:]))) 197 return options 198 199 200if __name__ == '__main__': 201 main() 202