Python asyncio 模块实现简单异步 https 请求

网上关于 asyncio 实现异步 https 请求的代码几乎都是基于Python 第三方库 aiohttp 的,而我仅需要一个无第三方依赖的、能一键运行的简单 Python 脚本。翻了翻官方文档,也没有什么值得参考的 sample 代码。无奈只能自己动手撸一个。

以下示例代码的作用是,请求百度首页,并将响应打印出来。支持 Python 3.7 及以上的版本。

版本一

使用 loop.create_connection() 从零开始撸。

pythonimport asyncio
import ssl
import io
import pprint

HTTP_HOST = 'www.baidu.com'
HTTP_REQUEST = 'GET / HTTP/1.1\r\nHOST: {host}\r\n\r\n'.format(host=HTTP_HOST).encode()
HTTP_TIMEOUT = 2.0    # 请求超时

class HttpRequestProtocol(asyncio.Protocol):
    def _set_future_result(self, future, result=None):
        '''
        从异步等待返回
        '''
        if not future.cancelled() and not future.done():
            future.set_result(result)

    def _parse_response(self):
        '''
        从返回的数据中解析 Http 响应体
        '''
        if 'headers' in self._response:
            content_length = int(self._response['headers']['Content-Length'])
            if self._response_data.tell() < content_length + self._response_body_offset:
                return False
            self._response_data.seek(self._response_body_offset)
            self._response['body'] = self._response_data.read(content_length)
            self._response_data.close()
            return self._response
        else:
            response_part = self._response_data.getvalue()
            header_end_offset = response_part.find(b'\r\n\r\n')
            if header_end_offset < 0:
                return False
            
            self._response_body_offset = header_end_offset + 4
            header_data = response_part[:header_end_offset]
            headers = header_data.decode().split('\r\n')
            self._response['status'] = headers[0]
            self._response['headers'] = {_[0]: _[1] for _ in \
                [[__.strip() for __ in _.split(':', 1)] for _ in headers[1:]]}
            if 'Content-Length' not in self._response['headers']:
                raise ValueError('响应数据丢失 "Content-Length" 头部,无法处理该响应')
            return self._parse_response()

    def __init__(self, on_data_received, on_connection_lost):
        self.on_data_received = on_data_received
        self.on_connection_lost = on_connection_lost
        self._response = dict()
        self._response_data = io.BytesIO()
        self._response_body_offset = 0

    def data_received(self, data):
        '''
        每次收到服务器返回的数据,都尝试能否解析出完整的响应
        '''
        self._response_data.write(data)
        response = self._parse_response()
        if response:
            # 已收到完整的响应数据,请求操作完成返回
            self._set_future_result(self.on_data_received, response)

    def connection_lost(self, exc):
        print('与服务器连接已关闭: {0}'.format(exc))
        self._set_future_result(self.on_connection_lost)


async def main():
    loop = asyncio.get_running_loop()
    on_connection_lost = loop.create_future()
    on_data_received = loop.create_future()
    transport, protocol = await loop.create_connection(
        lambda: HttpRequestProtocol(on_data_received, on_connection_lost),
        HTTP_HOST, 443
    )
    # 连接开启 TLS
    ssl_context = ssl.create_default_context()
    ssl_transport = await loop.start_tls(transport, protocol, ssl_context)
    ssl_transport.write(HTTP_REQUEST)
    print('发送请求: {!r}'.format(HTTP_REQUEST))

    try:
        # 等待请求返回响应
        response = await asyncio.wait_for(on_data_received, HTTP_TIMEOUT)
        pprint.pprint(response)
    except (asyncio.exceptions.CancelledError, 
            asyncio.exceptions.InvalidStateError,
            TimeoutError):
        print('请求失败')
    except (TypeError, ValueError):
        print('收到异常数据')
    finally:
        # 如果要维持长连接,此处不要关闭 transport
        transport.close()

    try:
        # 如果未关闭连接,此处等待连接关闭
        await on_connection_lost
    finally:
        transport.close()

if __name__ == '__main__':
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        pass

版本二

使用 asyncio.open_connection()

pythonimport asyncio
import ssl
import pprint


async def main():
    reader, writer = await asyncio.open_connection(
        'www.baidu.com', 
        443, 
        ssl=ssl.create_default_context()
    )
    writer.writelines([
        b'GET / HTTP/1.1\r\n'
        b'HOST: www.baidu.com\r\n'
        b'\r\n'
    ])
    await writer.drain()
    raw_repsonse_header = await reader.readuntil(b'\r\n\r\n')
    response_headers = {_[0]: _[1] for _ in \
        [[__.strip() for __ in _.split(b':', 1)] for _ in raw_repsonse_header.split(b'\r\n')[1:-2]]}
    
    if b'Content-Length' not in response_headers:
        raise ValueError('响应数据丢失 "Content-Length" 头部,无法处理该响应')
    content_length = int(response_headers[b'Content-Length'])
    repsonse_body = await reader.readexactly(content_length)
    pprint.pprint(response_headers)
    pprint.pprint(repsonse_body)
    writer.close()
    await writer.wait_closed()


if __name__ == '__main__':
    asyncio.run(main())

此版本的缺陷是必须指定 hostname 。如果 asyncio.open_connection()host 参数使用 IP 地址,那么必须同时传递 server_hostname 参数,否则会报错。下面代码会抛出 ssl.SSLCertVerificationError 异常:

reader, writer = await asyncio.open_connection(
    '112.80.248.75', 
    443, 
    ssl=ssl.create_default_context()
)

异常信息:

ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: IP address mismatch, certificate is not valid for '112.80.248.75'. (_ssl.c:992)

版本三

在版本二的基础上进行改进,不要求验证证书的 hostname 。照着 asyncio.streams 的代码依样画葫芦:

pythonimport asyncio
import ssl
import pprint
from asyncio.streams import _DEFAULT_LIMIT, StreamReader, StreamReaderProtocol, StreamWriter


class TlsStreamReaderProtocol(StreamReaderProtocol):
    def eof_received(self):
        super().eof_received()
        """
        如果这里返回 True
        会打印下面的提示信息
        returning true from eof_received() has no effect when using ssl
        """
        return False


async def open_connection(host=None, port=None, *,
                          limit=_DEFAULT_LIMIT, **kwds):
    ssl_context = ssl.create_default_context()
    loop = asyncio.get_running_loop()
    reader = StreamReader(limit=limit, loop=loop)
    protocol = TlsStreamReaderProtocol(reader, loop=loop)
    transport, _ = await loop.create_connection(
        lambda: protocol, host, port, **kwds)
    ssl_transport = await loop.start_tls(transport, protocol, ssl_context)
    writer = StreamWriter(ssl_transport, protocol, reader, loop)
    return reader, writer


async def main():
    reader, writer = await open_connection(
        '112.80.248.75', 
        443,
    )
    writer.writelines([
        b'GET / HTTP/1.1\r\n'
        b'HOST: www.baidu.com\r\n'
        b'\r\n'
    ])
    await writer.drain()
    raw_repsonse_header = await reader.readuntil(b'\r\n\r\n')
    response_headers = {_[0]: _[1] for _ in \
        [[__.strip() for __ in _.split(b':', 1)] for _ in raw_repsonse_header.split(b'\r\n')[1:-2]]}
    
    if b'Content-Length' not in response_headers:
        raise ValueError('响应数据丢失 "Content-Length" 头部,无法处理该响应')
    content_length = int(response_headers[b'Content-Length'])
    repsonse_body = await reader.readexactly(content_length)
    pprint.pprint(response_headers)
    pprint.pprint(repsonse_body)
    writer.close()
    await writer.wait_closed()


if __name__ == '__main__':
    asyncio.run(main())