Python SQLAlchemy的ORM模块封装

封装基类

模型基类,管理engine,事务处理,JSON序列化 的代码

通用代码 db_.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# -*- coding: utf-8 -*-
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager
from urllib import parse
import re
import datetime
import types
import decimal


class ModelBase(object):

@declared_attr
def __tablename__(cls):
# return cls.__name__.lower()
return re.sub(r'([A-Z])', r'_\1', cls.__name__[0].lower()+cls.__name__[1:]).lower()

@classmethod
def props(cls):
if cls.__base__.__name__ == "Base":
return [c for c in cls.__table__.columns]
elif cls.__base__.__base__.__name__ == "Base":
super_column_names = [c for c in cls.__base__.__table__.columns]
column_names = [c for c in cls.__table__.columns if c.name != 'id']
return super_column_names + column_names
else:
assert(False and "多层继承后的props获取暂未实现")

@classmethod
def prop_names(cls):
if cls.__base__.__name__ == "Base":
return [c.name for c in cls.__table__.columns]
elif cls.__base__.__base__.__name__ == "Base":
super_column_names = [c.name for c in cls.__base__.__table__.columns]
column_names = [c.name for c in cls.__table__.columns if c.name != 'id']
return super_column_names + column_names
else:
assert(False and "多层继承后的prop_names获取暂未实现")

@classmethod
def prop(cls, prop_name):
return cls.__table__.columns[prop_name]

# 对象在输出时如何序列化成字符串
def __repr__(self):
attrs = []
# for c in self.__table__.columns:
for name in self.prop_names():
attr = getattr(self, name)
if type(attr) in (str, datetime.date, datetime.time, datetime.datetime):
attrs.append(f"{name}='{attr}'")
else:
attrs.append(f"{name}={attr}")
return f"{self.__class__.__name__}({', '.join(attrs)})"

# 转换成dict时的键列表
def keys(self):
return self.prop_names()

def __getitem__(self, item):
return getattr(self, item)

def __setitem__(self, item, value):
return setattr(self, item, value)

def to_dict(self):
return {name: to_jsonable(getattr(self, name)) for name in self.keys()}


def to_dict(db_obj):
if isinstance(db_obj, ModelBase):
return db_obj.to_dict()
else:
return db_obj

def to_list(db_objs):
return [to_dict(db_obj) for db_obj in db_objs]


# Base = declarative_base(cls=ModelBase)


# 把ORM对象转成可序列化成JSON的对象,对于ORM对象的list转换为dict的list,对于ORM对象转换成dict
def to_jsonable(o):
if isinstance(o, list):
return [to_jsonable(e) for e in o]
if isinstance(o, dict):
return {k: to_jsonable(v) for (k,v) in o.items()}
if isinstance(o, ModelBase):
return o.to_dict()
if isinstance(o, datetime.datetime):
return o.strftime('%Y-%m-%d %H:%M:%S')
if isinstance(o, datetime.date):
return o.strftime('%Y-%m-%d')
if isinstance(o, datetime.time):
return o.strftime('%H:%M:%S')
if isinstance(o, decimal.Decimal):
return float(o)
return o


engines = {}

r"""
不同数据库URL样例
r"sqlite:///:memory:" sqlite内存数据库
r"sqlite:///C:\path\foo.db" sqlite在windows上
r"sqlite:////root/data/foo.db" sqlite在linux上
r"mysql+pymysql://username:password@server/dbname" MySQL
r"postgresql+psycopg2://username:password@server/dbname" Postgresql
r"mssql+pymssql://username:password@server/dbname?charset=utf8" Microsoft SQL Server
"""
def init_engine(
url=None, name="main",
dialect=None, username="", password="", server="", dbname=None,
**kwargs):
if dialect:
if dialect.lower() in ("mysql", ):
dialect = "mysql+pymysql"
elif dialect.lower() in ("postgresql", "postgres", "pgsql"):
dialect = "postgresql+psycopg2"
elif dialect.lower() in ("mssql", "sqlserver"):
dialect = "mssql+pymssql"
if url is None:
if password:
url = '{}://{}:{}@{}/{}'.format(dialect, username, parse.quote_plus(password), server, dbname)
elif username:
url = '{}://{}@{}/{}'.format(dialect, username, server, dbname)
else:
url = '{}://{}/{}'.format(dialect, server, dbname)

# # 设置连接池默认大小
# if "pool_size" not in kwargs:
# kwargs["pool_size"] = 5

# # 设置连接默认回收时间(MySQL默认8小时)
# if "pool_recycle" not in kwargs:
# kwargs["pool_recycle"] = 28000

# 默认启用pool_pre_ping
if "pool_pre_ping" not in kwargs:
kwargs["pool_pre_ping"] = True

engine = create_engine(url, **kwargs)
session_maker = sessionmaker(expire_on_commit=False)
session_maker.configure(bind=engine)

@contextmanager
def _session_scope(maker):
session = maker()
mutexes = []
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
engine.session_scope = types.MethodType(_session_scope, session_maker)
engines[name] = engine
return engine


def exit_engine(name="main"):
if name in engines:
del engines[name]


def get_engine(name="main"):
if name in engines:
return engines[name]
else:
return None


def session_scope(name="main"):
if name in engines:
return engines[name].session_scope()
else:
raise Exception("engine未初始化")

数据库模型(表)定义

db_test.py 基于db_.py来定义数据库和使用数据库,包括如何定义关系映射,如何定义外键,唯一键,索引,如何将查询结果转换成可JSON序列化的对象(以便通过RESTful API返回)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
from sqlalchemy import Table, Column, Integer, Float, Numeric, String, DateTime, Date, Time, Boolean
from sqlalchemy import PrimaryKeyConstraint, ForeignKey, Index, UniqueConstraint
from sqlalchemy import text, func, and_, or_, not_, asc, inspect, desc, distinct
from sqlalchemy.orm import relationship, backref
from sqlalchemy.ext.declarative import declarative_base

from utils.db_ import ModelBase, init_engine, exit_engine, get_engine, session_scope, to_jsonable

Base = declarative_base(cls=ModelBase)


class Clazz(Base):
id = Column(Integer(), primary_key=True, autoincrement=True)
enroll_year = Column(Integer(), comment="入学年份")


class Student(Base):
__table_args__ = (
UniqueConstraint('clazz_id', 'name'), # 创建多字段唯一约束
Index('ix_clazz_id_stu_no', 'clazz_id', 'stu_no'), # 创建多字段索引
)
id = Column(Integer(), primary_key=True, autoincrement=True) # 自增ID作为单字段主键,多字段主键也可以在多字段上都写上primary_key=True
stu_no = Column(String(), index=True, unique=True, comment="学号") # 创建单字段索引,创建单字段唯一约束
name = Column(String(), nullable=True, comment="姓名") # 创建非空约束
clazz_id = Column(Integer(), ForeignKey('clazz.id', ondelete='RESTRICT'), nullable=True, comment="班级ID") # 创建外键

clazz = relationship('Clazz') # 多对一(一对多)关系映射


class Course(Base):
id = Column(Integer(), primary_key=True, autoincrement=True)
name = Column(String(), comment="课程名")

students = relationship('Student', secondary='student_course', backref=backref('courses')) # 多对多映射关系,backref相当于给Student对象加了courses属性

def keys(self):
return super().keys() + ["students",]


# 选课(多对多关系表)
student_course = Table('student_course', Base.metadata,
Column('student_id', Integer(), ForeignKey('student.id'), primary_key=True),
Column('course_id', Integer(), ForeignKey('course.id'), primary_key=True))


def main():
engine = init_engine(url=r"sqlite:///:memory:")
Base.metadata.create_all(engine) # 根据定义创建出所有表
with session_scope() as session:
clazz = Clazz(enroll_year=2022) # 创建一个clazz对象
session.add(clazz) # 用clazz对象创建一行记录
session.flush() # 立即执行
session.refresh(clazz) # 更新clazz,把刚刚创建的一行记录的id取出来更新到对象中
print(clazz)

course = Course(name="数学")
session.add(course)
session.flush()
course = session.query(Course).filter(Course.name=="数学").first() # 把刚插入的记录查询出来
print(course)

student = Student(stu_no="007", name="张三", clazz_id=clazz.id)
student.courses.append(course) # 测试backref给student加了courses属性,并且可以通过其维护关系表
session.add(course)
session.flush()

session.refresh(course)
print(course.students)
print(to_jsonable(course)) #to_jsonable有students属性需要覆盖keys方法来支持

student.name = "李四"
session.flush() # 这里相当于执行了update


# 执行原生SQL查询
r = session.execute("select id, name from student").fetchall()
for e in r:
print(e) # 可以看到已经变成了李四



if __name__ == '__main__':
main()

使用sqlacodegen对已有数据库生成模型定义代码

通过pip安装

1
pip install sqlacodegen

命令格式 就是 sqlacodegen 后面跟初始化engine的url

1
sqlacodegen <sqldialect>[+<sqldriver>]://<username>:<password>@<server>/<dbname>

执行命令会在控制台输出生成的代码,可以重定向到.py

示例

1
sqlacodegen mysql+pymysql://mzdai:123456@192.168.1.140/mzdai > db_mzdai.py

然后可以修改生成的代码,例如导入db_中的一些内容,然后替换模型的基类 declarative_base(cls=ModelBase)