对pymongo的封装

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# -*- coding: utf-8 -*-
import pymongo
import threading


class _HashedSeq(list):
__slots__ = 'hashvalue'

def __init__(self, tup, hash=hash):
self[:] = tup
self.hashvalue = hash(tup)

def __hash__(self):
return self.hashvalue


def _make_key(args, kw, typed=False):
key = args
if kw:
kwd_mark = (object(),)
key += kwd_mark
for item in kw.items():
key += item
if typed:
key += tuple(type(v) for v in args)
if kw:
key += tuple(type(v) for v in kw.values())
elif len(key) == 1 and type(key[0]) in {int, str}:
return key[0]
return _HashedSeq(key)


local = threading.local()

class MongoDB:
@staticmethod
def instance(*args, **kw):
if not hasattr(local, "mongodb") or local.mongodb is None:
local.mongodb = {}
key = _make_key(args, kw, typed=False)
if key not in local.mongodb:
local.mongodb[key] = MongoDB(*args, **kw)
return local.mongodb[key]

def __init__(self, host="127.0.0.1", port=27017, dbname=None, username=None, password=None, **kw):
self.client = pymongo.MongoClient(f"mongodb://{host}:{port}/{dbname}", username=username, password=password, **kw)
if dbname:
self.db = self.client[dbname]
else:
self.db = self.client.test

def close(self):
return self.client.close()

def __getattr__(self, key):
return self.db[key]

def __getitem__(self, key):
return self.db[key]

def has_collection(self, name):
return name in self.db.list_collection_names()


# 常用运算符
# 逻辑运算
# {$or:[expression1,expression2,...]}
# {$and:[expression1,expression2,...]}
# {$not:expression1}}
# 比较运算
# {field:{$eq:value}} ==
# {field:{$ne:value}} !=
# {field:{$lt:value}} <
# {field:{$lte:value}} <=
# {field:{$gt:value}} >
# {field:{$gte:value}} >=
# {field:{$in:[value1,value2,...]} in
# {field:{$nin:[value1,value2,...]} not in
# 正则匹配
# {field:{"$regex": "正则表达式"}

# query: 查询条件, ex: {"type": "stock"}
# fields: 提取字段, ex: ["code", "sec_name"]
# sort: 排序字段和排序方式, ex: [("code", -1)] 表示按code排倒序
# limit: 之取前limit个, ex: 10
# **kw: 查询条件,把query展开成参数来写,作用是相同的
def get_securities(query={}, fields=None, sort=None, limit=None, **kw):
q = gen_query("securities", query, fields, sort, **kw)
return [doc for doc in (q.limit(limit) if limit else q)]

def gen_query(collection, query={}, fields=None, sort=None, **kw):
mongo = MongoDB.instance(host="192.168.1.99", dbname="ftresearch", username="ftresearch", password="******")
q = mongo[collection].find(dict(**query, **kw), to_projection_dict(fields))
if sort:
q = q.sort(to_sort_list(sort))
return q


def to_projection_dict(fields):
if fields is None:
return {"_id": 0}
elif isinstance(fields, dict):
return fields
else:
ret = {"_id": 0}
for k in fields:
ret[k] = 1
return ret


def to_sort_list(sort):
if not sort:
return None
ret = []
if isinstance(sort, list):
for e in sort:
if isinstance(e, str):
ret.append((e, 1))
else:
ret.append(e)
else:
if isinstance(sort, str):
ret.append((sort, 1))
else:
ret.append(sort)
return ret


if __name__ == '__main__':
# 查询以68编码开头的所有股票,只取code, sec_name字段
arr = get_securities(code={"$regex": "^68"}, type="stock", fields=["code", "sec_name"])
for e in arr:
print(e)

# 查询所有ST状态和*ST状态的未退市的股票信息,按股票代码排倒序,取前10只
arr = get_securities(
query={
"type": "stock",
"contract_state": "Active",
"special_type": {"$in": ["ST", "StarST"]},
},
fields=["code", "sec_name", "exchange", "special_type"],
sort=[("code", -1)],
limit=10)
for e in arr:
print(e)

# 取2023-01-04日当天在交易状态的股指期货合约代码
codes = [e["code"]
for e in get_securities(
fields=["code"],
type="future",
product={"$in": ["IC", "IF", "IH", "IM"]},
listed_date={"$lte":"2023-01-04"},
de_listed_date={"$gte":"2023-01-04"}
)
]
print(codes)