Created
May 6, 2014 15:36
-
-
Save rsj217/a05adb7b71638060d6f4 to your computer and use it in GitHub Desktop.
a python class to use mysql. mysql 的连接关闭查询的python封装
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
__author__ = 'ghost' | |
import time, uuid, functools, threading, logging | |
def next_id(t=None): | |
""" | |
""" | |
if t is None: | |
t = time.time() | |
return '%015d%s000' % (int(t * 1000), uuid.uuid4().hex) | |
def _profiling(start, sql=''): | |
t = time.time() - start | |
if t > 0.1: | |
logging.warning('[PROFILING] [DB] %s: %s' % (t, sql)) | |
else: | |
logging.info('[PROFILING] [DB] %s: %s' % (t, sql)) | |
class DBError(Exception): | |
""" | |
数据异常类 | |
""" | |
pass | |
class MultiColumnsError(DBError): | |
""" | |
""" | |
pass | |
class Dict(dict): | |
""" | |
增强型字典,继承原有的字典,可以将两个列表打包成字典,实现 dict(zip(list1, list2)) | |
>>> d1 = Dict() | |
>>> type(d1) | |
<class '__main__.Dict'> | |
>>> d1['name'] = 'python' | |
>>> d1.name | |
'python' | |
>>> d1['age'] = 13 | |
>>> d1['age'] | |
13 | |
>>> d1.get('name') | |
'python' | |
>>> d1.get('lll', 0) | |
0 | |
>>> d2 = Dict(name='python', age=13) | |
>>> d2['name'] | |
'python' | |
>>> d2['none'] | |
Traceback (most recent call last): | |
... | |
KeyError: 'none' | |
>>> d2.name | |
'python' | |
>>> d2.none | |
Traceback (most recent call last): | |
... | |
AttributeError: 'Dict' object has no attribute 'none' | |
>>> d3 = Dict(('name', 'age'), ('python', 13), isgood=True) | |
>>> d3 | |
{'isgood': True, 'age': 13, 'name': 'python'} | |
""" | |
def __init__(self, names=(), values=(), **kwargs): | |
super(Dict, self).__init__(**kwargs) | |
self.update(dict(zip(names, values))) | |
def __getattr__(self, key): | |
try: | |
return self[key] | |
except KeyError: | |
raise AttributeError(r"'Dict' object has no attribute '%s'" % key) | |
def __setattr__(self, key, value): | |
self[key] = value | |
class _LasyConnection(object): | |
""" | |
获取数据库引擎`连接资源句柄connection` | |
通过connection获取cursor | |
操作 commit, rollback | |
关闭连接 cleanup | |
""" | |
def __init__(self): | |
self.connection = None | |
def cursor(self): | |
if self.connection is None: | |
connection = engine.connect() | |
logging.info('open connection <%s>...' % hex(id(connection))) | |
self.connection = connection | |
return self.connection.cursor() | |
def commit(self): | |
self.connection.commit() | |
def rollback(self): | |
self.connection.rollback() | |
def cleanup(self): | |
if self.connection: | |
connection = self.connection | |
self.connection = None | |
logging.info('close connection <%s>...' % hex(id(connection))) | |
connection.close() | |
class _DbCtx(threading.local): | |
""" | |
数据库上下文操作类,实例全局数据库上下文实例 `_db_Ctx` | |
主要提供给 `_ConnectionCtx` 进行判断 connection 是否初始化`is_init`,进行初始化`init`和关闭`cleanup` | |
""" | |
def __init__(self): | |
self.connection = None | |
self.transactions = 0 | |
def is_init(self): | |
return not self.connection is None | |
def init(self): | |
logging.info('open lazy connections...') | |
self.connection = _LasyConnection() | |
self.transactions = 0 | |
def cleanup(self): | |
self.connection.cleanup() | |
self.connection = None | |
def cursor(self): | |
return self.connection.cursor() | |
class _Engine(object): | |
""" | |
数据库引擎类,用于连接数据库 | |
""" | |
def __init__(self, connect): | |
self._connect = connect | |
def connect(self): | |
return self._connect() | |
# 数据库上下文操作连接 | |
_db_ctx = _DbCtx() | |
# 全局数据库引擎对象 | |
engine = None | |
def create_engine(user, passwd, db, host='127.0.0.1', port=3306, **kwargs): | |
""" | |
创建数据库引擎,实现全局对象 `engine` | |
""" | |
import MySQLdb | |
global engine | |
if engine is not None: | |
raise DBError('Engine is already initialized.') | |
# 连接参数 | |
params = dict(user=user, passwd=passwd, db=db, host=host, port=port) | |
# 默认的连接参数 | |
defaults = dict(use_unicode=True, charset='utf8') | |
for k, v in defaults.iteritems(): | |
params[k] = kwargs.pop(k, v) | |
# 通过函数参数更新连接参数 | |
params.update(kwargs) | |
# 创建engine全局对象 | |
engine = _Engine(lambda: MySQLdb.connect(**params)) | |
logging.info('Init mysql engine <%s> ok.' % hex(id(engine))) | |
class _ConnectionCtx(object): | |
""" | |
打开关闭数据库上下文类,用于进行数据库操作时候,获取数据库引擎连接`connection`,操作结束后关闭连接 | |
with _ConnectionCtx(): | |
pass | |
""" | |
def __enter__(self): | |
global _db_ctx | |
self.should_cleanup = False | |
if not _db_ctx.is_init(): | |
_db_ctx.init() | |
self.should_cleanup = True | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
global _db_ctx | |
if self.should_cleanup: | |
_db_ctx.cleanup() | |
def connection(): | |
""" | |
对`_ConnectionCtx`的封装, 提供对外接口 | |
with connection(): | |
do_some_db_operation() | |
""" | |
return _ConnectionCtx() | |
def with_connection(func): | |
""" | |
获取数据库连接和关闭装饰器 | |
@with_connection | |
def foo(*args, **kwargs): | |
do_some_db_operation() | |
do_some_db_operation() | |
""" | |
@functools.wraps(func) | |
def _wrapper(*args, **kwargs): | |
with _ConnectionCtx(): | |
return func(*args, **kwargs) | |
return _wrapper | |
class _TransactionCtx(object): | |
def __enter__(self): | |
global _db_ctx | |
self.should_close_conn = False | |
if not _db_ctx.is_init(): | |
_db_ctx.init() | |
self.should_close_conn = True | |
_db_ctx.transactions += 1 | |
logging.info('begin transaction...' if _db_ctx.transactions==1 else 'join current transaction...') | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
global _db_ctx | |
_db_ctx.transactions -= 1 | |
try: | |
if _db_ctx.transactions == 0: | |
if exc_type is None: | |
self.commit() | |
else: | |
self.rollback() | |
finally: | |
if self.should_close_conn: | |
_db_ctx.cleanup() | |
def commit(self): | |
global _db_ctx | |
logging.info('commit transaction...') | |
try: | |
_db_ctx.connection.commit() | |
logging.info('commit ok.') | |
except: | |
logging.warning('commit failed. try rollback...') | |
_db_ctx.connection.rollback() | |
logging.warning('rollback ok.') | |
raise | |
def rollback(self): | |
global _db_ctx | |
logging.warning('rollback transaction...') | |
_db_ctx.connection.rollback() | |
logging.info('rollback ok.') | |
def transaction(): | |
""" | |
""" | |
return _TransactionCtx() | |
def with_transaction(func): | |
@functools.wraps(func) | |
def _wrapper(*args, **kwargs): | |
_start = time.time() | |
with _TransactionCtx(): | |
return func(*args, **kwargs) | |
_profiling(_start) | |
return _wrapper | |
@with_connection | |
def _select(sql, first, *args): | |
""" | |
查询函数 | |
""" | |
global _db_ctx | |
cursor = None | |
sql = sql.replace('?', '%s') | |
logging.info('SQL: %s, ARFS: %s' % (sql, args)) | |
try: | |
# 通过数据库上下文获取查询游标`cursor` | |
cursor = _db_ctx.connection.cursor() | |
# 执行sql查询 | |
cursor.execute(sql, args) | |
# 处理查询结果,返回 对象列表 | |
if cursor.description: | |
names = [x[0] for x in cursor.description] | |
if first: | |
values = cursor.fetchone() | |
if not values: | |
return None | |
return Dict(names, values) | |
return [Dict(names, x) for x in cursor.fetchall()] | |
finally: | |
# 关闭游标 | |
if cursor: | |
cursor.close() | |
@with_connection | |
def _update(sql, *args): | |
global _db_ctx | |
cursor = None | |
sql = sql.replace('?', '%s') | |
logging.info('SQL: %s, ARGS: %s' % (sql, args)) | |
try: | |
cursor = _db_ctx.connection.cursor() | |
cursor.execute(sql, args) | |
r = cursor.rowcount | |
if _db_ctx.transactions == 0: | |
logging.info('auto commit') | |
_db_ctx.connection.commit() | |
return r | |
finally: | |
if cursor: | |
cursor.close() | |
def update(sql, *args): | |
return _update(sql, *args) | |
def insert(table, **kwargs): | |
""" | |
""" | |
cols, args = zip(*kwargs.iteritems()) | |
sql = 'insert into `%s` (%s) values (%s)' % (table, ','.join(['`%s`' % col for col in cols]), ','.join(['?' for i in range(len(cols))])) | |
return _update(sql, *args) | |
def delete(sql, *args): | |
return _update(sql, *args) | |
def select_int(sql, *args): | |
d = _select(sql, True, *args) | |
if len(d) != 1: | |
raise MultiColumnsError('Expect only one column.') | |
return d.values()[0] | |
def select(sql, *args): | |
""" | |
""" | |
return _select(sql, False, *args) | |
def select_one(sql, *args): | |
""" | |
""" | |
return _select(sql, True, *args) | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.DEBUG) | |
create_engine(user='root', passwd='', db='webapp') | |
# update("DROP TABLE IF EXISTS user") | |
# update("CREATE TABLE user (id INT UNSIGNED NOT NULL PRIMARY KEY AUTO_INCREMENT, nickname VARCHAR(40), email VARCHAR(40), passwd VARCHAR(40), last_modified REAL)") | |
def test(rollback): | |
with transaction(): | |
u = dict(nickname='test', email='[email protected]', passwd='test') | |
insert('user', **u) | |
r = update("UPDATE user SET nickname='chage' WHERE passwd='test'") | |
if rollback: | |
raise StandardError('will cause rollback...') | |
import doctest | |
doctest.testmod() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
__author__ = 'ghost' | |
import logging, threading, functools, time | |
def _profiling(start, sql=''): | |
""" | |
""" | |
t = time.time() - start | |
if t > 0.1: | |
logging.warning('[PROFILING] [DB] %s: %s' % (t, sql)) | |
else: | |
logging.info('[PROFILING] [DB] %s: %s' % (t, sql)) | |
class Dict(dict): | |
""" 增强型字典,能够将序列一次打包成字典 | |
>>> d = Dict(name='python', age=12) | |
>>> d | |
{'age': 12, 'name': 'python'} | |
>>> d['name'] | |
'python' | |
>>> d.get('passwd', '1111') | |
'1111' | |
>>> d['passwd'] = 123 | |
>>> d | |
{'passwd': 123, 'age': 12, 'name': 'python'} | |
>>> | |
""" | |
def __init__(self, names=(), values=(), **kwargs): | |
""" 初始化 | |
@params names 元组或者列表序列 | |
values 元组或者列表序列 | |
kwargs 字典 | |
""" | |
super(Dict, self).__init__(**kwargs) | |
self.update(dict(zip(names, values))) | |
def __getattr__(self, item): | |
try: | |
return self[item] | |
except KeyError: | |
raise AttributeError(r"'Dict' object has no attribute '%s'" % item) | |
def __setattr__(self, key, value): | |
self[key] = value | |
class DBError(Exception): | |
pass | |
class MultiColumnsError(DBError): | |
pass | |
class _LazyConnection(object): | |
""" | |
获取数据库引擎`连接资源句柄connection` | |
通过connection获取cursor | |
操作 commit, rollback | |
关闭连接 cleanup | |
""" | |
def __init__(self): | |
self.connection = None | |
def cursor(self): | |
""" 获取游标 | |
""" | |
if self.connection is None: | |
connection = engine.connect() | |
logging.info('open connection <%s>...' % hex(id(connection))) | |
self.connection = connection | |
return self.connection.cursor() | |
def commit(self): | |
""" 提交session | |
""" | |
self.connection.commit() | |
def rollback(self): | |
""" 回滚 | |
""" | |
self.connection.rollback() | |
def cleanup(self): | |
""" 释放连接 | |
""" | |
if self.connection: | |
connection = self.connection | |
self.connection = None | |
logging.info('close connection <%s>...' % hex(id(connection))) | |
connection.close() | |
class _DbCtx(threading.local): | |
""" | |
数据库上下文操作类,用来生成全局数据库上下文实例 `_db_Ctx` | |
主要提供给 `_ConnectionCtx` 进行判断 connection 是否初始化`is_init`, | |
进行初始化`init`和关闭`cleanup` | |
""" | |
def __init__(self): | |
self.connection = None | |
self.transactions = 0 | |
def is_init(self): | |
""" 判断数据库上下文连接是否存在""" | |
return not self.connection is None | |
def init(self): | |
""" 初始化数据库上下文连接""" | |
logging.info('open lazy connections...') | |
self.connection = _LazyConnection() | |
self.transactions = 0 | |
def cleanup(self): | |
""" 清除数据库上下文连接""" | |
self.connection.cleanup() | |
self.connection = None | |
def cursor(self): | |
""" 获取上下文连接游标""" | |
return self.connection.cursor() | |
# 全局数据库引擎,用于获取连接 | |
engine = None | |
# 全局数据库连接池对象 | |
dbpool = None | |
# 全局数据库上下文对象 | |
_db_ctx = _DbCtx() | |
class _Engine(object): | |
""" 数据库引擎对象,用于动态生成连接 | |
""" | |
def __init__(self, connect): | |
self._connect = connect | |
def connect(self): | |
return self._connect() | |
class _ConnectCtx(object): | |
""" 数据库上下文打开关闭操作类,用于自动打开连接,清理连接 | |
with _ConnectCtx(): | |
pass | |
with _ConnectCtx(): | |
pass | |
""" | |
def __enter__(self): | |
global _db_ctx | |
self.should_cleanup = False | |
# 初始化连接 | |
if not _db_ctx.is_init(): | |
_db_ctx.init() | |
self.should_cleanup = True | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
global _db_ctx | |
if self.should_cleanup: | |
_db_ctx.cleanup() | |
def connection(): | |
""" 对外封装的数据库上下文自动打开和关闭接口方法 | |
with connection(): | |
do_some_db_operation | |
""" | |
return _ConnectCtx() | |
def with_connection(func): | |
""" 获取数据库连接和关闭装饰器 | |
@with_connection | |
def foo(*args, **kwargs): | |
do_some_db_operation() | |
do_some_db_operation() | |
""" | |
@functools.wraps(func) | |
def wrapper(*args, **kwargs): | |
with _ConnectCtx(): | |
return func(*args, **kwargs) | |
return wrapper | |
class _TransactionCtx(object): | |
""" 事务上下文自动管理类,用于事务处理时候获取和关闭上下文 | |
""" | |
def __enter__(self): | |
global _db_ctx | |
self.should_close_conn = False | |
if not _db_ctx.is_init(): | |
_db_ctx.init() | |
self.should_close_conn = True | |
_db_ctx.transactions += 1 | |
logging.info('begin transaction...' if _db_ctx.transactions==1 else 'join current transaction...') | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
global _db_ctx | |
_db_ctx.transactions -= 1 | |
try: | |
if _db_ctx.transactions == 0: | |
if exc_type is None: | |
self.commit() | |
else: | |
self.rollback() | |
finally: | |
if self.should_close_conn: | |
_db_ctx.cleanup() | |
def commit(self): | |
global _db_ctx | |
logging.info('commit transaction...') | |
try: | |
_db_ctx.connection.commit() | |
logging.info('commit ok.') | |
except: | |
logging.warning('commit failed. try rollback...') | |
_db_ctx.connection.rollback() | |
logging.warning('rollback ok.') | |
raise | |
def rollback(self): | |
global _db_ctx | |
logging.warning('rollback transaction...') | |
_db_ctx.connection.rollback() | |
logging.info('rollback ok.') | |
def transaction(): | |
""" 事务上下文对外接口 | |
with transaction(): | |
pass | |
""" | |
return _TransactionCtx() | |
def with_transaction(func): | |
""" 事务上下文装饰器 | |
@with_transaction | |
def foo(): | |
pass | |
""" | |
@functools.wraps(func) | |
def wrapper(*args, **kwargs): | |
_start = time.time() | |
with _TransactionCtx(): | |
return func(*args, **kwargs) | |
_profiling(_start) | |
return wrapper | |
@with_connection | |
def _update(sql, *args): | |
""" 执行 cud 操作的 sql 方法 | |
@params sql string SQL查询语句 | |
args tuple SQL 查询参数 | |
""" | |
global _db_ctx | |
cursor = None | |
# 格式化 sql 语句 | |
sql = sql.replace('?', '%s') | |
logging.info('SQL: %s, ARGS: %s' % (sql, args)) | |
try: | |
# 获取游标 | |
cursor = _db_ctx.connection.cursor() | |
# 执行 SQL | |
try: | |
cursor.execute(sql, args) | |
except Exception, e: | |
raise DBError('sql execute error') | |
# 影响行数 | |
r = cursor.rowcount | |
# 提交更改 | |
if _db_ctx.transactions == 0: | |
logging.info('auto commit') | |
_db_ctx.connection.commit() | |
return r | |
finally: | |
# 关闭游标 | |
if cursor: | |
cursor.close() | |
def update(sql, *args): | |
""" 更新方法 | |
@params: sql string SQL 语句 | |
args tuple 查询参数 | |
@return 返回影响行数 | |
>>> update("UPDATE user SET nickname=? WHERE nickname=? AND passwd=?", 'ruby', 'python', '111111') | |
""" | |
return _update(sql, *args) | |
def insert(table, **kwargs): | |
""" 插入方法 | |
@params tabel string 需要插入的数据表名 | |
kwargs dict 插入数据字典 | |
@return 返回影响的数据库行数 | |
>>> insert('user', nickname='python', email='[email protected]', passwd='111111') | |
""" | |
cols, args = zip(*kwargs.iteritems()) | |
sql = 'insert into `%s` (%s) values (%s)' % (table, ','.join(['`%s`' % col for col in cols]), ','.join(['?' for i in range(len(cols))])) | |
return _update(sql, *args) | |
def delete(sql, *args): | |
""" 删除方法 与 update 类似,用于执行 sql 删除 | |
""" | |
return _update(sql, *args) | |
@with_connection | |
def _select(sql, first, *args): | |
""" 查询数据库方法,执行查询sql语句,返回结果集 | |
@params: sql string SQL查询语句 | |
first bool 是否为查询一条,True为查询一条记录 | |
args tuple sql查询参数 | |
@return: 返回结果集列表 | |
""" | |
global _db_ctx | |
cursor = None | |
# 格式化 sql语句 | |
sql = sql.replace('?', '%s') | |
logging.info('SQL: %s, ARFS: %s' % (sql, args)) | |
try: | |
# 获取游标 | |
cursor = _db_ctx.connection.cursor() | |
# 执行 SQL 语句 | |
cursor.execute(sql, args) | |
if cursor.description: | |
names = [x[0] for x in cursor.description] | |
if first: | |
# 获得结果集 | |
values = cursor.fetchone() | |
if not values: | |
return None | |
# 格式化结果 | |
return Dict(names, values) | |
return [Dict(names, x) for x in cursor.fetchall()] | |
finally: | |
# 关闭游标 | |
if cursor: | |
cursor.close() | |
def select(sql, *args): | |
""" 查询sql方法 返回结果集 | |
@params: sql string sql查询语句 | |
args tuple 查询参数 | |
@return: 返回结果集 | |
""" | |
return _select(sql, False, *args) | |
def select_one(sql, *args): | |
return _select(sql, True, *args) | |
def create_pool(user, passwd, db, host='127.0.0.1', port=3306, **kwargs): | |
""" 创建连接池 | |
@params: user string 数据库用户名 | |
passwd string 数据库用户名密码 | |
db string 数据库名 | |
host string 数据库主机地址,默认为 127.0.0.1 | |
port number 数据库端口 , 默认为3306 | |
kwargs dict 其他设置参数 | |
@return None | |
""" | |
import MySQLdb | |
from DBUtils.PooledDB import PooledDB | |
global dbpool | |
# 判断连接池是否存在 | |
if dbpool is not None: | |
logging.info(DBError('pool is already initialized.')) | |
return | |
# 连接参数 | |
params = dict(user=user, passwd=passwd, db=db, host=host, port=port) | |
# 默认的连接参数 | |
# use_unicode 是否使用 unicode, 默认 True | |
# charset 数据库编码 默认使用utf8 | |
# mincached : 启动时开启的闲置连接数量(缺省值 0 以为着开始时不创建连接) | |
# maxcached : 连接池中允许的闲置的最多连接数量(缺省值 0 代表不闲置连接池大小) | |
# maxshared : 共享连接数允许的最大数量(缺省值 0 代表所有连接都是专用的)如果达到了最大数量,被请求为共享的连接将会被共享使用 | |
# maxconnections : 创建连接池的最大数量(缺省值 0 代表不限制) | |
# blocking : 设置在连接池达到最大数量时的行为(缺省值 0 或 False 代表返回一个错误<toMany......>; 其他代表阻塞直到连接数减少,连接被分配) | |
# maxusage : 单个连接的最大允许复用次数(缺省值 0 或 False 代表不限制的复用).当达到最大数时,连接会自动重新连接(关闭和重新打开) | |
# setsession : 一个可选的SQL命令列表用于准备每个会话, | |
defaults = dict(use_unicode=True, charset='utf8', mincached=10, maxcached=10, maxshared=30, maxconnections=100, blocking=True, maxusage=0, setsession=None) | |
# 处理自定义参数和默认参数 | |
for k, v in defaults.iteritems(): | |
params[k] = kwargs.pop(k, v) | |
# 更新连接参数 | |
params.update(kwargs) | |
# 创建连接池 | |
dbpool = PooledDB(MySQLdb, **params) | |
logging.info('Init mysql pool <%s>ok' % hex(id(dbpool))) | |
def create_engine(): | |
""" 创建数据库连接引擎,用于获取数据库连接池连接 | |
""" | |
global engine | |
if engine is not None: | |
logging.info('Engine is already initialized.') | |
return | |
engine = _Engine(lambda : dbpool.connection()) | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.INFO) | |
create_pool(user='root', passwd='', db='pytest', host='127.0.0.1') | |
create_engine() | |
update("DROP TABLE IF EXISTS user") | |
update("CREATE TABLE user (id INT UNSIGNED NOT NULL PRIMARY KEY AUTO_INCREMENT, nickname VARCHAR(40), email VARCHAR(40), passwd VARCHAR(40))") | |
insert('user', nickname='python', email='[email protected]', passwd='123456') | |
insert('user', nickname='python', email='[email protected]', passwd='111111') | |
update("UPDATE user SET nickname=? WHERE nickname=? AND passwd=?", 'ruby', 'python', '111111') | |
@with_transaction | |
def update_profile(name, rollback): | |
u = dict(nickname=name, email='{0}@test.com'.format(name)) | |
insert('user', **u) | |
r = update("UPDATE user SET nickname=? where passwd=?", name.upper(), '111111') | |
if rollback: | |
raise StandardError('will cause rollback...') | |
update_profile('test', True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#: -*- coding: utf-8 -*- | |
""" | |
database | |
~~~~~~~~ | |
python对mysql操作的封装类 | |
~~~~~~~~~~~~~~~~~~~~~~~~~ | |
:author: rsj217 | |
:license: BSD. | |
:contact: [email protected] | |
:version: 0.0.1 | |
""" | |
try: # 连接 MySQLdb, 或者 pymysql | |
import MySQLdb as mysql | |
except ImportError: | |
import pymysql as mysql | |
class DataBase(object): | |
""" | |
数据库操作类 | |
""" | |
def __init__(self, host, user, passwd, database, port=3306, charset='utf8'): | |
"""初始化数据库配置信息,端口默认 3306 编码 utf-8 | |
""" | |
#: * 数据库主机地址 | |
self.host=host | |
#: * 数据库用户名 | |
self.user=user | |
#: * 数据库用户密码 | |
self.passwd=passwd | |
#: * 数据库名 | |
self.database=database | |
#: * 数据库端口 | |
self.port=port | |
#: * 数据库字符编码 | |
self.charset=charset | |
def __get_db(self): | |
"""链接数据库,获取数据库句柄""" | |
#: 数据库连接, 返回资源句柄 | |
db = mysql.connect( | |
host=self.host, | |
user=self.user, | |
passwd=self.passwd, | |
db=self.database, | |
port=self.port, | |
charset=self.charset) | |
return db | |
def execrone(self, func): | |
''' | |
读取数据库,返回单条记录 | |
parameters | |
func | |
函数类型,被装饰器包装的函数,不用显示传递 | |
return | |
查询数据库单条记录结果和影响行数 | |
sample:: | |
@self.execrone | |
def getone(): | |
pass | |
getone 将会被本方法包装 | |
''' | |
def wrap(*args): | |
try: | |
#: 连接数据库 | |
db = self.__get_db() | |
#: 获取数据查询游标 | |
cursor = db.cursor() | |
#: 得到 sql 语句 | |
sql = func(*args) | |
#: 执行单条 sql 语句, 返回受影响的行数 | |
rownum = cursor.execute(sql) | |
#: 执行查询,返回单条数据 | |
result = cursor.fetchone() | |
#: 返回查询结果和影响行数 | |
return (rownum, result) | |
except mysql.Error, e: | |
print "Mysql Error %d: %s" % (e.args[0], e.args[1]) | |
finally: | |
#: 关闭游标 | |
cursor.close() | |
#: 关闭数据库 | |
db.close() | |
return wrap | |
def execrall(self, func): | |
''' | |
读取数据库,返回多条记录 | |
parameters: | |
func | |
函数类型,被装饰器包装的函数,不用显示传递 | |
return: | |
查询数据库多条记录结果和影响行数 | |
sample:: | |
@self.execrall | |
def getall(): | |
pass | |
getall 将会被本方法包装 | |
''' | |
def wrap(*args): | |
try: | |
#: 连接数据库 | |
db = self.__get_db() | |
#: 获取数据查询游标 | |
cursor = db.cursor() | |
#: 得到 sql 语句 | |
sql = func(*args) | |
#: 执行单条 sql 语句, 返回受影响的行数 | |
rownum = cursor.execute(sql) | |
#: 执行查询,返回多条数据 | |
result = cursor.fetchall() | |
#: 返回查询结果和影响行数 | |
return (rownum, result) | |
except mysql.Error, e: | |
print "Mysql Error %d: %s" % (e.args[0], e.args[1]) | |
finally: | |
#: 关闭游标 | |
cursor.close() | |
#: 关闭数据库 | |
db.close() | |
return wrap | |
def execcud(self, func): | |
''' | |
添加数据库记录,用于 create update delete 操作, | |
如果写入数据库失败,则执行回滚操作。 | |
parameters: | |
func | |
函数类型,被装饰器包装的函数,不用显示传递 | |
return: | |
增加更新删除数据库影响的行数 | |
sample:: | |
@self.execcud | |
def insert(): | |
pass | |
insert 将会被本方法包装 | |
''' | |
def wrap(*args): | |
try: | |
#: 连接数据库 | |
db = self.__get_db() | |
#: 获取数据查询游标 | |
cursor = db.cursor() | |
#: 得到 sql 语句 | |
sql = func(*args) | |
#: 执行单条 sql 语句, 返回受影响的行数 | |
rownum = cursor.execute(sql) | |
#: 提交查询 | |
db.commit() | |
#: 返回影响行数 | |
return rownum | |
except mysql.Error, e: | |
#: 发生错误时回滚 | |
db.rollback() | |
print "Mysql Error %d: %s" % (e.args[0], e.args[1]) | |
finally: | |
#: 关闭游标 | |
cursor.close() | |
#: 关闭数据库 | |
db.close() | |
return wrap | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
__author__ = 'ghost' | |
import unittest, logging | |
import pool | |
class TestPool(unittest.TestCase): | |
def setUp(self): | |
pool.create_pool(user='root', passwd='', db='pytest', host='127.0.0.1') | |
pool.create_engine() | |
pool.update("DROP TABLE IF EXISTS user") | |
pool.update("CREATE TABLE user (id INT UNSIGNED NOT NULL PRIMARY KEY AUTO_INCREMENT, nickname VARCHAR(40), email VARCHAR(40), passwd VARCHAR(40))") | |
def tearDown(self): | |
print 'test end.' | |
def test_insert(self): | |
r = pool.insert('user', nickname='insert', email='[email protected]', passwd='insert') | |
self.assertEquals(1L, r) | |
with self.assertRaises(pool.DBError): | |
r = pool.insert('users', nickname='python', email='[email protected]', passwd='123456') | |
print 'test insert end' | |
def test_delete(self): | |
r = pool.insert('user', nickname='delete', email='[email protected]', passwd='delete') | |
self.assertEquals(1, r) | |
dr = pool.delete("DELETE FROM user WHERE nickname=?", 'delete') | |
self.assertEquals(dr, 1) | |
print 'test delete end' | |
def test_update(self): | |
r = pool.insert('user', nickname='update', email='[email protected]', passwd='update') | |
self.assertEquals(1, r) | |
r = pool.update("UPDATE user SET email=?, passwd=? WHERE nickname=?", '[email protected]', '111111', 'update') | |
self.assertEquals(1, r) | |
print 'test update end' | |
def test_select(self): | |
r = pool.insert('user', nickname='python', email='[email protected]', passwd='python') | |
r = pool.insert('user', nickname='ruby', email='[email protected]', passwd='update') | |
r = pool.insert('user', nickname='python', email='[email protected]', passwd='update') | |
users = pool.select("SELECT * FROM user WHERE nickname=?", "python") | |
self.assertEquals(2, len(users)) | |
print users | |
print 'test select end' | |
def test_transaction(self): | |
def update_profile(name, rollback): | |
u = dict(nickname=name, email='{0}@test.com'.format(name)) | |
pool.insert('user', **u) | |
r = pool.update("UPDATE user SET nickname=? where passwd=?", name.upper(), '111111') | |
if rollback: | |
raise StandardError('will cause rollback...') | |
with self.assertRaises(StandardError): | |
with pool.transaction(): | |
update_profile('test', True) | |
with pool.transaction(): | |
update_profile('test', False) | |
u = pool.select("SELECT * FROM user WHERE nickname=?", 'test') | |
self.assertEquals(1, len(u)) | |
print u | |
print 'test transaction end' | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.INFO) | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# id classid title | |
# 1 -1 python | |
# 2 -1 ruby | |
# 3 -1 php | |
# 4 -1 lisp | |
# 5 1 flask | |
# 6 1 django | |
# 7 1 webpy | |
# 8 2 rails | |
# 9 3 zend | |
# 10 6 dblog | |
t = ( | |
(1, -1, 'python'), | |
(2, -1, 'ruby'), | |
(3, -1, 'php'), | |
(4, -1, 'lisp'), | |
(5, 1, 'flask'), | |
(6, 1, 'django'), | |
(7, 1, 'webpy'), | |
(8, 2, 'rails'), | |
(9, 3, 'zend'), | |
(10, 6, 'dblog') | |
) | |
# l = [ | |
# { | |
# 'id': 1, | |
# 'classid': -1, | |
# 'title': 'python', | |
# 'son': [ | |
# { | |
# 'id': 5, | |
# 'classid': 1, | |
# 'title': 'flask', | |
# 'son': None | |
# }, | |
# { | |
# 'id': 6, | |
# 'classid': 1, | |
# 'title': 'django', | |
# 'son': [ | |
# { | |
# 'id': 10, | |
# 'classid': 6, | |
# 'title': 'dblog', | |
# 'son': None | |
# }, | |
# ] | |
# }, | |
# { | |
# 'id': 7, | |
# 'classid': 1, | |
# 'title': 'webpy', | |
# 'son': None | |
# }, | |
# ] | |
# }, | |
# { | |
# 'id': 2, | |
# 'classid': -1, | |
# 'title': 'ruby', | |
# 'son': [ | |
# { | |
# 'id': 8, | |
# 'classid': 2, | |
# 'title': 'rails', | |
# 'son': None | |
# }, | |
# ] | |
# }, | |
# { | |
# 'id': 3, | |
# 'classid': -1, | |
# 'title': 'php', | |
# 'son': [ | |
# { | |
# 'id': 9, | |
# 'classid': 3, | |
# 'title': 'zend', | |
# 'son': None | |
# }, | |
# ] | |
# }, | |
# { | |
# 'id': 4, | |
# 'classid': -1, | |
# 'title': 'lisp', | |
# 'son': None | |
# } | |
# ] | |
# from pprint import pprint | |
# l = [] | |
# entries = {} | |
# for id, fid, title in t: | |
# entries[id] = entry = {'id': id, 'fid': fid, 'title': title} | |
# if fid == -1: | |
# l.append(entry) | |
# else: | |
# parent = entries[fid] | |
# parent.setdefault('son', []).append(entry) | |
# pprint(l) | |
children = {} | |
objs = [] | |
l = [] | |
for id, parent, title in t: | |
obj = { | |
"id": id, | |
"fid": parent, | |
"title": title | |
} | |
objs.append(obj) | |
if parent == -1: # keep only roots | |
l.append(obj) | |
if parent not in children: # append to children | |
children[parent] = [] | |
children[parent].append(obj) | |
for obj in objs: | |
if obj["id"] in children: | |
obj["son"] = children[obj["id"]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment