thiiiiings

This commit is contained in:
Jonas Zeunert
2020-02-13 19:05:55 +01:00
parent 82f1ad6ed8
commit 7414734dad
719 changed files with 41551 additions and 227 deletions

View File

@@ -0,0 +1,77 @@
import requests
import sys
from .adapters import UnixAdapter
DEFAULT_SCHEME = 'http+unix://'
class Session(requests.Session):
def __init__(self, url_scheme=DEFAULT_SCHEME, *args, **kwargs):
super(Session, self).__init__(*args, **kwargs)
self.mount(url_scheme, UnixAdapter())
class monkeypatch(object):
def __init__(self, url_scheme=DEFAULT_SCHEME):
self.session = Session()
requests = self._get_global_requests_module()
# Methods to replace
self.methods = ('request', 'get', 'head', 'post',
'patch', 'put', 'delete', 'options')
# Store the original methods
self.orig_methods = dict(
(m, requests.__dict__[m]) for m in self.methods)
# Monkey patch
g = globals()
for m in self.methods:
requests.__dict__[m] = g[m]
def _get_global_requests_module(self):
return sys.modules['requests']
def __enter__(self):
return self
def __exit__(self, *args):
requests = self._get_global_requests_module()
for m in self.methods:
requests.__dict__[m] = self.orig_methods[m]
# These are the same methods defined for the global requests object
def request(method, url, **kwargs):
session = Session()
return session.request(method=method, url=url, **kwargs)
def get(url, **kwargs):
kwargs.setdefault('allow_redirects', True)
return request('get', url, **kwargs)
def head(url, **kwargs):
kwargs.setdefault('allow_redirects', False)
return request('head', url, **kwargs)
def post(url, data=None, json=None, **kwargs):
return request('post', url, data=data, json=json, **kwargs)
def patch(url, data=None, **kwargs):
return request('patch', url, data=data, **kwargs)
def put(url, data=None, **kwargs):
return request('put', url, data=data, **kwargs)
def delete(url, **kwargs):
return request('delete', url, **kwargs)
def options(url, **kwargs):
kwargs.setdefault('allow_redirects', True)
return request('options', url, **kwargs)

View File

@@ -0,0 +1,89 @@
import socket
from requests.adapters import HTTPAdapter
from requests.compat import urlparse, unquote
try:
import http.client as httplib
except ImportError:
import httplib
try:
from requests.packages import urllib3
except ImportError:
import urllib3
# The following was adapted from some code from docker-py
# https://github.com/docker/docker-py/blob/master/docker/transport/unixconn.py
class UnixHTTPConnection(httplib.HTTPConnection, object):
def __init__(self, unix_socket_url, timeout=60):
"""Create an HTTP connection to a unix domain socket
:param unix_socket_url: A URL with a scheme of 'http+unix' and the
netloc is a percent-encoded path to a unix domain socket. E.g.:
'http+unix://%2Ftmp%2Fprofilesvc.sock/status/pid'
"""
super(UnixHTTPConnection, self).__init__('localhost', timeout=timeout)
self.unix_socket_url = unix_socket_url
self.timeout = timeout
self.sock = None
def __del__(self): # base class does not have d'tor
if self.sock:
self.sock.close()
def connect(self):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
socket_path = unquote(urlparse(self.unix_socket_url).netloc)
sock.connect(socket_path)
self.sock = sock
class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
def __init__(self, socket_path, timeout=60):
super(UnixHTTPConnectionPool, self).__init__(
'localhost', timeout=timeout)
self.socket_path = socket_path
self.timeout = timeout
def _new_conn(self):
return UnixHTTPConnection(self.socket_path, self.timeout)
class UnixAdapter(HTTPAdapter):
def __init__(self, timeout=60, pool_connections=25):
super(UnixAdapter, self).__init__()
self.timeout = timeout
self.pools = urllib3._collections.RecentlyUsedContainer(
pool_connections, dispose_func=lambda p: p.close()
)
super(UnixAdapter, self).__init__()
def get_connection(self, url, proxies=None):
proxies = proxies or {}
proxy = proxies.get(urlparse(url.lower()).scheme)
if proxy:
raise ValueError('%s does not support specifying proxies'
% self.__class__.__name__)
with self.pools.lock:
pool = self.pools.get(url)
if pool:
return pool
pool = UnixHTTPConnectionPool(url, self.timeout)
self.pools[url] = pool
return pool
def request_url(self, request, proxies):
return request.path_url
def close(self):
self.pools.clear()

View File

@@ -0,0 +1,121 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Tests for requests_unixsocket"""
import logging
import pytest
import requests
import requests_unixsocket
from requests_unixsocket.testutils import UnixSocketServerThread
logger = logging.getLogger(__name__)
def test_unix_domain_adapter_ok():
with UnixSocketServerThread() as usock_thread:
session = requests_unixsocket.Session('http+unix://')
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
url = 'http+unix://%s/path/to/page' % urlencoded_usock
for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
'options']:
logger.debug('Calling session.%s(%r) ...', method, url)
r = getattr(session, method)(url)
logger.debug(
'Received response: %r with text: %r and headers: %r',
r, r.text, r.headers)
assert r.status_code == 200
assert r.headers['server'] == 'waitress'
assert r.headers['X-Transport'] == 'unix domain socket'
assert r.headers['X-Requested-Path'] == '/path/to/page'
assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
assert r.url.lower() == url.lower()
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'
def test_unix_domain_adapter_url_with_query_params():
with UnixSocketServerThread() as usock_thread:
session = requests_unixsocket.Session('http+unix://')
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
url = ('http+unix://%s'
'/containers/nginx/logs?timestamp=true' % urlencoded_usock)
for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
'options']:
logger.debug('Calling session.%s(%r) ...', method, url)
r = getattr(session, method)(url)
logger.debug(
'Received response: %r with text: %r and headers: %r',
r, r.text, r.headers)
assert r.status_code == 200
assert r.headers['server'] == 'waitress'
assert r.headers['X-Transport'] == 'unix domain socket'
assert r.headers['X-Requested-Path'] == '/containers/nginx/logs'
assert r.headers['X-Requested-Query-String'] == 'timestamp=true'
assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
assert r.url.lower() == url.lower()
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'
def test_unix_domain_adapter_connection_error():
session = requests_unixsocket.Session('http+unix://')
for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
with pytest.raises(requests.ConnectionError):
getattr(session, method)(
'http+unix://socket_does_not_exist/path/to/page')
def test_unix_domain_adapter_connection_proxies_error():
session = requests_unixsocket.Session('http+unix://')
for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
with pytest.raises(ValueError) as excinfo:
getattr(session, method)(
'http+unix://socket_does_not_exist/path/to/page',
proxies={"http+unix": "http://10.10.1.10:1080"})
assert ('UnixAdapter does not support specifying proxies'
in str(excinfo.value))
def test_unix_domain_adapter_monkeypatch():
with UnixSocketServerThread() as usock_thread:
with requests_unixsocket.monkeypatch('http+unix://'):
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
url = 'http+unix://%s/path/to/page' % urlencoded_usock
for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
'options']:
logger.debug('Calling session.%s(%r) ...', method, url)
r = getattr(requests, method)(url)
logger.debug(
'Received response: %r with text: %r and headers: %r',
r, r.text, r.headers)
assert r.status_code == 200
assert r.headers['server'] == 'waitress'
assert r.headers['X-Transport'] == 'unix domain socket'
assert r.headers['X-Requested-Path'] == '/path/to/page'
assert r.headers['X-Socket-Path'] == usock_thread.usock
assert isinstance(r.connection,
requests_unixsocket.UnixAdapter)
assert r.url.lower() == url.lower()
if method == 'head':
assert r.text == ''
else:
assert r.text == 'Hello world!'
for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
with pytest.raises(requests.exceptions.InvalidSchema):
getattr(requests, method)(url)

View File

@@ -0,0 +1,97 @@
"""
Utilities helpful for writing tests
Provides a UnixSocketServerThread that creates a running server, listening on a
newly created unix socket.
Example usage:
.. code-block:: python
def test_unix_domain_adapter_monkeypatch():
with UnixSocketServerThread() as usock_thread:
with requests_unixsocket.monkeypatch('http+unix://'):
urlencoded_usock = quote_plus(usock_process.usock)
url = 'http+unix://%s/path/to/page' % urlencoded_usock
r = requests.get(url)
"""
import logging
import os
import threading
import time
import uuid
import waitress
logger = logging.getLogger(__name__)
class KillThread(threading.Thread):
def __init__(self, server, *args, **kwargs):
super(KillThread, self).__init__(*args, **kwargs)
self.server = server
def run(self):
time.sleep(1)
logger.debug('Sleeping')
self.server._map.clear()
class WSGIApp:
server = None
def __call__(self, environ, start_response):
logger.debug('WSGIApp.__call__: Invoked for %s', environ['PATH_INFO'])
logger.debug('WSGIApp.__call__: environ = %r', environ)
status_text = '200 OK'
response_headers = [
('X-Transport', 'unix domain socket'),
('X-Socket-Path', environ['SERVER_PORT']),
('X-Requested-Query-String', environ['QUERY_STRING']),
('X-Requested-Path', environ['PATH_INFO'])]
body_bytes = b'Hello world!'
start_response(status_text, response_headers)
logger.debug(
'WSGIApp.__call__: Responding with '
'status_text = %r; '
'response_headers = %r; '
'body_bytes = %r',
status_text, response_headers, body_bytes)
return [body_bytes]
class UnixSocketServerThread(threading.Thread):
def __init__(self, *args, **kwargs):
super(UnixSocketServerThread, self).__init__(*args, **kwargs)
self.usock = self.get_tempfile_name()
self.server = None
self.server_ready_event = threading.Event()
def get_tempfile_name(self):
# I'd rather use tempfile.NamedTemporaryFile but IDNA limits
# the hostname to 63 characters and we'll get a "InvalidURL:
# URL has an invalid label" error if we exceed that.
args = (os.stat(__file__).st_ino, os.getpid(), uuid.uuid4().hex[-8:])
return '/tmp/test_requests.%s_%s_%s' % args
def run(self):
logger.debug('Call waitress.serve in %r ...', self)
wsgi_app = WSGIApp()
server = waitress.create_server(wsgi_app, unix_socket=self.usock)
wsgi_app.server = server
self.server = server
self.server_ready_event.set()
server.run()
def __enter__(self):
logger.debug('Starting %r ...' % self)
self.start()
logger.debug('Started %r.', self)
self.server_ready_event.wait()
return self
def __exit__(self, *args):
self.server_ready_event.wait()
if self.server:
KillThread(self.server).start()