thiiiiings
This commit is contained in:
@@ -0,0 +1,77 @@
|
||||
import requests
|
||||
import sys
|
||||
|
||||
from .adapters import UnixAdapter
|
||||
|
||||
DEFAULT_SCHEME = 'http+unix://'
|
||||
|
||||
|
||||
class Session(requests.Session):
|
||||
def __init__(self, url_scheme=DEFAULT_SCHEME, *args, **kwargs):
|
||||
super(Session, self).__init__(*args, **kwargs)
|
||||
self.mount(url_scheme, UnixAdapter())
|
||||
|
||||
|
||||
class monkeypatch(object):
|
||||
def __init__(self, url_scheme=DEFAULT_SCHEME):
|
||||
self.session = Session()
|
||||
requests = self._get_global_requests_module()
|
||||
|
||||
# Methods to replace
|
||||
self.methods = ('request', 'get', 'head', 'post',
|
||||
'patch', 'put', 'delete', 'options')
|
||||
# Store the original methods
|
||||
self.orig_methods = dict(
|
||||
(m, requests.__dict__[m]) for m in self.methods)
|
||||
# Monkey patch
|
||||
g = globals()
|
||||
for m in self.methods:
|
||||
requests.__dict__[m] = g[m]
|
||||
|
||||
def _get_global_requests_module(self):
|
||||
return sys.modules['requests']
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
requests = self._get_global_requests_module()
|
||||
for m in self.methods:
|
||||
requests.__dict__[m] = self.orig_methods[m]
|
||||
|
||||
|
||||
# These are the same methods defined for the global requests object
|
||||
def request(method, url, **kwargs):
|
||||
session = Session()
|
||||
return session.request(method=method, url=url, **kwargs)
|
||||
|
||||
|
||||
def get(url, **kwargs):
|
||||
kwargs.setdefault('allow_redirects', True)
|
||||
return request('get', url, **kwargs)
|
||||
|
||||
|
||||
def head(url, **kwargs):
|
||||
kwargs.setdefault('allow_redirects', False)
|
||||
return request('head', url, **kwargs)
|
||||
|
||||
|
||||
def post(url, data=None, json=None, **kwargs):
|
||||
return request('post', url, data=data, json=json, **kwargs)
|
||||
|
||||
|
||||
def patch(url, data=None, **kwargs):
|
||||
return request('patch', url, data=data, **kwargs)
|
||||
|
||||
|
||||
def put(url, data=None, **kwargs):
|
||||
return request('put', url, data=data, **kwargs)
|
||||
|
||||
|
||||
def delete(url, **kwargs):
|
||||
return request('delete', url, **kwargs)
|
||||
|
||||
|
||||
def options(url, **kwargs):
|
||||
kwargs.setdefault('allow_redirects', True)
|
||||
return request('options', url, **kwargs)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,89 @@
|
||||
import socket
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.compat import urlparse, unquote
|
||||
|
||||
try:
|
||||
import http.client as httplib
|
||||
except ImportError:
|
||||
import httplib
|
||||
|
||||
try:
|
||||
from requests.packages import urllib3
|
||||
except ImportError:
|
||||
import urllib3
|
||||
|
||||
|
||||
# The following was adapted from some code from docker-py
|
||||
# https://github.com/docker/docker-py/blob/master/docker/transport/unixconn.py
|
||||
class UnixHTTPConnection(httplib.HTTPConnection, object):
|
||||
|
||||
def __init__(self, unix_socket_url, timeout=60):
|
||||
"""Create an HTTP connection to a unix domain socket
|
||||
|
||||
:param unix_socket_url: A URL with a scheme of 'http+unix' and the
|
||||
netloc is a percent-encoded path to a unix domain socket. E.g.:
|
||||
'http+unix://%2Ftmp%2Fprofilesvc.sock/status/pid'
|
||||
"""
|
||||
super(UnixHTTPConnection, self).__init__('localhost', timeout=timeout)
|
||||
self.unix_socket_url = unix_socket_url
|
||||
self.timeout = timeout
|
||||
self.sock = None
|
||||
|
||||
def __del__(self): # base class does not have d'tor
|
||||
if self.sock:
|
||||
self.sock.close()
|
||||
|
||||
def connect(self):
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.timeout)
|
||||
socket_path = unquote(urlparse(self.unix_socket_url).netloc)
|
||||
sock.connect(socket_path)
|
||||
self.sock = sock
|
||||
|
||||
|
||||
class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
|
||||
|
||||
def __init__(self, socket_path, timeout=60):
|
||||
super(UnixHTTPConnectionPool, self).__init__(
|
||||
'localhost', timeout=timeout)
|
||||
self.socket_path = socket_path
|
||||
self.timeout = timeout
|
||||
|
||||
def _new_conn(self):
|
||||
return UnixHTTPConnection(self.socket_path, self.timeout)
|
||||
|
||||
|
||||
class UnixAdapter(HTTPAdapter):
|
||||
|
||||
def __init__(self, timeout=60, pool_connections=25):
|
||||
super(UnixAdapter, self).__init__()
|
||||
self.timeout = timeout
|
||||
self.pools = urllib3._collections.RecentlyUsedContainer(
|
||||
pool_connections, dispose_func=lambda p: p.close()
|
||||
)
|
||||
super(UnixAdapter, self).__init__()
|
||||
|
||||
def get_connection(self, url, proxies=None):
|
||||
proxies = proxies or {}
|
||||
proxy = proxies.get(urlparse(url.lower()).scheme)
|
||||
|
||||
if proxy:
|
||||
raise ValueError('%s does not support specifying proxies'
|
||||
% self.__class__.__name__)
|
||||
|
||||
with self.pools.lock:
|
||||
pool = self.pools.get(url)
|
||||
if pool:
|
||||
return pool
|
||||
|
||||
pool = UnixHTTPConnectionPool(url, self.timeout)
|
||||
self.pools[url] = pool
|
||||
|
||||
return pool
|
||||
|
||||
def request_url(self, request, proxies):
|
||||
return request.path_url
|
||||
|
||||
def close(self):
|
||||
self.pools.clear()
|
||||
Binary file not shown.
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Tests for requests_unixsocket"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
import requests_unixsocket
|
||||
from requests_unixsocket.testutils import UnixSocketServerThread
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_unix_domain_adapter_ok():
|
||||
with UnixSocketServerThread() as usock_thread:
|
||||
session = requests_unixsocket.Session('http+unix://')
|
||||
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
|
||||
url = 'http+unix://%s/path/to/page' % urlencoded_usock
|
||||
|
||||
for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
|
||||
'options']:
|
||||
logger.debug('Calling session.%s(%r) ...', method, url)
|
||||
r = getattr(session, method)(url)
|
||||
logger.debug(
|
||||
'Received response: %r with text: %r and headers: %r',
|
||||
r, r.text, r.headers)
|
||||
assert r.status_code == 200
|
||||
assert r.headers['server'] == 'waitress'
|
||||
assert r.headers['X-Transport'] == 'unix domain socket'
|
||||
assert r.headers['X-Requested-Path'] == '/path/to/page'
|
||||
assert r.headers['X-Socket-Path'] == usock_thread.usock
|
||||
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
|
||||
assert r.url.lower() == url.lower()
|
||||
if method == 'head':
|
||||
assert r.text == ''
|
||||
else:
|
||||
assert r.text == 'Hello world!'
|
||||
|
||||
|
||||
def test_unix_domain_adapter_url_with_query_params():
|
||||
with UnixSocketServerThread() as usock_thread:
|
||||
session = requests_unixsocket.Session('http+unix://')
|
||||
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
|
||||
url = ('http+unix://%s'
|
||||
'/containers/nginx/logs?timestamp=true' % urlencoded_usock)
|
||||
|
||||
for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
|
||||
'options']:
|
||||
logger.debug('Calling session.%s(%r) ...', method, url)
|
||||
r = getattr(session, method)(url)
|
||||
logger.debug(
|
||||
'Received response: %r with text: %r and headers: %r',
|
||||
r, r.text, r.headers)
|
||||
assert r.status_code == 200
|
||||
assert r.headers['server'] == 'waitress'
|
||||
assert r.headers['X-Transport'] == 'unix domain socket'
|
||||
assert r.headers['X-Requested-Path'] == '/containers/nginx/logs'
|
||||
assert r.headers['X-Requested-Query-String'] == 'timestamp=true'
|
||||
assert r.headers['X-Socket-Path'] == usock_thread.usock
|
||||
assert isinstance(r.connection, requests_unixsocket.UnixAdapter)
|
||||
assert r.url.lower() == url.lower()
|
||||
if method == 'head':
|
||||
assert r.text == ''
|
||||
else:
|
||||
assert r.text == 'Hello world!'
|
||||
|
||||
|
||||
def test_unix_domain_adapter_connection_error():
|
||||
session = requests_unixsocket.Session('http+unix://')
|
||||
|
||||
for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
|
||||
with pytest.raises(requests.ConnectionError):
|
||||
getattr(session, method)(
|
||||
'http+unix://socket_does_not_exist/path/to/page')
|
||||
|
||||
|
||||
def test_unix_domain_adapter_connection_proxies_error():
|
||||
session = requests_unixsocket.Session('http+unix://')
|
||||
|
||||
for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
getattr(session, method)(
|
||||
'http+unix://socket_does_not_exist/path/to/page',
|
||||
proxies={"http+unix": "http://10.10.1.10:1080"})
|
||||
assert ('UnixAdapter does not support specifying proxies'
|
||||
in str(excinfo.value))
|
||||
|
||||
|
||||
def test_unix_domain_adapter_monkeypatch():
|
||||
with UnixSocketServerThread() as usock_thread:
|
||||
with requests_unixsocket.monkeypatch('http+unix://'):
|
||||
urlencoded_usock = requests.compat.quote_plus(usock_thread.usock)
|
||||
url = 'http+unix://%s/path/to/page' % urlencoded_usock
|
||||
|
||||
for method in ['get', 'post', 'head', 'patch', 'put', 'delete',
|
||||
'options']:
|
||||
logger.debug('Calling session.%s(%r) ...', method, url)
|
||||
r = getattr(requests, method)(url)
|
||||
logger.debug(
|
||||
'Received response: %r with text: %r and headers: %r',
|
||||
r, r.text, r.headers)
|
||||
assert r.status_code == 200
|
||||
assert r.headers['server'] == 'waitress'
|
||||
assert r.headers['X-Transport'] == 'unix domain socket'
|
||||
assert r.headers['X-Requested-Path'] == '/path/to/page'
|
||||
assert r.headers['X-Socket-Path'] == usock_thread.usock
|
||||
assert isinstance(r.connection,
|
||||
requests_unixsocket.UnixAdapter)
|
||||
assert r.url.lower() == url.lower()
|
||||
if method == 'head':
|
||||
assert r.text == ''
|
||||
else:
|
||||
assert r.text == 'Hello world!'
|
||||
|
||||
for method in ['get', 'post', 'head', 'patch', 'put', 'delete', 'options']:
|
||||
with pytest.raises(requests.exceptions.InvalidSchema):
|
||||
getattr(requests, method)(url)
|
||||
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Utilities helpful for writing tests
|
||||
|
||||
Provides a UnixSocketServerThread that creates a running server, listening on a
|
||||
newly created unix socket.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def test_unix_domain_adapter_monkeypatch():
|
||||
with UnixSocketServerThread() as usock_thread:
|
||||
with requests_unixsocket.monkeypatch('http+unix://'):
|
||||
urlencoded_usock = quote_plus(usock_process.usock)
|
||||
url = 'http+unix://%s/path/to/page' % urlencoded_usock
|
||||
r = requests.get(url)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import waitress
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KillThread(threading.Thread):
|
||||
def __init__(self, server, *args, **kwargs):
|
||||
super(KillThread, self).__init__(*args, **kwargs)
|
||||
self.server = server
|
||||
|
||||
def run(self):
|
||||
time.sleep(1)
|
||||
logger.debug('Sleeping')
|
||||
self.server._map.clear()
|
||||
|
||||
|
||||
class WSGIApp:
|
||||
server = None
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
logger.debug('WSGIApp.__call__: Invoked for %s', environ['PATH_INFO'])
|
||||
logger.debug('WSGIApp.__call__: environ = %r', environ)
|
||||
status_text = '200 OK'
|
||||
response_headers = [
|
||||
('X-Transport', 'unix domain socket'),
|
||||
('X-Socket-Path', environ['SERVER_PORT']),
|
||||
('X-Requested-Query-String', environ['QUERY_STRING']),
|
||||
('X-Requested-Path', environ['PATH_INFO'])]
|
||||
body_bytes = b'Hello world!'
|
||||
start_response(status_text, response_headers)
|
||||
logger.debug(
|
||||
'WSGIApp.__call__: Responding with '
|
||||
'status_text = %r; '
|
||||
'response_headers = %r; '
|
||||
'body_bytes = %r',
|
||||
status_text, response_headers, body_bytes)
|
||||
return [body_bytes]
|
||||
|
||||
|
||||
class UnixSocketServerThread(threading.Thread):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(UnixSocketServerThread, self).__init__(*args, **kwargs)
|
||||
self.usock = self.get_tempfile_name()
|
||||
self.server = None
|
||||
self.server_ready_event = threading.Event()
|
||||
|
||||
def get_tempfile_name(self):
|
||||
# I'd rather use tempfile.NamedTemporaryFile but IDNA limits
|
||||
# the hostname to 63 characters and we'll get a "InvalidURL:
|
||||
# URL has an invalid label" error if we exceed that.
|
||||
args = (os.stat(__file__).st_ino, os.getpid(), uuid.uuid4().hex[-8:])
|
||||
return '/tmp/test_requests.%s_%s_%s' % args
|
||||
|
||||
def run(self):
|
||||
logger.debug('Call waitress.serve in %r ...', self)
|
||||
wsgi_app = WSGIApp()
|
||||
server = waitress.create_server(wsgi_app, unix_socket=self.usock)
|
||||
wsgi_app.server = server
|
||||
self.server = server
|
||||
self.server_ready_event.set()
|
||||
server.run()
|
||||
|
||||
def __enter__(self):
|
||||
logger.debug('Starting %r ...' % self)
|
||||
self.start()
|
||||
logger.debug('Started %r.', self)
|
||||
self.server_ready_event.wait()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.server_ready_event.wait()
|
||||
if self.server:
|
||||
KillThread(self.server).start()
|
||||
Reference in New Issue
Block a user