0%

UV 是一款基于 Rust 编写的高性能 Python 包管理工具,旨在替代传统的工具如 pippip-toolsvenv。UV 的设计目标是简化 Python 项目的依赖管理和虚拟环境配置。

如何安装uv命令?

1)在没有python的环境下,可以使用curl下载安装

curl -LsSf https://gh-proxy.com/https://raw.githubusercontent.com/astral-sh/uv/main/install.sh | sh

2)在有python的环境下,可以通过pip安装

pip install uv

修改Python解释器安装镜像

uv可以在没有python的环境下安装,因为它可以维护python的版本(像anaconda或者nvm一样),默认的镜像服务器从国内访问不稳定,所以可以换成国内镜像地址。

export UV_PYTHON_INSTALL_MIRROR="https://gh-proxy.com/https://github.com/astral-sh/python-build-standalone/releases/download/"

可以加到~/.bashrc

如何初始化一个uv管理的Python项目?

uv init

会在当前文件夹下初始化项目,会自动创建配置文件,包括:

pyproject.toml 现代Python项目通过pyproject.toml管理项目元数据(其中包括项目的依赖包,其替代了传统的requirements.txt管理项目依赖)
.python-version 记录了项目的Python解释器版本,可以通过uv python pin 3.10来修改Python版本(3.10改成希望固定的Python版本)

如何给项目添加依赖包?

给项目添加依赖和安装包是同时进行的,通过

uv add <包名>

这个命令会自动维护pyproject.toml想其中添加包依赖,以及在向虚拟环境安装包的同时更新uv.lock

uv.lock 描述了当前环境具体的完全的包依赖的关系图。

如何给uv指定包的安装源?

和pip一样有时候我们希望把安装源指定成国内镜像或者私有源,可以通过

uv.toml

项目级别私有源配置放在项目根目录(和pyproject.toml同目录)

系统用户级私有源配置放在~/.config/uv/uv.toml

1
2
3
[[index]]
url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
default = true

如何启动项目

uv run main.py

无论是配置好的旧环境还是纯新的环境,都可以通过这一句启动项目,它会自动检查是否存在正确版本的Python解释器,如果不存在则自动下载安装,自动检查是否创建了虚拟环境如果没有创建则创建并激活,自动检查当前虚拟环境和uv.lock中描述的是否匹配,如果不匹配则自动更新虚拟环境,一切准备好了则启动。

提交必要文件到git

pyproject.toml

.python-version

uv.lock

uv.toml 如果存在项目级的

总结下基于uv的现代Python项目标准工作流

  1. 初始化项目uv init
  2. 固定python版本uv python pin 3.10
  3. 手动编辑pyproject.toml中dependencies以外的字段
  4. 通过uv add [-r]安装包(代替pip install [-r]
  5. 运行项目uv run main.py

其他一些也许会用到的命令

uv python list 查看当前系统存在的Python版本
uv run python --version 查看当前Python版本
uv pip freeze|grep numpy 查看虚拟环境下某个包版本
uv pip install <包名> 临时安装一个包到虚拟环境下,用于测试
uv sync 把环境恢复到标准的配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

import os
import threading


# 抄自 functools.lrucache
class _HashedSeq(list):
__slots__ = 'hashvalue'

def __init__(self, tup, hash=hash):
self[:] = tup
self.hashvalue = hash(tup)

def __hash__(self):
return self.hashvalue


# 抄自 functools.lrucache,用于打包函数参数做key
def _make_key(args, kw, typed=False):
key = args
if kw:
kwd_mark = (object(),)
key += kwd_mark
for item in kw.items():
key += item
if typed:
key += tuple(type(v) for v in args)
if kw:
key += tuple(type(v) for v in kw.values())
elif len(key) == 1 and type(key[0]) in {int, str}:
return key[0]
return _HashedSeq(key)


class ThreadSingletonMeta(type):
_instances = {}
_lock = threading.RLock()

def __call__(cls, *args, **kwargs):
key = (cls, threading.get_ident(), _make_key(args, kwargs))
with cls._lock:
if key not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[key] = instance
return cls._instances[key]


class ProcessSingletonMeta(type):
_instances = {}
_lock = threading.RLock()

def __call__(cls, *args, **kwargs):
key = (cls, os.getpid(), _make_key(args, kwargs))
with cls._lock:
if key not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[key] = instance
return cls._instances[key]


class SingletonMeta(type):
_instances = {}
_lock = threading.RLock()

def __call__(cls, *args, **kwargs):
key = (cls, _make_key(args, kwargs))
with cls._lock:
if key not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[key] = instance
return cls._instances[key]


# Test
if __name__ == '__main__':

class A(metaclass=SingletonMeta):
def __init__(self):
print("A.__init__")

a1 = A()
a2 = A()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import struct
import numpy as np
import pandas as pd


# 返回一个字符串是整数或浮点数按机器默认大小端的16进制字节数据表示,方便精确的对比
def number_to_hex(num):
if isinstance(num, int):
if -2147483648 <= num <= 2147483647:
return struct.pack('@i', num).hex()
elif -9223372036854775808 <= num <= 9223372036854775807:
return struct.pack('@q', num).hex()
else:
return num
if isinstance(num, float):
return struct.pack('@d', num).hex()
if isinstance(num, np.float32):
return num.tobytes().hex()
if isinstance(num, np.float64):
return num.tobytes().hex()
return num


# 把包含number的容器或pd.Series、pd.DataFrame中的number全部转成16进制的数据表示
def to_hex(data):
if isinstance(data, list):
return [to_hex(e) for e in data]
if isinstance(data, dict):
return {to_hex(k): to_hex(v) for k, v in data.items()}
if isinstance(data, pd.Series):
return data.apply(lambda x: to_hex(x) if pd.notna(x) else None)
if isinstance(data, pd.DataFrame):
result = []
for index, row in data.iterrows():
row_hex = []
for value in row:
try:
row_hex.append(to_hex(value))
except ValueError:
row_hex.append(None)
result.append(row_hex)
return pd.DataFrame(result, index=data.index, columns=data.columns)
return number_to_hex(data)


# test
if __name__ == '__main__':
a = 1234567890123
e = 2.71
p = 3.14
print(to_hex(a), to_hex(e), to_hex(p))
print(to_hex(pd.DataFrame([[1,2],[3,4]], columns=["A", "B"])))
print(to_hex({1.3: [1, 2.3]}))

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import re

class Token:
def __init__(self, type, value):
self.type = type
self.value = value


def tokenize(expression):
# 词法规则
token_specification = [
('STOCKSET', r'(ALL|[0-9a-zA-Z]+\.[0-9a-zA-Z]+)'), # 股票集,指数代码或自定义股票组代码
('PLUS', r'\+'), # 匹配加号
('MINUS', r'-'), # 匹配减号
('INTERSECT', r'\&'), # 匹配&
('LPAREN', r'\('), # 匹配左括号
('RPAREN', r'\)'), # 匹配右括号
('WS', r'\s+'), # 匹配空格,忽略
]
tok_regex = '|'.join('(?P<%s>%s)' % pair for pair in token_specification)
token_list = []
for mo in re.finditer(tok_regex, expression):
kind = mo.lastgroup
value = mo.group()
if kind == 'WS':
continue # 忽略空格
token_list.append(Token(kind, value))
return token_list


class ASTNode:
def __init__(self, type, value=None, left=None, right=None):
self.type = type
self.value = value
self.left = left
self.right = right


def parse_expression(tokens):

def _parse_expression(tokens):
node = _parse_term(tokens)
while tokens and (tokens[0].type in ('PLUS', 'MINUS')):
op = tokens.pop(0).type
right = _parse_term(tokens)
node = ASTNode(op, left=node, right=right)

return node

def _parse_term(tokens):
node = _parse_factor(tokens)
while tokens and (tokens[0].type in ('INTERSECT',)):
op = tokens.pop(0).type
right = _parse_factor(tokens)
node = ASTNode(op, left=node, right=right)
return node

def _parse_factor(tokens):
if not tokens:
raise SyntaxError("Invalid syntax")
if tokens[0].type == 'STOCKSET':
value = tokens.pop(0).value
return ASTNode('STOCKSET', value=value)
elif tokens[0].type == 'LPAREN':
tokens.pop(0) # 去掉左括号
node = _parse_expression(tokens)
if not tokens or tokens[0].type!= 'RPAREN':
raise SyntaxError("Invalid syntax")
tokens.pop(0) # 去掉右括号
return node
else:
raise SyntaxError("Invalid syntax")

ret = _parse_expression(tokens)
if len(tokens):
raise SyntaxError("Invalid syntax")
return ret


def evaluate_ast(node, code2set):
if node.type == 'STOCKSET':
return code2set(node.value)
elif node.type == 'PLUS':
return evaluate_ast(node.left, code2set).union(evaluate_ast(node.right, code2set))
elif node.type == 'MINUS':
return evaluate_ast(node.left, code2set) - evaluate_ast(node.right, code2set)
elif node.type == 'INTERSECT':
return evaluate_ast(node.left, code2set) & evaluate_ast(node.right, code2set)


class StockSetExpression:
# exp_str: 表达式的字符串
def __init__(self, exp_str):
exp_str = exp_str
tokens = tokenize(exp_str) # 词法分析,得到token序列
self.ast = parse_expression(tokens) # 语法分析,得到抽象语法树

# 表达式求值
# code2set: 指数/股票组代码到股票代码集合的转换方法
def evaluate(self, code2set):
return evaluate_ast(self.ast, code2set) # 遍历抽象语法树计算表达式的值

# 返回股票集code列表
def get_stock_set_codes(self):
ret = []
def dfs(node):
if node.type == 'STOCKSET':
ret.append(node.value)
else:
if node.left:
dfs(node.left)
if node.right:
dfs(node.right)
dfs(self.ast)
return ret



# test
if __name__ == "__main__":
def bind(func, *args, **kw):
return lambda *_args, **_kw: func(*args, *_args, **kw, **_kw)

# expression = "(000300.SH + 000905.SH) & 000852.SH - 399006.SZ"
expression = "000300.SH + 000905.SH & 000852.SH - 399006.SZ"

# 得到股票成分股集合
def get_index_con(index_code, trade_date):
if index_code == "000300.SH":
return {"000001.SZ", "000002.SZ", "000003.SZ"}
if index_code == "000905.SH":
return {"000002.SZ", "000003.SZ", "000004.SZ"}
if index_code == "000852.SH":
return {"000003.SZ", "000004.SZ", "000005.SZ"}
if index_code == "399006.SZ":
return {"000004.SZ", "000005.SZ", "000006.SZ"}

# 创建一个表达式对象,实际上就是做好词法分析语法分析,把抽象语法树结构建出来
exp_obj = StockSetExpression(expression)
print("Result of {}: {}".format(
expression,
# 表达式求值,实际上就是遍历树,传入的是叶子结点的求值方式,也就是具体的指数如何取得当日成分股
exp_obj.evaluate(bind(get_index_con, trade_date='2024-12-17'))
))

print(exp_obj.get_stock_set_codes())

# 这样就把建立树结构和每天求表达式的值分离,从而不需要每天都重新parse表达式
# 同时又把具体的求指数成分股的函数从表达式求值的过程中分离成可以从外部指定

简单描述kosaraju算法就两个步骤

  1. 在反向图上做后根遍历,按遍历顺序把结点压栈

  2. 以出栈顺序为初始节点去原图上遍历

在第二次遍历时,每次以初始节点遍历到的所有节点都和初始节点在同一个强连通分量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

void dfs1(int u, const vector<vector<int>>& reverse_edges, vector<bool>& visited, stack<int>& s) {
visited[u] = true;
for(int v: reverse_edges[u]) {
if(!visited[v]) {
dfs1(v, reverse_edges, visited, s);
}
}
s.push(u);
}

void dfs2(int u, const vector<vector<int>>& edges, vector<bool>& visited, int source, vector<int>& result) {
visited[u] = true;
result[u] = source;
for(int v: edges[u]) {
if(!visited[v]) {
dfs2(v, edges, visited, source, result);
}
}
}

// 计算图的强联通分量
// params: edges 邻接表
// returns: 返回每个结点所在的强联通分量的代表元
// 相同代表元说明属于同一强联通分量,不同代表元说明属于不同强联通分量
// 返回的不同值的个数就是强连通分量数
vector<int> kosaraju(const vector<vector<int>>& edges) {
int n = edges.size();
vector<int> result(n);

// 反向图
vector<vector<int>> reverse_edges(n);
for(int u=0;u<n;u++) {
for(int v: edges[u]) {
reverse_edges[v].push_back(u);
}
}

stack<int> s;
do {
// 在反向图上做后根遍历,把结点压栈
vector<bool> visited(n, false);
for(int u=0;u<n;u++) {
if(!visited[u]) {
dfs1(u, reverse_edges, visited, s);
}
}
} while(false);

do {
vector<bool> visited(n, false);
// 这时依次按照s中弹出的结点为开始结点去遍历原图(即使用edges),就可以得到强连通分量
while(!s.empty()) {
int u = s.top(); s.pop();
if(!visited[u]) {
dfs2(u, edges, visited, u/*代表元*/, result);
}
}
} while(false);
return result;
}

有了代表元数组就可以把原图简化成一个由强连通分量构成的有向无环图,再去做进一步处理就方便了,比如拓朴排序之类的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

template<int MOD>
class Combination {
vector<vector<int>> _comb;
public:
Combination(int max_n) {
// 打表组合值
_comb.resize(max_n+1, vector<int>());
for(int i=0;i<=max_n;i++) {
_comb[i].resize(i/2+1);
}
_comb[0][0] = 1;
for(int i=1;i<=max_n;i++) {
_comb[i][0] = 1;
for(int j=1;j<=i/2;j++) {
_comb[i][j] = (_comb[i-1][j-1] + _comb[i-1][j<=(i-1)/2?j:i-1-j]) % MOD;
}
}
}

// 组合数 n!/k!/(n-k)!
int operator()(int n, int k) {
if(k < 0 || n < 0 || k > n) return 0;
if(k > n/2) return _comb[n][n-k];
else return _comb[n][k];
}
};


// 不取MOD的特化实现
template<>
class Combination<0> {
vector<vector<int>> _comb;
public:
Combination(int max_n) {
// 打表组合值
_comb.resize(max_n+1, vector<int>());
for(int i=0;i<=max_n;i++) {
_comb[i].resize(i/2+1);
}
_comb[0][0] = 1;
for(int i=1;i<=max_n;i++) {
_comb[i][0] = 1;
for(int j=1;j<=i/2;j++) {
_comb[i][j] = (_comb[i-1][j-1] + _comb[i-1][j<=(i-1)/2?j:i-1-j]);
}
}
}

// 组合数 n!/k!/(n-k)!
int operator()(int n, int k) {
if(k < 0 || n < 0 || k > n) return 0;
if(k > n/2) return _comb[n][n-k];
else return _comb[n][k];
}
};

并查集结构支持3种操作:

  1. 合并两个集合
  2. 查询两个元素是否属于相同集合
  3. 查询元素所在集合的大小

实现的思路也很简单,最初每个元素都在一个只包含自身的集合中,之后通过Union操作建立关联,把一个元素的父亲指向另一个元素,那么判断两个元素是否属于同一集合只需要判断两个元素的根是否相同,这里顺便用parent记录下元素的个数,当parent[i]>=0时表示元素的父元素的下标,当parent[i]<0时(意味着i是某个集合的根元素),parent[i]则表示该集合的大小。
优化查询:在查询元素所属集合的根元素时顺便把自己的父元素直接指向根元素,也就是把树压扁。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

class UFSets {
private:
vector<int> parent;
int size;
public:
UFSets(int s): parent(s){
size = s;
memset(&parent[0], -1, sizeof(int)*s);
}

int Find(int x){
if (parent[x] < 0) {
return x;
} else {
return parent[x] = Find(parent[x]);
}
}

void Union(int v1, int v2){
int s1 = Find(v1), s2 = Find(v2);
if(s1==s2)return;
int t = parent[s1] + parent[s2];
if ( parent[s2] < parent[s1] ) {
parent[s1] = s2;
parent[s2] = t;
}
else {
parent[s2] = s1;
parent[s1] = t;
}
}

int Count(int x){
return -parent[Find(x)];
}
};

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

// 马拉车算法,O(n)时间处理字符串所有子串的回文问题
// https://oi-wiki.org/string/manacher/
// 例题:https://leetcode.cn/problems/check-if-dfs-strings-are-palindromes/description/
// https://leetcode.cn/problems/longest-palindromic-substring/description/
class Manacher {
vector<int> d1; // 奇数长度子串的半径,d1[i] = r 表示,以下标i为中心的最长回文子串半径是r(长度是2r-1)
vector<int> d2; // 偶数长度子串的半径,d2[i] = r 表示,以下标i为右中心的最长回文子串半径是r(长度是2r)
public:
Manacher(const string& s) {
int n = s.size();
d1.resize(n);
d2.resize(n);

for (int i = 0, l = 0, r = -1; i < n; i++) {
int k = (i > r) ? 1 : min(d1[l + r - i], r - i + 1);
while (0 <= i - k && i + k < n && s[i - k] == s[i + k]) {
k++;
}
d1[i] = k--;
if (i + k > r) {
l = i - k;
r = i + k;
}
}


for (int i = 0, l = 0, r = -1; i < n; i++) {
int k = (i > r) ? 0 : min(d2[l + r - i + 1], r - i + 1);
while (0 <= i - k - 1 && i + k < n && s[i - k - 1] == s[i + k]) {
k++;
}
d2[i] = k--;
if (i + k > r) {
l = i - k - 1;
r = i + k;
}
}
}

// 返回s[b:e]是否回文
bool isPalindrome(int b, int e) {
if((e-b)%2) { // 奇数长度的子串
int c = b + (e-b) / 2;
int r = c - b + 1;
return d1[c] >= r;
} else { // 偶数长度的子串
int c = b + (e-b) / 2;
int r = c - b;
return d2[c] >= r;
}
}

// 返回s的最长回文子串的begin、end
pair<int, int> longestPalindrome() {
int b=0, e=0;
for(int c=0;c<d1.size();c++) {
if(d1[c]*2-1>e-b) {
b = c+1-d1[c];
e = c + d1[c];
}
if(d2[c]*2>e-b) {
b = c - d2[c];
e = c + d2[c];
}
}
return {b, e};
}
};

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import colorsys

def generate_colors(n):
"""
生成 n 个尽量有明显区别的颜色。

参数:
n (int):要生成的颜色数量,必须是正整数。

返回:
list:包含 n 个 RGB 颜色值的列表,每个颜色值是一个三元组 (r, g, b)。
"""
colors = []
cnt = 0
k = (n + 9) // 10
for i in range(k):
for j in range(10):
h = (360 / 10) * j + 23 * i
s = 1.0 if cnt < 25 else 0.8
l = 0.3 + (j * 19) % 7 / 7 * 0.4
cnt += 1
r, g, b = colorsys.hls_to_rgb(h / 360, l, s)
colors.append((int(r * 255), int(g * 255), int(b * 255)))
if cnt == n:
break
return colors

import matplotlib.pyplot as plt
import numpy as np

colors = generate_colors(30)

fig, ax = plt.subplots(figsize=(6, 6))
for i, color in enumerate(colors):
rect = plt.Rectangle((0, -i), 1, 1, color=np.array(color) / 255)
ax.add_patch(rect)
ax.set_xlim(0, 1)
ax.set_ylim(-len(colors), 0)
ax.axis('off')
plt.show()

用法

导入全局的profiler或创建一个局部的profiler

给函数添加装饰器@profile(profiler)

给代码段增加with profile_scope(profiler, name)

在统计结束后程序退出前调用print_profile(profiler)会在控制台输出各个线程中各个方法或代码段的耗时和调用次数,通过缩进体现调用关系。

如果想输出到日志文件,可以给print_profile传递print_func参数。

参考print_profile遍历节点的方式可以实现自定义输出或功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

import time
import threading
from collections import deque
from functools import wraps
from contextlib import contextmanager


# 使用孩子兄弟树存储调用关系(对于递归的情况特判)
class Node:
def __init__(self, name, parent):
self.name = name
self.parent = parent
self.sibling = None
self.child = None
self.total_calls = 0 # 调用计数
self.recur_counter = 0 # 递归层数
self.start_ts = 0
self.total_ts = 0
self.is_recursive = False

def get_sub_node(self, name):
p = self.child
if p is None:
q = Node(name, self)
self.child = q
return q
else:
# 找到name节点并返回,如果找不到,退出循环时p是当前节点的最后一个孩子
while True:
if p.name == name:
return p
if p.sibling is None:
break
p = p.sibling
# 创建新节点
q = Node(name, self)
# 把新节点作为p的兄弟,即新的最后一个孩子
p.sibling = q
return q

def call(self):
self.total_calls += 1
self.recur_counter += 1
if self.recur_counter == 1:
self.start_ts = time.perf_counter()

def returns(self):
if self.recur_counter == 1:
self.total_ts += time.perf_counter() - self.start_ts
self.recur_counter -= 1

# 返回名字=name的祖先节点或当前节点自身
# 如果没有名字=name的祖先并且自身名字也不是name,则返回None
# 此方法用来辅助处理递归调用
def ancestor_or_self(self, name):
p = self
while p:
if p.name == name:
return p
p = p.parent
return None


class Profiler:
def __init__(self):
self.node_stack = deque()
self.roots = {}
self.current_node = None
self.main_tid = threading.get_ident()

def begin_range(self, name):
# 每个线程需要创建一个root节点
tid = threading.get_ident()
if tid not in self.roots:
root_name = f"thread:{tid}" if tid != self.main_tid else "main_thread"
self.current_node = self.roots[tid] = Node(root_name, None)

self.node_stack.append(self.current_node)
p = self.current_node.ancestor_or_self(name)
if p: # 说明发生了递归
self.current_node = p
else:
self.current_node = self.current_node.get_sub_node(name)
self.current_node.call()

def end_range(self):
self.current_node.returns()
self.current_node = self.node_stack.pop()


class ProfilerViewer:
def __init__(self, root, view_func):
self.root = root
self.view_func = view_func

def view(self):
self._view(self.root, 0)

def _view(self, p, depth):
self.view_func(p, depth)
if p.child:
self._view(p.child, depth+1)
if p.sibling:
self._view(p.sibling, depth)


def print_profile(profiler, print_func=print):
def view_func(p, depth):
if p.parent and p.parent.total_ts:
pct = p.total_ts / p.parent.total_ts
else:
pct = 1.0
s = r"{}{} {:.2f}pct {:.3f}ms {}calls".format(' '*depth, p.name, pct*100, p.total_ts*1000, p.total_calls)
print_func(s)

for root in profiler.roots.values():
viewer = ProfilerViewer(root, view_func)
viewer.view()


# 通过装饰器的方式统计一个函数的耗时和调用次数
def profile(profiler):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
profiler.begin_range(func.__name__)
ret = func(*args, **kwargs)
profiler.end_range()
return ret
return wrapper
return decorator


# 通过with的方式统计一个代码块的耗时和调用次数
@contextmanager
def profile_scope(profiler, name):
try:
profiler.begin_range(name)
yield
except Exception:
raise
finally:
profiler.end_range()


profiler = Profiler()


if __name__ == '__main__':

@profile(profiler)
def fib(n):
if n == 1 or n == 2:
return 1

with profile_scope(profiler, "sleep"):
time.sleep(0.1) # 模拟函数实现中一处耗时很多的操作

# 递归调用自身
return fib(n-1) + fib(n-2)

@profile(profiler)
def add(a, b):
return a + b

th = threading.Thread(target=add, args=(1, 2))
th.start()

@profile(profiler)
def test():
# 计算 fib(5)
print(fib(5))

# 计算10000次add
for i in range(10000):
add(1,i)

test()

print_profile(profiler)