Last active
January 20, 2018 10:11
-
-
Save rsj217/3bc33adb4fed43795bea8c398362919b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
import socket | |
import sys | |
import threading | |
from itertools import chain | |
from queue import LifoQueue, Full, Empty | |
import os | |
from io import BytesIO | |
class RedisError(Exception): | |
pass | |
class ConnectionError(RedisError): | |
pass | |
class BusyLoadingError(ConnectionError): | |
pass | |
class InvalidResponse(RedisError): | |
pass | |
class ResponseError(RedisError): | |
pass | |
class AuthenticationError(RedisError): | |
pass | |
class NoScriptError(ResponseError): | |
pass | |
class ExecAbortError(ResponseError): | |
pass | |
class ReadOnlyError(ResponseError): | |
pass | |
class Token(object): | |
def __init__(self, value): | |
if isinstance(value, Token): | |
value = value.value | |
self.value = value | |
def __repr__(self): | |
return self.value | |
def __str__(self): | |
return self.value | |
def nativestr(x): | |
"""解码返回unicode 字符串""" | |
return x if isinstance(x, str) else x.decode('utf-8', 'replace') | |
def b(x): | |
"""将`unicode`编码成`bytes` 编码格式位 `latin-1`""" | |
return x.encode('latin-1') if not isinstance(x, bytes) else x | |
def byte_to_chr(x): | |
return chr(x) | |
SYM_STAR = b('*') | |
SYM_DOLLAR = b('$') | |
SYM_CRLF = b('\r\n') | |
SYM_EMPTY = b('') | |
SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." | |
def iteritems(x): | |
return iter(x.items()) | |
class BaseParser(object): | |
EXCEPTION_CLASSES = { | |
'ERR': { | |
'max number of clients reached': ConnectionError | |
}, | |
'EXCABORT': ExecAbortError, | |
'LOADING': BusyLoadingError, | |
'NOSCRIPT': NoScriptError, | |
'READONLY': ReadOnlyError | |
} | |
# todo read | |
def parse_error(self, response): | |
"Parse an error response" | |
error_code = response.split(' ')[0] | |
if error_code in self.EXCEPTION_CLASSES: | |
response = response[len(error_code) + 1:] | |
exception_class = self.EXCEPTION_CLASSES[error_code] | |
if isinstance(exception_class, dict): | |
exception_class = exception_class.get(response, ResponseError) | |
return exception_class(response) | |
return ResponseError(response) | |
class SocketBuffer(object): | |
def __init__(self, socket, socket_read_size): | |
self._sock = socket | |
self.socket_read_size = socket_read_size | |
self._buffer = BytesIO() | |
self.bytes_written = 0 | |
self.bytes_read = 0 | |
@property | |
def length(self): | |
return self.bytes_written - self.bytes_read | |
def _read_from_socket(self, length=None): | |
socket_read_size = self.socket_read_size | |
buf = self._buffer | |
buf.seek(self.bytes_written) | |
# 表示每次从网络中读取的数据 | |
marker = 0 | |
try: | |
while True: | |
data = self._sock.recv(socket_read_size) | |
if isinstance(data, bytes) and len(data) == 0: | |
raise socket.error(SERVER_CLOSED_CONNECTION_ERROR) | |
buf.write(data) | |
data_length = len(data) | |
self.bytes_written += data_length | |
marker += data_length | |
# 保证读取length字段到buf中 | |
if length is not None and length > marker: | |
continue | |
break | |
except socket.timeout: | |
raise TimeoutError("Timeout reading from socket") | |
except socket.error: | |
e = sys.exc_info()[1] | |
raise ConnectionError("Error while reading from socket: {}".format(e.args)) | |
def purge(self): | |
self._buffer.seek(0) | |
self._buffer.truncate() | |
self.bytes_read = 0 | |
self.bytes_written = 0 | |
def read(self, length): | |
length = length + 2 | |
if length > self.length: | |
self._read_from_socket(length - self.length) | |
self._buffer.seek(self.bytes_read) | |
data = self._buffer.read(length) | |
self.bytes_read += len(data) | |
if self.bytes_read == self.bytes_written: | |
self.purge() | |
return data[:-2] | |
def readline(self): | |
buf = self._buffer | |
buf.seek(self.bytes_read) | |
data = buf.readline() | |
# 处理包结束 | |
while not data.endswith(SYM_CRLF): | |
self._read_from_socket() | |
buf.seek(self.bytes_read) | |
data = buf.readline() | |
self.bytes_read += len(data) | |
if self.bytes_read == self.bytes_written: | |
self.purge() | |
return data[:-2] | |
def close(self): | |
try: | |
self.purge() | |
self._buffer.close() | |
except: | |
pass | |
self._buffer = None | |
self._sock = None | |
class PythonParser(BaseParser): | |
encoding = None | |
def __init__(self, socket_read_size): | |
self.socket_read_size = socket_read_size | |
self._sock = None | |
self._buffer = None | |
def __del__(self): | |
try: | |
self.on_disconnect() | |
except Exception: | |
pass | |
def on_connect(self, connection): | |
self._sock = connection._sock | |
self._buffer = SocketBuffer(self._sock, self.socket_read_size) | |
if connection.decode_responses: | |
self.encoding = connection.encoding | |
def read_response(self): | |
response = self._buffer.readline() | |
if not response: | |
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) | |
byte, response = byte_to_chr(response[0]), response[1:] | |
if byte not in ('-', '+', ':', '$', '*'): | |
raise InvalidResponse("Protocol Error: %s, %s" % | |
(str(byte), str(response))) | |
# server returned an error | |
if byte == '-': | |
response = nativestr(response) | |
error = self.parse_error(response) | |
# if the error is a ConnectionError, raise immediately so the user | |
# is notified | |
if isinstance(error, ConnectionError): | |
raise error | |
# otherwise, we're dealing with a ResponseError that might belong | |
# inside a pipeline response. the connection's read_response() | |
# and/or the pipeline's execute() will raise this error if | |
# necessary, so just return the exception instance here. | |
return error | |
# single value | |
elif byte == '+': | |
pass | |
# int value | |
elif byte == ':': | |
response = int(response) | |
# bulk response | |
elif byte == '$': | |
length = int(response) | |
if length == -1: | |
return None | |
response = self._buffer.read(length) | |
# multi-bulk response | |
elif byte == '*': | |
length = int(response) | |
if length == -1: | |
return None | |
response = [self.read_response() for i in range(length)] | |
if isinstance(response, bytes) and self.encoding: | |
response = response.decode(self.encoding) | |
return response | |
def on_disconnect(self): | |
if self._sock is not None: | |
self._sock.close() | |
self._sock = None | |
if self._buffer is not None: | |
self._buffer.close() | |
self._buffer = None | |
self.encoding = None | |
DefaultParser = PythonParser | |
class Connection(object): | |
description_format = "Connection<host=%(host)s,port=%(port)s,db=%(db)s>" | |
def __init__(self, host='127.0.0.1', port=6379, db=0, passwd=None, | |
socket_timeout=None, socket_connect_timeout=None, | |
socket_keepalive=False, socket_keepalive_options=None, | |
encoding='utf-8', encoding_errors='strict', decode_responses=False, | |
parser_class=DefaultParser, socket_read_size=65536): | |
self.host = host | |
self.port = port | |
self.passwd = passwd | |
self.db = db | |
self.encoding = encoding | |
self.encoding_errors = encoding_errors | |
self.decode_responses = decode_responses | |
self.socket_timeout = socket_timeout | |
self.socket_connect_timeout = socket_connect_timeout or socket_timeout | |
self.socket_keepalive = socket_keepalive | |
self.socket_keepalive_options = socket_keepalive_options or {} | |
self._sock = None | |
self._parser = parser_class(socket_read_size=socket_read_size) | |
self._description_args = { | |
'host': self.host, | |
'port': self.port, | |
'db': self.db, | |
} | |
def __repr__(self): | |
return self.description_format % self._description_args | |
def __del__(self): | |
try: | |
self.disconnect() | |
except Exception: | |
pass | |
def encode(self, value): | |
if isinstance(value, Token): | |
return b(value.value) | |
elif isinstance(value, bytes): | |
return value | |
elif isinstance(value, int): | |
value = b(str(value)) | |
elif not isinstance(value, str): | |
value = str(value) | |
if isinstance(value, str): | |
value = value.encode(self.encoding, self.encoding_errors) | |
return value | |
def pack_command(self, *args): | |
"""将redis命令安装redis的协议编码,返回编码后的数组,如果命令很大,返回的是编码后chunk的数组""" | |
output = [] | |
command = args[0] | |
if ' ' in command: | |
args = tuple([Token(s) for s in command.split(' ')]) + args[1:] | |
else: | |
args = (Token(command),) + args[1:] | |
buff = SYM_EMPTY.join( | |
(SYM_STAR, b(str(len(args))), SYM_CRLF)) | |
for arg in map(self.encode, args): | |
# 数据量特别大的时候,分成部分小的chunk | |
if len(buff) > 6000 or len(arg) > 6000: | |
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF)) | |
output.append(buff) | |
output.append(arg) | |
buff = SYM_CRLF | |
else: | |
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, arg, SYM_CRLF)) | |
output.append(buff) | |
return output | |
def send_command(self, *args): | |
self.send_packed_command(self.pack_command(*args)) | |
def _error_message(self, exception): | |
if len(exception.args) == 1: | |
return "Error connecting to {}:{}. {}.".format(self.host, self.port, exception.args[0]) | |
return "Error {} connecting to {}:{} {}.".format(exception.args[0], self.host, self.port, exception.args[1]) | |
def _connect(self): | |
""" 创建 socket 连接,并返回socket对象 | |
""" | |
err = None | |
for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM): | |
family, socktype, proto, canonname, socket_address = res | |
try: | |
# 创建 tcp socket 对象 | |
sock = socket.socket(family, socktype, proto) | |
# TCP_NODELAY | |
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
if self.socket_keepalive: | |
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | |
for k, v in iteritems(self.socket_keepalive_options): | |
sock.setsockopt(socket.SOL_TCP, k, v) | |
# 设置创建连接的超时时间 | |
sock.settimeout(self.socket_connect_timeout) | |
sock.connect(socket_address) | |
# 设置连接之后的超时时间 | |
sock.settimeout(self.socket_timeout) | |
return sock | |
except socket.error as _: | |
err = _ | |
if sock is not None: | |
sock.close() | |
if err is not None: | |
raise err | |
raise socket.error("socket.getaddrinfo returned an empty list") | |
def read_response(self): | |
try: | |
response = self._parser.read_response() | |
except: | |
self.disconnect() | |
raise | |
if isinstance(response, ResponseError): | |
raise response | |
return response | |
def on_connect(self): | |
""" socket连接创建之后,进行redis的认证和选择数据库,开始发送数据. | |
""" | |
self._parser.on_connect(self) | |
# if a password is specified, authenticate | |
if self.passwd: | |
self.send_command('AUTH', self.passwd) | |
response = self.read_response() | |
if nativestr(response) != 'OK': | |
raise AuthenticationError('Invalid Password') | |
if self.db: | |
self.send_command('SELECT', self.db) | |
if nativestr(self.read_response()) != 'OK': | |
raise ConnectionError('Invalid Database') | |
def connect(self): | |
if self._sock: | |
return | |
try: | |
sock = self._connect() | |
except socket.error: | |
e = sys.exc_info()[1] | |
raise ConnectionError(self._error_message(e)) | |
self._sock = sock | |
try: | |
self.on_connect() | |
except RedisError: | |
self.disconnect() | |
raise | |
# for callback in self._connect_callbacks: | |
# callback(self) | |
def disconnect(self): | |
self._parser.on_disconnect() | |
if self._sock is None: | |
return | |
try: | |
self._sock.shutdown(socket.SHUT_RDWR) | |
self._sock.close() | |
except socket.error: | |
pass | |
self._sock = None | |
def send_packed_command(self, command): | |
"""将编码后的redis命令发送到redis服务器""" | |
if not self._sock: | |
self.connect() | |
try: | |
if isinstance(command, str): | |
command = [command] | |
for item in command: | |
self._sock.sendall(item) | |
except socket.timeout: | |
self.disconnect() | |
raise TimeoutError('Timeout writing to socket') | |
except socket.error: | |
e = sys.exc_info()[1] | |
self.disconnect() | |
if len(e.args) == 1: | |
errno, errmsg = 'UNKNOWN', e.args[0] | |
else: | |
errno = e.args[0] | |
errmsg = e.args[1] | |
raise ConnectionError("Error {} while writing to socket. {}.".format(errno, errmsg)) | |
except: | |
self.disconnect() | |
raise | |
class ConnectionPool(object): | |
def __init__(self, connection_class=Connection, max_connections=None, **connection_kwargs): | |
max_connections = max_connections or 2 ** 31 | |
if not isinstance(max_connections, int) or max_connections < 0: | |
raise ValueError('"max_connections" must be a positive integer') | |
self.connection_class = connection_class | |
self.connection_kwargs = connection_kwargs | |
self.max_connections = max_connections | |
self.reset() | |
def reset(self): | |
self.pid = os.getpid() | |
self._created_connections = 0 | |
self._available_connections = [] | |
self._in_use_connections = set() | |
self._check_lock = threading.Lock() | |
def _checkpid(self): | |
if self.pid != os.getpid(): | |
with self._check_lock: | |
if self.pid == os.getpid(): | |
# another thread already did the work while we waited | |
# on the lock. | |
return | |
self.disconnect() | |
self.reset() | |
def make_connection(self): | |
if self._created_connections >= self.max_connections: | |
raise ConnectionError("Too many connections") | |
self._created_connections += 1 | |
return self.connection_class(**self.connection_kwargs) | |
def disconnect(self): | |
"Disconnects all connections in the pool" | |
all_conns = chain(self._available_connections, | |
self._in_use_connections) | |
for connection in all_conns: | |
connection.disconnect() | |
def get_connection(self, command_name, *keys, **options): | |
self._checkpid() | |
try: | |
connection = self._available_connections.pop() | |
except IndexError: | |
connection = self.make_connection() | |
self._in_use_connections.add(connection) | |
return connection | |
class BlockingConnectionPool(ConnectionPool): | |
def __init__(self, max_connections=50, timeout=20, | |
connection_class=Connection, queue_class=LifoQueue, | |
**connection_kwargs): | |
self.queue_class = queue_class | |
self.timeout = timeout | |
super(BlockingConnectionPool, self).__init__( | |
connection_class=connection_class, | |
max_connections=max_connections, | |
**connection_kwargs) | |
def reset(self): | |
self.pid = os.getpid() | |
self._check_lock = threading.Lock() | |
# Create and fill up a thread safe queue with ``None`` values. | |
self.pool = self.queue_class(self.max_connections) | |
while True: | |
try: | |
self.pool.put_nowait(None) | |
except Full: | |
break | |
self._connections = [] | |
def make_connection(self): | |
connection = self.connection_class(**self.connection_kwargs) | |
self._connections.append(connection) | |
return connection | |
def get_connection(self, command_name, *keys, **options): | |
self._checkpid() | |
connection = None | |
try: | |
connection = self.pool.get(block=True, timeout=self.timeout) | |
except Empty: | |
raise ConnectionError("No connection available.") | |
if connection is None: | |
connection = self.make_connection() | |
return connection | |
def release(self, connection): | |
self._checkpid() | |
if connection.pid != self.pid: | |
return | |
try: | |
self.pool.put_nowait(connection) | |
except Full: | |
pass | |
def disconnect(self): | |
"Disconnects all connections in the pool." | |
for connection in self._connections: | |
connection.disconnect() | |
if __name__ == '__main__': | |
p = BlockingConnectionPool(host='127.0.0.1', port=6379, max_connections=3) | |
# p = ConnectionPool(host='127.0.0.1', port=6379, max_connections=3) | |
args = ('PING',) | |
command_name = args[0] | |
# conn = p.get_connection(command_name, **{}) | |
# conn.send_command("GET", "hello") | |
# print(conn.read_response()) | |
import time | |
class Query(threading.Thread): | |
def run(self): | |
conn = p.get_connection(command_name, **{}) | |
print('thread start ', self.getName(), conn.send_command("SET", self.getName(), 1)) | |
time.sleep(1) | |
p.release(conn) | |
print('thread end') | |
threads = [Query() for i in range(5)] | |
[t.start() for t in threads] | |
print(len(p._connections)) | |
print('main end') | |
while True: | |
print(len(p._connections)) | |
time.sleep(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
import socket | |
import sys | |
import threading | |
from itertools import chain | |
import os | |
import warnings | |
from io import BytesIO | |
class RedisError(Exception): | |
pass | |
class ConnectionError(RedisError): | |
pass | |
class BusyLoadingError(ConnectionError): | |
pass | |
class InvalidResponse(RedisError): | |
pass | |
class ResponseError(RedisError): | |
pass | |
class AuthenticationError(RedisError): | |
pass | |
class NoScriptError(ResponseError): | |
pass | |
class ExecAbortError(ResponseError): | |
pass | |
class ReadOnlyError(ResponseError): | |
pass | |
class Token(object): | |
def __init__(self, value): | |
if isinstance(value, Token): | |
value = value.value | |
self.value = value | |
def __repr__(self): | |
return self.value | |
def __str__(self): | |
return self.value | |
def nativestr(x): | |
"""解码返回unicode 字符串""" | |
return x if isinstance(x, str) else x.decode('utf-8', 'replace') | |
def b(x): | |
"""将`unicode`编码成`bytes` 编码格式位 `latin-1`""" | |
return x.encode('latin-1') if not isinstance(x, bytes) else x | |
def byte_to_chr(x): | |
return chr(x) | |
SYM_STAR = b('*') | |
SYM_DOLLAR = b('$') | |
SYM_CRLF = b('\r\n') | |
SYM_EMPTY = b('') | |
SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." | |
def iteritems(x): | |
return iter(x.items()) | |
def string_keys_to_dict(key_string, callback): | |
return dict.fromkeys(key_string.split(), callback) | |
def dict_merge(*dicts): | |
merged = {} | |
for d in dicts: | |
merged.update(d) | |
return merged | |
class BaseParser(object): | |
EXCEPTION_CLASSES = { | |
'ERR': { | |
'max number of clients reached': ConnectionError | |
}, | |
'EXCABORT': ExecAbortError, | |
'LOADING': BusyLoadingError, | |
'NOSCRIPT': NoScriptError, | |
'READONLY': ReadOnlyError | |
} | |
# todo read | |
def parse_error(self, response): | |
"Parse an error response" | |
error_code = response.split(' ')[0] | |
if error_code in self.EXCEPTION_CLASSES: | |
response = response[len(error_code) + 1:] | |
exception_class = self.EXCEPTION_CLASSES[error_code] | |
if isinstance(exception_class, dict): | |
exception_class = exception_class.get(response, ResponseError) | |
return exception_class(response) | |
return ResponseError(response) | |
class SocketBuffer(object): | |
def __init__(self, socket, socket_read_size): | |
self._sock = socket | |
self.socket_read_size = socket_read_size | |
self._buffer = BytesIO() | |
self.bytes_written = 0 | |
self.bytes_read = 0 | |
@property | |
def length(self): | |
return self.bytes_written - self.bytes_read | |
def _read_from_socket(self, length=None): | |
socket_read_size = self.socket_read_size | |
buf = self._buffer | |
buf.seek(self.bytes_written) | |
# 表示每次从网络中读取的数据 | |
marker = 0 | |
try: | |
while True: | |
data = self._sock.recv(socket_read_size) | |
if isinstance(data, bytes) and len(data) == 0: | |
raise socket.error(SERVER_CLOSED_CONNECTION_ERROR) | |
buf.write(data) | |
data_length = len(data) | |
self.bytes_written += data_length | |
marker += data_length | |
# 保证读取length字段到buf中 | |
if length is not None and length > marker: | |
continue | |
break | |
except socket.timeout: | |
raise TimeoutError("Timeout reading from socket") | |
except socket.error: | |
e = sys.exc_info()[1] | |
raise ConnectionError("Error while reading from socket: {}".format(e.args)) | |
def purge(self): | |
self._buffer.seek(0) | |
self._buffer.truncate() | |
self.bytes_read = 0 | |
self.bytes_written = 0 | |
def read(self, length): | |
length = length + 2 | |
if length > self.length: | |
self._read_from_socket(length - self.length) | |
self._buffer.seek(self.bytes_read) | |
data = self._buffer.read(length) | |
self.bytes_read += len(data) | |
if self.bytes_read == self.bytes_written: | |
self.purge() | |
return data[:-2] | |
def readline(self): | |
buf = self._buffer | |
buf.seek(self.bytes_read) | |
data = buf.readline() | |
# 处理包结束 | |
while not data.endswith(SYM_CRLF): | |
self._read_from_socket() | |
buf.seek(self.bytes_read) | |
data = buf.readline() | |
self.bytes_read += len(data) | |
if self.bytes_read == self.bytes_written: | |
self.purge() | |
return data[:-2] | |
def close(self): | |
try: | |
self.purge() | |
self._buffer.close() | |
except: | |
pass | |
self._buffer = None | |
self._sock = None | |
class PythonParser(BaseParser): | |
encoding = None | |
def __init__(self, socket_read_size): | |
self.socket_read_size = socket_read_size | |
self._sock = None | |
self._buffer = None | |
def __del__(self): | |
try: | |
self.on_disconnect() | |
except Exception: | |
pass | |
def on_connect(self, connection): | |
self._sock = connection._sock | |
self._buffer = SocketBuffer(self._sock, self.socket_read_size) | |
if connection.decode_responses: | |
self.encoding = connection.encoding | |
def read_response(self): | |
response = self._buffer.readline() | |
if not response: | |
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) | |
byte, response = byte_to_chr(response[0]), response[1:] | |
if byte not in ('-', '+', ':', '$', '*'): | |
raise InvalidResponse("Protocol Error: %s, %s" % | |
(str(byte), str(response))) | |
# server returned an error | |
if byte == '-': | |
response = nativestr(response) | |
error = self.parse_error(response) | |
# if the error is a ConnectionError, raise immediately so the user | |
# is notified | |
if isinstance(error, ConnectionError): | |
raise error | |
# otherwise, we're dealing with a ResponseError that might belong | |
# inside a pipeline response. the connection's read_response() | |
# and/or the pipeline's execute() will raise this error if | |
# necessary, so just return the exception instance here. | |
return error | |
# single value | |
elif byte == '+': | |
pass | |
# int value | |
elif byte == ':': | |
response = int(response) | |
# bulk response | |
elif byte == '$': | |
length = int(response) | |
if length == -1: | |
return None | |
response = self._buffer.read(length) | |
# multi-bulk response | |
elif byte == '*': | |
length = int(response) | |
if length == -1: | |
return None | |
response = [self.read_response() for i in range(length)] | |
if isinstance(response, bytes) and self.encoding: | |
response = response.decode(self.encoding) | |
return response | |
def on_disconnect(self): | |
if self._sock is not None: | |
self._sock.close() | |
self._sock = None | |
if self._buffer is not None: | |
self._buffer.close() | |
self._buffer = None | |
self.encoding = None | |
DefaultParser = PythonParser | |
class Connection(object): | |
def __init__(self, host='127.0.0.1', port=6379, db=0, passwd=None, | |
socket_timeout=None, socket_connect_timeout=None, | |
socket_keepalive=False, socket_keepalive_options=None, retry_on_timeout=False, | |
encoding='utf-8', encoding_errors='strict', decode_responses=False, | |
parser_class=DefaultParser, socket_read_size=65536): | |
self.pid = os.getpid() | |
self.host = host | |
self.port = int(port) | |
self.passwd = passwd | |
self.db = db | |
self.encoding = encoding | |
self.encoding_errors = encoding_errors | |
self.decode_responses = decode_responses | |
self.socket_timeout = socket_timeout | |
self.socket_connect_timeout = socket_connect_timeout or socket_timeout | |
self.socket_keepalive = socket_keepalive | |
self.socket_keepalive_options = socket_keepalive_options or {} | |
self._sock = None | |
self._parser = parser_class(socket_read_size=socket_read_size) | |
def __repr__(self): | |
return self.description_format % self._description_args | |
def __del__(self): | |
try: | |
self.disconnect() | |
except Exception: | |
pass | |
def encode(self, value): | |
if isinstance(value, Token): | |
return b(value.value) | |
elif isinstance(value, bytes): | |
return value | |
elif isinstance(value, int): | |
value = b(str(value)) | |
elif not isinstance(value, str): | |
value = str(value) | |
if isinstance(value, str): | |
value = value.encode(self.encoding, self.encoding_errors) | |
return value | |
def pack_command(self, *args): | |
"""将redis命令安装redis的协议编码,返回编码后的数组,如果命令很大,返回的是编码后chunk的数组""" | |
output = [] | |
command = args[0] | |
if ' ' in command: | |
args = tuple([Token(s) for s in command.split(' ')]) + args[1:] | |
else: | |
args = (Token(command),) + args[1:] | |
buff = SYM_EMPTY.join( | |
(SYM_STAR, b(str(len(args))), SYM_CRLF)) | |
for arg in map(self.encode, args): | |
# 数据量特别大的时候,分成部分小的chunk | |
if len(buff) > 6000 or len(arg) > 6000: | |
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF)) | |
output.append(buff) | |
output.append(arg) | |
buff = SYM_CRLF | |
else: | |
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, arg, SYM_CRLF)) | |
output.append(buff) | |
return output | |
def send_command(self, *args): | |
self.send_packed_command(self.pack_command(*args)) | |
def _error_message(self, exception): | |
if len(exception.args) == 1: | |
return "Error connecting to {}:{}. {}.".format(self.host, self.port, exception.args[0]) | |
return "Error {} connecting to {}:{} {}.".format(exception.args[0], self.host, self.port, exception.args[1]) | |
def _connect(self): | |
""" 创建 socket 连接,并返回socket对象 | |
""" | |
err = None | |
for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM): | |
family, socktype, proto, canonname, socket_address = res | |
try: | |
# 创建 tcp socket 对象 | |
sock = socket.socket(family, socktype, proto) | |
# TCP_NODELAY | |
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
if self.socket_keepalive: | |
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | |
for k, v in iteritems(self.socket_keepalive_options): | |
sock.setsockopt(socket.SOL_TCP, k, v) | |
# 设置创建连接的超时时间 | |
sock.settimeout(self.socket_connect_timeout) | |
sock.connect(socket_address) | |
# 设置连接之后的超时时间 | |
sock.settimeout(self.socket_timeout) | |
return sock | |
except socket.error as _: | |
err = _ | |
if sock is not None: | |
sock.close() | |
if err is not None: | |
raise err | |
raise socket.error("socket.getaddrinfo returned an empty list") | |
def read_response(self): | |
try: | |
response = self._parser.read_response() | |
except: | |
self.disconnect() | |
raise | |
if isinstance(response, ResponseError): | |
raise response | |
return response | |
def on_connect(self): | |
""" socket连接创建之后,进行redis的认证和选择数据库,开始发送数据. | |
""" | |
self._parser.on_connect(self) | |
# if a password is specified, authenticate | |
if self.passwd: | |
self.send_command('AUTH', self.passwd) | |
response = self.read_response() | |
if nativestr(response) != 'OK': | |
raise AuthenticationError('Invalid Password') | |
if self.db: | |
self.send_command('SELECT', self.db) | |
if nativestr(self.read_response()) != 'OK': | |
raise ConnectionError('Invalid Database') | |
def connect(self): | |
if self._sock: | |
return | |
try: | |
sock = self._connect() | |
except socket.error: | |
e = sys.exc_info()[1] | |
raise ConnectionError(self._error_message(e)) | |
self._sock = sock | |
try: | |
self.on_connect() | |
except RedisError: | |
self.disconnect() | |
raise | |
# for callback in self._connect_callbacks: | |
# callback(self) | |
def disconnect(self): | |
self._parser.on_disconnect() | |
if self._sock is None: | |
return | |
try: | |
self._sock.shutdown(socket.SHUT_RDWR) | |
self._sock.close() | |
except socket.error: | |
pass | |
self._sock = None | |
def send_packed_command(self, command): | |
"""将编码后的redis命令发送到redis服务器""" | |
if not self._sock: | |
self.connect() | |
try: | |
if isinstance(command, str): | |
command = [command] | |
for item in command: | |
self._sock.sendall(item) | |
except socket.timeout: | |
self.disconnect() | |
raise TimeoutError('Timeout writing to socket') | |
except socket.error: | |
e = sys.exc_info()[1] | |
self.disconnect() | |
if len(e.args) == 1: | |
errno, errmsg = 'UNKNOWN', e.args[0] | |
else: | |
errno = e.args[0] | |
errmsg = e.args[1] | |
raise ConnectionError("Error {} while writing to socket. {}.".format(errno, errmsg)) | |
except: | |
self.disconnect() | |
raise | |
class ConnectionPool(object): | |
def __init__(self, connection_class=Connection, max_connections=None, **connection_kwargs): | |
max_connections = max_connections or 2 ** 31 | |
if not isinstance(max_connections, int) or max_connections < 0: | |
raise ValueError('"max_connections" must be a positive integer') | |
self.connection_class = connection_class | |
self.connection_kwargs = connection_kwargs | |
self.max_connections = max_connections | |
self.reset() | |
def reset(self): | |
self.pid = os.getpid() | |
self._created_connections = 0 | |
self._available_connections = [] | |
self._in_use_connections = set() | |
self._check_lock = threading.Lock() | |
def _checkpid(self): | |
if self.pid != os.getpid(): | |
with self._check_lock: | |
if self.pid == os.getpid(): | |
return | |
self.disconnect() | |
self.reset() | |
def make_connection(self): | |
if self._created_connections >= self.max_connections: | |
raise ConnectionError("Too many connections") | |
self._created_connections += 1 | |
return self.connection_class(**self.connection_kwargs) | |
def get_connection(self, command_name, *keys, **options): | |
self._checkpid() | |
try: | |
connection = self._available_connections.pop() | |
except IndexError: | |
connection = self.make_connection() | |
self._in_use_connections.add(connection) | |
return connection | |
def release(self, connection): | |
self._checkpid() | |
if connection.pid != self.pid: | |
return | |
self._in_use_connections.remove(connection) | |
self._available_connections.append(connection) | |
def disconnect(self): | |
all_conns = chain(self._available_connections, self._in_use_connections) | |
for connection in all_conns: | |
connection.disconnect() | |
class StrictRedis(object): | |
RESPONSE_CALLBACKS = dict_merge( | |
{ | |
'PING': lambda r: nativestr(r) == 'PONG', | |
} | |
) | |
def __init__(self, host='localhost', port=6379, db=0, passwd=None, | |
socket_timeout=None, socket_connect_timeout=None, | |
socket_keepalive=None, socket_keepalive_options=None, | |
connection_pool=None, encoding='utf-8', encoding_errors='strict', | |
charset=None, errors=None, decode_responses=False, retry_on_timeout=False, max_connections=None): | |
if not connection_pool: | |
if charset is not None: | |
warnings.warn(DeprecationWarning( | |
'"charset" is deprecated. Use "encoding" instead')) | |
encoding = charset | |
if errors is not None: | |
warnings.warn(DeprecationWarning( | |
'"errors" is deprecated. Use "encoding_errors" instead')) | |
encoding_errors = errors | |
kwargs = { | |
'db': db, | |
'passwd': passwd, | |
'socket_timeout': socket_timeout, | |
'encoding': encoding, | |
'encoding_errors': encoding_errors, | |
'decode_responses': decode_responses, | |
'retry_on_timeout': retry_on_timeout, | |
'max_connections': max_connections | |
} | |
connection_pool = ConnectionPool(**kwargs) | |
self.connection_pool = connection_pool | |
self.response_callbacks = self.__class__.RESPONSE_CALLBACKS.copy() | |
def parse_response(self, connection, command_name, **options): | |
response = connection.read_response() | |
if command_name in self.response_callbacks: | |
return self.response_callbacks[command_name](response, **options) | |
return response | |
def execute_command(self, *args, **options): | |
pool = self.connection_pool | |
command_name = args[0] | |
connection = pool.get_connection(command_name, **options) | |
try: | |
connection.send_command(*args) | |
return self.parse_response(connection, command_name, **options) | |
except (ConnectionError, TimeoutError) as e: | |
connection.disconnect() | |
if not connection.retry_on_timeout and isinstance(e, TimeoutError): | |
raise | |
connection.send_command(*args) | |
return self.parse_response(connection, command_name, **options) | |
finally: | |
pool.release(connection) | |
def ping(self, *args, **options): | |
return self.execute_command("PING") | |
def get(self, *args, **options): | |
return self.execute_command(*args, **options) | |
if __name__ == '__main__': | |
rc = StrictRedis() | |
ret = rc.execute_command("PING") | |
print(ret) | |
print(rc.execute_command("GET", "hello")) | |
print(rc.ping()) | |
print(rc.get("GET", "hello")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
import socket | |
import sys | |
import threading | |
from itertools import chain | |
import os | |
from io import BytesIO | |
class RedisError(Exception): | |
pass | |
class ConnectionError(RedisError): | |
pass | |
class BusyLoadingError(ConnectionError): | |
pass | |
class InvalidResponse(RedisError): | |
pass | |
class ResponseError(RedisError): | |
pass | |
class AuthenticationError(RedisError): | |
pass | |
class NoScriptError(ResponseError): | |
pass | |
class ExecAbortError(ResponseError): | |
pass | |
class ReadOnlyError(ResponseError): | |
pass | |
class Token(object): | |
def __init__(self, value): | |
if isinstance(value, Token): | |
value = value.value | |
self.value = value | |
def __repr__(self): | |
return self.value | |
def __str__(self): | |
return self.value | |
def nativestr(x): | |
"""解码返回unicode 字符串""" | |
return x if isinstance(x, str) else x.decode('utf-8', 'replace') | |
def b(x): | |
"""将`unicode`编码成`bytes` 编码格式位 `latin-1`""" | |
return x.encode('latin-1') if not isinstance(x, bytes) else x | |
def byte_to_chr(x): | |
return chr(x) | |
SYM_STAR = b('*') | |
SYM_DOLLAR = b('$') | |
SYM_CRLF = b('\r\n') | |
SYM_EMPTY = b('') | |
SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." | |
def iteritems(x): | |
return iter(x.items()) | |
class BaseParser(object): | |
EXCEPTION_CLASSES = { | |
'ERR': { | |
'max number of clients reached': ConnectionError | |
}, | |
'EXCABORT': ExecAbortError, | |
'LOADING': BusyLoadingError, | |
'NOSCRIPT': NoScriptError, | |
'READONLY': ReadOnlyError | |
} | |
# todo read | |
def parse_error(self, response): | |
"Parse an error response" | |
error_code = response.split(' ')[0] | |
if error_code in self.EXCEPTION_CLASSES: | |
response = response[len(error_code) + 1:] | |
exception_class = self.EXCEPTION_CLASSES[error_code] | |
if isinstance(exception_class, dict): | |
exception_class = exception_class.get(response, ResponseError) | |
return exception_class(response) | |
return ResponseError(response) | |
class SocketBuffer(object): | |
def __init__(self, socket, socket_read_size): | |
self._sock = socket | |
self.socket_read_size = socket_read_size | |
self._buffer = BytesIO() | |
self.bytes_written = 0 | |
self.bytes_read = 0 | |
@property | |
def length(self): | |
return self.bytes_written - self.bytes_read | |
def _read_from_socket(self, length=None): | |
socket_read_size = self.socket_read_size | |
buf = self._buffer | |
buf.seek(self.bytes_written) | |
# 表示每次从网络中读取的数据 | |
marker = 0 | |
try: | |
while True: | |
data = self._sock.recv(socket_read_size) | |
if isinstance(data, bytes) and len(data) == 0: | |
raise socket.error(SERVER_CLOSED_CONNECTION_ERROR) | |
buf.write(data) | |
data_length = len(data) | |
self.bytes_written += data_length | |
marker += data_length | |
# 保证读取length字段到buf中 | |
if length is not None and length > marker: | |
continue | |
break | |
except socket.timeout: | |
raise TimeoutError("Timeout reading from socket") | |
except socket.error: | |
e = sys.exc_info()[1] | |
raise ConnectionError("Error while reading from socket: {}".format(e.args)) | |
def purge(self): | |
self._buffer.seek(0) | |
self._buffer.truncate() | |
self.bytes_read = 0 | |
self.bytes_written = 0 | |
def read(self, length): | |
length = length + 2 | |
if length > self.length: | |
self._read_from_socket(length - self.length) | |
self._buffer.seek(self.bytes_read) | |
data = self._buffer.read(length) | |
self.bytes_read += len(data) | |
if self.bytes_read == self.bytes_written: | |
self.purge() | |
return data[:-2] | |
def readline(self): | |
buf = self._buffer | |
buf.seek(self.bytes_read) | |
data = buf.readline() | |
# 处理包结束 | |
while not data.endswith(SYM_CRLF): | |
self._read_from_socket() | |
buf.seek(self.bytes_read) | |
data = buf.readline() | |
self.bytes_read += len(data) | |
if self.bytes_read == self.bytes_written: | |
self.purge() | |
return data[:-2] | |
def close(self): | |
try: | |
self.purge() | |
self._buffer.close() | |
except: | |
pass | |
self._buffer = None | |
self._sock = None | |
class PythonParser(BaseParser): | |
encoding = None | |
def __init__(self, socket_read_size): | |
self.socket_read_size = socket_read_size | |
self._sock = None | |
self._buffer = None | |
def __del__(self): | |
try: | |
self.on_disconnect() | |
except Exception: | |
pass | |
def on_connect(self, connection): | |
self._sock = connection._sock | |
self._buffer = SocketBuffer(self._sock, self.socket_read_size) | |
if connection.decode_responses: | |
self.encoding = connection.encoding | |
def read_response(self): | |
response = self._buffer.readline() | |
if not response: | |
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) | |
byte, response = byte_to_chr(response[0]), response[1:] | |
if byte not in ('-', '+', ':', '$', '*'): | |
raise InvalidResponse("Protocol Error: %s, %s" % | |
(str(byte), str(response))) | |
# server returned an error | |
if byte == '-': | |
response = nativestr(response) | |
error = self.parse_error(response) | |
# if the error is a ConnectionError, raise immediately so the user | |
# is notified | |
if isinstance(error, ConnectionError): | |
raise error | |
# otherwise, we're dealing with a ResponseError that might belong | |
# inside a pipeline response. the connection's read_response() | |
# and/or the pipeline's execute() will raise this error if | |
# necessary, so just return the exception instance here. | |
return error | |
# single value | |
elif byte == '+': | |
pass | |
# int value | |
elif byte == ':': | |
response = int(response) | |
# bulk response | |
elif byte == '$': | |
length = int(response) | |
if length == -1: | |
return None | |
response = self._buffer.read(length) | |
# multi-bulk response | |
elif byte == '*': | |
length = int(response) | |
if length == -1: | |
return None | |
response = [self.read_response() for i in range(length)] | |
if isinstance(response, bytes) and self.encoding: | |
response = response.decode(self.encoding) | |
return response | |
def on_disconnect(self): | |
if self._sock is not None: | |
self._sock.close() | |
self._sock = None | |
if self._buffer is not None: | |
self._buffer.close() | |
self._buffer = None | |
self.encoding = None | |
DefaultParser = PythonParser | |
class Connection(object): | |
def __init__(self, host='127.0.0.1', port=6379, db=0, passwd=None, | |
socket_timeout=None, socket_connect_timeout=None, | |
socket_keepalive=False, socket_keepalive_options=None, | |
encoding='utf-8', encoding_errors='strict', decode_responses=False, | |
parser_class=DefaultParser, socket_read_size=65536): | |
self.host = host | |
self.port = port | |
self.passwd = passwd | |
self.db = db | |
self.encoding = encoding | |
self.encoding_errors = encoding_errors | |
self.decode_responses = decode_responses | |
self.socket_timeout = socket_timeout | |
self.socket_connect_timeout = socket_connect_timeout or socket_timeout | |
self.socket_keepalive = socket_keepalive | |
self.socket_keepalive_options = socket_keepalive_options or {} | |
self._sock = None | |
self._parser = parser_class(socket_read_size=socket_read_size) | |
def __repr__(self): | |
return self.description_format % self._description_args | |
def __del__(self): | |
try: | |
self.disconnect() | |
except Exception: | |
pass | |
def encode(self, value): | |
if isinstance(value, Token): | |
return b(value.value) | |
elif isinstance(value, bytes): | |
return value | |
elif isinstance(value, int): | |
value = b(str(value)) | |
elif not isinstance(value, str): | |
value = str(value) | |
if isinstance(value, str): | |
value = value.encode(self.encoding, self.encoding_errors) | |
return value | |
def pack_command(self, *args): | |
"""将redis命令安装redis的协议编码,返回编码后的数组,如果命令很大,返回的是编码后chunk的数组""" | |
output = [] | |
command = args[0] | |
if ' ' in command: | |
args = tuple([Token(s) for s in command.split(' ')]) + args[1:] | |
else: | |
args = (Token(command),) + args[1:] | |
buff = SYM_EMPTY.join( | |
(SYM_STAR, b(str(len(args))), SYM_CRLF)) | |
for arg in map(self.encode, args): | |
# 数据量特别大的时候,分成部分小的chunk | |
if len(buff) > 6000 or len(arg) > 6000: | |
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF)) | |
output.append(buff) | |
output.append(arg) | |
buff = SYM_CRLF | |
else: | |
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, arg, SYM_CRLF)) | |
output.append(buff) | |
return output | |
def send_command(self, *args): | |
self.send_packed_command(self.pack_command(*args)) | |
def _error_message(self, exception): | |
if len(exception.args) == 1: | |
return "Error connecting to {}:{}. {}.".format(self.host, self.port, exception.args[0]) | |
return "Error {} connecting to {}:{} {}.".format(exception.args[0], self.host, self.port, exception.args[1]) | |
def _connect(self): | |
""" 创建 socket 连接,并返回socket对象 | |
""" | |
err = None | |
for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM): | |
family, socktype, proto, canonname, socket_address = res | |
try: | |
# 创建 tcp socket 对象 | |
sock = socket.socket(family, socktype, proto) | |
# TCP_NODELAY | |
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
if self.socket_keepalive: | |
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) | |
for k, v in iteritems(self.socket_keepalive_options): | |
sock.setsockopt(socket.SOL_TCP, k, v) | |
# 设置创建连接的超时时间 | |
sock.settimeout(self.socket_connect_timeout) | |
sock.connect(socket_address) | |
# 设置连接之后的超时时间 | |
sock.settimeout(self.socket_timeout) | |
return sock | |
except socket.error as _: | |
err = _ | |
if sock is not None: | |
sock.close() | |
if err is not None: | |
raise err | |
raise socket.error("socket.getaddrinfo returned an empty list") | |
def read_response(self): | |
try: | |
response = self._parser.read_response() | |
except: | |
self.disconnect() | |
raise | |
if isinstance(response, ResponseError): | |
raise response | |
return response | |
def on_connect(self): | |
""" socket连接创建之后,进行redis的认证和选择数据库,开始发送数据. | |
""" | |
self._parser.on_connect(self) | |
# if a password is specified, authenticate | |
if self.passwd: | |
self.send_command('AUTH', self.passwd) | |
response = self.read_response() | |
if nativestr(response) != 'OK': | |
raise AuthenticationError('Invalid Password') | |
if self.db: | |
self.send_command('SELECT', self.db) | |
if nativestr(self.read_response()) != 'OK': | |
raise ConnectionError('Invalid Database') | |
def connect(self): | |
if self._sock: | |
return | |
try: | |
sock = self._connect() | |
except socket.error: | |
e = sys.exc_info()[1] | |
raise ConnectionError(self._error_message(e)) | |
self._sock = sock | |
try: | |
self.on_connect() | |
except RedisError: | |
self.disconnect() | |
raise | |
# for callback in self._connect_callbacks: | |
# callback(self) | |
def disconnect(self): | |
self._parser.on_disconnect() | |
if self._sock is None: | |
return | |
try: | |
self._sock.shutdown(socket.SHUT_RDWR) | |
self._sock.close() | |
except socket.error: | |
pass | |
self._sock = None | |
def send_packed_command(self, command): | |
"""将编码后的redis命令发送到redis服务器""" | |
if not self._sock: | |
self.connect() | |
try: | |
if isinstance(command, str): | |
command = [command] | |
for item in command: | |
self._sock.sendall(item) | |
except socket.timeout: | |
self.disconnect() | |
raise TimeoutError('Timeout writing to socket') | |
except socket.error: | |
e = sys.exc_info()[1] | |
self.disconnect() | |
if len(e.args) == 1: | |
errno, errmsg = 'UNKNOWN', e.args[0] | |
else: | |
errno = e.args[0] | |
errmsg = e.args[1] | |
raise ConnectionError("Error {} while writing to socket. {}.".format(errno, errmsg)) | |
except: | |
self.disconnect() | |
raise | |
class ConnectionPool(object): | |
def __init__(self, connection_class=Connection, max_connections=None, **connection_kwargs): | |
max_connections = max_connections or 2 ** 31 | |
if not isinstance(max_connections, int) or max_connections < 0: | |
raise ValueError('"max_connections" must be a positive integer') | |
self.connection_class = connection_class | |
self.connection_kwargs = connection_kwargs | |
self.max_connections = max_connections | |
self.reset() | |
def reset(self): | |
self.pid = os.getpid() | |
self._created_connections = 0 | |
self._available_connections = [] | |
self._in_use_connections = set() | |
self._check_lock = threading.Lock() | |
def _checkpid(self): | |
if self.pid != os.getpid(): | |
with self._check_lock: | |
if self.pid == os.getpid(): | |
# another thread already did the work while we waited | |
# on the lock. | |
return | |
self.disconnect() | |
self.reset() | |
def make_connection(self): | |
if self._created_connections >= self.max_connections: | |
raise ConnectionError("Too many connections") | |
self._created_connections += 1 | |
return self.connection_class(**self.connection_kwargs) | |
def disconnect(self): | |
"Disconnects all connections in the pool" | |
all_conns = chain(self._available_connections, | |
self._in_use_connections) | |
for connection in all_conns: | |
connection.disconnect() | |
def get_connection(self, command_name, *keys, **options): | |
self._checkpid() | |
try: | |
connection = self._available_connections.pop() | |
except IndexError: | |
connection = self.make_connection() | |
self._in_use_connections.add(connection) | |
return connection | |
if __name__ == '__main__': | |
p = ConnectionPool(host='127.0.0.1', port=6379) | |
args = ('PING',) | |
command_name = args[0] | |
conn = p.get_connection(command_name, **{}) | |
conn.send_command("GET", "hello") | |
print(conn.read_response()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
from io import BytesIO | |
class RedisError(Exception): | |
pass | |
def nativestr(x): | |
"""解码返回unicode 字符串""" | |
return x if isinstance(x, str) else x.decode('utf-8', 'replace') | |
def b(x): | |
"""将`unicode`编码成`bytes` 编码格式位 `latin-1`""" | |
return x.encode('latin-1') if not isinstance(x, bytes) else x | |
def byte_to_chr(x): | |
return chr(x) | |
SYM_STAR = b('*') | |
SYM_DOLLAR = b('$') | |
SYM_CRLF = b('\r\n') | |
SYM_EMPTY = b('') | |
def iteritems(x): | |
return iter(x.items()) | |
class Socket(object): | |
def __init__(self, data): | |
self.data = data | |
def recv(self, length): | |
data = self.data[:length] | |
self.data = self.data[length:] | |
return data | |
class SocketBuffer(object): | |
def __init__(self, socket, socket_read_size): | |
self._sock = socket | |
self.socket_read_size = socket_read_size | |
self._buffer = BytesIO() | |
self.bytes_written = 0 | |
self.bytes_read = 0 | |
@property | |
def length(self): | |
return self.bytes_written - self.bytes_read | |
def _read_from_socket(self, length=None): | |
socket_read_size = self.socket_read_size | |
buf = self._buffer | |
buf.seek(self.bytes_written) | |
# 表示每次从网络中读取的数据 | |
marker = 0 | |
while True: | |
data = self._sock.recv(socket_read_size) | |
buf.write(data) | |
data_length = len(data) | |
self.bytes_written += data_length | |
marker += data_length | |
# 保证读取length字段到buf中 | |
if length is not None and length > marker: | |
continue | |
break | |
def purge(self): | |
self._buffer.seek(0) | |
self._buffer.truncate() | |
self.bytes_read = 0 | |
self.bytes_written = 0 | |
def read(self, length): | |
length = length + 2 | |
if length > self.length: | |
self._read_from_socket(length - self.length) | |
self._buffer.seek(self.bytes_read) | |
data = self._buffer.read(length) | |
self.bytes_read += len(data) | |
if self.bytes_read == self.bytes_written: | |
self.purge() | |
return data[:-2] | |
def readline(self): | |
buf = self._buffer | |
buf.seek(self.bytes_read) | |
data = buf.readline() | |
# 处理包结束 | |
while not data.endswith(SYM_CRLF): | |
self._read_from_socket() | |
buf.seek(self.bytes_read) | |
data = buf.readline() | |
self.bytes_read += len(data) | |
if self.bytes_read == self.bytes_written: | |
self.purge() | |
return data[:-2] | |
class PythonParser(object): | |
encoding = None | |
def __init__(self, socket_read_size): | |
self.socket_read_size = socket_read_size | |
self._sock = None | |
self._buffer = None | |
def on_connect(self, data): | |
self._sock = Socket(data) | |
self._buffer = SocketBuffer(self._sock, self.socket_read_size) | |
def read_response(self): | |
response = self._buffer.readline() | |
byte, response = byte_to_chr(response[0]), response[1:] | |
if byte not in ('-', '+', ':', '$', '*'): | |
raise RedisError | |
# server returned an error | |
if byte == '-': | |
response = nativestr(response) | |
# 处理错误 | |
return response | |
# single value | |
elif byte == '+': | |
pass | |
# int value | |
elif byte == ':': | |
response = int(response) | |
# bulk response | |
elif byte == '$': | |
length = int(response) | |
if length == -1: | |
return None | |
response = self._buffer.read(length) | |
# multi-bulk response | |
elif byte == '*': | |
length = int(response) | |
if length == -1: | |
return None | |
response = [self.read_response() for i in range(length)] | |
if isinstance(response, bytes) and self.encoding: | |
response = response.decode(self.encoding) | |
return response | |
if __name__ == '__main__': | |
# redis-> $6\r\nfoobar\r\n | |
# set -> hello 你好 777 redis-> b'*3\r\n$3\r\n777\r\n$6\r\n\xe4\xbd\xa0\xe5\xa5\xbd$5\r\nhello\r\n' | |
# data = b'+OK\r\n' | |
data = b'$6\r\nfoobar\r\n' | |
# data = b'*3\r\n$3\r\n777\r\n$6\r\n\xe4\xbd\xa0\xe5\xa5\xbd\r\n$5\r\nhello\r\n' | |
pp = PythonParser(socket_read_size=5) | |
pp.on_connect(data) | |
print(pp.read_response()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
class Token(object): | |
def __init__(self, value): | |
if isinstance(value, Token): | |
value = value.value | |
self.value = value | |
def __repr__(self): | |
return self.value | |
def __str__(self): | |
return self.value | |
def b(x): | |
'''将`unicode`编码成`bytes` 编码格式位 `latin-1`''' | |
return x.encode('latin-1') if not isinstance(x, bytes) else x | |
SYM_STAR = b('*') | |
SYM_DOLLAR = b('$') | |
SYM_CRLF = b('\r\n') | |
SYM_EMPTY = b('') | |
class Connection(object): | |
def __init__(self, encoding='utf-8', encoding_errors='strict'): | |
self.encoding = encoding | |
self.encoding_errors = encoding_errors | |
self._sock = None | |
def pack_command(self, *args): | |
"""将redis命令安装redis的协议编码,返回编码后的数组,如果命令很大,返回的是编码后chunk的数组""" | |
output = [] | |
command = args[0] | |
if ' ' in command: | |
args = tuple([Token(s) for s in command.split(' ')]) + args[1:] | |
else: | |
args = (Token(command),) + args[1:] | |
buff = SYM_EMPTY.join( | |
(SYM_STAR, b(str(len(args))), SYM_CRLF)) | |
for arg in map(self.encode, args): | |
# 数据量特别大的时候,分成部分小的chunk | |
if len(buff) > 6000 or len(arg) > 6000: | |
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF)) | |
output.append(buff) | |
output.append(arg) | |
buff = SYM_CRLF | |
else: | |
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF, arg, SYM_CRLF)) | |
output.append(buff) | |
return output | |
def encode(self, value): | |
if isinstance(value, Token): | |
return b(value.value) | |
elif isinstance(value, bytes): | |
return value | |
elif isinstance(value, int): | |
value = b(str(value)) | |
elif not isinstance(value, str): | |
value = str(value) | |
if isinstance(value, str): | |
value = value.encode(self.encoding, self.encoding_errors) | |
return value | |
if __name__ == '__main__': | |
args = ('PING',) | |
# args = ('GET', 'hello') | |
# args = ('SET', 'hello', 'world') | |
args = ('CONFIG SET', 'timeout', 0) | |
# args = ('SET', 'key', '中国') | |
# args = ('SET', 'hello', 1) | |
# args = ('SET', '中国', 21.7) | |
# args = ('AUTH', '') | |
command = Connection().pack_command(*args) | |
print(command) | |
# print(type(command)) | |
# | |
# import socket | |
# | |
# address = ('127.0.0.1', 6379) | |
# sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
# sock.connect(address) | |
# sock.sendall(command[0]) | |
# print(sock.recv(1024)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment