diff --git a/vn.trader/vnrpc.py b/vn.trader/vnrpc.py new file mode 100644 index 00000000..c3a979c9 --- /dev/null +++ b/vn.trader/vnrpc.py @@ -0,0 +1,316 @@ +# encoding: UTF-8 + +import threading +import traceback +import signal + +import zmq +from msgpack import packb, unpackb +from json import dumps, loads + +import cPickle +pDumps = cPickle.dumps +pLoads = cPickle.loads + + +# 实现Ctrl-c中断recv +signal.signal(signal.SIGINT, signal.SIG_DFL) + + +######################################################################## +class RpcObject(object): + """ + RPC对象 + + 提供对数据的序列化打包和解包接口,目前提供了json、msgpack、cPickle三种工具。 + + msgpack:性能更高,但通常需要安装msgpack相关工具; + json:性能略低但通用性更好,大部分编程语言都内置了相关的库。 + cPickle:性能一般且仅能用于Python,但是可以直接传送Python对象,非常方便。 + + 因此建议尽量使用msgpack,如果要和某些语言通讯没有提供msgpack时再使用json, + 当传送的数据包含很多自定义的Python对象时建议使用cPickle。 + + 如果希望使用其他的序列化工具也可以在这里添加。 + """ + + #---------------------------------------------------------------------- + def __init__(self): + """Constructor""" + # 默认使用msgpack作为序列化工具 + #self.useMsgpack() + self.usePickle() + + #---------------------------------------------------------------------- + def pack(self, data): + """打包""" + pass + + #---------------------------------------------------------------------- + def unpack(self, data): + """解包""" + pass + + #---------------------------------------------------------------------- + def __jsonPack(self, data): + """使用json打包""" + return dumps(data) + + #---------------------------------------------------------------------- + def __jsonUnpack(self, data): + """使用json解包""" + return loads(data) + + #---------------------------------------------------------------------- + def __msgpackPack(self, data): + """使用msgpack打包""" + return packb(data) + + #---------------------------------------------------------------------- + def __msgpackUnpack(self, data): + """使用msgpack解包""" + return unpackb(data) + + #---------------------------------------------------------------------- + def __picklePack(self, data): + """使用cPickle打包""" + return pDumps(data) + + #---------------------------------------------------------------------- + def __pickleUnpack(self, data): + """使用cPickle解包""" + return pLoads(data) + + #---------------------------------------------------------------------- + def useJson(self): + """使用json作为序列化工具""" + self.pack = self.__jsonPack + self.unpack = self.__jsonUnpack + + #---------------------------------------------------------------------- + def useMsgpack(self): + """使用msgpack作为序列化工具""" + self.pack = self.__msgpackPack + self.unpack = self.__msgpackUnpack + + #---------------------------------------------------------------------- + def usePickle(self): + """使用cPickle作为序列化工具""" + self.pack = self.__picklePack + self.unpack = self.__pickleUnpack + + +######################################################################## +class RpcServer(RpcObject): + """RPC服务器""" + + #---------------------------------------------------------------------- + def __init__(self, repAddress, pubAddress): + """Constructor""" + super(RpcServer, self).__init__() + + # 保存功能函数的字典,key是函数名,value是函数对象 + self.__functions = {} + + # zmq端口相关 + self.__context = zmq.Context() + + self.__socketREP = self.__context.socket(zmq.REP) # 请求回应socket + self.__socketREP.bind(repAddress) + + self.__socketPUB = self.__context.socket(zmq.PUB) # 数据广播socket + self.__socketPUB.bind(pubAddress) + + # 工作线程相关 + self.__active = False # 服务器的工作状态 + self.__thread = threading.Thread(target=self.run) # 服务器的工作线程 + + #---------------------------------------------------------------------- + def start(self): + """启动服务器""" + # 将服务器设为启动 + self.__active = True + + # 启动工作线程 + if not self.__thread.isAlive(): + self.__thread.start() + + #---------------------------------------------------------------------- + def stop(self): + """停止服务器""" + # 将服务器设为停止 + self.__active = False + + # 等待工作线程退出 + if self.__thread.isAlive(): + self.__thread.join() + + #---------------------------------------------------------------------- + def run(self): + """服务器运行函数""" + while self.__active: + # 使用poll来等待事件到达,等待1秒(1000毫秒) + if not self.__socketREP.poll(1000): + continue + + # 从请求响应socket收取请求数据 + reqb = self.__socketREP.recv() + + # 序列化解包 + req = self.unpack(reqb) + + # 获取函数名和参数 + name, args, kwargs = req + + # 获取引擎中对应的函数对象,并执行调用,如果有异常则捕捉后返回 + try: + func = self.__functions[name] + r = func(*args, **kwargs) + rep = [True, r] + except Exception as e: + rep = [False, traceback.format_exc()] + + # 序列化打包 + repb = self.pack(rep) + + # 通过请求响应socket返回调用结果 + self.__socketREP.send(repb) + + #---------------------------------------------------------------------- + def publish(self, topic, data): + """ + 广播推送数据 + topic:主题内容 + data:具体的数据 + """ + # 序列化数据 + datab = self.pack(data) + + # 通过广播socket发送数据 + self.__socketPUB.send_multipart([topic, datab]) + + #---------------------------------------------------------------------- + def register(self, func): + """注册函数""" + self.__functions[func.__name__] = func + + +######################################################################## +class RpcClient(RpcObject): + """RPC客户端""" + + #---------------------------------------------------------------------- + def __init__(self, reqAddress, subAddress): + """Constructor""" + super(RpcClient, self).__init__() + + # zmq端口相关 + self.__reqAddress = reqAddress + self.__subAddress = subAddress + + self.__context = zmq.Context() + self.__socketREQ = self.__context.socket(zmq.REQ) # 请求发出socket + self.__socketSUB = self.__context.socket(zmq.SUB) # 广播订阅socket + + # 工作线程相关,用于处理服务器推送的数据 + self.__active = False # 客户端的工作状态 + self.__thread = threading.Thread(target=self.run) # 客户端的工作线程 + + #---------------------------------------------------------------------- + def __getattr__(self, name): + """实现远程调用功能""" + # 执行远程调用任务 + def dorpc(*args, **kwargs): + # 生成请求 + req = [name, args, kwargs] + + # 序列化打包请求 + reqb = self.pack(req) + + # 发送请求并等待回应 + self.__socketREQ.send(reqb) + repb = self.__socketREQ.recv() + + # 序列化解包回应 + rep = self.unpack(repb) + + # 若正常则返回结果,调用失败则触发异常 + if rep[0]: + return rep[1] + else: + raise RemoteException(rep[1]) + + return dorpc + + #---------------------------------------------------------------------- + def start(self): + """启动客户端""" + # 连接端口 + self.__socketREQ.connect(self.__reqAddress) + self.__socketSUB.connect(self.__subAddress) + + # 将服务器设为启动 + self.__active = True + + # 启动工作线程 + if not self.__thread.isAlive(): + self.__thread.start() + + #---------------------------------------------------------------------- + def stop(self): + """停止客户端""" + # 将客户端设为停止 + self.__active = False + + # 等待工作线程退出 + if self.__thread.isAlive(): + self.__thread.join() + + #---------------------------------------------------------------------- + def run(self): + """客户端运行函数""" + while self.__active: + # 使用poll来等待事件到达,等待1秒(1000毫秒) + if not self.__socketSUB.poll(1000): + continue + + # 从订阅socket收取广播数据 + topic, datab = self.__socketSUB.recv_multipart() + + # 序列化解包 + data = self.unpack(datab) + + # 调用回调函数处理 + self.callback(topic, data) + + #---------------------------------------------------------------------- + def callback(self, topic, data): + """回调函数,必须由用户实现""" + raise NotImplementedError + + #---------------------------------------------------------------------- + def subscribeTopic(self, topic): + """ + 订阅特定主题的广播数据 + + 可以使用topic=''来订阅所有的主题 + """ + self.__socketSUB.setsockopt(zmq.SUBSCRIBE, topic) + + + +######################################################################## +class RemoteException(Exception): + """RPC远程异常""" + + #---------------------------------------------------------------------- + def __init__(self, value): + """Constructor""" + self.__value = value + + #---------------------------------------------------------------------- + def __str__(self): + """输出错误信息""" + return self.__value + + \ No newline at end of file diff --git a/vn.trader/vtClient.py b/vn.trader/vtClient.py new file mode 100644 index 00000000..dfe51ba7 --- /dev/null +++ b/vn.trader/vtClient.py @@ -0,0 +1,206 @@ +# encoding: utf-8 + +import sys +import os +import ctypes +import platform + +import vtPath +from uiMainWindow import * + +from eventEngine import * +from vnrpc import RpcClient + +from ctaAlgo.ctaEngine import CtaEngine +from dataRecorder.drEngine import DrEngine +from riskManager.rmEngine import RmEngine + + + +# 文件路径名 +path = os.path.abspath(os.path.dirname(__file__)) +ICON_FILENAME = 'vnpy.ico' +ICON_FILENAME = os.path.join(path, ICON_FILENAME) + +SETTING_FILENAME = 'VT_setting.json' +SETTING_FILENAME = os.path.join(path, SETTING_FILENAME) + + +######################################################################## +class VtClient(RpcClient): + """vn.trader客户端""" + + #---------------------------------------------------------------------- + def __init__(self, reqAddress, subAddress, eventEngine): + """Constructor""" + super(VtClient, self).__init__(reqAddress, subAddress) + + self.eventEngine = eventEngine + + self.usePickle() + + #---------------------------------------------------------------------- + def callback(self, topic, data): + """回调函数""" + self.eventEngine.put(data) + + +######################################################################## +class ClientEngine(object): + """客户端引擎,提供和MainEngine完全相同的API接口""" + + #---------------------------------------------------------------------- + def __init__(self, client, eventEngine): + """Constructor""" + self.client = client + self.eventEngine = eventEngine + + # 扩展模块 + self.ctaEngine = CtaEngine(self, self.eventEngine) + self.drEngine = DrEngine(self, self.eventEngine) + self.rmEngine = RmEngine(self, self.eventEngine) + + #---------------------------------------------------------------------- + def connect(self, gatewayName): + """连接特定名称的接口""" + self.client.connect(gatewayName) + + #---------------------------------------------------------------------- + def subscribe(self, subscribeReq, gatewayName): + """订阅特定接口的行情""" + self.client.subscribe(subscribeReq, gatewayName) + + #---------------------------------------------------------------------- + def sendOrder(self, orderReq, gatewayName): + """对特定接口发单""" + self.client.sendOrder(orderReq, gatewayName) + + #---------------------------------------------------------------------- + def cancelOrder(self, cancelOrderReq, gatewayName): + """对特定接口撤单""" + self.client.cancelOrder(cancelOrderReq, gatewayName) + + #---------------------------------------------------------------------- + def qryAccont(self, gatewayName): + """查询特定接口的账户""" + self.client.qryAccount(gatewayName) + + #---------------------------------------------------------------------- + def qryPosition(self, gatewayName): + """查询特定接口的持仓""" + self.client.qryPosition(gatewayName) + + #---------------------------------------------------------------------- + def exit(self): + """退出程序前调用,保证正常退出""" + # 停止事件引擎 + self.eventEngine.stop() + + # 关闭客户端的推送数据接收 + self.client.stop() + + # 停止数据记录引擎 + self.drEngine.stop() + + #---------------------------------------------------------------------- + def writeLog(self, content): + """快速发出日志事件""" + self.client.writeLog(content) + + #---------------------------------------------------------------------- + def dbConnect(self): + """连接MongoDB数据库""" + self.client.dbConnect() + + #---------------------------------------------------------------------- + def dbInsert(self, dbName, collectionName, d): + """向MongoDB中插入数据,d是具体数据""" + self.client.dbInsert(dbName, collectionName, d) + + #---------------------------------------------------------------------- + def dbQuery(self, dbName, collectionName, d): + """从MongoDB中读取数据,d是查询要求,返回的是数据库查询的数据列表""" + self.client.dbQuery(dbName, collectionName, d) + + #---------------------------------------------------------------------- + def dbUpdate(self, dbName, collectionName, d, flt, upsert=False): + """向MongoDB中更新数据,d是具体数据,flt是过滤条件,upsert代表若无是否要插入""" + self.client.dbUpdate(dbName, collectionName, d, flt, upsert) + + #---------------------------------------------------------------------- + def getContract(self, vtSymbol): + """查询合约""" + return self.client.getContract(vtSymbol) + + #---------------------------------------------------------------------- + def getAllContracts(self): + """查询所有合约(返回列表)""" + return self.client.getAllContracts() + + #---------------------------------------------------------------------- + def getOrder(self, vtOrderID): + """查询委托""" + return self.client.getOrder(vtOrderID) + + #---------------------------------------------------------------------- + def getAllWorkingOrders(self): + """查询所有的活跃的委托(返回列表)""" + return self.client.getAllWorkingOrders() + + #---------------------------------------------------------------------- + def getAllGatewayNames(self): + """查询所有的接口名称""" + return self.client.getAllGatewayNames() + + +#---------------------------------------------------------------------- +def main(): + """客户端主程序入口""" + # 重载sys模块,设置默认字符串编码方式为utf8 + reload(sys) + sys.setdefaultencoding('utf8') + + # 设置Windows底部任务栏图标 + if 'Windows' in platform.uname() : + ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID('vn.trader') + + # 创建事件引擎 + eventEngine = EventEngine() + eventEngine.start(timer=False) + + # 创建客户端 + reqAddress = 'tcp://localhost:2014' + subAddress = 'tcp://localhost:0602' + client = VtClient(reqAddress, subAddress, eventEngine) + + client.subscribeTopic('') + client.start() + + # 初始化Qt应用对象 + app = QtGui.QApplication(sys.argv) + app.setWindowIcon(QtGui.QIcon(ICON_FILENAME)) + app.setFont(BASIC_FONT) + + # 设置Qt的皮肤 + try: + f = file(SETTING_FILENAME) + setting = json.load(f) + if setting['darkStyle']: + import qdarkstyle + app.setStyleSheet(qdarkstyle.load_stylesheet(pyside=False)) + except: + pass + + # 初始化主引擎和主窗口对象 + mainEngine = ClientEngine(client, eventEngine) + mainWindow = MainWindow(mainEngine, mainEngine.eventEngine) + mainWindow.showMaximized() + + # 在主线程中启动Qt事件循环 + sys.exit(app.exec_()) + + +if __name__ == '__main__': + main() + + \ No newline at end of file diff --git a/vn.trader/vtEngine.py b/vn.trader/vtEngine.py index eb68d627..103317b0 100644 --- a/vn.trader/vtEngine.py +++ b/vn.trader/vtEngine.py @@ -27,7 +27,7 @@ class MainEngine(object): self.todayDate = datetime.now().strftime('%Y%m%d') # 创建事件引擎 - self.eventEngine = EventEngine() + self.eventEngine = EventEngine2() self.eventEngine.start() # 创建数据引擎 @@ -275,7 +275,10 @@ class MainEngine(object): db = self.dbClient[dbName] collection = db[collectionName] cursor = collection.find(d) - return list(cursor) + if cursor: + return list(cursor) + else: + return [] else: self.writeLog(u'数据查询失败,MongoDB没有连接') return [] diff --git a/vn.trader/vtServer.py b/vn.trader/vtServer.py new file mode 100644 index 00000000..bf7f805d --- /dev/null +++ b/vn.trader/vtServer.py @@ -0,0 +1,97 @@ +# encoding: utf-8 + +import sys +import os + +from datetime import datetime +from time import sleep +from threading import Thread + +import eventType +from vnrpc import RpcServer +from vtEngine import MainEngine + + +######################################################################## +class VtServer(RpcServer): + """vn.trader服务器""" + + #---------------------------------------------------------------------- + def __init__(self, repAddress, pubAddress): + """Constructor""" + super(VtServer, self).__init__(repAddress, pubAddress) + self.usePickle() + + # 创建主引擎对象 + self.engine = MainEngine() + + # 注册主引擎的方法到服务器的RPC函数 + self.register(self.engine.connect) + self.register(self.engine.subscribe) + self.register(self.engine.sendOrder) + self.register(self.engine.cancelOrder) + self.register(self.engine.qryAccont) + self.register(self.engine.qryPosition) + self.register(self.engine.exit) + self.register(self.engine.writeLog) + self.register(self.engine.dbConnect) + self.register(self.engine.dbInsert) + self.register(self.engine.dbQuery) + self.register(self.engine.dbUpdate) + self.register(self.engine.getContract) + self.register(self.engine.getAllContracts) + self.register(self.engine.getOrder) + self.register(self.engine.getAllWorkingOrders) + self.register(self.engine.getAllGatewayNames) + + # 注册事件引擎发送的事件处理监听 + self.engine.eventEngine.registerGeneralHandler(self.eventHandler) + + #---------------------------------------------------------------------- + def eventHandler(self, event): + """事件处理""" + self.publish(event.type_, event) + + #---------------------------------------------------------------------- + def stopServer(self): + """停止服务器""" + # 关闭引擎 + self.engine.exit() + + # 停止服务器线程 + self.stop() + + +#---------------------------------------------------------------------- +def printLog(content): + """打印日志""" + print datetime.now().strftime("%H:%M:%S"), '\t', content + + +#---------------------------------------------------------------------- +def runServer(): + """运行服务器""" + repAddress = 'tcp://*:2014' + pubAddress = 'tcp://*:0602' + + # 创建并启动服务器 + server = VtServer(repAddress, pubAddress) + server.start() + + printLog('-'*50) + printLog(u'vn.trader服务器已启动') + + # 进入主循环 + while True: + printLog(u'请输入exit来关闭服务器') + if raw_input() != 'exit': + continue + + printLog(u'确认关闭服务器?yes|no') + if raw_input() == 'yes': + break + + server.stopServer() + +if __name__ == '__main__': + runServer() \ No newline at end of file