Skip to content

Instantly share code, notes, and snippets.

@rsj217
Last active January 20, 2018 10:11
Show Gist options
  • Save rsj217/3bc33adb4fed43795bea8c398362919b to your computer and use it in GitHub Desktop.
Save rsj217/3bc33adb4fed43795bea8c398362919b to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import socket
import sys
import threading
from itertools import chain
from queue import LifoQueue, Full, Empty
import os
from io import BytesIO
class RedisError(Exception):
pass
class ConnectionError(RedisError):
pass
class BusyLoadingError(ConnectionError):
pass
class InvalidResponse(RedisError):
pass
class ResponseError(RedisError):
pass
class AuthenticationError(RedisError):
pass
class NoScriptError(ResponseError):
pass
class ExecAbortError(ResponseError):
pass
class ReadOnlyError(ResponseError):
pass
class Token(object):
def __init__(self, value):
if isinstance(value, Token):
value = value.value
self.value = value
def __repr__(self):
return self.value
def __str__(self):
return self.value
def nativestr(x):
"""解码返回unicode 字符串"""
return x if isinstance(x, str) else x.decode('utf-8', 'replace')
def b(x):
"""将`unicode`编码成`bytes` 编码格式位 `latin-1`"""
return x.encode('latin-1') if not isinstance(x, bytes) else x
def byte_to_chr(x):
return chr(x)
SYM_STAR = b('*')
SYM_DOLLAR = b('$')
SYM_CRLF = b('\r\n')
SYM_EMPTY = b('')
SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
def iteritems(x):
return iter(x.items())
class BaseParser(object):
EXCEPTION_CLASSES = {
'ERR': {
'max number of clients reached': ConnectionError
},
'EXCABORT': ExecAbortError,
'LOADING': BusyLoadingError,
'NOSCRIPT': NoScriptError,
'READONLY': ReadOnlyError
}
# todo read
def parse_error(self, response):
"Parse an error response"
error_code = response.split(' ')[0]
if error_code in self.EXCEPTION_CLASSES:
response = response[len(error_code) + 1:]
exception_class = self.EXCEPTION_CLASSES[error_code]
if isinstance(exception_class, dict):
exception_class = exception_class.get(response, ResponseError)
return exception_class(response)
return ResponseError(response)
class SocketBuffer(object):
def __init__(self, socket, socket_read_size):
self._sock = socket
self.socket_read_size = socket_read_size
self._buffer = BytesIO()
self.bytes_written = 0
self.bytes_read = 0
@property
def length(self):
return self.bytes_written - self.bytes_read
def _read_from_socket(self, length=None):
socket_read_size = self.socket_read_size
buf = self._buffer
buf.seek(self.bytes_written)
# 表示每次从网络中读取的数据
marker = 0
try:
while True:
data = self._sock.recv(socket_read_size)
if isinstance(data, bytes) and len(data) == 0:
raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length
# 保证读取length字段到buf中
if length is not None and length > marker:
continue
break
except socket.timeout:
raise TimeoutError("Timeout reading from socket")
except socket.error:
e = sys.exc_info()[1]
raise ConnectionError("Error while reading from socket: {}".format(e.args))
def purge(self):
self._buffer.seek(0)
self._buffer.truncate()
self.bytes_read = 0
self.bytes_written = 0
def read(self, length):
length = length + 2
if length > self.length:
self._read_from_socket(length - self.length)
self._buffer.seek(self.bytes_read)
data = self._buffer.read(length)
self.bytes_read += len(data)
if self.bytes_read == self.bytes_written:
self.purge()
return data[:-2]
def readline(self):
buf = self._buffer
buf.seek(self.bytes_read)
data = buf.readline()
# 处理包结束
while not data.endswith(SYM_CRLF):
self._read_from_socket()
buf.seek(self.bytes_read)
data = buf.readline()
self.bytes_read += len(data)
if self.bytes_read == self.bytes_written:
self.purge()
return data[:-2]
def close(self):
try:
self.purge()
self._buffer.close()
except:
pass
self._buffer = None
self._sock = None
class PythonParser(BaseParser):
encoding = None
def __init__(self, socket_read_size):
self.socket_read_size = socket_read_size
self._sock = None
self._buffer = None
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def on_connect(self, connection):
self._sock = connection._sock
self._buffer = SocketBuffer(self._sock, self.socket_read_size)
if connection.decode_responses:
self.encoding = connection.encoding
def read_response(self):
response = self._buffer.readline()
if not response:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
byte, response = byte_to_chr(response[0]), response[1:]
if byte not in ('-', '+', ':', '$', '*'):
raise InvalidResponse("Protocol Error: %s, %s" %
(str(byte), str(response)))
# server returned an error
if byte == '-':
response = nativestr(response)
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == '+':
pass
# int value
elif byte == ':':
response = int(response)
# bulk response
elif byte == '$':
length = int(response)
if length == -1:
return None
response = self._buffer.read(length)
# multi-bulk response
elif byte == '*':
length = int(response)
if length == -1:
return None
response = [self.read_response() for i in range(length)]
if isinstance(response, bytes) and self.encoding:
response = response.decode(self.encoding)
return response
def on_disconnect(self):
if self._sock is not None:
self._sock.close()
self._sock = None
if self._buffer is not None:
self._buffer.close()
self._buffer = None
self.encoding = None
DefaultParser = PythonParser
class Connection(object):
description_format = "Connection<host=%(host)s,port=%(port)s,db=%(db)s>"
def __init__(self, host='127.0.0.1', port=6379, db=0, passwd=None,
socket_timeout=None, socket_connect_timeout=None,
socket_keepalive=False, socket_keepalive_options=None,
encoding='utf-8', encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser, socket_read_size=65536):
self.host = host
self.port = port
self.passwd = passwd
self.db = db
self.encoding = encoding
self.encoding_errors = encoding_errors
self.decode_responses = decode_responses
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
self.socket_keepalive = socket_keepalive
self.socket_keepalive_options = socket_keepalive_options or {}
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
self._description_args = {
'host': self.host,
'port': self.port,
'db': self.db,
}
def __repr__(self):
return self.description_format % self._description_args
def __del__(self):
try:
self.disconnect()
except Exception:
pass
def encode(self, value):
if isinstance(value, Token):
return b(value.value)
elif isinstance(value, bytes):
return value
elif isinstance(value, int):
value = b(str(value))
elif not isinstance(value, str):
value = str(value)
if isinstance(value, str):
value = value.encode(self.encoding, self.encoding_errors)
return value
def pack_command(self, *args):
"""将redis命令安装redis的协议编码,返回编码后的数组,如果命令很大,返回的是编码后chunk的数组"""
output = []
command = args[0]
if ' ' in command:
args = tuple([Token(s) for s in command.split(' ')]) + args[1:]
else:
args = (Token(command),) + args[1:]
buff = SYM_EMPTY.join(
(SYM_STAR, b(str(len(args))), SYM_CRLF))
for arg in map(self.encode, args):
# 数据量特别大的时候,分成部分小的chunk
if len(buff) > 6000 or len(arg) > 6000:
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF))
output.append(buff)
output.append(arg)
buff = SYM_CRLF
else:
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, arg, SYM_CRLF))
output.append(buff)
return output
def send_command(self, *args):
self.send_packed_command(self.pack_command(*args))
def _error_message(self, exception):
if len(exception.args) == 1:
return "Error connecting to {}:{}. {}.".format(self.host, self.port, exception.args[0])
return "Error {} connecting to {}:{} {}.".format(exception.args[0], self.host, self.port, exception.args[1])
def _connect(self):
""" 创建 socket 连接,并返回socket对象
"""
err = None
for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM):
family, socktype, proto, canonname, socket_address = res
try:
# 创建 tcp socket 对象
sock = socket.socket(family, socktype, proto)
# TCP_NODELAY
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if self.socket_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
for k, v in iteritems(self.socket_keepalive_options):
sock.setsockopt(socket.SOL_TCP, k, v)
# 设置创建连接的超时时间
sock.settimeout(self.socket_connect_timeout)
sock.connect(socket_address)
# 设置连接之后的超时时间
sock.settimeout(self.socket_timeout)
return sock
except socket.error as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
raise socket.error("socket.getaddrinfo returned an empty list")
def read_response(self):
try:
response = self._parser.read_response()
except:
self.disconnect()
raise
if isinstance(response, ResponseError):
raise response
return response
def on_connect(self):
""" socket连接创建之后,进行redis的认证和选择数据库,开始发送数据.
"""
self._parser.on_connect(self)
# if a password is specified, authenticate
if self.passwd:
self.send_command('AUTH', self.passwd)
response = self.read_response()
if nativestr(response) != 'OK':
raise AuthenticationError('Invalid Password')
if self.db:
self.send_command('SELECT', self.db)
if nativestr(self.read_response()) != 'OK':
raise ConnectionError('Invalid Database')
def connect(self):
if self._sock:
return
try:
sock = self._connect()
except socket.error:
e = sys.exc_info()[1]
raise ConnectionError(self._error_message(e))
self._sock = sock
try:
self.on_connect()
except RedisError:
self.disconnect()
raise
# for callback in self._connect_callbacks:
# callback(self)
def disconnect(self):
self._parser.on_disconnect()
if self._sock is None:
return
try:
self._sock.shutdown(socket.SHUT_RDWR)
self._sock.close()
except socket.error:
pass
self._sock = None
def send_packed_command(self, command):
"""将编码后的redis命令发送到redis服务器"""
if not self._sock:
self.connect()
try:
if isinstance(command, str):
command = [command]
for item in command:
self._sock.sendall(item)
except socket.timeout:
self.disconnect()
raise TimeoutError('Timeout writing to socket')
except socket.error:
e = sys.exc_info()[1]
self.disconnect()
if len(e.args) == 1:
errno, errmsg = 'UNKNOWN', e.args[0]
else:
errno = e.args[0]
errmsg = e.args[1]
raise ConnectionError("Error {} while writing to socket. {}.".format(errno, errmsg))
except:
self.disconnect()
raise
class ConnectionPool(object):
def __init__(self, connection_class=Connection, max_connections=None, **connection_kwargs):
max_connections = max_connections or 2 ** 31
if not isinstance(max_connections, int) or max_connections < 0:
raise ValueError('"max_connections" must be a positive integer')
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections
self.reset()
def reset(self):
self.pid = os.getpid()
self._created_connections = 0
self._available_connections = []
self._in_use_connections = set()
self._check_lock = threading.Lock()
def _checkpid(self):
if self.pid != os.getpid():
with self._check_lock:
if self.pid == os.getpid():
# another thread already did the work while we waited
# on the lock.
return
self.disconnect()
self.reset()
def make_connection(self):
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_connections += 1
return self.connection_class(**self.connection_kwargs)
def disconnect(self):
"Disconnects all connections in the pool"
all_conns = chain(self._available_connections,
self._in_use_connections)
for connection in all_conns:
connection.disconnect()
def get_connection(self, command_name, *keys, **options):
self._checkpid()
try:
connection = self._available_connections.pop()
except IndexError:
connection = self.make_connection()
self._in_use_connections.add(connection)
return connection
class BlockingConnectionPool(ConnectionPool):
def __init__(self, max_connections=50, timeout=20,
connection_class=Connection, queue_class=LifoQueue,
**connection_kwargs):
self.queue_class = queue_class
self.timeout = timeout
super(BlockingConnectionPool, self).__init__(
connection_class=connection_class,
max_connections=max_connections,
**connection_kwargs)
def reset(self):
self.pid = os.getpid()
self._check_lock = threading.Lock()
# Create and fill up a thread safe queue with ``None`` values.
self.pool = self.queue_class(self.max_connections)
while True:
try:
self.pool.put_nowait(None)
except Full:
break
self._connections = []
def make_connection(self):
connection = self.connection_class(**self.connection_kwargs)
self._connections.append(connection)
return connection
def get_connection(self, command_name, *keys, **options):
self._checkpid()
connection = None
try:
connection = self.pool.get(block=True, timeout=self.timeout)
except Empty:
raise ConnectionError("No connection available.")
if connection is None:
connection = self.make_connection()
return connection
def release(self, connection):
self._checkpid()
if connection.pid != self.pid:
return
try:
self.pool.put_nowait(connection)
except Full:
pass
def disconnect(self):
"Disconnects all connections in the pool."
for connection in self._connections:
connection.disconnect()
if __name__ == '__main__':
p = BlockingConnectionPool(host='127.0.0.1', port=6379, max_connections=3)
# p = ConnectionPool(host='127.0.0.1', port=6379, max_connections=3)
args = ('PING',)
command_name = args[0]
# conn = p.get_connection(command_name, **{})
# conn.send_command("GET", "hello")
# print(conn.read_response())
import time
class Query(threading.Thread):
def run(self):
conn = p.get_connection(command_name, **{})
print('thread start ', self.getName(), conn.send_command("SET", self.getName(), 1))
time.sleep(1)
p.release(conn)
print('thread end')
threads = [Query() for i in range(5)]
[t.start() for t in threads]
print(len(p._connections))
print('main end')
while True:
print(len(p._connections))
time.sleep(1)
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import socket
import sys
import threading
from itertools import chain
import os
import warnings
from io import BytesIO
class RedisError(Exception):
pass
class ConnectionError(RedisError):
pass
class BusyLoadingError(ConnectionError):
pass
class InvalidResponse(RedisError):
pass
class ResponseError(RedisError):
pass
class AuthenticationError(RedisError):
pass
class NoScriptError(ResponseError):
pass
class ExecAbortError(ResponseError):
pass
class ReadOnlyError(ResponseError):
pass
class Token(object):
def __init__(self, value):
if isinstance(value, Token):
value = value.value
self.value = value
def __repr__(self):
return self.value
def __str__(self):
return self.value
def nativestr(x):
"""解码返回unicode 字符串"""
return x if isinstance(x, str) else x.decode('utf-8', 'replace')
def b(x):
"""将`unicode`编码成`bytes` 编码格式位 `latin-1`"""
return x.encode('latin-1') if not isinstance(x, bytes) else x
def byte_to_chr(x):
return chr(x)
SYM_STAR = b('*')
SYM_DOLLAR = b('$')
SYM_CRLF = b('\r\n')
SYM_EMPTY = b('')
SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
def iteritems(x):
return iter(x.items())
def string_keys_to_dict(key_string, callback):
return dict.fromkeys(key_string.split(), callback)
def dict_merge(*dicts):
merged = {}
for d in dicts:
merged.update(d)
return merged
class BaseParser(object):
EXCEPTION_CLASSES = {
'ERR': {
'max number of clients reached': ConnectionError
},
'EXCABORT': ExecAbortError,
'LOADING': BusyLoadingError,
'NOSCRIPT': NoScriptError,
'READONLY': ReadOnlyError
}
# todo read
def parse_error(self, response):
"Parse an error response"
error_code = response.split(' ')[0]
if error_code in self.EXCEPTION_CLASSES:
response = response[len(error_code) + 1:]
exception_class = self.EXCEPTION_CLASSES[error_code]
if isinstance(exception_class, dict):
exception_class = exception_class.get(response, ResponseError)
return exception_class(response)
return ResponseError(response)
class SocketBuffer(object):
def __init__(self, socket, socket_read_size):
self._sock = socket
self.socket_read_size = socket_read_size
self._buffer = BytesIO()
self.bytes_written = 0
self.bytes_read = 0
@property
def length(self):
return self.bytes_written - self.bytes_read
def _read_from_socket(self, length=None):
socket_read_size = self.socket_read_size
buf = self._buffer
buf.seek(self.bytes_written)
# 表示每次从网络中读取的数据
marker = 0
try:
while True:
data = self._sock.recv(socket_read_size)
if isinstance(data, bytes) and len(data) == 0:
raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length
# 保证读取length字段到buf中
if length is not None and length > marker:
continue
break
except socket.timeout:
raise TimeoutError("Timeout reading from socket")
except socket.error:
e = sys.exc_info()[1]
raise ConnectionError("Error while reading from socket: {}".format(e.args))
def purge(self):
self._buffer.seek(0)
self._buffer.truncate()
self.bytes_read = 0
self.bytes_written = 0
def read(self, length):
length = length + 2
if length > self.length:
self._read_from_socket(length - self.length)
self._buffer.seek(self.bytes_read)
data = self._buffer.read(length)
self.bytes_read += len(data)
if self.bytes_read == self.bytes_written:
self.purge()
return data[:-2]
def readline(self):
buf = self._buffer
buf.seek(self.bytes_read)
data = buf.readline()
# 处理包结束
while not data.endswith(SYM_CRLF):
self._read_from_socket()
buf.seek(self.bytes_read)
data = buf.readline()
self.bytes_read += len(data)
if self.bytes_read == self.bytes_written:
self.purge()
return data[:-2]
def close(self):
try:
self.purge()
self._buffer.close()
except:
pass
self._buffer = None
self._sock = None
class PythonParser(BaseParser):
encoding = None
def __init__(self, socket_read_size):
self.socket_read_size = socket_read_size
self._sock = None
self._buffer = None
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def on_connect(self, connection):
self._sock = connection._sock
self._buffer = SocketBuffer(self._sock, self.socket_read_size)
if connection.decode_responses:
self.encoding = connection.encoding
def read_response(self):
response = self._buffer.readline()
if not response:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
byte, response = byte_to_chr(response[0]), response[1:]
if byte not in ('-', '+', ':', '$', '*'):
raise InvalidResponse("Protocol Error: %s, %s" %
(str(byte), str(response)))
# server returned an error
if byte == '-':
response = nativestr(response)
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == '+':
pass
# int value
elif byte == ':':
response = int(response)
# bulk response
elif byte == '$':
length = int(response)
if length == -1:
return None
response = self._buffer.read(length)
# multi-bulk response
elif byte == '*':
length = int(response)
if length == -1:
return None
response = [self.read_response() for i in range(length)]
if isinstance(response, bytes) and self.encoding:
response = response.decode(self.encoding)
return response
def on_disconnect(self):
if self._sock is not None:
self._sock.close()
self._sock = None
if self._buffer is not None:
self._buffer.close()
self._buffer = None
self.encoding = None
DefaultParser = PythonParser
class Connection(object):
def __init__(self, host='127.0.0.1', port=6379, db=0, passwd=None,
socket_timeout=None, socket_connect_timeout=None,
socket_keepalive=False, socket_keepalive_options=None, retry_on_timeout=False,
encoding='utf-8', encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser, socket_read_size=65536):
self.pid = os.getpid()
self.host = host
self.port = int(port)
self.passwd = passwd
self.db = db
self.encoding = encoding
self.encoding_errors = encoding_errors
self.decode_responses = decode_responses
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
self.socket_keepalive = socket_keepalive
self.socket_keepalive_options = socket_keepalive_options or {}
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
def __repr__(self):
return self.description_format % self._description_args
def __del__(self):
try:
self.disconnect()
except Exception:
pass
def encode(self, value):
if isinstance(value, Token):
return b(value.value)
elif isinstance(value, bytes):
return value
elif isinstance(value, int):
value = b(str(value))
elif not isinstance(value, str):
value = str(value)
if isinstance(value, str):
value = value.encode(self.encoding, self.encoding_errors)
return value
def pack_command(self, *args):
"""将redis命令安装redis的协议编码,返回编码后的数组,如果命令很大,返回的是编码后chunk的数组"""
output = []
command = args[0]
if ' ' in command:
args = tuple([Token(s) for s in command.split(' ')]) + args[1:]
else:
args = (Token(command),) + args[1:]
buff = SYM_EMPTY.join(
(SYM_STAR, b(str(len(args))), SYM_CRLF))
for arg in map(self.encode, args):
# 数据量特别大的时候,分成部分小的chunk
if len(buff) > 6000 or len(arg) > 6000:
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF))
output.append(buff)
output.append(arg)
buff = SYM_CRLF
else:
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, arg, SYM_CRLF))
output.append(buff)
return output
def send_command(self, *args):
self.send_packed_command(self.pack_command(*args))
def _error_message(self, exception):
if len(exception.args) == 1:
return "Error connecting to {}:{}. {}.".format(self.host, self.port, exception.args[0])
return "Error {} connecting to {}:{} {}.".format(exception.args[0], self.host, self.port, exception.args[1])
def _connect(self):
""" 创建 socket 连接,并返回socket对象
"""
err = None
for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM):
family, socktype, proto, canonname, socket_address = res
try:
# 创建 tcp socket 对象
sock = socket.socket(family, socktype, proto)
# TCP_NODELAY
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if self.socket_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
for k, v in iteritems(self.socket_keepalive_options):
sock.setsockopt(socket.SOL_TCP, k, v)
# 设置创建连接的超时时间
sock.settimeout(self.socket_connect_timeout)
sock.connect(socket_address)
# 设置连接之后的超时时间
sock.settimeout(self.socket_timeout)
return sock
except socket.error as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
raise socket.error("socket.getaddrinfo returned an empty list")
def read_response(self):
try:
response = self._parser.read_response()
except:
self.disconnect()
raise
if isinstance(response, ResponseError):
raise response
return response
def on_connect(self):
""" socket连接创建之后,进行redis的认证和选择数据库,开始发送数据.
"""
self._parser.on_connect(self)
# if a password is specified, authenticate
if self.passwd:
self.send_command('AUTH', self.passwd)
response = self.read_response()
if nativestr(response) != 'OK':
raise AuthenticationError('Invalid Password')
if self.db:
self.send_command('SELECT', self.db)
if nativestr(self.read_response()) != 'OK':
raise ConnectionError('Invalid Database')
def connect(self):
if self._sock:
return
try:
sock = self._connect()
except socket.error:
e = sys.exc_info()[1]
raise ConnectionError(self._error_message(e))
self._sock = sock
try:
self.on_connect()
except RedisError:
self.disconnect()
raise
# for callback in self._connect_callbacks:
# callback(self)
def disconnect(self):
self._parser.on_disconnect()
if self._sock is None:
return
try:
self._sock.shutdown(socket.SHUT_RDWR)
self._sock.close()
except socket.error:
pass
self._sock = None
def send_packed_command(self, command):
"""将编码后的redis命令发送到redis服务器"""
if not self._sock:
self.connect()
try:
if isinstance(command, str):
command = [command]
for item in command:
self._sock.sendall(item)
except socket.timeout:
self.disconnect()
raise TimeoutError('Timeout writing to socket')
except socket.error:
e = sys.exc_info()[1]
self.disconnect()
if len(e.args) == 1:
errno, errmsg = 'UNKNOWN', e.args[0]
else:
errno = e.args[0]
errmsg = e.args[1]
raise ConnectionError("Error {} while writing to socket. {}.".format(errno, errmsg))
except:
self.disconnect()
raise
class ConnectionPool(object):
def __init__(self, connection_class=Connection, max_connections=None, **connection_kwargs):
max_connections = max_connections or 2 ** 31
if not isinstance(max_connections, int) or max_connections < 0:
raise ValueError('"max_connections" must be a positive integer')
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections
self.reset()
def reset(self):
self.pid = os.getpid()
self._created_connections = 0
self._available_connections = []
self._in_use_connections = set()
self._check_lock = threading.Lock()
def _checkpid(self):
if self.pid != os.getpid():
with self._check_lock:
if self.pid == os.getpid():
return
self.disconnect()
self.reset()
def make_connection(self):
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_connections += 1
return self.connection_class(**self.connection_kwargs)
def get_connection(self, command_name, *keys, **options):
self._checkpid()
try:
connection = self._available_connections.pop()
except IndexError:
connection = self.make_connection()
self._in_use_connections.add(connection)
return connection
def release(self, connection):
self._checkpid()
if connection.pid != self.pid:
return
self._in_use_connections.remove(connection)
self._available_connections.append(connection)
def disconnect(self):
all_conns = chain(self._available_connections, self._in_use_connections)
for connection in all_conns:
connection.disconnect()
class StrictRedis(object):
RESPONSE_CALLBACKS = dict_merge(
{
'PING': lambda r: nativestr(r) == 'PONG',
}
)
def __init__(self, host='localhost', port=6379, db=0, passwd=None,
socket_timeout=None, socket_connect_timeout=None,
socket_keepalive=None, socket_keepalive_options=None,
connection_pool=None, encoding='utf-8', encoding_errors='strict',
charset=None, errors=None, decode_responses=False, retry_on_timeout=False, max_connections=None):
if not connection_pool:
if charset is not None:
warnings.warn(DeprecationWarning(
'"charset" is deprecated. Use "encoding" instead'))
encoding = charset
if errors is not None:
warnings.warn(DeprecationWarning(
'"errors" is deprecated. Use "encoding_errors" instead'))
encoding_errors = errors
kwargs = {
'db': db,
'passwd': passwd,
'socket_timeout': socket_timeout,
'encoding': encoding,
'encoding_errors': encoding_errors,
'decode_responses': decode_responses,
'retry_on_timeout': retry_on_timeout,
'max_connections': max_connections
}
connection_pool = ConnectionPool(**kwargs)
self.connection_pool = connection_pool
self.response_callbacks = self.__class__.RESPONSE_CALLBACKS.copy()
def parse_response(self, connection, command_name, **options):
response = connection.read_response()
if command_name in self.response_callbacks:
return self.response_callbacks[command_name](response, **options)
return response
def execute_command(self, *args, **options):
pool = self.connection_pool
command_name = args[0]
connection = pool.get_connection(command_name, **options)
try:
connection.send_command(*args)
return self.parse_response(connection, command_name, **options)
except (ConnectionError, TimeoutError) as e:
connection.disconnect()
if not connection.retry_on_timeout and isinstance(e, TimeoutError):
raise
connection.send_command(*args)
return self.parse_response(connection, command_name, **options)
finally:
pool.release(connection)
def ping(self, *args, **options):
return self.execute_command("PING")
def get(self, *args, **options):
return self.execute_command(*args, **options)
if __name__ == '__main__':
rc = StrictRedis()
ret = rc.execute_command("PING")
print(ret)
print(rc.execute_command("GET", "hello"))
print(rc.ping())
print(rc.get("GET", "hello"))
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import socket
import sys
import threading
from itertools import chain
import os
from io import BytesIO
class RedisError(Exception):
pass
class ConnectionError(RedisError):
pass
class BusyLoadingError(ConnectionError):
pass
class InvalidResponse(RedisError):
pass
class ResponseError(RedisError):
pass
class AuthenticationError(RedisError):
pass
class NoScriptError(ResponseError):
pass
class ExecAbortError(ResponseError):
pass
class ReadOnlyError(ResponseError):
pass
class Token(object):
def __init__(self, value):
if isinstance(value, Token):
value = value.value
self.value = value
def __repr__(self):
return self.value
def __str__(self):
return self.value
def nativestr(x):
"""解码返回unicode 字符串"""
return x if isinstance(x, str) else x.decode('utf-8', 'replace')
def b(x):
"""将`unicode`编码成`bytes` 编码格式位 `latin-1`"""
return x.encode('latin-1') if not isinstance(x, bytes) else x
def byte_to_chr(x):
return chr(x)
SYM_STAR = b('*')
SYM_DOLLAR = b('$')
SYM_CRLF = b('\r\n')
SYM_EMPTY = b('')
SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
def iteritems(x):
return iter(x.items())
class BaseParser(object):
EXCEPTION_CLASSES = {
'ERR': {
'max number of clients reached': ConnectionError
},
'EXCABORT': ExecAbortError,
'LOADING': BusyLoadingError,
'NOSCRIPT': NoScriptError,
'READONLY': ReadOnlyError
}
# todo read
def parse_error(self, response):
"Parse an error response"
error_code = response.split(' ')[0]
if error_code in self.EXCEPTION_CLASSES:
response = response[len(error_code) + 1:]
exception_class = self.EXCEPTION_CLASSES[error_code]
if isinstance(exception_class, dict):
exception_class = exception_class.get(response, ResponseError)
return exception_class(response)
return ResponseError(response)
class SocketBuffer(object):
def __init__(self, socket, socket_read_size):
self._sock = socket
self.socket_read_size = socket_read_size
self._buffer = BytesIO()
self.bytes_written = 0
self.bytes_read = 0
@property
def length(self):
return self.bytes_written - self.bytes_read
def _read_from_socket(self, length=None):
socket_read_size = self.socket_read_size
buf = self._buffer
buf.seek(self.bytes_written)
# 表示每次从网络中读取的数据
marker = 0
try:
while True:
data = self._sock.recv(socket_read_size)
if isinstance(data, bytes) and len(data) == 0:
raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length
# 保证读取length字段到buf中
if length is not None and length > marker:
continue
break
except socket.timeout:
raise TimeoutError("Timeout reading from socket")
except socket.error:
e = sys.exc_info()[1]
raise ConnectionError("Error while reading from socket: {}".format(e.args))
def purge(self):
self._buffer.seek(0)
self._buffer.truncate()
self.bytes_read = 0
self.bytes_written = 0
def read(self, length):
length = length + 2
if length > self.length:
self._read_from_socket(length - self.length)
self._buffer.seek(self.bytes_read)
data = self._buffer.read(length)
self.bytes_read += len(data)
if self.bytes_read == self.bytes_written:
self.purge()
return data[:-2]
def readline(self):
buf = self._buffer
buf.seek(self.bytes_read)
data = buf.readline()
# 处理包结束
while not data.endswith(SYM_CRLF):
self._read_from_socket()
buf.seek(self.bytes_read)
data = buf.readline()
self.bytes_read += len(data)
if self.bytes_read == self.bytes_written:
self.purge()
return data[:-2]
def close(self):
try:
self.purge()
self._buffer.close()
except:
pass
self._buffer = None
self._sock = None
class PythonParser(BaseParser):
encoding = None
def __init__(self, socket_read_size):
self.socket_read_size = socket_read_size
self._sock = None
self._buffer = None
def __del__(self):
try:
self.on_disconnect()
except Exception:
pass
def on_connect(self, connection):
self._sock = connection._sock
self._buffer = SocketBuffer(self._sock, self.socket_read_size)
if connection.decode_responses:
self.encoding = connection.encoding
def read_response(self):
response = self._buffer.readline()
if not response:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
byte, response = byte_to_chr(response[0]), response[1:]
if byte not in ('-', '+', ':', '$', '*'):
raise InvalidResponse("Protocol Error: %s, %s" %
(str(byte), str(response)))
# server returned an error
if byte == '-':
response = nativestr(response)
error = self.parse_error(response)
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
# and/or the pipeline's execute() will raise this error if
# necessary, so just return the exception instance here.
return error
# single value
elif byte == '+':
pass
# int value
elif byte == ':':
response = int(response)
# bulk response
elif byte == '$':
length = int(response)
if length == -1:
return None
response = self._buffer.read(length)
# multi-bulk response
elif byte == '*':
length = int(response)
if length == -1:
return None
response = [self.read_response() for i in range(length)]
if isinstance(response, bytes) and self.encoding:
response = response.decode(self.encoding)
return response
def on_disconnect(self):
if self._sock is not None:
self._sock.close()
self._sock = None
if self._buffer is not None:
self._buffer.close()
self._buffer = None
self.encoding = None
DefaultParser = PythonParser
class Connection(object):
def __init__(self, host='127.0.0.1', port=6379, db=0, passwd=None,
socket_timeout=None, socket_connect_timeout=None,
socket_keepalive=False, socket_keepalive_options=None,
encoding='utf-8', encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser, socket_read_size=65536):
self.host = host
self.port = port
self.passwd = passwd
self.db = db
self.encoding = encoding
self.encoding_errors = encoding_errors
self.decode_responses = decode_responses
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
self.socket_keepalive = socket_keepalive
self.socket_keepalive_options = socket_keepalive_options or {}
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
def __repr__(self):
return self.description_format % self._description_args
def __del__(self):
try:
self.disconnect()
except Exception:
pass
def encode(self, value):
if isinstance(value, Token):
return b(value.value)
elif isinstance(value, bytes):
return value
elif isinstance(value, int):
value = b(str(value))
elif not isinstance(value, str):
value = str(value)
if isinstance(value, str):
value = value.encode(self.encoding, self.encoding_errors)
return value
def pack_command(self, *args):
"""将redis命令安装redis的协议编码,返回编码后的数组,如果命令很大,返回的是编码后chunk的数组"""
output = []
command = args[0]
if ' ' in command:
args = tuple([Token(s) for s in command.split(' ')]) + args[1:]
else:
args = (Token(command),) + args[1:]
buff = SYM_EMPTY.join(
(SYM_STAR, b(str(len(args))), SYM_CRLF))
for arg in map(self.encode, args):
# 数据量特别大的时候,分成部分小的chunk
if len(buff) > 6000 or len(arg) > 6000:
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF))
output.append(buff)
output.append(arg)
buff = SYM_CRLF
else:
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, arg, SYM_CRLF))
output.append(buff)
return output
def send_command(self, *args):
self.send_packed_command(self.pack_command(*args))
def _error_message(self, exception):
if len(exception.args) == 1:
return "Error connecting to {}:{}. {}.".format(self.host, self.port, exception.args[0])
return "Error {} connecting to {}:{} {}.".format(exception.args[0], self.host, self.port, exception.args[1])
def _connect(self):
""" 创建 socket 连接,并返回socket对象
"""
err = None
for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM):
family, socktype, proto, canonname, socket_address = res
try:
# 创建 tcp socket 对象
sock = socket.socket(family, socktype, proto)
# TCP_NODELAY
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if self.socket_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
for k, v in iteritems(self.socket_keepalive_options):
sock.setsockopt(socket.SOL_TCP, k, v)
# 设置创建连接的超时时间
sock.settimeout(self.socket_connect_timeout)
sock.connect(socket_address)
# 设置连接之后的超时时间
sock.settimeout(self.socket_timeout)
return sock
except socket.error as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
raise socket.error("socket.getaddrinfo returned an empty list")
def read_response(self):
try:
response = self._parser.read_response()
except:
self.disconnect()
raise
if isinstance(response, ResponseError):
raise response
return response
def on_connect(self):
""" socket连接创建之后,进行redis的认证和选择数据库,开始发送数据.
"""
self._parser.on_connect(self)
# if a password is specified, authenticate
if self.passwd:
self.send_command('AUTH', self.passwd)
response = self.read_response()
if nativestr(response) != 'OK':
raise AuthenticationError('Invalid Password')
if self.db:
self.send_command('SELECT', self.db)
if nativestr(self.read_response()) != 'OK':
raise ConnectionError('Invalid Database')
def connect(self):
if self._sock:
return
try:
sock = self._connect()
except socket.error:
e = sys.exc_info()[1]
raise ConnectionError(self._error_message(e))
self._sock = sock
try:
self.on_connect()
except RedisError:
self.disconnect()
raise
# for callback in self._connect_callbacks:
# callback(self)
def disconnect(self):
self._parser.on_disconnect()
if self._sock is None:
return
try:
self._sock.shutdown(socket.SHUT_RDWR)
self._sock.close()
except socket.error:
pass
self._sock = None
def send_packed_command(self, command):
"""将编码后的redis命令发送到redis服务器"""
if not self._sock:
self.connect()
try:
if isinstance(command, str):
command = [command]
for item in command:
self._sock.sendall(item)
except socket.timeout:
self.disconnect()
raise TimeoutError('Timeout writing to socket')
except socket.error:
e = sys.exc_info()[1]
self.disconnect()
if len(e.args) == 1:
errno, errmsg = 'UNKNOWN', e.args[0]
else:
errno = e.args[0]
errmsg = e.args[1]
raise ConnectionError("Error {} while writing to socket. {}.".format(errno, errmsg))
except:
self.disconnect()
raise
class ConnectionPool(object):
def __init__(self, connection_class=Connection, max_connections=None, **connection_kwargs):
max_connections = max_connections or 2 ** 31
if not isinstance(max_connections, int) or max_connections < 0:
raise ValueError('"max_connections" must be a positive integer')
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections
self.reset()
def reset(self):
self.pid = os.getpid()
self._created_connections = 0
self._available_connections = []
self._in_use_connections = set()
self._check_lock = threading.Lock()
def _checkpid(self):
if self.pid != os.getpid():
with self._check_lock:
if self.pid == os.getpid():
# another thread already did the work while we waited
# on the lock.
return
self.disconnect()
self.reset()
def make_connection(self):
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_connections += 1
return self.connection_class(**self.connection_kwargs)
def disconnect(self):
"Disconnects all connections in the pool"
all_conns = chain(self._available_connections,
self._in_use_connections)
for connection in all_conns:
connection.disconnect()
def get_connection(self, command_name, *keys, **options):
self._checkpid()
try:
connection = self._available_connections.pop()
except IndexError:
connection = self.make_connection()
self._in_use_connections.add(connection)
return connection
if __name__ == '__main__':
p = ConnectionPool(host='127.0.0.1', port=6379)
args = ('PING',)
command_name = args[0]
conn = p.get_connection(command_name, **{})
conn.send_command("GET", "hello")
print(conn.read_response())
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from io import BytesIO
class RedisError(Exception):
pass
def nativestr(x):
"""解码返回unicode 字符串"""
return x if isinstance(x, str) else x.decode('utf-8', 'replace')
def b(x):
"""将`unicode`编码成`bytes` 编码格式位 `latin-1`"""
return x.encode('latin-1') if not isinstance(x, bytes) else x
def byte_to_chr(x):
return chr(x)
SYM_STAR = b('*')
SYM_DOLLAR = b('$')
SYM_CRLF = b('\r\n')
SYM_EMPTY = b('')
def iteritems(x):
return iter(x.items())
class Socket(object):
def __init__(self, data):
self.data = data
def recv(self, length):
data = self.data[:length]
self.data = self.data[length:]
return data
class SocketBuffer(object):
def __init__(self, socket, socket_read_size):
self._sock = socket
self.socket_read_size = socket_read_size
self._buffer = BytesIO()
self.bytes_written = 0
self.bytes_read = 0
@property
def length(self):
return self.bytes_written - self.bytes_read
def _read_from_socket(self, length=None):
socket_read_size = self.socket_read_size
buf = self._buffer
buf.seek(self.bytes_written)
# 表示每次从网络中读取的数据
marker = 0
while True:
data = self._sock.recv(socket_read_size)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length
# 保证读取length字段到buf中
if length is not None and length > marker:
continue
break
def purge(self):
self._buffer.seek(0)
self._buffer.truncate()
self.bytes_read = 0
self.bytes_written = 0
def read(self, length):
length = length + 2
if length > self.length:
self._read_from_socket(length - self.length)
self._buffer.seek(self.bytes_read)
data = self._buffer.read(length)
self.bytes_read += len(data)
if self.bytes_read == self.bytes_written:
self.purge()
return data[:-2]
def readline(self):
buf = self._buffer
buf.seek(self.bytes_read)
data = buf.readline()
# 处理包结束
while not data.endswith(SYM_CRLF):
self._read_from_socket()
buf.seek(self.bytes_read)
data = buf.readline()
self.bytes_read += len(data)
if self.bytes_read == self.bytes_written:
self.purge()
return data[:-2]
class PythonParser(object):
encoding = None
def __init__(self, socket_read_size):
self.socket_read_size = socket_read_size
self._sock = None
self._buffer = None
def on_connect(self, data):
self._sock = Socket(data)
self._buffer = SocketBuffer(self._sock, self.socket_read_size)
def read_response(self):
response = self._buffer.readline()
byte, response = byte_to_chr(response[0]), response[1:]
if byte not in ('-', '+', ':', '$', '*'):
raise RedisError
# server returned an error
if byte == '-':
response = nativestr(response)
# 处理错误
return response
# single value
elif byte == '+':
pass
# int value
elif byte == ':':
response = int(response)
# bulk response
elif byte == '$':
length = int(response)
if length == -1:
return None
response = self._buffer.read(length)
# multi-bulk response
elif byte == '*':
length = int(response)
if length == -1:
return None
response = [self.read_response() for i in range(length)]
if isinstance(response, bytes) and self.encoding:
response = response.decode(self.encoding)
return response
if __name__ == '__main__':
# redis-> $6\r\nfoobar\r\n
# set -> hello 你好 777 redis-> b'*3\r\n$3\r\n777\r\n$6\r\n\xe4\xbd\xa0\xe5\xa5\xbd$5\r\nhello\r\n'
# data = b'+OK\r\n'
data = b'$6\r\nfoobar\r\n'
# data = b'*3\r\n$3\r\n777\r\n$6\r\n\xe4\xbd\xa0\xe5\xa5\xbd\r\n$5\r\nhello\r\n'
pp = PythonParser(socket_read_size=5)
pp.on_connect(data)
print(pp.read_response())
#!/usr/bin/env python
# -*- coding:utf-8 -*-
class Token(object):
def __init__(self, value):
if isinstance(value, Token):
value = value.value
self.value = value
def __repr__(self):
return self.value
def __str__(self):
return self.value
def b(x):
'''将`unicode`编码成`bytes` 编码格式位 `latin-1`'''
return x.encode('latin-1') if not isinstance(x, bytes) else x
SYM_STAR = b('*')
SYM_DOLLAR = b('$')
SYM_CRLF = b('\r\n')
SYM_EMPTY = b('')
class Connection(object):
def __init__(self, encoding='utf-8', encoding_errors='strict'):
self.encoding = encoding
self.encoding_errors = encoding_errors
self._sock = None
def pack_command(self, *args):
"""将redis命令安装redis的协议编码,返回编码后的数组,如果命令很大,返回的是编码后chunk的数组"""
output = []
command = args[0]
if ' ' in command:
args = tuple([Token(s) for s in command.split(' ')]) + args[1:]
else:
args = (Token(command),) + args[1:]
buff = SYM_EMPTY.join(
(SYM_STAR, b(str(len(args))), SYM_CRLF))
for arg in map(self.encode, args):
# 数据量特别大的时候,分成部分小的chunk
if len(buff) > 6000 or len(arg) > 6000:
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF))
output.append(buff)
output.append(arg)
buff = SYM_CRLF
else:
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, arg, SYM_CRLF))
output.append(buff)
return output
def encode(self, value):
if isinstance(value, Token):
return b(value.value)
elif isinstance(value, bytes):
return value
elif isinstance(value, int):
value = b(str(value))
elif not isinstance(value, str):
value = str(value)
if isinstance(value, str):
value = value.encode(self.encoding, self.encoding_errors)
return value
if __name__ == '__main__':
args = ('PING',)
# args = ('GET', 'hello')
# args = ('SET', 'hello', 'world')
args = ('CONFIG SET', 'timeout', 0)
# args = ('SET', 'key', '中国')
# args = ('SET', 'hello', 1)
# args = ('SET', '中国', 21.7)
# args = ('AUTH', '')
command = Connection().pack_command(*args)
print(command)
# print(type(command))
#
# import socket
#
# address = ('127.0.0.1', 6379)
# sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# sock.connect(address)
# sock.sendall(command[0])
# print(sock.recv(1024))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment