0%

之前用在SQLAlchemy的ORM模型的类名(驼峰风格)和数据库表名(下划线风格)的转换。

Python类名驼峰风格这个不用解释,数据库表名使用下划线风格主要是因为一些数据库系统如果使用了带大写字母的表名,那么在select、insert、update、delete语句中都要用特殊分割符包住表名才能使用,很麻烦。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 驼峰转下划线
def hump_to_underline(name, drop_first_underline=True):
result = re.sub(r'([A-Z])', r'_\1', name).lower()
if drop_first_underline and result[0] == '_':
result = result[1:]
return result

# 下划线转驼峰
def underline_to_hump(name, capitalize_first_letter=True):
ret = ""
i = 0
while i < len(name):
if name[i] == '_' and i+1 < len(name) and name[i+1] != '_':
i += 1
ret += name[i].upper()
else:
ret += name[i]
i += 1
if capitalize_first_letter:
return ret[0].upper()+ret[1:]
else:
return ret

基于SQLAlchemy的Upserter,当时是基于SQLAlchemy写的,不过最后似乎没怎么用到SQLAlchemy的特性,只是取了一下数据库的类型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# -*- coding: utf-8 -*-

import math
import datetime
import numpy as np
import pandas as pd
import sqlalchemy
import decimal
from sqlalchemy import text
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager

'''
engine: SQLAlchemy Engine
buffer_size: 缓存条目数,当缓存满时自动flush
update_on_duplicate: 当唯一键重复时的行为,默认是update,设置为False表示不更新,即忽略插入失败。
'''
def create_upserter(engine, buffer_size=5000, update_on_duplicate=True, dummy=False):
if dummy:
return DummyUpserter(engine)
if engine.dialect.name.lower().find("mysql") != -1:
return MySQLUpserter(engine, buffer_size, update_on_duplicate)
elif engine.dialect.name.lower().find("postgresql") != -1:
return PSQLUpserter(engine, buffer_size, update_on_duplicate)
else:
print(f"没有为{engine.dialect.name}实现特殊的Upsert,使用默认版本,请确认可以正常工作,建议特化一个专门版本")
return UpserterBase(engine, buffer_size, update_on_duplicate)


def is_duplicate_key(e):
for T in UpserterBase.__subclasses__():
if T.is_duplicate_key(e):
return True
return UpserterBase.is_duplicate_key(e)


'''
class Upserter:

# 传入的engine类型应该和使用的Upserter支持的数据库类型相匹配
# buffer_size表示插入或更新数据缓存到多少才flush(即向数据库插入或更新),None表示在析构时flush,0表示不缓存
# update_on_duplicate当唯一键重复时的行为,默认是update,设置为False表示不更新,即忽略插入失败。
def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
pass

# tablename为数据库表名
# pk为主键的元组,可以不是真正的表主键,但是可以用来判重决定insert还是update,例如('exchange_id', 'trade_id')
# data为单条数据,dict的形式,例如{'exchange_id': 'DCE', 'trade_id': ' 1', 'price': 1.2, 'volume': 1}
def upsert(self, tablename, pk, data):
pass

# 立即把缓冲器的数据推到数据库,会在buffer_size满了或者析构时自动调用,也可以手动调用
def flush(self):
pass
'''
class UpserterBase:
def __init__(self, engine, buffer_size=None, update_on_duplicate=True, field_quote_mark=''):
self.engine = engine
self.session_maker = sessionmaker(expire_on_commit=False)
self.session_maker.configure(bind=engine)
self.tablename2pk = {}
self.tablename2datas = {}
self.buffer_size = buffer_size
self.update_on_duplicate = update_on_duplicate
self.field_quote_mark = field_quote_mark

def __del__(self):
self.flush()

@contextmanager
def session_scope(self):
session = self.session_maker()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()

def flush(self):
for (tablename, pk) in self.tablename2pk.items():
datas = self.tablename2datas[tablename]
if len(datas) > 0:
with self.session_scope() as session:
self._flush(session, tablename, pk, datas)
self.tablename2datas[tablename] = []

def _gen_batch_insert_sql(self, tablename, datas):
quote_mark = self.field_quote_mark
columns = datas[0].keys()
# ex: INSERT INTO tablename(`field1`, `field2`, `field3`) VALUES
sql = f"""INSERT INTO {tablename}({quote_mark}{f"{quote_mark}, {quote_mark}".join(columns)}{quote_mark}) VALUES\n"""
for i, data in enumerate(datas):
if i != len(datas) - 1:
sql += f""" ({self._format_values(data.values())}),\n"""
else:
sql += f""" ({self._format_values(data.values())});\n"""
return sql

def _flush(self, session, tablename, pk, datas):
sql = self._gen_batch_insert_sql(tablename, datas)
try:
session.execute(text(sql))
except sqlalchemy.exc.IntegrityError as e:
if self.is_duplicate_key(e):
# 插入遇到重复KEY
if len(datas) <= 500:
for data in datas:
self.upsert_one(session, tablename, pk, data)
else:
l = len(datas)
p = int(l // 2)
self._flush(session, tablename, pk, datas[:p])
self._flush(session, tablename, pk, datas[p:])
else:
raise e

def upsert_one(self, session, tablename, pk, data):
quote_mark = self.field_quote_mark
r = None
if self.update_on_duplicate:
update_str = self._format_update_values(pk, data)
if self.update_on_duplicate and update_str.strip():
r = session.execute(text(f"UPDATE {tablename} SET {update_str} WHERE {self._format_update_conditions(pk, data)}"))
if not r or r.rowcount == 0:
try:
r = session.execute(text(
f"INSERT INTO {tablename}({quote_mark}{f'{quote_mark}, {quote_mark}'.join(data.keys())}{quote_mark}) VALUES ({self._format_values(data.values())})"
))
except sqlalchemy.exc.IntegrityError as e:
if self.is_duplicate_key(e):
pass
else:
raise e

@staticmethod
def is_duplicate_key(e):
if type(e) != sqlalchemy.exc.IntegrityError:
return False
return (str(e.orig).lower().find("duplicate") != -1)

def _isinf(self, x):
return x>=9223372036854775807 or x<=-9223372036854775808

def _format_value(self, v):
if v is None:
return "null"
elif type(v) == float or type(v) == decimal.Decimal:
if math.isnan(v) or math.isinf(v) or self._isinf(v):
return "null"
else:
return f"{v}"
elif type(v) == int:
if self._isinf(v):
return "null"
else:
return f"{v}"
elif type(v) == datetime.datetime:
return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'"
elif type(v) == datetime.date:
return "'"+v.strftime("%Y-%m-%d")+"'"
elif type(v) == pd.Timestamp:
return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'"
elif type(v) == str:
return repr(v).replace(r'%',r"%%") # 可以解决字符串包含'"%
else:
return repr(f'{v}').replace(r'%',r"%%")

def _format_values(self, data):
s = ''
for i, e in enumerate(data):
s += self._format_value(e)
s += ', '
return s[:-2]

def _format_update_values(self, pk, data):
s = ''
for i, (k, v) in enumerate(data.items()):
if k not in pk:
s += f"{self.field_quote_mark}{k}{self.field_quote_mark}={self._format_value(v)}, "
return s[:-2]

def _format_update_conditions(self, pk, data):
s = ''
for i, (k, v) in enumerate(data.items()):
if k in pk:
s += f"{self.field_quote_mark}{k}{self.field_quote_mark}={self._format_value(v)} and "
return s[:-4]

def upsert(self, tablename, pk, data):
if self.buffer_size is not None and self.buffer_size == 0:
with self.session_scope() as session:
self.upsert_one(session, tablename, pk, data)
else:
if pk:
self.tablename2pk[tablename] = pk
if tablename not in self.tablename2datas:
self.tablename2datas[tablename] = []
self.tablename2datas[tablename].append(data)

if self.buffer_size is not None and len(self.tablename2datas[tablename]) >= self.buffer_size:
with self.session_scope() as session:
self._flush(session, tablename, self.tablename2pk[tablename], self.tablename2datas[tablename])
self.tablename2datas[tablename] = []

def upsert_dataframe(self, tablename, pk, df):
for index, row in df.iterrows():
self.upsert(tablename, pk, row.to_dict())


class MySQLUpserter(UpserterBase):
def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
super().__init__(engine, buffer_size, update_on_duplicate, field_quote_mark='`')

def __del__(self):
super().__del__()

@staticmethod
def is_duplicate_key(e):
if type(e) != sqlalchemy.exc.IntegrityError:
return False
if len(e.orig.args) > 1 and str(e.orig.args[1]).startswith("Duplicate entry"):
return True
return False

def upsert_one(self, session, tablename, pk, data):
if self.update_on_duplicate:
update_str = self._format_update_values(pk, data)
if self.update_on_duplicate and update_str.strip():
sql = f"""INSERT INTO {tablename}(`{"`, `".join(data.keys())}`) VALUES
({self._format_values(data.values())})
ON DUPLICATE KEY UPDATE {update_str}\n"""
else:
sql = f"""INSERT INTO {tablename}(`{"`, `".join(data.keys())}`) VALUES
({self._format_values(data.values())})
ON DUPLICATE KEY UPDATE `{pk[0]}`=VALUES(`{pk[0]}`)\n"""
session.execute(text(sql))


class PSQLUpserter(UpserterBase):
def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
super().__init__(engine, buffer_size, update_on_duplicate, field_quote_mark='"')

def __del__(self):
super().__del__()

@staticmethod
def is_duplicate_key(e):
if type(e) != sqlalchemy.exc.IntegrityError:
return False
if str(e.orig).startswith("duplicate key"):
return True
return False

def upsert_one(self, session, tablename, pk, data):
if self.update_on_duplicate:
update_str = self._format_update_values(pk, data)
if self.update_on_duplicate and update_str.strip():
sql = f"""INSERT INTO {tablename}("{'", "'.join(data.keys())}") VALUES
({self._format_values(data.values())})
on conflict ("{'", "'.join(pk)}")
do update set {update_str}\n"""
else:
sql = f"""INSERT INTO {tablename}("{'", "'.join(data.keys())}") VALUES
({self._format_values(data.values())})
on conflict ("{'", "'.join(pk)}")
do nothing\n"""
session.execute(text(sql))


class DummyUpserter(UpserterBase):
def __init__(self, engine):
super().__init__(engine)

def upsert(self, tablename, pk, data):
pass

def upsert_one(self, session, tablename, pk, data):
pass

def upsert_dataframe(self, tablename, pk, df):
pass

std::priority_queue

std::priority_queue 是C++标准库提供的优先队列(最大堆)实现,位于头文件

默认情况下要求元素有“小于”运算,取堆顶,返回最大值。

可以通过模板参数调整排序方式让其返回最小值,或者为自定义类型定义排序方式。

1
2
template <class T, class Container = vector<T>, class Compare = less<typename Container::value_type> >
class priority_queue;

模板参数:

T是数据类型

Container是维护最大(小)堆使用的容器类型

Compare是一个function object的类型,定义了排序方式

什么是function object?

function object是一种对象,这个对象的类重载了括号运算符,也就是 operator() ,所以这个对象可以使用 obj(…),看上去就像在调用一个function一样。

使用比较器类定义优先队列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#include <queue>
#include <vector>
using namespace std;

struct Student { // 学生
int id; // 学号
int height; // 身高
};

int main()
{
// 定义比较器
struct Cmp {
bool operator() (const Node& a, const Node& b) {
return a.height < b.height;
}
};
// 定义优先队列
priority_queue<Student, vector<Student>, Cmp> PQ;
}

使用lamda表达式排序

1
2
3
4
5
6
7
vector<Student> students = {...};
std::sort(students.begin(), students.end(), [](const Student& a, const Student& b){
return a.height < b.height;
});

// 等价的
std::sort(students.begin(), students.end(), Cmp());

使用lambda表达式的好处是让“比较方法的描述”接近sort的调用,无论从编写还是阅读都是更好的。

使用lambda表达式的坏处是,不方便复用比较方法。

使用lamda表达式定义优先队列

实际上priority_queue有一个构造函数,可以传递一个比较对象,如果不传递就会用模板参数定义默认的比较对象。

1
explicit priority_queue (const Compare& comp = Compare(), Container&& ctnr = Container());

我们可以通过构造函数参数传递一个lambda表达式定义比较方式,我们期望的定义优先队列的方式是

1
2
3
priority_queue<Node> PQ([](const Node& a, const Node& b) {
return a.height < b.height;
});

但是很遗憾,我们并不能这样定义,这会导致编译错误,原因是我们在模板参数仅传递了数据类型T,而没有传递Compare,因此Compare使用了默认的less,而我们传递的lambda表达式显然不是less类型,因此不符合构造函数的参数要求。

使用decltype获取lambda表达式类型

因此我们不得不传递Compare为我们定义的lambda表达式的类型,这里可以使用 decltype 关键字,这个关键字直到C++11才被引入。

1
2
3
4
5
// 通过lambda表达式定义序
auto cmp = [](const Node& a, const Node& b) {
return a.height < b.height;
};
priority_queue<Node, vector<Node>, decltype(cmp)> PQ(cmp);

看上去和通过定义比较器定义优先队列似乎差不多,实际上lambda表达式的魅力在于可以访问当前上下文中的其他变量。

例如:假设我们有一个 vector<Student>存储着学生信息,我们想定义一个存储学号的优先队列priority_queue,依然按照身高对其中学号排序

1
2
3
4
5
6
7
8
9
10
11
vector<Student> students = {...};
unordered_map<int, Student> id2stu;
for(auto& stu: students) {
id2stu[stu.id] = stu;
}

// 我们可以很方便地把id2stu绑定到lambda表达式中用来排序
auto cmp = [&](int a, int b) {
return id2stu[a].height < id2stu[b].height;
};
priority_queue<int, vector<int>, decltype(cmp)> PQ(cmp);

如果用定义比较器类的方式则需要通过构造函数传递id2stu的引用,然后绑定给成员变量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
vector<Student> students = {...};
unordered_map<int, Student> id2stu;
for(auto& stu: students) {
id2stu[stu.id] = stu;
}

struct Cmp {
const unordered_map<int, Student>& id2stu;
Cmp(const unordered_map<int, Student>& id2stu) : id2stu(id2stu) {}
bool operator() (int a, int b) const {
return id2stu[a].height < id2stu[b].height;
}
};
priority_queue<int, vector<int>, Cmp> PQ;

本质是一样的,但是写法有些累赘。

问题:在字符串s中查找字符串p首次出现的位置。

正常情况下对s和p进行匹配的最坏时间复杂度是O(len(s)*len(p)),我们用i,j分别从s,p的头部进行匹配,每次匹配失败我们回退j到0,i+=1,进行下一轮匹配。

1
2
3
4
5
6
7
8
9
10
11
int find(const string& s, const string& p) {
for(int i=0;i<=s.size()-p.size();i++) {
int j=0; // 每次j都从0开始
for(;j<p.size();j++) {
if(s[i+j] != p[j]) break;
}
if(j == p.size())return i;
// 如果当前以i开始的字串不匹配,则从i+1继续尝试
}
return -1;
}

KMP的思想就是预处理p得到next数组,保证i不回退,next就是预先算出i不回退的情况下j应该回退到哪,这样算法复杂度就降到了O(len(s)+len(p)) 也就是 O(len(s))。

有时候模式串是固定的,需要重复在不同的串中查找模式串,所以next数组也可以预先算好一直复用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// 计算next数组,即j不匹配时的回退位置
vector<int> compute_next(const string& p){
vector<int> next(p.size(), 0);
for(int i=2; i<p.size(); i++){
int j = next[i-1];
while(p[i-1] != p[j] && j>0)
j = next[j];
if(p[i-1]==p[j])
next[i] = j+1;
else
next[i] = 0;
}
return move(next);
}

//串匹配的KMP算法
//返回s中第一个与p匹配的子串的起始下标,若找不到则返回-1
int find(const string& s, const string& p, const vector<int>* pnext=NULL) {
vector<int> _next;
if(pnext == NULL) {
_next = compute_next(p);
pnext = &_next;
}
const vector<int>& next(*pnext);
int i = 0, j = 0;
while(s[i] && p[j]){
if(s[i] == p[j]){
++i;
++j;
}
else {
if(j==0) ++i; // 第一个字符就不匹配,直接后移i
else if(p[j]==0) break; // p[j] 表示找到了匹配,跳出循环
else j = next[j]; // 这里就是利用预处理好的next来回退j,而i不用变
}
}
if(p[j]==0) return i-j;
else return -1;
}

有了next数组后的匹配就想前面说的,只要根据next进行回退就可以了,没有过多技巧。

那么主要讲一下next数组的生成思路,根据next的定义,其实next[i]表示的是p[i]前面最长能有多少字符和p的开头匹配

例如:我们生成 "aabaaab"的next数组,考察next[4]和next[5],p[4]的前面最长有“a”和p的开头匹配,所以next[4]=1,

p[5]的前面最长有“aa”和p的开头匹配,所以next[5]=2。

1
2
3
4
5
0123456
aabaaab
- -↑
aabaaab
-- --↑

总有next[0] = next[1] = 0,我们只要从下标2开始计算next。

对于next[i],我们可以采用数学归纳法的思维,我们找到i-1回退的位置,取j = next[i-1],如果p[i-1]==p[j],那么显然next[i] = next[i-1] + 1,如果p[i-1]!=p[j]呢,next[i]=0吗?并不是

我们还是以 aabaaab 为例,考察next[6],首先我们算出了next[0…5]=[0, 0, 1, 0, 1, 2],而p[5] !=p[2] (‘a’ != ‘b’)

1
2
3
4
5
0123456
aabaaab
-- --↑
aabaaab
-- --↑

这里就有个技巧了,对于p[i-1]和p[j]不匹配时,我们想知道让j回退多少,我们可以利用next数组的含义,尝试让j回退到next[j],再看看p[i-1]和p[j]是否相等,我们在生成next的时候就用到了规模更小的next,还是数学归纳法的思维,j=next[5]=2, 因为p[5]!=p[2] ,令j=next[j]=next[2]=1,而p[5]==p[1],所以next[6] = next[1] + 1 = 2,大致的理解思路就是这样,严格的证明见:前缀函数与 KMP 算法 - OI Wiki (oi-wiki.org)

实现比较简单直接看代码,说几点:

  1. (i%MAX_SIZE+MAX_SIZE)%MAX_SIZE 是为了支持负数下标,如果不需要负下标可以直接i%MAX_SIZE
  2. 循环队列中,因为begin、end一直增加,所以不需要full标志仍然可以把空间用足,不存在队满和队空条件相同。
  3. 如果限定MAX_SIZE是2的幂,可以用 i&(MAX_SIZE-1) 来代替取模,而且同样支持负数下标,真是又快又好 _
  4. CircularArray的主要用途是在DP循环的时候如果状态仅依赖前N项,那么可以简单地把空间节省到N。
  5. Queue的用途是为了替代std::queue,但实际上std::queue已经相当快了。

总的来说这两个数据结构在比赛中基本用不到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
template<typename T, int MAX_SIZE>
class CircularArray {
T data[MAX_SIZE];
public:
CircularArray(const T& t=T()) {fill(data, data+MAX_SIZE, t);}
T& operator [] (int i) {
return data[(i%MAX_SIZE+MAX_SIZE)%MAX_SIZE];
}
};

template<typename T, int MAX_SIZE>
class Queue {
CircularArray<T, MAX_SIZE> data;
int begin = 0;
int end = 0;
public:
void push(const T& t) {
assert(!full());
data[end++] = t;
}

T pop() {
assert(!empty());
return data[begin++];
}

T& peek() {
assert(!empty());
return data[begin];
}

const T& peek() const {
return const_cast<Queue*>(this)->peek();
}

bool empty() const {
return end == begin;
}

size_t size() const {
return end-begin;
}

bool full() const {
return size() == MAX_SIZE;
}
};

什么是树状数组?

树状数组又称二叉索引树(Binary Indexed Tree),又以其发明者命名为Fenwick树

是一种支持以O(logn)时间计算区间和同时以O(logn)时间修改元素值的数据结构。

它的功能可以被线段树替代,而且线段树提供了更多功能,树状数组的优势是实现简单。

树状数组的的实现

树状数组提供两种操作:

1)对单点赋值。

2)查询区间和。

(事实上也可以扩展出区间修改和单点查询,我们暂不考虑)

假设我们有一个数组

1
arr = {3, 8, 3, 3, 5, 6, 8, 7};

通常情况比如我们求区间[2,7)的和需要遍历区间上的元素,O(区间长度),如何减少运算次数呢,最简单的思路是我们预处理前缀和sum[],令

1
sum[i]= arr[0]+arr[1]+...+arr[i-1]

当我们需要算[2,7)的区间和,其实就算[0,7)前缀和减去[0,2)的前缀和,即sum[7]-sum[2],这样就可以用O(1)的时间算出任意区间的和,如果数组元素不会发生动态变化这样是可以的,但如果需要交替修改数组元素和查询区间和,这样处理会导致前缀和维护的成本很高,原本的arr[i]=x,我们不得不修改所有k>i的sum[k]来维护前缀和,这样修改数组元素的时间就从原本的O(1)变成了O(n)。

有没有方法可以在修改单点值的便利性和查询区间和的便利性上做个折中呢,肯定是有的,我们可以预处理一些子段和而不是所有前缀和,让修改单点值和查询区间和都只需要访问O(logn)的元素,树状数组线段树都是类似这个思想。

这是在讲线段树时的图,如果我们只考虑前缀和,即从0开始的区间和,这中间很多子段和的存储是不必要的,我们来看[0, 7)的和,如果我们有了下图维护的子段和信息,[0, 7)的和最快可以通过 [0, 4)的和 + [4, 6)的和 + [6,7)的和 = 17 + 11 + 8 计算得到。

segtree

我们看图时是很容易想到的,那么这个[0, 4)、[4, 6)、[6,7)划分是怎么得出来的呢,我们可以把前缀和[0, 7)的右端点7转换成2进制,即111,如果仅保留最左侧的1其他位置0,得到100,就是十进制的4,保留最左侧的两个1,得到110,就是十进制的6,保留最最侧的三个1,得到111,就是十进制的7。4,6,7 正好和我们的划分是一样的。7并不是特殊的,可以选择其他数字也都有这个规律。那么给定一个i,我们就可以通过不断把最右侧1变成0,记录这个过程中所有的数,就可以得到需要用到的子段和的划分。具体操作方式我们可以通过位运算。

1
2
3
4
// 得到n二进制最右侧的1表示的数
int lowbit(int n) {
return n & (-n);
}

还有一些等价的写法

1
n & ~(n-1) 或 n ^ (n & (n-1))

划分子段和的方式有了,刚才提到如果我们只计算前缀和,线段树维护的这些子段很多是多余的,再结合我们的划分方式,其实只要把子段和存储在子段的右端点即可。需要用到的子段右端点是不会重复的,因为任何一个右端点i 对应唯一子段就是 [k , i),其中k是i的二进制去掉最右的1。也就是下面这样,虚线上方圆圈内的值就是存储在右端点的子段和。

bitree

那么我们就可以仅用一个数组来存储上面的这棵树了。

1
2
bitree[] = {0, 3, 11, 3, 17, 5, 11, 8, 43};
0 1 2 3 4 5 6 7 8

实际上这里 len(bitree) = len(arr) + 1,不过bitree[0] 代表空的子段和,总是0,想省下多出来的1个单位空间也是可以的,不过没有必要。

所以求任意的区间和,我们先转成前缀和相减,再划分为子段和去bitree[]里取值就可以了

例如:

1
2
3
4
5
  [2, 7)的和
= [0, 7)的和 - [0, 2)的和
= (bitree[4] + bitree[6] + bitree[7]) - bitree[2]
= (17 + 11 + 8) - 11
= 25

用代码实现就是

1
2
3
4
5
6
7
8
9
// 查询前缀和
int query(int i) {
int result = 0;
while(i) {
result += bitree[i];
i -= lowbit(i); // 等价的 i &= i-1
}
return result;
}

剩下的问题是如何修改元素值呢?我们看上面的树状图,当要修改元素4的时候,会影响到[4,5)、[4,6),[0,8) 三个区间,即一直要沿着父节点修改到根,那么就是我们用bitree的下标表示就是5-6-8,如何得到这一串数呢,是否也和二进制存在某种关系呢,直接说答案,从要修改的元素编号+1开始,每次令i+=lowbit(i) 得到下个序号(正好是把查询里的减法变成加法),直到下标超出bitree的长度,事实上最后一次的下标总是bitree的根,也就是最后一个的元素。

1
2
3
4
5
6
7
8
// 修改元素i的值
void add(int i, int x) {
++i;
while(i<=8) {
bitree[i] += x;
i += lowbit(i);
}
}

这里实现的是add方法,如果我们想设置元素i为新值,可以

1
2
add(i, -query(i,i+1)); // 减去旧值
add(i, x); // 加上新值

完整实现

和线段树比一下是不是简单很多

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
template <typename T, int N=(1<<17)>
class BITree{
T C[N+1];
int lowbit(int i){return i&(-i);}
public:
BITree(){clear();}

void clear(){memset(C,0,sizeof(C));}

// 增量修改元素i的值
void add(int i, T d){
for(i++; i<=N; i+=lowbit(i)) C[i]+=d;
}

// [0, i)元素的和
T sum(int i){
T r=0;
for(; i; i-=lowbit(i)) r+=C[i];
return r;
}

// [b, e)元素的和
T sum(int b, int e) {
return sum(e) - sum(b);
}

// 修改元素i的值
void set(int i, T d) {
add(i, -sum(i, i+1));
add(i, d);
}
};

C++ STL 有提供priority_queue,但是它没有接口修改已经加入队列的元素的优先级。有时候我们希望修改队列中元素的优先级,可以通过priority_queue + 最后更新的时间戳,例如用一个map或unordered_map存储每个key最后更新的时间戳,然后将重复的key和priority的pair加入到priority_queue中,在取得队头元素时根据时间戳判断是否丢弃,这可以解决大多数问题,但如果遇到插入和更新特别多的情况,为了避免priority_queue迅速膨胀,我们不得不自己实现最大堆。

实际上通过之前讲过的Treap稍加改造就可以得到一个可以动态更新key优先级的Heap

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#include <iostream>
#include <vector>
#include <queue>
using namespace std;

namespace detail {

template<typename KT, typename PT>
struct TreeNode {
KT key;
PT priority;
int size;
TreeNode* left;
TreeNode* right;

TreeNode(const KT& key, const PT& priority):key(key),priority(priority),size(1),left(NULL),right(NULL){}
};

template<typename KT, typename PT>
class TreeNodePool {
int batchsize = 1024;
vector<char*> memchunks;
queue<char*> blocks;
public:
~TreeNodePool() {
for(int i=0;i<memchunks.size();i++) {
free(memchunks[i]);
}
}

TreeNode<KT, PT>* create(const KT& key, const PT& priority) {
if(blocks.empty()) {
char* p = (char*)malloc(sizeof(TreeNode<KT, PT>)*batchsize);
memchunks.emplace_back(p);
for(int i=0;i<batchsize;i++) {
blocks.push(p+sizeof(TreeNode<KT, PT>)*i);
}
batchsize *= 2;
}
char* p = blocks.front(); blocks.pop();
return new (p)TreeNode<KT, PT>(key, priority);
}

void release(TreeNode<KT, PT>* p) {
blocks.push((char*)p);
}
};
} // namespace detail

template<typename KT, typename PT, typename CMPF=less<PT> >
class Heap {
public:
// 插入或更新key对应的priority
void upsert(const KT& key, const PT& priority) {
erase(key);
insert(key, priority);
}
// 查看堆顶key
const KT& peek() const {
return root->key;
}
KT& peek() {
return root->key;
}
// 取出堆顶key
KT pop() {
KT result = peek();
erase(result);
return result;
}

bool empty() const {
return size() == 0;
}

int size() const {
return _size(root);
}

private:
int insert(const KT& key, const PT& priority) {
return _insert(root, create_tree_node(key, priority));
}

void erase(const KT& key) {
_erase(root, key);
}

private:
detail::TreeNode<KT, PT>* create_tree_node(const KT& key, const PT& priority) {
return pool.create(key, priority);
}

void release_tree_node(detail::TreeNode<KT,PT>* root) {
pool.release(root);
}

static int _size(detail::TreeNode<KT,PT>* root) {
return (root != NULL ? root->size : 0);
}

// 返回root子树的size减去左右子树的size
static int _count(detail::TreeNode<KT,PT>* root) {
if(root == NULL) return 0;
return _size(root) - _size(root->left) - _size(root->right);
}

static void lrotate(detail::TreeNode<KT,PT>* &root) {
auto right = root->right;
auto rightleft = right->left;
root->right = rightleft;
right->left = root;
int root_size = root->size;
root->size += _size(rightleft) - _size(right);
right->size = root_size;
root = right;
}

static void rrotate(detail::TreeNode<KT,PT>* &root) {
auto left = root->left;
auto leftright = left->right;
root->left = leftright;
left->right = root;
int root_size = root->size;
root->size += _size(leftright) - _size(left);
left->size = root_size;
root = left;
}

int _insert(detail::TreeNode<KT,PT>* &root, detail::TreeNode<KT,PT>* p) {
if(root == NULL) {
root = p;
return 1;
} else {
if(p->key < root->key) {
int r = _insert(root->left, p);
root->size += r;
if(CMPF()(root->priority, root->left->priority)){
rrotate(root);
}
return r;
} else if(root->key < p->key) {
int r = _insert(root->right, p);
root->size += r;
if(CMPF()(root->priority, root->right->priority)){
lrotate(root);
}
return r;
} else {
root->size += 1;
release_tree_node(p);
return 1;
}
}
}

static int _count(detail::TreeNode<KT,PT>* root, const KT& key) {
if(root == NULL) return 0;
if(key < root->key) return _count(root->left, key);
else if(root->key < key) return _count(root->right, key);
else return _count(root);
}

int _erase(detail::TreeNode<KT,PT>* &root, const KT& key) {
if(root == NULL) return 0;
if(key == root->key) {
auto left = root->left;
auto right = root->right;
if(left == NULL) {
int root_count = _count(root);
release_tree_node(root);
root = right;
return root_count;
} else if(right == NULL) {
int root_count = _count(root);
release_tree_node(root);
root = left;
return root_count;
} else {
int root_count = _count(root);
int r = 0;
if(CMPF()(left->priority, right->priority)) {
lrotate(root); // 把优先级大的孩子转成根
r = _erase(root->left, key);
} else {
rrotate(root);
r = _erase(root->right, key);
}
root->size -= r;
return r;
}
} else if(key < root->key) {
int r = _erase(root->left, key);
root->size -= r;
return r;
} else {
int r = _erase(root->right, key);
root->size -= r;
return r;
}
}

private:
detail::TreeNodePool<KT, PT> pool;
detail::TreeNode<KT, PT>* root = NULL;
};

什么是线段树?

线段树是一棵二叉树,每个节点维护一个区间和区间上的值,是一种用来维护 区间信息 的数据结构。

我们先从线段树能提供的操作上来理解。想象一个数组,每个下标上可以存一个值。所谓区间就是一段连续的数组下标。

我们可以进行的操作有:

1)区间赋值:设置一个区间内所有下标对应的值

2)单点赋值:设置单点的值

3)区间查询:查询一个区间内的最大值 / 最小值 / 值的和

4)单点查询:查询一个点的值

如果用数组来实现上述操作,所有区间操作的时间复杂度都是O(n),n是区间长度,单点操作的复杂度都是O(1),而线段树可以把上述全部操作的时间复杂度同时变成O(logn)

线段树的实现

根据维护的信息和支持的操作不同线段树分很多种,常用的有:着色线段树、覆盖线段树、求和线段树、最值线段树。

计数线段树:可以对区间反复覆盖,动态查询单点的覆盖次数。

着色线段树:可以对区间反复着色,动态查询单点的颜色。

求和线段树:可以对单点或区间赋值,动态查询区间数据的和。

最值线段树:可以对单点或区间赋值,动态查询区间数据的最值。

为了方便起见我用求和线段树进行讲解,并且只实现两种操作:

1)对单点赋值。

2)查询区间和。

假设我们有一个数组

1
arr = {3, 8, 3, 3, 5, 6, 8, 7};

通常情况比如我们求区间[2,7)的和需要遍历区间上的元素,O(区间长度),如何减少运算次数呢,最简单的思路是我们预处理前缀和sum[],令

1
sum[i]= arr[0]+arr[1]+...+arr[i-1]

当我们需要算[2,7)的区间和,其实就算[0,7)前缀和减去[0,2)的前缀和,即sum[7]-sum[2],这样就可以用O(1)的时间算出任意区间的和,如果数组元素不会发生动态变化这样是可以的,但如果需要交替修改数组元素和查询区间和,这样处理会导致前缀和维护的成本很高,原本的arr[i]=x,我们不得不修改所有k>i的sum[k]来维护前缀和,这样修改数组元素的时间就从原本的O(1)变成了O(n)。

有没有方法可以在修改单点值的便利性和查询区间和的便利性上做个折中呢,肯定是有的,我们可以预处理一些子段和而不是所有前缀和,让修改单点值和查询区间和都只需要访问O(logn)的元素,树状数组线段树都是类似这个思想,其实树状数组的设计更符合这个思想,但它只能用于求和,树状数组能做的线段树都能做,我们之后有空再讲树状数组,对于线段树我们先在这个数组上生成一个子段和数组,让

1
b = { arr[0]+arr[1],  arr[2]+arr[3], arr[4]+arr[5], arr[6]+arr[7]} = {11, 6, 11, 15};

这样求区间[2, 7)的和的时候,可以取b[1] + b[2] + arr[6],只需要算三个数的和就行了。

更近一步,我们可以在数组b上再合并建立 c = { b[0]+ b[1], b[2]+b[3], …}, d, e, …,直至把整个数组合并成一个元素,如下图

segtree

这样之后,对于这个数组上的任意区间的和都可以只取O(logn)的元素的和来得到。

为了方便,我们通过二叉树来维护这些数据,我们把最上面的整个数组的和43作为二叉树的根,同时记住这个节点对应的左右端点即0和8,我们把这个节点的数据表示为(0, 8, 43),然后把它的左孩子就是如图的和为17的节点,数据为(0, 4, 17),同理右孩子(4, 8, 26),以此类推,这样就可以建立一颗完全二叉树(我这里有意将数组的大小选为2的幂,如果不是2的幂我们也可以扩充到2的幂来建树)。

对于修改单点值的操作,我们要从树根一直修改到叶子,正好修改了树的深度个节点,复杂度也是O(logn)。

因为是完全二叉树,我们可以通过数组来实现,用data[0]作为根,data[k]的孩子是data[2*k+1]data[2*k+2]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <iostream>
using namespace std;

#define MAXN 8
int bg[MAXN*2]; // 区间左端点,闭
int nd[MAXN*2]; // 区间右端点,开
int sm[MAXN*2]; // 节点值,即区间和

// 递归构建树
void build(int k, int b, int e) {
bg[k] = b;
nd[k] = e;
sm[k] = 0;
if(e-b>1) {
int mid = b + (e-b) / 2;
build(k*2+1, b, mid);
build(k*2+2, mid, e);
}
}

// 设置数组下标i的元素值,类似arr[i] = val
void change(int k, int i, int val) {
if(nd[k]-bg[k]==1) { // 叶子节点
sm[k] = val; // 直接更新
} else {
int mid = bg[k] + (nd[k]-bg[k]) / 2;
// 递归向下更新
if(i<mid) change(k*2+1, i, val);
else change(k*2+2, i, val);
// 回溯时通过孩子的值更新当前节点值
sm[k] = sm[k*2+1] + sm[k*2+2];
}
}

// 查询[b, e)的元素和
int query(int k, int b, int e) {
if(b<=bg[k] && nd[k]<=e) {
cout << "DEBUG:[" << bg[k] << "," << nd[k] << ") sm[k]=" << sm[k] << endl;
return sm[k];
}
int mid = bg[k] + (nd[k]-bg[k]) / 2;
return (b<mid ?query(k*2+1, b, e):0) + (mid<e ? query(k*2+2, b, e):0);
}

// 把n向上对齐到2的幂
int align(int n) {
if(n & (n-1)) {
while(n & (n-1)) {
n &= (n-1);
}
n = (n << 1);
}
return n;
}

int main()
{
int arr[] = {3,8,3,3,5,6,8,7};
build(0, 0, align(8));
for(int i=0;i<8;i++) {
change(0, i, arr[i]);
}
cout << query(0, 2, 7) << endl;
}

输出

1
2
3
4
DEBUG:[2,4) sm[k]=6
DEBUG:[4,6) sm[k]=11
DEBUG:[6,7) sm[k]=8
25

可以看到,和预期的一样,只通过3个节点的值求和算出了区间[2,7)的和为25。

最值线段树

对于最大值、最小值的线段树,我们只需要把更新节点值的代码改掉。

区间更新

对于区间更新,例如把[a,b)的值全部设置为x,或者让[a,b)的值全部增加x,我们不能对区间内的点逐一更新,否则复杂度会变成O(nlogn)了,区间更新时,我们要引入懒惰标志,我们还是从根开始拆分区间,如果当前节点的区间被完全覆盖,我们就更新当前节点值并不向下继续更新,并把待更新的值记在懒惰标志内,直到查询需要用到节点值时再逐级更新并把懒惰标志一层层推下去。

离散化

我们可以看到上述线段树实现方式使用的数据空间是数组最大下标的2倍,有时候区间的范围很大,但区间的更新和查询操作的数量并不多,我们可以先统计所有的区间端点,比如有n个端点,然后排序,在所有区间端点和 0,1,…,n-1 之间做一一映射,然后就可以把线段树的空间降到 2*n,和区间个数相关而和区间范围无关。

动态开点

我们也可以不用数组表示的完全二叉树来实现线段树,而用一棵记录左右孩子指针的动态申请空间的普通二叉树,根据更新的区间来动态生成区间节点,这一般被称作动态开点。

动态开点线段树实现

下面是动态开点线段树的完整实现,包含了懒惰标志的使用,默认区间范围(-IINF, IINF),支持负下标,支持使用下标读写单点数据,支持使用下标修改区间数据,支持区间求和、最大值、最小值查询。
区间表示前闭后开[b, e)
支持单点修改 tree[i] = x; tree[i] += x;
支持区间修改 tree[{b,e}] = x; tree[{b,e}] += x;
支持取单点值 tree[i];
支持取区间sum/min/max tree.sum(b, e);

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
typedef long long ll;

class SegTree {
private:
enum { INF = ll(4.6e18) };

struct Node {
Node* left;
Node* right;
ll sum;
ll min;
ll max;
ll lazy;
bool set_not_add;
Node():left(0), right(0), sum(0), min(0), max(0), lazy(INF), set_not_add(false) {}
};

void update(Node* node, int start, int end, int l, int r, bool set_not_add, ll val) {
if (l <= start && end <= r) {
if(set_not_add) {
node->sum = (end - start + 1) * val;
node->min = val;
node->max = val;
node->lazy = val;
} else {
node->sum += (end - start + 1) * val;
node->min += val;
node->max += val;
node->lazy = (node->lazy!=INF?node->lazy:0) + val;
}
node->set_not_add = set_not_add;
return;
}
int mid = (start + end) >> 1;
pushDown(node, mid - start + 1, end - mid);
if (l <= mid) update(node->left, start, mid, l, r, set_not_add, val);
if (r > mid) update(node->right, mid + 1, end, l, r, set_not_add, val);
pushUp(node);
}

ll sum(Node* node, int start, int end, int l, int r) {
if (l <= start && end <= r) return node->sum;
int mid = (start + end) >> 1;
ll ans = 0;
pushDown(node, mid - start + 1, end - mid);
if (l <= mid) ans += sum(node->left, start, mid, l, r);
if (r > mid) ans += sum(node->right, mid + 1, end, l, r);
return ans;
}

ll min(Node* node, int start, int end, int l, int r) {
if (l <= start && end <= r) return node->min;
int mid = (start + end) >> 1;
ll ans = INF;
pushDown(node, mid - start + 1, end - mid);
if (l <= mid) ans = std::min(ans, min(node->left, start, mid, l, r));
if (r > mid) ans = std::min(ans, min(node->right, mid + 1, end, l, r));
return ans;
}

ll max(Node* node, int start, int end, int l, int r) {
if (l <= start && end <= r) return node->max;
int mid = (start + end) >> 1;
ll ans = -INF;
pushDown(node, mid - start + 1, end - mid);
if (l <= mid) ans = std::max(ans, max(node->left, start, mid, l, r));
if (r > mid) ans = std::max(ans, max(node->right, mid + 1, end, l, r));
return ans;
}

void pushUp(Node* node) {
node->sum = node->left->sum + node->right->sum;
node->min = std::min(node->left->min, node->right->min);
node->max = std::max(node->left->max, node->right->max);
}

void pushDown(Node* node, int leftNum, int rightNum) {
if (node->left == nullptr) node->left = new Node();
if (node->right == nullptr) node->right = new Node();
if (node->lazy == INF) return ;
if(node->set_not_add) {
node->left->sum = node->lazy * leftNum;
node->left->min = node->lazy;
node->left->max = node->lazy;
node->left->lazy = node->lazy;
node->right->sum = node->lazy * rightNum;
node->right->min = node->lazy;
node->right->max = node->lazy;
node->right->lazy = node->lazy;
} else {
node->left->sum += node->lazy * leftNum;
node->left->min += node->lazy;
node->left->max += node->lazy;
node->left->lazy += node->lazy;
node->right->sum += node->lazy * rightNum;
node->right->min += node->lazy;
node->right->max += node->lazy;
node->right->lazy += node->lazy;
}
node->left->set_not_add = node->set_not_add;
node->right->set_not_add = node->set_not_add;
node->lazy = INF;
}

void destory(Node* p) {
if(p == NULL) return;
auto left = p->left;
auto right = p->right;
delete p;
destory(left);
destory(right);
}

public:
Node* root = NULL;
int start;
int end;
SegTree(int start=0, int end=1e9) : root(new Node()), start(start), end(end) {}
~SegTree() { destory(root); }
private:
class ItemAccessor {
public:
ItemAccessor(SegTree& st, int i): st(st), i(i) {}
ItemAccessor& operator = (ll x) {st.update(st.root, st.start, st.end, i, i, true, x);return *this;}
ItemAccessor& operator += (ll x) {st.update(st.root, st.start, st.end, i, i, false, x);return *this;}
ItemAccessor& operator -= (ll x) {st.update(st.root, st.start, st.end, i, i, false, -x);return *this;}
operator ll() const {return st.min(st.root, st.start, st.end, i, i);}
private:
SegTree& st;
int i;
};

class RangeAccessor {
public:
RangeAccessor(SegTree& st, int b, int e):st(st), b(b), e(e) {}
void operator = (ll x) {st.update(st.root, st.start, st.end, b, e-1, true, x);}
void operator += (ll x) {st.update(st.root, st.start, st.end, b, e-1, false, x);}
void operator -= (ll x) {st.update(st.root, st.start, st.end, b, e-1, false, -x);}
private:
SegTree& st;
int b;
int e;
};

public:
// 通过下标访问单点值的语法糖
ItemAccessor operator [] (int n) {
return ItemAccessor(*this, n);
}

// 通过下标访问区间值的语法糖
RangeAccessor operator [] (const tuple<int, int>& range) {
return RangeAccessor(*this, std::get<0>(range), std::get<1>(range));
}

ll sum(int l, int r) { return sum(root, start, end, l, r-1); }
ll min(int l, int r) { return min(root, start, end, l, r-1); }
ll max(int l, int r) { return max(root, start, end, l, r-1); }
};

类比C++ STL容器,记录Python容器初始化的一些技巧和陷阱

初始化一个集合的数组

C++写法

1
vector<set<int>> a(n);

正确的Python写法

1
a = [set() for i in range(n)]

一个错误的Python写法

1
a = [set()] * n

如果是常数值,这样写没问题,但是set()是对象,这会导致a[0] a[1] … a[n-1] 存储的都是指向同一个set对象的引用

初始化一个整数到整数集合的映射

C++写法

1
2
3
4
5
map<int, set<int>> b;

for(int i=0;i<10;i++) {
b[i].insert(i); // 在首次访问b[i]时会自动创建一个空的set<int>,所以可以直接调用insert
}

有用的但不舒服Python写法

1
2
3
4
5
b = {}
for i in range(10):
if i not in b: # 首次使用b[i]前要先判断并初始化
b[i] = set()
b[i].add(i)

defaultdict

有用且舒服的Python写法

1
2
3
4
5
from collections import defaultdict

b = defaultdict(set) # 这里指定了dict的值类型
for i in range(10):
b[i].add(i) # 这里就可以像C++一样直接add了

嵌套defaultdict

如果我们想定义一个二重嵌套的defaultdict,类似map<int, map<int>>,不能简单地使用 defaultdict(defaultdict(int)),因为defaultdict需要传递的参数是一个类型T,可以通过T()构造出默认的对象,而defaultdict(int)是对象而不是对象的类型,其实关键是通过T()能得到构造出默认对象就可以,所以我们可以定义一个方法返回defaultdict(int)

1
2
3
def defaultdict_int_creator():
return defaultdict(int)
defaultdict(defaultdict_int_creator)

这样就可以定义一个二重嵌套的defaultdict,更简单地,我们可以把函数改成lambda表达式

1
2
3
4
# 二重嵌套 defaultdict
defaultdict(lambda: defaultdict(int))
# 三重嵌套 defaultdict
defaultdict(lambda: defaultdict(lambda: defaultdict(int)))

质数(又称素数):只能被1和自身整除的大于1的正整数。

质数是初等数论的重要研究对象,有关质数的常用算法有如下一些:

线性筛法:O(n)时间筛出n以内的所有质数

质数判定:判断给定数n是否是质数。大质数判定有Miller Rabin测试法

合数分解:找到给定合数n的一个非平凡因子。大数分解有Pollard Rho分解法

BTW:一些和整除、因子、质数相关的题目,先做质因数分解会后豁然开朗,int范围内的数至多有1300多个因子,先分解再枚举所有因子,甚至对所有因子双层循环的时间都是可以接受的。

下面给出一个包含线性筛、Miller Rabin测试、Pollard Rho分解、质因数分解的质数工具类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#include <iostream>
#include <cmath>
#include <cassert>
#include <map>
#include <cstring>
#include <vector>
using namespace std;

typedef long long ll;

template<typename T>
T gcd(T a, T b) { while(b) { T r = a%b; a = b; b = r;} return a;}
template<typename T>
T lcm(T a, T b) { return a/gcd(a,b)*b; }

// 打表大小,[0, UPBOUND)查表,[UPBOUND, UPBOUND*UPBOUND)根据表测试,[UPBOUND*UPBOUND, ll_max)米勒罗宾测试
const ll UPBOUND = 5e6;
class CPrime {
public:
ll modMul(ll a,ll b,ll m) {
ll t=0;
a=(a%m+m)%m;
b=(b%m+m)%m;
while(b){
if(b&1){
t=(t+a)%m;
}
a=(a+a)%m;
b>>=1;
}
return t;
}

ll modExp(ll a,ll b,ll m) {
ll t=1,y=a%m;
while(b){
if(b&1){
t=modMul(t,y,m);
}
y=modMul(y,y,m);
b=(b>>1);
}
return t;
}

bool miller_rabin(ll n,ll b) {
ll m=n-1;
int j=0;
while(!(m&1)){
m>>=1;
j++;
}
ll v=modExp(b,m,n);
if(v==1 || v==n-1)return 1;
for(int i=1;i<j;i++){
v=modMul(v,v,n);
if(v==n-1)return 1;
}
return 0;
}

bool _isprime(ll n)
{
const int K=10;//Miller_Rabin的偏真正确率为75%, isprime的正确率为1-(1/4)^K
if(n<2)return 0;
if(n==2 || n==3)return 1;
for(int i=0;i<K;i++){
if(!miller_rabin(n,rand()%(n-2)+2))return 0;
}
return 1;
}

//pollard_rho分解,给出N的一个非1因数,返回N时为一次没有找到,C为一个[1,N]的随机数
ll pollard_rho(ll C, ll N)
{
ll I, X, Y, K, D;
I = 1;
X = rand() % N;
Y = X;
K = 2;
do{
I++;
D = gcd(N + Y - X, N);
if (D > 1 && D < N) return D;
if (I == K) Y = X, K *= 2;
X = (modMul(X, X, N) + N - C) % N;
}while (Y != X);
return N;
}

//找出N的最小(质)因数
ll _min_divisor(ll N)
{
if (isprime(N)) return N;
while(1){
ll T = pollard_rho(rand() % (N - 1) + 1, N);
if (T < N){
ll A, B;
A = _min_divisor(T);
B = _min_divisor(N / T);
return A < B ? A : B;
}
}
}

int a[UPBOUND];
int p[UPBOUND];
int pn = 0;

//号称线性的筛素数算法,实际性能确实不错
//p[]={2,3,5,7,...},pn为小于UPBOUND的素数个数
//若i是合数a[i]为i的最小因子,若i是素数a[i]=0
void primefilter(){
int i, j;
for(i = 2; i < UPBOUND; ++i){
if(!(a[i])) p[pn++] = i;
for(j = 0; (j<pn && i*p[j]<UPBOUND && (p[j]<=a[i]||a[i]==0)); ++j) {
a[i*p[j]] = p[j];
}
}
}
public:
CPrime(){
memset(a, 0, sizeof(a));
memset(p, 0, sizeof(p));
primefilter();
}

bool isprime(ll n){
if(n<UPBOUND) return (n>0?!a[n]:0);
int c = min(3500, pn-1); // 3500个足够支持到1e9
if(n<p[c]*p[c]) {
for(int i=0; i<c; i++) {
if(n%p[i]==0) return false;
}
return true;
} else {
int c = min(1000, pn);
for(int i=0; i<c; i++) {
if(n%p[i]==0) return false;
}
return _isprime(n);
}
}

// 返回质数表第n个质数,p[0]=2, ..., p[n]<UPBOUND
int nth_prime(int n) {
return n<pn?p[n]:-32768;
}

// 返回n的最小因子
ll min_divisor(ll n) {
if(n<UPBOUND) return a[n]?a[n]:n;
int c = min(3500, pn-1); // 3500个足够支持到1e9
if(n<p[c]*p[c]) {
for(int i=0; i<c; i++) {
if(n%p[i]==0) return p[i];
}
return n;
} else {
// miller_rabin 和 pollard_rho 需要较多的运算
// 如果有一个小因子不妨先找出来
int c = min(1000, pn);
for(int i=0; i<c; i++) {
if(n%p[i]==0) return p[i];
}
return _min_divisor(n);
}
}

// 分解质因数
vector<pair<ll, int>> factorize(ll n) {
vector<pair<ll, int>> result;
while(n!=1) {
ll p = min_divisor(n);
int c = 0;
do {
n/=p;
++c;
} while(n%p==0);
result.emplace_back(make_pair(p, c));
}
return result;
}

// 欧拉函数
ll phi(ll n) {
auto factors = factorize(n);
for(auto [p, c] : factors) {
n = n / p * (p-1);
}
return n;
}
} Prime;