diff --git a/docs/install.md b/docs/install.md index a36dc137..c50d1bca 100644 --- a/docs/install.md +++ b/docs/install.md @@ -7,6 +7,7 @@ ### 使用VNConda + #### 1.下载VNConda (Python 3.7 64位) 下载地址如下:[VNConda-2.0.1-Windows-x86_64](https://conda.vnpy.com/VNConda-2.0.1-Windows-x86_64.exe) @@ -44,6 +45,7 @@     + ### 手动安装 #### 1.下载并安装最新版Anaconda3.7 64位 diff --git a/requirements.txt b/requirements.txt index 2b64ee1d..ee935ce6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ PyQt5<5.12 +pyqtgraph dataclasses; python_version<="3.6" qdarkstyle requests diff --git a/tests/backtesting/genetic_algorithm.ipynb b/tests/backtesting/genetic_algorithm.ipynb new file mode 100644 index 00000000..c9e92e12 --- /dev/null +++ b/tests/backtesting/genetic_algorithm.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "import multiprocessing\n", + "import numpy as np\n", + "from deap import creator, base, tools, algorithms\n", + "from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n", + "from boll_channel_strategy import BollChannelStrategy\n", + "from datetime import datetime\n", + "import multiprocessing #多进程\n", + "from scoop import futures #多进程" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def parameter_generate():\n", + " '''\n", + " 根据设置的起始值,终止值和步进,随机生成待优化的策略参数\n", + " '''\n", + " parameter_list = []\n", + " p1 = random.randrange(4,50,2) #布林带窗口\n", + " p2 = random.randrange(4,50,2) #布林带通道阈值\n", + " p3 = random.randrange(4,50,2) #CCI窗口\n", + " p4 = random.randrange(18,40,2) #ATR窗口 \n", + "\n", + " parameter_list.append(p1)\n", + " parameter_list.append(p2)\n", + " parameter_list.append(p3)\n", + " parameter_list.append(p4)\n", + "\n", + " return parameter_list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def object_func(strategy_avg):\n", + " \"\"\"\n", + " 本函数为优化目标函数,根据随机生成的策略参数,运行回测后自动返回2个结果指标:收益回撤比和夏普比率\n", + " \"\"\"\n", + " # 创建回测引擎对象\n", + " engine = BacktestingEngine()\n", + " engine.set_parameters(\n", + " vt_symbol=\"IF88.CFFEX\",\n", + " interval=\"1m\",\n", + " start=datetime(2018, 9, 1),\n", + " end=datetime(2019, 1,1),\n", + " rate=0,\n", + " slippage=0,\n", + " size=300,\n", + " pricetick=0.2,\n", + " capital=1_000_000,\n", + " )\n", + "\n", + " setting = {'boll_window': strategy_avg[0], #布林带窗口\n", + " 'boll_dev': strategy_avg[1], #布林带通道阈值\n", + " 'cci_window': strategy_avg[2], #CCI窗口\n", + " 'atr_window': strategy_avg[3],} #ATR窗口 \n", + "\n", + " #加载策略 \n", + " #engine.initStrategy(TurtleTradingStrategy, setting)\n", + " engine.add_strategy(BollChannelStrategy, setting)\n", + " engine.load_data()\n", + " engine.run_backtesting()\n", + " engine.calculate_result()\n", + " result = engine.calculate_statistics(Output=False)\n", + "\n", + " return_drawdown_ratio = round(result['return_drawdown_ratio'],2) #收益回撤比\n", + " sharpe_ratio= round(result['sharpe_ratio'],2) #夏普比率\n", + " return return_drawdown_ratio , sharpe_ratio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "#设置优化方向:最大化收益回撤比,最大化夏普比率\n", + "creator.create(\"FitnessMulti\", base.Fitness, weights=(1.0, 1.0)) # 1.0 求最大值;-1.0 求最小值\n", + "creator.create(\"Individual\", list, fitness=creator.FitnessMulti)\n", + "\n", + "def optimize():\n", + " \"\"\"\"\"\" \n", + " toolbox = base.Toolbox() #Toolbox是deap库内置的工具箱,里面包含遗传算法中所用到的各种函数\n", + "\n", + " # 初始化 \n", + " toolbox.register(\"individual\", tools.initIterate, creator.Individual,parameter_generate) # 注册个体:随机生成的策略参数parameter_generate() \n", + " toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual) #注册种群:个体形成种群 \n", + " toolbox.register(\"mate\", tools.cxTwoPoint) #注册交叉:两点交叉 \n", + " toolbox.register(\"mutate\", tools.mutUniformInt,low = 4,up = 40,indpb=0.6) #注册变异:随机生成一定区间内的整数\n", + " toolbox.register(\"evaluate\", object_func) #注册评估:优化目标函数object_func() \n", + " toolbox.register(\"select\", tools.selNSGA2) #注册选择:NSGA-II(带精英策略的非支配排序的遗传算法)\n", + " #pool = multiprocessing.Pool()\n", + " #toolbox.register(\"map\", pool.map)\n", + " #toolbox.register(\"map\", futures.map)\n", + "\n", + " #遗传算法参数设置\n", + " MU = 40 #设置每一代选择的个体数\n", + " LAMBDA = 160 #设置每一代产生的子女数\n", + " pop = toolbox.population(400) #设置族群里面的个体数量\n", + " CXPB, MUTPB, NGEN = 0.5, 0.35,40 #分别为种群内部个体的交叉概率、变异概率、产生种群代数\n", + " hof = tools.ParetoFront() #解的集合:帕累托前沿(非占优最优集)\n", + "\n", + " #解的集合的描述统计信息\n", + " #集合内平均值,标准差,最小值,最大值可以体现集合的收敛程度\n", + " #收敛程度低可以增加算法的迭代次数\n", + " stats = tools.Statistics(lambda ind: ind.fitness.values)\n", + " np.set_printoptions(suppress=True) #对numpy默认输出的科学计数法转换\n", + " stats.register(\"mean\", np.mean, axis=0) #统计目标优化函数结果的平均值\n", + " stats.register(\"std\", np.std, axis=0) #统计目标优化函数结果的标准差\n", + " stats.register(\"min\", np.min, axis=0) #统计目标优化函数结果的最小值\n", + " stats.register(\"max\", np.max, axis=0) #统计目标优化函数结果的最大值\n", + "\n", + " #运行算法\n", + " algorithms.eaMuPlusLambda(pop, toolbox, MU, LAMBDA, CXPB, MUTPB, NGEN, stats,\n", + " halloffame=hof) #esMuPlusLambda是一种基于(μ+λ)选择策略的多目标优化分段遗传算法\n", + "\n", + " return pop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "optimize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/trader/run.py b/tests/trader/run.py index 1a402f82..2cda6ae2 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -10,9 +10,13 @@ from vnpy.gateway.ib import IbGateway from vnpy.gateway.ctp import CtpGateway from vnpy.gateway.tiger import TigerGateway from vnpy.gateway.oes import OesGateway +from vnpy.gateway.okex import OkexGateway +from vnpy.gateway.huobi import HuobiGateway from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.csv_loader import CsvLoaderApp +from vnpy.app.algo_trading import AlgoTradingApp +from vnpy.app.cta_backtester import CtaBacktesterApp def main(): @@ -28,9 +32,13 @@ def main(): main_engine.add_gateway(BitmexGateway) main_engine.add_gateway(TigerGateway) main_engine.add_gateway(OesGateway) + main_engine.add_gateway(OkexGateway) + main_engine.add_gateway(HuobiGateway) main_engine.add_app(CtaStrategyApp) + main_engine.add_app(CtaBacktesterApp) main_engine.add_app(CsvLoaderApp) + main_engine.add_app(AlgoTradingApp) main_window = MainWindow(main_engine, event_engine) main_window.showMaximized() diff --git a/vnpy/api/websocket/websocket_client.py b/vnpy/api/websocket/websocket_client.py index 5fde1d82..0fe28ee7 100644 --- a/vnpy/api/websocket/websocket_client.py +++ b/vnpy/api/websocket/websocket_client.py @@ -48,14 +48,18 @@ class WebsocketClient(object): self.proxy_host = None self.proxy_port = None + self.ping_interval = 60 # seconds # For debugging self._last_sent_text = None self._last_received_text = None - def init(self, host: str, proxy_host: str = "", proxy_port: int = 0): - """""" + def init(self, host: str, proxy_host: str = "", proxy_port: int = 0, ping_interval: int = 60): + """ + :param ping_interval: unit: seconds, type: int + """ self.host = host + self.ping_interval = ping_interval # seconds if proxy_host and proxy_port: self.proxy_host = proxy_host @@ -206,7 +210,7 @@ class WebsocketClient(object): et, ev, tb = sys.exc_info() self.on_error(et, ev, tb) self._reconnect() - for i in range(60): + for i in range(self.ping_interval): if not self._active: break sleep(1) diff --git a/vnpy/app/algo_trading/__init__.py b/vnpy/app/algo_trading/__init__.py new file mode 100644 index 00000000..dba58a66 --- /dev/null +++ b/vnpy/app/algo_trading/__init__.py @@ -0,0 +1,18 @@ +from pathlib import Path + +from vnpy.trader.app import BaseApp + +from .engine import AlgoEngine, APP_NAME +from .template import AlgoTemplate + + +class AlgoTradingApp(BaseApp): + """""" + + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "算法交易" + engine_class = AlgoEngine + widget_name = "AlgoManager" + icon_name = "algo.ico" diff --git a/vnpy/app/algo_trading/algos/__init__.py b/vnpy/app/algo_trading/algos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/algo_trading/algos/iceberg_algo.py b/vnpy/app/algo_trading/algos/iceberg_algo.py new file mode 100644 index 00000000..74ab7f44 --- /dev/null +++ b/vnpy/app/algo_trading/algos/iceberg_algo.py @@ -0,0 +1,137 @@ +from vnpy.trader.constant import Offset, Direction +from vnpy.trader.object import TradeData, OrderData, TickData +from vnpy.trader.engine import BaseEngine + +from vnpy.app.algo_trading import AlgoTemplate + + +class IcebergAlgo(AlgoTemplate): + """""" + + display_name = "Iceberg 冰山" + + default_setting = { + "vt_symbol": "", + "direction": [Direction.LONG.value, Direction.SHORT.value], + "price": 0.0, + "volume": 0.0, + "display_volume": 0.0, + "interval": 0, + "offset": [ + Offset.NONE.value, + Offset.OPEN.value, + Offset.CLOSE.value, + Offset.CLOSETODAY.value, + Offset.CLOSEYESTERDAY.value + ] + } + + variables = [ + "traded", + "timer_count", + "vt_orderid" + ] + + def __init__( + self, + algo_engine: BaseEngine, + algo_name: str, + setting: dict + ): + """""" + super().__init__(algo_engine, algo_name, setting) + + # Parameters + self.vt_symbol = setting["vt_symbol"] + self.direction = Direction(setting["direction"]) + self.price = setting["price"] + self.volume = setting["volume"] + self.display_volume = setting["display_volume"] + self.interval = setting["interval"] + self.offset = Offset(setting["offset"]) + + # Variables + self.timer_count = 0 + self.vt_orderid = "" + self.traded = 0 + + self.last_tick = None + + self.subscribe(self.vt_symbol) + self.put_parameters_event() + self.put_variables_event() + + def on_stop(self): + """""" + self.write_log("停止算法") + + def on_tick(self, tick: TickData): + """""" + self.last_tick = tick + + def on_order(self, order: OrderData): + """""" + msg = f"委托号:{order.vt_orderid},委托状态:{order.status.value}" + self.write_log(msg) + + if not order.is_active(): + self.vt_orderid = "" + self.put_variables_event() + + def on_trade(self, trade: TradeData): + """""" + self.traded += trade.volume + + if self.traded >= self.volume: + self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}") + self.stop() + else: + self.put_variables_event() + + def on_timer(self): + """""" + self.timer_count += 1 + + if self.timer_count < self.interval: + self.put_variables_event() + return + + self.timer_count = 0 + + contract = self.get_contract(self.vt_symbol) + if not contract: + return + + # If order already finished, just send new order + if not self.vt_orderid: + order_volume = self.volume - self.traded + order_volume = min(order_volume, self.display_volume) + + if self.direction == Direction.LONG: + self.vt_orderid = self.buy( + self.vt_symbol, + self.price, + order_volume, + offset=self.offset + ) + else: + self.vt_orderid = self.sell( + self.vt_symbol, + self.price, + order_volume, + offset=self.offset + ) + # Otherwise check for cancel + else: + if self.direction == Direction.LONG: + if self.last_tick.ask_price_1 <= self.price: + self.cancel_order(self.vt_orderid) + self.vt_orderid = "" + self.write_log(u"最新Tick卖一价,低于买入委托价格,之前委托可能丢失,强制撤单") + else: + if self.last_tick.bid_price_1 >= self.price: + self.cancel_order(self.vt_orderid) + self.vt_orderid = "" + self.write_log(u"最新Tick买一价,高于卖出委托价格,之前委托可能丢失,强制撤单") + + self.put_variables_event() diff --git a/vnpy/app/algo_trading/algos/sniper_algo.py b/vnpy/app/algo_trading/algos/sniper_algo.py new file mode 100644 index 00000000..5f4b1c56 --- /dev/null +++ b/vnpy/app/algo_trading/algos/sniper_algo.py @@ -0,0 +1,101 @@ +from vnpy.trader.constant import Offset, Direction +from vnpy.trader.object import TradeData, OrderData, TickData +from vnpy.trader.engine import BaseEngine + +from vnpy.app.algo_trading import AlgoTemplate + + +class SniperAlgo(AlgoTemplate): + """""" + + display_name = "Sniper 狙击手" + + default_setting = { + "vt_symbol": "", + "direction": [Direction.LONG.value, Direction.SHORT.value], + "price": 0.0, + "volume": 0.0, + "offset": [ + Offset.NONE.value, + Offset.OPEN.value, + Offset.CLOSE.value, + Offset.CLOSETODAY.value, + Offset.CLOSEYESTERDAY.value + ] + } + + variables = [ + "traded", + "vt_orderid" + ] + + def __init__( + self, + algo_engine: BaseEngine, + algo_name: str, + setting: dict + ): + """""" + super().__init__(algo_engine, algo_name, setting) + + # Parameters + self.vt_symbol = setting["vt_symbol"] + self.direction = Direction(setting["direction"]) + self.price = setting["price"] + self.volume = setting["volume"] + self.offset = Offset(setting["offset"]) + + # Variables + self.vt_orderid = "" + self.traded = 0 + + self.subscribe(self.vt_symbol) + self.put_parameters_event() + self.put_variables_event() + + def on_tick(self, tick: TickData): + """""" + if self.vt_orderid: + self.cancel_all() + return + + if self.direction == Direction.LONG: + if tick.ask_price_1 <= self.price: + order_volume = self.volume - self.traded + order_volume = min(order_volume, tick.ask_volume_1) + + self.vt_orderid = self.buy( + self.vt_symbol, + self.price, + order_volume, + offset=self.offset + ) + else: + if tick.bid_price_1 >= self.price: + order_volume = self.volume - self.traded + order_volume = min(order_volume, tick.bid_volume_1) + + self.vt_orderid = self.sell( + self.vt_symbol, + self.price, + order_volume, + offset=self.offset + ) + + self.put_variables_event() + + def on_order(self, order: OrderData): + """""" + if not order.is_active(): + self.vt_orderid = "" + self.put_variables_event() + + def on_trade(self, trade: TradeData): + """""" + self.traded += trade.volume + + if self.traded >= self.volume: + self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}") + self.stop() + else: + self.put_variables_event() diff --git a/vnpy/app/algo_trading/algos/twap_algo.py b/vnpy/app/algo_trading/algos/twap_algo.py new file mode 100644 index 00000000..ebd24833 --- /dev/null +++ b/vnpy/app/algo_trading/algos/twap_algo.py @@ -0,0 +1,105 @@ +from vnpy.trader.constant import Offset, Direction +from vnpy.trader.object import TradeData +from vnpy.trader.engine import BaseEngine + +from vnpy.app.algo_trading import AlgoTemplate + + +class TwapAlgo(AlgoTemplate): + """""" + + display_name = "TWAP 时间加权平均" + + default_setting = { + "vt_symbol": "", + "direction": [Direction.LONG.value, Direction.SHORT.value], + "price": 0.0, + "volume": 0.0, + "time": 600, + "interval": 60, + "offset": [ + Offset.NONE.value, + Offset.OPEN.value, + Offset.CLOSE.value, + Offset.CLOSETODAY.value, + Offset.CLOSEYESTERDAY.value + ] + } + + variables = [ + "traded", + "order_volume", + "timer_count", + "total_count" + ] + + def __init__( + self, + algo_engine: BaseEngine, + algo_name: str, + setting: dict + ): + """""" + super().__init__(algo_engine, algo_name, setting) + + # Parameters + self.vt_symbol = setting["vt_symbol"] + self.direction = Direction(setting["direction"]) + self.price = setting["price"] + self.volume = setting["volume"] + self.time = setting["time"] + self.interval = setting["interval"] + self.offset = Offset(setting["offset"]) + + # Variables + self.order_volume = self.volume / (self.time / self.interval) + self.timer_count = 0 + self.total_count = 0 + self.traded = 0 + + self.subscribe(self.vt_symbol) + self.put_parameters_event() + self.put_variables_event() + + def on_trade(self, trade: TradeData): + """""" + self.traded += trade.volume + + if self.traded >= self.volume: + self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}") + self.stop() + else: + self.put_variables_event() + + def on_timer(self): + """""" + self.timer_count += 1 + self.total_count += 1 + self.put_variables_event() + + if self.total_count >= self.time: + self.write_log("执行时间已结束,停止算法") + self.stop() + return + + if self.timer_count < self.interval: + return + self.timer_count = 0 + + tick = self.get_tick(self.vt_symbol) + if not tick: + return + + self.cancel_all() + + left_volume = self.volume - self.traded + order_volume = min(self.order_volume, left_volume) + + if self.direction == Direction.LONG: + if tick.ask_price_1 <= self.price: + self.buy(self.vt_symbol, self.price, + order_volume, offset=self.offset) + else: + if tick.bid_price_1 >= self.price: + self.sell(self.vt_symbol, self.price, + order_volume, offset=self.offset) diff --git a/vnpy/app/algo_trading/engine.py b/vnpy/app/algo_trading/engine.py new file mode 100644 index 00000000..8607cc71 --- /dev/null +++ b/vnpy/app/algo_trading/engine.py @@ -0,0 +1,265 @@ + +from vnpy.event import EventEngine, Event +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.event import ( + EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE) +from vnpy.trader.constant import (Direction, Offset, OrderType) +from vnpy.trader.object import (SubscribeRequest, OrderRequest) +from vnpy.trader.utility import load_json, save_json + +from .template import AlgoTemplate + + +APP_NAME = "AlgoTrading" + +EVENT_ALGO_LOG = "eAlgoLog" +EVENT_ALGO_SETTING = "eAlgoSetting" +EVENT_ALGO_VARIABLES = "eAlgoVariables" +EVENT_ALGO_PARAMETERS = "eAlgoParameters" + + +class AlgoEngine(BaseEngine): + """""" + setting_filename = "algo_trading_setting.json" + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """Constructor""" + super().__init__(main_engine, event_engine, APP_NAME) + + self.algos = {} + self.symbol_algo_map = {} + self.orderid_algo_map = {} + + self.algo_templates = {} + self.algo_settings = {} + + self.load_algo_template() + self.register_event() + + def init_engine(self): + """""" + self.write_log("算法交易引擎启动") + self.load_algo_setting() + + def load_algo_template(self): + """""" + from .algos.twap_algo import TwapAlgo + from .algos.iceberg_algo import IcebergAlgo + from .algos.sniper_algo import SniperAlgo + + self.add_algo_template(TwapAlgo) + self.add_algo_template(IcebergAlgo) + self.add_algo_template(SniperAlgo) + + def add_algo_template(self, template: AlgoTemplate): + """""" + self.algo_templates[template.__name__] = template + + def load_algo_setting(self): + """""" + self.algo_settings = load_json(self.setting_filename) + + for setting_name, setting in self.algo_settings.items(): + self.put_setting_event(setting_name, setting) + + self.write_log("算法配置载入成功") + + def save_algo_setting(self): + """""" + save_json(self.setting_filename, self.algo_settings) + + def register_event(self): + """""" + self.event_engine.register(EVENT_TICK, self.process_tick_event) + self.event_engine.register(EVENT_TIMER, self.process_timer_event) + self.event_engine.register(EVENT_ORDER, self.process_order_event) + self.event_engine.register(EVENT_TRADE, self.process_trade_event) + + def process_tick_event(self, event: Event): + """""" + tick = event.data + + algos = self.symbol_algo_map.get(tick.vt_symbol, None) + if algos: + for algo in algos: + algo.update_tick(tick) + + def process_timer_event(self, event: Event): + """""" + for algo in self.algos.values(): + algo.update_timer() + + def process_trade_event(self, event: Event): + """""" + trade = event.data + + algo = self.orderid_algo_map.get(trade.vt_orderid, None) + if algo: + algo.update_trade(trade) + + def process_order_event(self, event: Event): + """""" + order = event.data + + algo = self.orderid_algo_map.get(order.vt_orderid, None) + if algo: + algo.update_order(order) + + def start_algo(self, setting: dict): + """""" + template_name = setting["template_name"] + algo_template = self.algo_templates[template_name] + + algo = algo_template.new(self, setting) + algo.start() + + self.algos[algo.algo_name] = algo + return algo.algo_name + + def stop_algo(self, algo_name: str): + """""" + algo = self.algos.get(algo_name, None) + if algo: + algo.stop() + self.algos.pop(algo_name) + + def stop_all(self): + """""" + for algo_name in list(self.algos.keys()): + self.stop_algo(algo_name) + + def subscribe(self, algo: AlgoTemplate, vt_symbol: str): + """""" + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(f'订阅行情失败,找不到合约:{vt_symbol}', algo) + return + + algos = self.symbol_algo_map.setdefault(vt_symbol, set()) + + if not algos: + req = SubscribeRequest( + symbol=contract.symbol, + exchange=contract.exchange + ) + self.main_engine.subscribe(req, contract.gateway_name) + + algos.add(algo) + + def send_order( + self, + algo: AlgoTemplate, + vt_symbol: str, + direction: Direction, + price: float, + volume: float, + order_type: OrderType, + offset: Offset + ): + """""" + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(f'委托下单失败,找不到合约:{vt_symbol}', algo) + return + + req = OrderRequest( + symbol=contract.symbol, + exchange=contract.exchange, + direction=direction, + type=order_type, + volume=volume, + price=price, + offset=offset + ) + vt_orderid = self.main_engine.send_order(req, contract.gateway_name) + + self.orderid_algo_map[vt_orderid] = algo + return vt_orderid + + def cancel_order(self, algo: AlgoTemplate, vt_orderid: str): + """""" + order = self.main_engine.get_order(vt_orderid) + + if not order: + self.write_log(f"委托撤单失败,找不到委托:{vt_orderid}", algo) + return + + req = order.create_cancel_request() + self.main_engine.cancel_order(req, order.gateway_name) + + def get_tick(self, algo: AlgoTemplate, vt_symbol: str): + """""" + tick = self.main_engine.get_tick(vt_symbol) + + if not tick: + self.write_log(f"查询行情失败,找不到行情:{vt_symbol}", algo) + + return tick + + def get_contract(self, algo: AlgoTemplate, vt_symbol: str): + """""" + contract = self.main_engine.get_contract(vt_symbol) + + if not contract: + self.write_log(f"查询合约失败,找不到合约:{vt_symbol}", algo) + + return contract + + def write_log(self, msg: str, algo: AlgoTemplate = None): + """""" + if algo: + msg = f"{algo.algo_name}:{msg}" + + event = Event(EVENT_ALGO_LOG) + event.data = msg + self.event_engine.put(event) + + def put_setting_event(self, setting_name: str, setting: dict): + """""" + event = Event(EVENT_ALGO_SETTING) + event.data = { + "setting_name": setting_name, + "setting": setting + } + self.event_engine.put(event) + + def update_algo_setting(self, setting_name: str, setting: dict): + """""" + self.algo_settings[setting_name] = setting + + self.save_algo_setting() + + self.put_setting_event(setting_name, setting) + + def remove_algo_setting(self, setting_name: str): + """""" + if setting_name not in self.algo_settings: + return + self.algo_settings.pop(setting_name) + + event = Event(EVENT_ALGO_SETTING) + event.data = { + "setting_name": setting_name, + "setting": None + } + self.event_engine.put(event) + + self.save_algo_setting() + + def put_parameters_event(self, algo: AlgoTemplate, parameters: dict): + """""" + event = Event(EVENT_ALGO_PARAMETERS) + event.data = { + "algo_name": algo.algo_name, + "parameters": parameters + } + self.event_engine.put(event) + + def put_variables_event(self, algo: AlgoTemplate, variables: dict): + """""" + event = Event(EVENT_ALGO_VARIABLES) + event.data = { + "algo_name": algo.algo_name, + "variables": variables + } + self.event_engine.put(event) diff --git a/vnpy/app/algo_trading/template.py b/vnpy/app/algo_trading/template.py new file mode 100644 index 00000000..3b6620b7 --- /dev/null +++ b/vnpy/app/algo_trading/template.py @@ -0,0 +1,193 @@ +from vnpy.trader.engine import BaseEngine +from vnpy.trader.object import TickData, OrderData, TradeData +from vnpy.trader.constant import OrderType, Offset, Direction +from vnpy.trader.utility import virtual + + +class AlgoTemplate: + """""" + + _count = 0 + display_name = "" + default_setting = {} + variables = [] + + def __init__( + self, + algo_engine: BaseEngine, + algo_name: str, + setting: dict + ): + """Constructor""" + self.algo_engine = algo_engine + self.algo_name = algo_name + + self.active = False + self.active_orders = {} # vt_orderid:order + + self.variables.insert(0, "active") + + @classmethod + def new(cls, algo_engine: BaseEngine, setting: dict): + """Create new algo instance""" + cls._count += 1 + algo_name = f"{cls.__name__}_{cls._count}" + algo = cls(algo_engine, algo_name, setting) + return algo + + def update_tick(self, tick: TickData): + """""" + if self.active: + self.on_tick(tick) + + def update_order(self, order: OrderData): + """""" + if self.active: + if order.is_active(): + self.active_orders[order.vt_orderid] = order + elif order.vt_orderid in self.active_orders: + self.active_orders.pop(order.vt_orderid) + + self.on_order(order) + + def update_trade(self, trade: TradeData): + """""" + if self.active: + self.on_trade(trade) + + def update_timer(self): + """""" + if self.active: + self.on_timer() + + def on_start(self): + """""" + pass + + @virtual + def on_stop(self): + """""" + pass + + @virtual + def on_tick(self, tick: TickData): + """""" + pass + + @virtual + def on_order(self, order: OrderData): + """""" + pass + + @virtual + def on_trade(self, trade: TradeData): + """""" + pass + + @virtual + def on_timer(self): + """""" + pass + + def start(self): + """""" + self.active = True + self.on_start() + self.put_variables_event() + + def stop(self): + """""" + self.active = False + self.cancel_all() + self.on_stop() + self.put_variables_event() + + self.write_log("停止算法") + + def subscribe(self, vt_symbol): + """""" + self.algo_engine.subscribe(self, vt_symbol) + + def buy( + self, + vt_symbol, + price, + volume, + order_type: OrderType = OrderType.LIMIT, + offset: Offset = Offset.NONE + ): + """""" + msg = f"委托买入{vt_symbol}:{volume}@{price}" + self.write_log(msg) + + return self.algo_engine.send_order( + self, + vt_symbol, + Direction.LONG, + price, + volume, + order_type, + offset + ) + + def sell( + self, + vt_symbol, + price, + volume, + order_type: OrderType = OrderType.LIMIT, + offset: Offset = Offset.NONE + ): + """""" + msg = f"委托卖出{vt_symbol}:{volume}@{price}" + self.write_log(msg) + + return self.algo_engine.send_order( + self, + vt_symbol, + Direction.SHORT, + price, + volume, + order_type, + offset + ) + + def cancel_order(self, vt_orderid: str): + """""" + self.algo_engine.cancel_order(self, vt_orderid) + + def cancel_all(self): + """""" + if not self.active_orders: + return + + for vt_orderid in self.active_orders.keys(): + self.cancel_order(vt_orderid) + + def get_tick(self, vt_symbol: str): + """""" + return self.algo_engine.get_tick(self, vt_symbol) + + def get_contract(self, vt_symbol: str): + """""" + return self.algo_engine.get_contract(self, vt_symbol) + + def write_log(self, msg: str): + """""" + self.algo_engine.write_log(msg, self) + + def put_parameters_event(self): + """""" + parameters = {} + for name in self.default_setting.keys(): + parameters[name] = getattr(self, name) + + self.algo_engine.put_parameters_event(self, parameters) + + def put_variables_event(self): + """""" + variables = {} + for name in self.variables: + variables[name] = getattr(self, name) + + self.algo_engine.put_variables_event(self, variables) diff --git a/vnpy/app/algo_trading/ui/__init__.py b/vnpy/app/algo_trading/ui/__init__.py new file mode 100644 index 00000000..9ac801bb --- /dev/null +++ b/vnpy/app/algo_trading/ui/__init__.py @@ -0,0 +1 @@ +from .widget import AlgoManager diff --git a/vnpy/app/algo_trading/ui/algo.ico b/vnpy/app/algo_trading/ui/algo.ico new file mode 100644 index 00000000..83114df8 Binary files /dev/null and b/vnpy/app/algo_trading/ui/algo.ico differ diff --git a/vnpy/app/algo_trading/ui/display.py b/vnpy/app/algo_trading/ui/display.py new file mode 100644 index 00000000..4f7b9502 --- /dev/null +++ b/vnpy/app/algo_trading/ui/display.py @@ -0,0 +1,16 @@ +NAME_DISPLAY_MAP = { + "vt_symbol": "本地代码", + "direction": "方向", + "price": "价格", + "volume": "数量", + "time": "执行时间(秒)", + "interval": "每轮间隔(秒)", + "offset": "开平", + "active": "算法状态", + "traded": "成交数量", + "order_volume": "单笔委托", + "timer_count": "本轮读秒", + "total_count": "累计读秒", + "template_name": "算法模板", + "display_volume": "挂出数量" +} diff --git a/vnpy/app/algo_trading/ui/widget.py b/vnpy/app/algo_trading/ui/widget.py new file mode 100644 index 00000000..111d599a --- /dev/null +++ b/vnpy/app/algo_trading/ui/widget.py @@ -0,0 +1,575 @@ +""" +Widget for algo trading. +""" + +from functools import partial +from datetime import datetime + +from vnpy.event import EventEngine, Event +from vnpy.trader.engine import MainEngine +from vnpy.trader.ui import QtWidgets, QtCore + +from ..engine import ( + AlgoEngine, + AlgoTemplate, + APP_NAME, + EVENT_ALGO_LOG, + EVENT_ALGO_PARAMETERS, + EVENT_ALGO_VARIABLES, + EVENT_ALGO_SETTING +) +from .display import NAME_DISPLAY_MAP + + +class AlgoWidget(QtWidgets.QWidget): + """ + Start connection of a certain gateway. + """ + + def __init__( + self, + algo_engine: AlgoEngine, + algo_template: AlgoTemplate + ): + """""" + super().__init__() + + self.algo_engine = algo_engine + self.template_name = algo_template.__name__ + self.default_setting = algo_template.default_setting + + self.widgets = {} + + self.init_ui() + + def init_ui(self): + """ + Initialize line edits and form layout based on setting. + """ + self.setMaximumWidth(400) + + form = QtWidgets.QFormLayout() + + for field_name, field_value in self.default_setting.items(): + field_type = type(field_value) + + if field_type == list: + widget = QtWidgets.QComboBox() + widget.addItems(field_value) + else: + widget = QtWidgets.QLineEdit() + + display_name = NAME_DISPLAY_MAP.get(field_name, field_name) + + form.addRow(display_name, widget) + self.widgets[field_name] = (widget, field_type) + + start_algo_button = QtWidgets.QPushButton("启动算法") + start_algo_button.clicked.connect(self.start_algo) + form.addRow(start_algo_button) + + form.addRow(QtWidgets.QLabel("")) + + self.setting_name_line = QtWidgets.QLineEdit() + form.addRow("配置名称", self.setting_name_line) + + save_setting_button = QtWidgets.QPushButton("保存配置") + save_setting_button.clicked.connect(self.save_setting) + form.addRow(save_setting_button) + + self.setLayout(form) + + def get_setting(self): + """ + Get setting value from line edits. + """ + setting = {"template_name": self.template_name} + + for field_name, tp in self.widgets.items(): + widget, field_type = tp + if field_type == list: + field_value = str(widget.currentText()) + else: + try: + field_value = field_type(widget.text()) + except ValueError: + display_name = NAME_DISPLAY_MAP.get(field_name, field_name) + QtWidgets.QMessageBox.warning( + self, + "参数错误", + f"{display_name}参数类型应为{field_type},请检查!" + ) + return None + + setting[field_name] = field_value + + return setting + + def start_algo(self): + """ + Start algo trading. + """ + setting = self.get_setting() + if setting: + self.algo_engine.start_algo(setting) + + def update_setting(self, setting_name: str, setting: dict): + """ + Update setting into widgets. + """ + self.setting_name_line.setText(setting_name) + + for name, tp in self.widgets.items(): + widget, _ = tp + value = setting[name] + + if isinstance(widget, QtWidgets.QLineEdit): + widget.setText(str(value)) + elif isinstance(widget, QtWidgets.QComboBox): + ix = widget.findText(value) + widget.setCurrentIndex(ix) + + def save_setting(self): + """ + Save algo setting + """ + setting_name = self.setting_name_line.text() + if not setting_name: + return + + setting = self.get_setting() + if setting: + self.algo_engine.update_algo_setting(setting_name, setting) + + +class AlgoMonitor(QtWidgets.QTableWidget): + """""" + parameters_signal = QtCore.pyqtSignal(Event) + variables_signal = QtCore.pyqtSignal(Event) + + def __init__( + self, + algo_engine: AlgoEngine, + event_engine: EventEngine, + mode_active: bool + ): + """""" + super().__init__() + + self.algo_engine = algo_engine + self.event_engine = event_engine + self.mode_active = mode_active + + self.algo_cells = {} + + self.init_ui() + self.register_event() + + def init_ui(self): + """""" + labels = [ + "", + "算法", + "参数", + "状态" + ] + self.setColumnCount(len(labels)) + self.setHorizontalHeaderLabels(labels) + self.verticalHeader().setVisible(False) + self.setEditTriggers(self.NoEditTriggers) + + self.verticalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.ResizeToContents + ) + + for column in range(2, 4): + self.horizontalHeader().setSectionResizeMode( + column, + QtWidgets.QHeaderView.Stretch + ) + self.setWordWrap(True) + + if not self.mode_active: + self.hideColumn(0) + + def register_event(self): + """""" + self.parameters_signal.connect(self.process_parameters_event) + self.variables_signal.connect(self.process_variables_event) + + self.event_engine.register( + EVENT_ALGO_PARAMETERS, self.parameters_signal.emit) + self.event_engine.register( + EVENT_ALGO_VARIABLES, self.variables_signal.emit) + + def process_parameters_event(self, event): + """""" + data = event.data + algo_name = data["algo_name"] + parameters = data["parameters"] + + cells = self.get_algo_cells(algo_name) + text = to_text(parameters) + cells["parameters"].setText(text) + + def process_variables_event(self, event): + """""" + data = event.data + algo_name = data["algo_name"] + variables = data["variables"] + + cells = self.get_algo_cells(algo_name) + variables_cell = cells["variables"] + text = to_text(variables) + variables_cell.setText(text) + + row = self.row(variables_cell) + active = variables["active"] + + if self.mode_active: + if active: + self.showRow(row) + else: + self.hideRow(row) + else: + if active: + self.hideRow(row) + else: + self.showRow(row) + + def stop_algo(self, algo_name: str): + """""" + self.algo_engine.stop_algo(algo_name) + + def get_algo_cells(self, algo_name: str): + """""" + cells = self.algo_cells.get(algo_name, None) + + if not cells: + stop_func = partial(self.stop_algo, algo_name=algo_name) + stop_button = QtWidgets.QPushButton("停止") + stop_button.clicked.connect(stop_func) + + name_cell = QtWidgets.QTableWidgetItem(algo_name) + parameters_cell = QtWidgets.QTableWidgetItem() + variables_cell = QtWidgets.QTableWidgetItem() + + self.insertRow(0) + self.setCellWidget(0, 0, stop_button) + self.setItem(0, 1, name_cell) + self.setItem(0, 2, parameters_cell) + self.setItem(0, 3, variables_cell) + + cells = { + "name": name_cell, + "parameters": parameters_cell, + "variables": variables_cell + } + self.algo_cells[algo_name] = cells + + return cells + + +class ActiveAlgoMonitor(AlgoMonitor): + """ + Monitor for active algos. + """ + + def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine): + """""" + super().__init__(algo_engine, event_engine, True) + + +class InactiveAlgoMonitor(AlgoMonitor): + """ + Monitor for inactive algos. + """ + + def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine): + """""" + super().__init__(algo_engine, event_engine, False) + + +class SettingMonitor(QtWidgets.QTableWidget): + """""" + setting_signal = QtCore.pyqtSignal(Event) + use_signal = QtCore.pyqtSignal(dict) + + def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine): + """""" + super().__init__() + + self.algo_engine = algo_engine + self.event_engine = event_engine + + self.settings = {} + self.setting_cells = {} + + self.init_ui() + self.register_event() + + def init_ui(self): + """""" + labels = [ + "", + "", + "名称", + "配置" + ] + self.setColumnCount(len(labels)) + self.setHorizontalHeaderLabels(labels) + self.verticalHeader().setVisible(False) + self.setEditTriggers(self.NoEditTriggers) + + self.verticalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.ResizeToContents + ) + + self.horizontalHeader().setSectionResizeMode( + 3, + QtWidgets.QHeaderView.Stretch + ) + self.setWordWrap(True) + + def register_event(self): + """""" + self.setting_signal.connect(self.process_setting_event) + + self.event_engine.register( + EVENT_ALGO_SETTING, self.setting_signal.emit) + + def process_setting_event(self, event): + """""" + data = event.data + setting_name = data["setting_name"] + setting = data["setting"] + cells = self.get_setting_cells(setting_name) + + if setting: + self.settings[setting_name] = setting + + cells["setting"].setText(to_text(setting)) + else: + if setting_name in self.settings: + self.settings.pop(setting_name) + + row = self.row(cells["setting"]) + self.removeRow(row) + + self.setting_cells.pop(setting_name) + + def get_setting_cells(self, setting_name: str): + """""" + cells = self.setting_cells.get(setting_name, None) + + if not cells: + use_func = partial(self.use_setting, setting_name=setting_name) + use_button = QtWidgets.QPushButton("使用") + use_button.clicked.connect(use_func) + + remove_func = partial(self.remove_setting, + setting_name=setting_name) + remove_button = QtWidgets.QPushButton("移除") + remove_button.clicked.connect(remove_func) + + name_cell = QtWidgets.QTableWidgetItem(setting_name) + setting_cell = QtWidgets.QTableWidgetItem() + + self.insertRow(0) + self.setCellWidget(0, 0, use_button) + self.setCellWidget(0, 1, remove_button) + self.setItem(0, 2, name_cell) + self.setItem(0, 3, setting_cell) + + cells = { + "name": name_cell, + "setting": setting_cell + } + self.setting_cells[setting_name] = cells + + return cells + + def use_setting(self, setting_name: str): + """""" + setting = self.settings[setting_name] + setting["setting_name"] = setting_name + self.use_signal.emit(setting) + + def remove_setting(self, setting_name: str): + """""" + self.algo_engine.remove_algo_setting(setting_name) + + +class LogMonitor(QtWidgets.QTableWidget): + """""" + signal = QtCore.pyqtSignal(Event) + + def __init__(self, event_engine: EventEngine): + """""" + super().__init__() + + self.event_engine = event_engine + + self.init_ui() + self.register_event() + + def init_ui(self): + """""" + labels = [ + "时间", + "信息" + ] + self.setColumnCount(len(labels)) + self.setHorizontalHeaderLabels(labels) + self.verticalHeader().setVisible(False) + self.setEditTriggers(self.NoEditTriggers) + + self.verticalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.ResizeToContents + ) + + self.horizontalHeader().setSectionResizeMode( + 1, + QtWidgets.QHeaderView.Stretch + ) + self.setWordWrap(True) + + def register_event(self): + """""" + self.signal.connect(self.process_log_event) + + self.event_engine.register(EVENT_ALGO_LOG, self.signal.emit) + + def process_log_event(self, event): + """""" + msg = event.data + timestamp = datetime.now().strftime("%H:%M:%S") + + timestamp_cell = QtWidgets.QTableWidgetItem(timestamp) + msg_cell = QtWidgets.QTableWidgetItem(msg) + + self.insertRow(0) + self.setItem(0, 0, timestamp_cell) + self.setItem(0, 1, msg_cell) + + +class AlgoManager(QtWidgets.QWidget): + """""" + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__() + + self.main_engine = main_engine + self.event_engine = event_engine + self.algo_engine = main_engine.get_engine(APP_NAME) + + self.algo_widgets = {} + + self.init_ui() + self.algo_engine.init_engine() + + def init_ui(self): + """""" + self.setWindowTitle("算法交易") + + # Left side control widgets + self.template_combo = QtWidgets.QComboBox() + self.template_combo.currentIndexChanged.connect(self.show_algo_widget) + + form = QtWidgets.QFormLayout() + form.addRow("算法", self.template_combo) + widget = QtWidgets.QWidget() + widget.setLayout(form) + + vbox = QtWidgets.QVBoxLayout() + vbox.addWidget(widget) + + for algo_template in self.algo_engine.algo_templates.values(): + widget = AlgoWidget(self.algo_engine, algo_template) + vbox.addWidget(widget) + + template_name = algo_template.__name__ + display_name = algo_template.display_name + + self.algo_widgets[template_name] = widget + self.template_combo.addItem(display_name, template_name) + + vbox.addStretch() + + stop_all_button = QtWidgets.QPushButton("全部停止") + stop_all_button.setFixedHeight(stop_all_button.sizeHint().height() * 2) + stop_all_button.clicked.connect(self.algo_engine.stop_all) + + vbox.addWidget(stop_all_button) + + # Right side monitor widgets + active_algo_monitor = ActiveAlgoMonitor( + self.algo_engine, self.event_engine + ) + inactive_algo_monitor = InactiveAlgoMonitor( + self.algo_engine, self.event_engine + ) + tab1 = QtWidgets.QTabWidget() + tab1.addTab(active_algo_monitor, "执行中") + tab1.addTab(inactive_algo_monitor, "已结束") + + log_monitor = LogMonitor(self.event_engine) + tab2 = QtWidgets.QTabWidget() + tab2.addTab(log_monitor, "日志") + + setting_monitor = SettingMonitor(self.algo_engine, self.event_engine) + setting_monitor.use_signal.connect(self.use_setting) + tab3 = QtWidgets.QTabWidget() + tab3.addTab(setting_monitor, "配置") + + grid = QtWidgets.QGridLayout() + grid.addWidget(tab1, 0, 0, 1, 2) + grid.addWidget(tab2, 1, 0) + grid.addWidget(tab3, 1, 1) + + hbox2 = QtWidgets.QHBoxLayout() + hbox2.addLayout(vbox) + hbox2.addLayout(grid) + self.setLayout(hbox2) + + self.show_algo_widget() + + def show_algo_widget(self): + """""" + ix = self.template_combo.currentIndex() + current_name = self.template_combo.itemData(ix) + + for template_name, widget in self.algo_widgets.items(): + if template_name == current_name: + widget.show() + else: + widget.hide() + + def use_setting(self, setting: dict): + """""" + setting_name = setting["setting_name"] + template_name = setting["template_name"] + + widget = self.algo_widgets[template_name] + widget.update_setting(setting_name, setting) + + ix = self.template_combo.findData(template_name) + self.template_combo.setCurrentIndex(ix) + self.show_algo_widget() + + def show(self): + """""" + self.showMaximized() + + +def to_text(data: dict): + """ + Convert dict data into string. + """ + buf = [] + for key, value in data.items(): + key = NAME_DISPLAY_MAP.get(key, key) + buf.append(f"{key}:{value}") + text = ",".join(buf) + return text diff --git a/vnpy/app/csv_loader/engine.py b/vnpy/app/csv_loader/engine.py index af4a49a8..fa88afe7 100644 --- a/vnpy/app/csv_loader/engine.py +++ b/vnpy/app/csv_loader/engine.py @@ -23,9 +23,11 @@ Sample csv file: import csv from datetime import datetime +from peewee import chunked + from vnpy.event import EventEngine from vnpy.trader.constant import Exchange, Interval -from vnpy.trader.database import DbBarData +from vnpy.trader.database import DbBarData, DB from vnpy.trader.engine import BaseEngine, MainEngine @@ -39,17 +41,17 @@ class CsvLoaderEngine(BaseEngine): """""" super().__init__(main_engine, event_engine, APP_NAME) - self.file_path: str = '' + self.file_path: str = "" self.symbol: str = "" self.exchange: Exchange = Exchange.SSE self.interval: Interval = Interval.MINUTE - self.datetime_head: str = '' - self.open_head: str = '' - self.close_head: str = '' - self.low_head: str = '' - self.high_head: str = '' - self.volume_head: str = '' + self.datetime_head: str = "" + self.open_head: str = "" + self.close_head: str = "" + self.low_head: str = "" + self.high_head: str = "" + self.volume_head: str = "" def load( self, @@ -72,33 +74,40 @@ class CsvLoaderEngine(BaseEngine): end = None count = 0 - with open(file_path, 'rt') as f: + with open(file_path, "rt") as f: reader = csv.DictReader(f) + db_bars = [] + for item in reader: - db_bar = DbBarData() + dt = datetime.strptime(item[datetime_head], datetime_format) - db_bar.symbol = symbol - db_bar.exchange = exchange.value - db_bar.datetime = datetime.strptime( - item[datetime_head], datetime_format - ) - db_bar.interval = interval.value - db_bar.volume = item[volume_head] - db_bar.open_price = item[open_head] - db_bar.high_price = item[high_head] - db_bar.low_price = item[low_head] - db_bar.close_price = item[close_head] - db_bar.vt_symbol = vt_symbol - db_bar.gateway_name = "DB" + db_bar = { + "symbol": symbol, + "exchange": exchange.value, + "datetime": dt, + "interval": interval.value, + "volume": item[volume_head], + "open_price": item[open_head], + "high_price": item[high_head], + "low_price": item[low_head], + "close_price": item[close_head], + "vt_symbol": vt_symbol, + "gateway_name": "DB" + } - db_bar.replace() + db_bars.append(db_bar) # do some statistics count += 1 if not start: - start = db_bar.datetime + start = db_bar["datetime"] - end = db_bar.datetime + end = db_bar["datetime"] + + # Insert into DB + with DB.atomic(): + for batch in chunked(db_bars, 50): + DbBarData.insert_many(batch).on_conflict_replace().execute() return start, end, count diff --git a/vnpy/app/csv_loader/ui/widget.py b/vnpy/app/csv_loader/ui/widget.py index 6e89741c..79ef9e0c 100644 --- a/vnpy/app/csv_loader/ui/widget.py +++ b/vnpy/app/csv_loader/ui/widget.py @@ -90,7 +90,8 @@ class CsvLoaderWidget(QtWidgets.QWidget): def select_file(self): """""" - result: str = QtWidgets.QFileDialog.getOpenFileName(self) + result: str = QtWidgets.QFileDialog.getOpenFileName( + self, filter="CSV (*.csv)") filename = result[0] if filename: self.file_edit.setText(filename) diff --git a/vnpy/app/cta_backtester/__init__.py b/vnpy/app/cta_backtester/__init__.py new file mode 100644 index 00000000..c589e02f --- /dev/null +++ b/vnpy/app/cta_backtester/__init__.py @@ -0,0 +1,17 @@ +from pathlib import Path + +from vnpy.trader.app import BaseApp + +from .engine import BacktesterEngine, APP_NAME + + +class CtaBacktesterApp(BaseApp): + """""" + + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "CTA回测" + engine_class = BacktesterEngine + widget_name = "BacktesterManager" + icon_name = "backtester.ico" diff --git a/vnpy/app/cta_backtester/engine.py b/vnpy/app/cta_backtester/engine.py new file mode 100644 index 00000000..4dc80589 --- /dev/null +++ b/vnpy/app/cta_backtester/engine.py @@ -0,0 +1,200 @@ +import os +import importlib +import traceback +from datetime import datetime +from threading import Thread +from pathlib import Path + +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.app.cta_strategy import ( + CtaTemplate, + BacktestingEngine +) + + +APP_NAME = "CtaBacktester" + +EVENT_BACKTESTER_LOG = "eBacktesterLog" +EVENT_BACKTESTER_BACKTESTING_FINISHED = "eBacktesterBacktestingFinished" +EVENT_BACKTESTER_OPTIMIZATION_FINISHED = "eBacktesterOptimizationFinished" + + +class BacktesterEngine(BaseEngine): + """ + For running CTA strategy backtesting. + """ + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__(main_engine, event_engine, APP_NAME) + + self.classes = {} + self.backtesting_engine = None + self.thread = None + + self.result_df = None + self.result_statistics = None + + self.load_strategy_class() + + def init_engine(self): + """""" + self.write_log("初始化CTA回测引擎") + + self.backtesting_engine = BacktestingEngine() + # Redirect log from backtesting engine outside. + self.backtesting_engine.output = self.write_log + + self.write_log("策略文件加载完成") + + def write_log(self, msg: str): + """""" + event = Event(EVENT_BACKTESTER_LOG) + event.data = msg + self.event_engine.put(event) + + def load_strategy_class(self): + """ + Load strategy class from source code. + """ + app_path = Path(__file__).parent.parent + path1 = app_path.joinpath("cta_strategy", "strategies") + self.load_strategy_class_from_folder( + path1, "vnpy.app.cta_strategy.strategies") + + path2 = Path.cwd().joinpath("strategies") + self.load_strategy_class_from_folder(path2, "strategies") + + def load_strategy_class_from_folder(self, path: Path, module_name: str = ""): + """ + Load strategy class from certain folder. + """ + for dirpath, dirnames, filenames in os.walk(path): + for filename in filenames: + if filename.endswith(".py"): + strategy_module_name = ".".join( + [module_name, filename.replace(".py", "")]) + self.load_strategy_class_from_module(strategy_module_name) + + def load_strategy_class_from_module(self, module_name: str): + """ + Load strategy class from module file. + """ + try: + module = importlib.import_module(module_name) + + for name in dir(module): + value = getattr(module, name) + if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate): + self.classes[value.__name__] = value + except: # noqa + msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}" + self.write_log(msg) + + def get_strategy_class_names(self): + """""" + return list(self.classes.keys()) + + def run_backtesting( + self, + class_name: str, + vt_symbol: str, + interval: str, + start: datetime, + end: datetime, + rate: float, + slippage: float, + size: int, + pricetick: float, + capital: int, + setting: dict + ): + """""" + self.result_df = None + self.result_statistics = None + + engine = self.backtesting_engine + engine.clear_data() + + engine.set_parameters( + vt_symbol=vt_symbol, + interval=interval, + start=start, + end=end, + rate=rate, + slippage=slippage, + size=size, + pricetick=pricetick, + capital=capital + ) + + strategy_class = self.classes[class_name] + engine.add_strategy( + strategy_class, + setting + ) + + engine.load_data() + engine.run_backtesting() + self.result_df = engine.calculate_result() + self.result_statistics = engine.calculate_statistics(output=False) + + # Clear thread object handler. + self.thread = None + + # Put backtesting done event + event = Event(EVENT_BACKTESTER_BACKTESTING_FINISHED) + self.event_engine.put(event) + + def start_backtesting( + self, + class_name: str, + vt_symbol: str, + interval: str, + start: datetime, + end: datetime, + rate: float, + slippage: float, + size: int, + pricetick: float, + capital: int, + setting: dict + ): + if self.thread: + self.write_log("已有回测在运行中,请等待完成") + return False + + self.write_log("-" * 40) + self.thread = Thread( + target=self.run_backtesting, + args=( + class_name, + vt_symbol, + interval, + start, + end, + rate, + slippage, + size, + pricetick, + capital, + setting + ) + ) + self.thread.start() + + return True + + def get_result_df(self): + """""" + return self.result_df + + def get_result_statistics(self): + """""" + return self.result_statistics + + def get_default_setting(self, class_name: str): + """""" + strategy_class = self.classes[class_name] + return strategy_class.get_class_parameters() diff --git a/vnpy/app/cta_backtester/ui/__init__.py b/vnpy/app/cta_backtester/ui/__init__.py new file mode 100644 index 00000000..a02dafb3 --- /dev/null +++ b/vnpy/app/cta_backtester/ui/__init__.py @@ -0,0 +1 @@ +from .widget import BacktesterManager diff --git a/vnpy/app/cta_backtester/ui/backtester.ico b/vnpy/app/cta_backtester/ui/backtester.ico new file mode 100644 index 00000000..647b8ca9 Binary files /dev/null and b/vnpy/app/cta_backtester/ui/backtester.ico differ diff --git a/vnpy/app/cta_backtester/ui/widget.py b/vnpy/app/cta_backtester/ui/widget.py new file mode 100644 index 00000000..d8cde991 --- /dev/null +++ b/vnpy/app/cta_backtester/ui/widget.py @@ -0,0 +1,476 @@ +from datetime import datetime, timedelta + +import pyqtgraph as pg +import numpy as np + +from vnpy.event import Event, EventEngine +from vnpy.trader.ui import QtCore, QtWidgets, QtGui +from vnpy.trader.engine import MainEngine +from vnpy.trader.constant import Interval + +from ..engine import ( + APP_NAME, + EVENT_BACKTESTER_LOG, + EVENT_BACKTESTER_BACKTESTING_FINISHED, + EVENT_BACKTESTER_OPTIMIZATION_FINISHED +) + + +class BacktesterManager(QtWidgets.QWidget): + """""" + + signal_log = QtCore.pyqtSignal(Event) + signal_backtesting_finished = QtCore.pyqtSignal(Event) + signal_optimization_finished = QtCore.pyqtSignal(Event) + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__() + + self.main_engine = main_engine + self.event_engine = event_engine + + self.backtester_engine = main_engine.get_engine(APP_NAME) + self.class_names = [] + self.settings = {} + + self.init_strategy_settings() + self.init_ui() + self.register_event() + self.backtester_engine.init_engine() + + def init_strategy_settings(self): + """""" + self.class_names = self.backtester_engine.get_strategy_class_names() + + for class_name in self.class_names: + setting = self.backtester_engine.get_default_setting(class_name) + self.settings[class_name] = setting + + def init_ui(self): + """""" + self.setWindowTitle("CTA回测") + + # Setting Part + self.class_combo = QtWidgets.QComboBox() + self.class_combo.addItems(self.class_names) + + self.symbol_line = QtWidgets.QLineEdit("IF88.CFFEX") + + self.interval_combo = QtWidgets.QComboBox() + for inteval in Interval: + self.interval_combo.addItem(inteval.value) + + end_dt = datetime.now() + start_dt = end_dt - timedelta(days=3 * 365) + + self.start_date_edit = QtWidgets.QDateEdit( + QtCore.QDate( + start_dt.year, + start_dt.month, + start_dt.day + ) + ) + self.end_date_edit = QtWidgets.QDateEdit( + QtCore.QDate.currentDate() + ) + + self.rate_line = QtWidgets.QLineEdit("0.000025") + self.slippage_line = QtWidgets.QLineEdit("0.2") + self.size_line = QtWidgets.QLineEdit("300") + self.pricetick_line = QtWidgets.QLineEdit("0.2") + self.capital_line = QtWidgets.QLineEdit("1000000") + + start_button = QtWidgets.QPushButton("开始回测") + start_button.clicked.connect(self.start_backtesting) + + form = QtWidgets.QFormLayout() + form.addRow("交易策略", self.class_combo) + form.addRow("本地代码", self.symbol_line) + form.addRow("K线周期", self.interval_combo) + form.addRow("开始日期", self.start_date_edit) + form.addRow("结束日期", self.end_date_edit) + form.addRow("手续费率", self.rate_line) + form.addRow("交易滑点", self.slippage_line) + form.addRow("合约乘数", self.size_line) + form.addRow("价格跳动", self.pricetick_line) + form.addRow("回测资金", self.capital_line) + form.addRow(start_button) + + # Result part + self.statistics_monitor = StatisticsMonitor() + + self.log_monitor = QtWidgets.QTextEdit() + self.log_monitor.setMaximumHeight(400) + + self.chart = BacktesterChart() + self.chart.setMinimumWidth(1000) + + # Layout + vbox = QtWidgets.QVBoxLayout() + vbox.addWidget(self.statistics_monitor) + vbox.addWidget(self.log_monitor) + + hbox = QtWidgets.QHBoxLayout() + hbox.addLayout(form) + hbox.addLayout(vbox) + hbox.addWidget(self.chart) + self.setLayout(hbox) + + def register_event(self): + """""" + self.signal_log.connect(self.process_log_event) + self.signal_backtesting_finished.connect( + self.process_backtesting_finished_event) + self.signal_optimization_finished.connect( + self.process_optimization_finished_event) + + self.event_engine.register(EVENT_BACKTESTER_LOG, self.signal_log.emit) + self.event_engine.register( + EVENT_BACKTESTER_BACKTESTING_FINISHED, self.signal_backtesting_finished.emit) + self.event_engine.register( + EVENT_BACKTESTER_OPTIMIZATION_FINISHED, self.signal_optimization_finished.emit) + + def process_log_event(self, event: Event): + """""" + msg = event.data + timestamp = datetime.now().strftime("%H:%M:%S") + msg = f"{timestamp}\t{msg}" + self.log_monitor.append(msg) + + def process_backtesting_finished_event(self, event: Event): + """""" + statistics = self.backtester_engine.get_result_statistics() + self.statistics_monitor.set_data(statistics) + + df = self.backtester_engine.get_result_df() + self.chart.set_data(df) + + def process_optimization_finished_event(self, event: Event): + """""" + pass + + def start_backtesting(self): + """""" + class_name = self.class_combo.currentText() + vt_symbol = self.symbol_line.text() + interval = self.interval_combo.currentText() + start = self.start_date_edit.date().toPyDate() + end = self.end_date_edit.date().toPyDate() + rate = float(self.rate_line.text()) + slippage = float(self.slippage_line.text()) + size = float(self.size_line.text()) + pricetick = float(self.pricetick_line.text()) + capital = float(self.capital_line.text()) + + old_setting = self.settings[class_name] + dialog = SettingEditor(class_name, old_setting) + i = dialog.exec() + if i != dialog.Accepted: + return + + new_setting = dialog.get_setting() + self.settings[class_name] = new_setting + + result = self.backtester_engine.start_backtesting( + class_name, + vt_symbol, + interval, + start, + end, + rate, + slippage, + size, + pricetick, + capital, + new_setting + ) + + if result: + self.statistics_monitor.clear_data() + self.chart.clear_data() + + def show(self): + """""" + self.showMaximized() + + +class StatisticsMonitor(QtWidgets.QTableWidget): + """""" + KEY_NAME_MAP = { + "start_date": "首个交易日", + "end_date": "最后交易日", + + "total_days": "总交易日", + "profit_days": "盈利交易日", + "loss_days": "亏损交易日", + + "capital": "起始资金", + "end_balance": "结束资金", + + "total_return": "总收益率", + "annual_return": "年化收益", + "max_drawdown": "最大回撤", + "max_ddpercent": "百分比最大回撤", + + "total_net_pnl": "总盈亏", + "total_commission": "总手续费", + "total_slippage": "总滑点", + "total_turnover": "总成交额", + "total_trade_count": "总成交笔数", + + "daily_net_pnl": "日均盈亏", + "daily_commission": "日均手续费", + "daily_slippage": "日均滑点", + "daily_turnover": "日均成交额", + "daily_trade_count": "日均成交笔数", + + "daily_return": "日均收益率", + "return_std": "收益标准差", + "sharpe_ratio": "夏普比率", + "return_drawdown_ratio": "收益回撤比" + } + + def __init__(self): + """""" + super().__init__() + + self.cells = {} + + self.init_ui() + + def init_ui(self): + """""" + self.setRowCount(len(self.KEY_NAME_MAP)) + self.setVerticalHeaderLabels(list(self.KEY_NAME_MAP.values())) + + self.setColumnCount(1) + self.horizontalHeader().setVisible(False) + self.horizontalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.Stretch + ) + + for row, key in enumerate(self.KEY_NAME_MAP.keys()): + cell = QtWidgets.QTableWidgetItem() + self.setItem(row, 0, cell) + self.cells[key] = cell + + def clear_data(self): + """""" + for cell in self.cells.values(): + cell.setText("") + + def set_data(self, data: dict): + """""" + data["capital"] = f"{data['capital']:,.2f}" + data["end_balance"] = f"{data['end_balance']:,.2f}" + data["total_return"] = f"{data['total_return']:,.2f}%" + data["annual_return"] = f"{data['annual_return']:,.2f}%" + data["max_drawdown"] = f"{data['max_drawdown']:,.2f}" + data["max_ddpercent"] = f"{data['max_ddpercent']:,.2f}%" + data["total_net_pnl"] = f"{data['total_net_pnl']:,.2f}" + data["total_commission"] = f"{data['total_commission']:,.2f}" + data["total_slippage"] = f"{data['total_slippage']:,.2f}" + data["total_turnover"] = f"{data['total_turnover']:,.2f}" + data["daily_net_pnl"] = f"{data['daily_net_pnl']:,.2f}" + data["daily_commission"] = f"{data['daily_commission']:,.2f}" + data["daily_slippage"] = f"{data['daily_slippage']:,.2f}" + data["daily_turnover"] = f"{data['daily_turnover']:,.2f}" + data["daily_return"] = f"{data['daily_return']:,.2f}%" + data["return_std"] = f"{data['return_std']:,.2f}%" + data["sharpe_ratio"] = f"{data['sharpe_ratio']:,.2f}" + data["return_drawdown_ratio"] = f"{data['return_drawdown_ratio']:,.2f}" + + for key, cell in self.cells.items(): + value = data.get(key, "") + cell.setText(str(value)) + + +class SettingEditor(QtWidgets.QDialog): + """ + For creating new strategy and editing strategy parameters. + """ + + def __init__( + self, class_name: str, parameters: dict + ): + """""" + super(SettingEditor, self).__init__() + + self.class_name = class_name + self.parameters = parameters + self.edits = {} + + self.init_ui() + + def init_ui(self): + """""" + form = QtWidgets.QFormLayout() + + # Add vt_symbol and name edit if add new strategy + self.setWindowTitle(f"策略参数配置:{self.class_name}") + button_text = "确定" + parameters = self.parameters + + for name, value in parameters.items(): + type_ = type(value) + + edit = QtWidgets.QLineEdit(str(value)) + if type_ is int: + validator = QtGui.QIntValidator() + edit.setValidator(validator) + elif type_ is float: + validator = QtGui.QDoubleValidator() + edit.setValidator(validator) + + form.addRow(f"{name} {type_}", edit) + + self.edits[name] = (edit, type_) + + button = QtWidgets.QPushButton(button_text) + button.clicked.connect(self.accept) + form.addRow(button) + + self.setLayout(form) + + def get_setting(self): + """""" + setting = {} + + for name, tp in self.edits.items(): + edit, type_ = tp + value_text = edit.text() + + if type_ == bool: + if value_text == "True": + value = True + else: + value = False + else: + value = type_(value_text) + + setting[name] = value + + return setting + + +class BacktesterChart(pg.GraphicsWindow): + """""" + + def __init__(self): + """""" + super().__init__(title="Backtester Chart") + + self.dates = {} + + self.init_ui() + + def init_ui(self): + """""" + pg.setConfigOptions(antialias=True) + + # Create plot widgets + self.balance_plot = self.addPlot( + title="账户净值", + axisItems={"bottom": DateAxis(self.dates, orientation="bottom")} + ) + self.nextRow() + + self.drawdown_plot = self.addPlot( + title="净值回撤", + axisItems={"bottom": DateAxis(self.dates, orientation="bottom")} + ) + self.nextRow() + + self.pnl_plot = self.addPlot( + title="每日盈亏", + axisItems={"bottom": DateAxis(self.dates, orientation="bottom")} + ) + self.nextRow() + + self.distribution_plot = self.addPlot(title="盈亏分布") + + # Add curves and bars on plot widgets + self.balance_curve = self.balance_plot.plot( + pen=pg.mkPen("#ffc107", width=3) + ) + + dd_color = "#303f9f" + self.drawdown_curve = self.drawdown_plot.plot( + fillLevel=-0.3, brush=dd_color, pen=dd_color + ) + + profit_color = 'r' + loss_color = 'g' + self.profit_pnl_bar = pg.BarGraphItem( + x=[], height=[], width=0.3, brush=profit_color, pen=profit_color + ) + self.loss_pnl_bar = pg.BarGraphItem( + x=[], height=[], width=0.3, brush=loss_color, pen=loss_color + ) + self.pnl_plot.addItem(self.profit_pnl_bar) + self.pnl_plot.addItem(self.loss_pnl_bar) + + distribution_color = "#6d4c41" + self.distribution_curve = self.distribution_plot.plot( + fillLevel=-0.3, brush=distribution_color, pen=distribution_color + ) + + def clear_data(self): + """""" + self.balance_curve.setData([], []) + self.drawdown_curve.setData([], []) + self.profit_pnl_bar.setOpts(x=[], height=[]) + self.loss_pnl_bar.setOpts(x=[], height=[]) + self.distribution_curve.setData([], []) + + def set_data(self, df): + """""" + count = len(df) + + self.dates.clear() + for n, date in enumerate(df.index): + self.dates[n] = date + + # Set data for curve of balance and drawdown + self.balance_curve.setData(df["balance"]) + self.drawdown_curve.setData(df["drawdown"]) + + # Set data for daily pnl bar + profit_pnl_x = [] + profit_pnl_height = [] + loss_pnl_x = [] + loss_pnl_height = [] + + for count, pnl in enumerate(df["net_pnl"]): + if pnl >= 0: + profit_pnl_height.append(pnl) + profit_pnl_x.append(count) + else: + loss_pnl_height.append(pnl) + loss_pnl_x.append(count) + + self.profit_pnl_bar.setOpts(x=profit_pnl_x, height=profit_pnl_height) + self.loss_pnl_bar.setOpts(x=loss_pnl_x, height=loss_pnl_height) + + # Set data for pnl distribution + hist, x = np.histogram(df["net_pnl"], bins="auto") + x = x[:-1] + self.distribution_curve.setData(x, hist) + + +class DateAxis(pg.AxisItem): + """Axis for showing date data""" + + def __init__(self, dates: dict, *args, **kwargs): + """""" + super().__init__(*args, **kwargs) + self.dates = dates + + def tickStrings(self, values, scale, spacing): + """""" + strings = [] + for v in values: + dt = self.dates.get(v, "") + strings.append(str(dt)) + return strings diff --git a/vnpy/app/cta_strategy/__init__.py b/vnpy/app/cta_strategy/__init__.py index 9d079f9b..e753746b 100644 --- a/vnpy/app/cta_strategy/__init__.py +++ b/vnpy/app/cta_strategy/__init__.py @@ -7,6 +7,7 @@ from vnpy.trader.utility import BarGenerator, ArrayManager from .base import APP_NAME, StopOrder from .engine import CtaEngine +from .backtesting import BacktestingEngine, OptimizationSetting from .template import CtaTemplate, CtaSignal, TargetPosTemplate diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index 4b5077aa..c1ba8e67 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -2,6 +2,7 @@ from collections import defaultdict from datetime import date, datetime from typing import Callable from itertools import product +from functools import lru_cache import multiprocessing import numpy as np @@ -197,28 +198,18 @@ class BacktestingEngine: self.output("开始加载历史数据") if self.mode == BacktestingMode.BAR: - s = ( - DbBarData.select() - .where( - (DbBarData.vt_symbol == self.vt_symbol) - & (DbBarData.interval == self.interval) - & (DbBarData.datetime >= self.start) - & (DbBarData.datetime <= self.end) - ) - .order_by(DbBarData.datetime) + self.history_data = load_bar_data( + self.vt_symbol, + self.interval, + self.start, + self.end ) - self.history_data = [db_bar.to_bar() for db_bar in s] else: - s = ( - DbTickData.select() - .where( - (DbTickData.vt_symbol == self.vt_symbol) - & (DbTickData.datetime >= self.start) - & (DbTickData.datetime <= self.end) - ) - .order_by(DbTickData.datetime) + self.history_data = load_tick_data( + self.vt_symbol, + self.start, + self.end ) - self.history_data = [db_tick.to_tick() for db_tick in s] self.output(f"历史数据加载完成,数据量:{len(self.history_data)}") @@ -293,7 +284,7 @@ class BacktestingEngine: self.output("逐日盯市盈亏计算完成") return self.daily_df - def calculate_statistics(self, df: DataFrame = None): + def calculate_statistics(self, df: DataFrame = None, output=True): """""" self.output("开始计算策略统计指标") @@ -325,6 +316,7 @@ class BacktestingEngine: daily_return = 0 return_std = 0 sharpe_ratio = 0 + return_drawdown_ratio = 0 else: # Calculate balance related time series data df["balance"] = df["net_pnl"].cumsum() + self.capital @@ -373,38 +365,42 @@ class BacktestingEngine: else: sharpe_ratio = 0 + return_drawdown_ratio = -total_return / max_ddpercent + # Output - self.output("-" * 30) - self.output(f"首个交易日:\t{start_date}") - self.output(f"最后交易日:\t{end_date}") + if output: + self.output("-" * 30) + self.output(f"首个交易日:\t{start_date}") + self.output(f"最后交易日:\t{end_date}") - self.output(f"总交易日:\t{total_days}") - self.output(f"盈利交易日:\t{profit_days}") - self.output(f"亏损交易日:\t{loss_days}") + self.output(f"总交易日:\t{total_days}") + self.output(f"盈利交易日:\t{profit_days}") + self.output(f"亏损交易日:\t{loss_days}") - self.output(f"起始资金:\t{self.capital:,.2f}") - self.output(f"结束资金:\t{end_balance:,.2f}") + self.output(f"起始资金:\t{self.capital:,.2f}") + self.output(f"结束资金:\t{end_balance:,.2f}") - self.output(f"总收益率:\t{total_return:,.2f}%") - self.output(f"年化收益:\t{annual_return:,.2f}%") - self.output(f"最大回撤: \t{max_drawdown:,.2f}") - self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") + self.output(f"总收益率:\t{total_return:,.2f}%") + self.output(f"年化收益:\t{annual_return:,.2f}%") + self.output(f"最大回撤: \t{max_drawdown:,.2f}") + self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") - self.output(f"总盈亏:\t{total_net_pnl:,.2f}") - self.output(f"总手续费:\t{total_commission:,.2f}") - self.output(f"总滑点:\t{total_slippage:,.2f}") - self.output(f"总成交金额:\t{total_turnover:,.2f}") - self.output(f"总成交笔数:\t{total_trade_count}") + self.output(f"总盈亏:\t{total_net_pnl:,.2f}") + self.output(f"总手续费:\t{total_commission:,.2f}") + self.output(f"总滑点:\t{total_slippage:,.2f}") + self.output(f"总成交金额:\t{total_turnover:,.2f}") + self.output(f"总成交笔数:\t{total_trade_count}") - self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}") - self.output(f"日均手续费:\t{daily_commission:,.2f}") - self.output(f"日均滑点:\t{daily_slippage:,.2f}") - self.output(f"日均成交金额:\t{daily_turnover:,.2f}") - self.output(f"日均成交笔数:\t{daily_trade_count}") + self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}") + self.output(f"日均手续费:\t{daily_commission:,.2f}") + self.output(f"日均滑点:\t{daily_slippage:,.2f}") + self.output(f"日均成交金额:\t{daily_turnover:,.2f}") + self.output(f"日均成交笔数:\t{daily_trade_count}") - self.output(f"日均收益率:\t{daily_return:,.2f}%") - self.output(f"收益标准差:\t{return_std:,.2f}%") - self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}") + self.output(f"日均收益率:\t{daily_return:,.2f}%") + self.output(f"收益标准差:\t{return_std:,.2f}%") + self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}") + self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}") statistics = { "start_date": start_date, @@ -412,6 +408,7 @@ class BacktestingEngine: "total_days": total_days, "profit_days": profit_days, "loss_days": loss_days, + "capital": self.capital, "end_balance": end_balance, "max_drawdown": max_drawdown, "max_ddpercent": max_ddpercent, @@ -430,6 +427,7 @@ class BacktestingEngine: "daily_return": daily_return, "return_std": return_std, "sharpe_ratio": sharpe_ratio, + "return_drawdown_ratio": return_drawdown_ratio, } return statistics @@ -963,3 +961,45 @@ def optimize( target_value = statistics[target_name] return (str(setting), target_value, statistics) + + +@lru_cache(maxsize=10) +def load_bar_data( + vt_symbol: str, + interval: str, + start: datetime, + end: datetime +): + """""" + s = ( + DbBarData.select() + .where( + (DbBarData.vt_symbol == vt_symbol) + & (DbBarData.interval == interval) + & (DbBarData.datetime >= start) + & (DbBarData.datetime <= end) + ) + .order_by(DbBarData.datetime) + ) + data = [db_bar.to_bar() for db_bar in s] + return data + + +@lru_cache(maxsize=10) +def load_tick_data( + vt_symbol: str, + start: datetime, + end: datetime +): + """""" + s = ( + DbTickData.select() + .where( + (DbTickData.vt_symbol == vt_symbol) + & (DbTickData.datetime >= start) + & (DbTickData.datetime <= end) + ) + .order_by(DbTickData.datetime) + ) + data = [db_tick.db_tick() for db_tick in s] + return data \ No newline at end of file diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index cb29a652..37e5973a 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -40,6 +40,7 @@ from vnpy.trader.database import DbTickData, DbBarData from vnpy.trader.setting import SETTINGS from .base import ( + APP_NAME, EVENT_CTA_LOG, EVENT_CTA_STRATEGY, EVENT_CTA_STOPORDER, @@ -73,7 +74,7 @@ class CtaEngine(BaseEngine): def __init__(self, main_engine: MainEngine, event_engine: EventEngine): """""" super(CtaEngine, self).__init__( - main_engine, event_engine, "CtaStrategy") + main_engine, event_engine, APP_NAME) self.strategy_setting = {} # strategy_name: dict self.strategy_data = {} # strategy_name: dict @@ -541,7 +542,7 @@ class CtaEngine(BaseEngine): DbBarData.select() .where( (DbBarData.vt_symbol == vt_symbol) - & (DbBarData.interval == interval) + & (DbBarData.interval == interval.value) & (DbBarData.datetime >= start) & (DbBarData.datetime <= end) ) diff --git a/vnpy/app/cta_strategy/template.py b/vnpy/app/cta_strategy/template.py index 20cf88d2..0fde32c0 100644 --- a/vnpy/app/cta_strategy/template.py +++ b/vnpy/app/cta_strategy/template.py @@ -4,6 +4,7 @@ from typing import Any, Callable from vnpy.trader.constant import Interval, Direction, Offset from vnpy.trader.object import BarData, TickData, OrderData, TradeData +from vnpy.trader.utility import virtual from .base import StopOrder, EngineType @@ -87,48 +88,56 @@ class CtaTemplate(ABC): } return strategy_data + @virtual def on_init(self): """ Callback when strategy is inited. """ pass + @virtual def on_start(self): """ Callback when strategy is started. """ pass + @virtual def on_stop(self): """ Callback when strategy is stopped. """ pass + @virtual def on_tick(self, tick: TickData): """ Callback of new tick data update. """ pass + @virtual def on_bar(self, bar: BarData): """ Callback of new bar data update. """ pass + @virtual def on_trade(self, trade: TradeData): """ Callback of new trade data update. """ pass + @virtual def on_order(self, order: OrderData): """ Callback of new order data update. """ pass + @virtual def on_stop_order(self, stop_order: StopOrder): """ Callback of stop order update. @@ -255,12 +264,14 @@ class CtaSignal(ABC): """""" self.signal_pos = 0 + @virtual def on_tick(self, tick: TickData): """ Callback of new tick data update. """ pass + @virtual def on_bar(self, bar: BarData): """ Callback of new bar data update. @@ -292,6 +303,7 @@ class TargetPosTemplate(CtaTemplate): ) self.variables.append("target_pos") + @virtual def on_tick(self, tick: TickData): """ Callback of new tick data update. @@ -301,12 +313,14 @@ class TargetPosTemplate(CtaTemplate): if self.trading: self.trade() + @virtual def on_bar(self, bar: BarData): """ Callback of new bar data update. """ self.last_bar = bar + @virtual def on_order(self, order: OrderData): """ Callback of new order data update. diff --git a/vnpy/gateway/bitmex/bitmex_gateway.py b/vnpy/gateway/bitmex/bitmex_gateway.py index 3c3e6879..11579c4b 100644 --- a/vnpy/gateway/bitmex/bitmex_gateway.py +++ b/vnpy/gateway/bitmex/bitmex_gateway.py @@ -71,8 +71,8 @@ class BitmexGateway(BaseGateway): "Secret": "", "会话数": 3, "服务器": ["REAL", "TESTNET"], - "代理地址": "127.0.0.1", - "代理端口": 1080, + "代理地址": "", + "代理端口": "", } def __init__(self, event_engine): @@ -91,6 +91,11 @@ class BitmexGateway(BaseGateway): proxy_host = setting["代理地址"] proxy_port = setting["代理端口"] + if proxy_port.isdigit(): + proxy_port = int(proxy_port) + else: + proxy_port = 0 + self.rest_api.connect(key, secret, session_number, server, proxy_host, proxy_port) diff --git a/vnpy/gateway/ctp/ctp_gateway.py b/vnpy/gateway/ctp/ctp_gateway.py index fe2784a7..6dd92450 100644 --- a/vnpy/gateway/ctp/ctp_gateway.py +++ b/vnpy/gateway/ctp/ctp_gateway.py @@ -405,7 +405,7 @@ class CtpTdApi(TdApi): """""" if not error['ErrorID']: self.authStatus = True - self.writeLog("交易授权验证成功") + self.gateway.write_log("交易授权验证成功") self.login() else: self.gateway.write_error("交易授权验证失败", error) @@ -418,7 +418,7 @@ class CtpTdApi(TdApi): self.login_status = True self.gateway.write_log("交易登录成功") - # Confirm settelment + # Confirm settlement req = { "BrokerID": self.brokerid, "InvestorID": self.userid @@ -487,11 +487,6 @@ class CtpTdApi(TdApi): ) self.positions[key] = position - # Get contract size, return if size value not collected - size = symbol_size_map.get(position.symbol, None) - if not size: - return - # For SHFE position data update if position.exchange == Exchange.SHFE: if data["YdPosition"] and not data["TodayPosition"]: @@ -500,6 +495,9 @@ class CtpTdApi(TdApi): else: position.yd_volume = data["Position"] - data["TodayPosition"] + # Get contract size (spread contract has no size value) + size = symbol_size_map.get(position.symbol, 0) + # Calculate previous position cost cost = position.price * position.volume * size @@ -508,7 +506,7 @@ class CtpTdApi(TdApi): position.pnl += data["PositionProfit"] # Calculate average position price - if position.volume: + if position.volume and size: cost += data["PositionCost"] position.price = cost / (position.volume * size) @@ -549,13 +547,16 @@ class CtpTdApi(TdApi): product=product, size=data["VolumeMultiple"], pricetick=data["PriceTick"], - option_underlying=data["UnderlyingInstrID"], - option_type=OPTIONTYPE_CTP2VT.get(data["OptionsType"], None), - option_strike=data["StrikePrice"], - option_expiry=datetime.strptime(data["ExpireDate"], "%Y%m%d"), gateway_name=self.gateway_name ) + # For option only + if contract.product == Product.OPTION: + contract.option_underlying = data["UnderlyingInstrID"], + contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None), + contract.option_strike = data["StrikePrice"], + contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"), + self.gateway.on_contract(contract) symbol_exchange_map[contract.symbol] = contract.exchange @@ -662,7 +663,7 @@ class CtpTdApi(TdApi): "UserID": self.userid, "BrokerID": self.brokerid, "AuthCode": self.auth_code, - "ProductInfo": self.product_info + "UserProductInfo": self.product_info } self.reqid += 1 @@ -678,7 +679,8 @@ class CtpTdApi(TdApi): req = { "UserID": self.userid, "Password": self.password, - "BrokerID": self.brokerid + "BrokerID": self.brokerid, + "UserProductInfo": self.product_info } self.reqid += 1 diff --git a/vnpy/gateway/huobi/__init__.py b/vnpy/gateway/huobi/__init__.py new file mode 100644 index 00000000..9e3d74e8 --- /dev/null +++ b/vnpy/gateway/huobi/__init__.py @@ -0,0 +1,4 @@ +from .huobi_gateway import HuobiGateway + + + diff --git a/vnpy/gateway/huobi/huobi_gateway.py b/vnpy/gateway/huobi/huobi_gateway.py new file mode 100644 index 00000000..474e6f69 --- /dev/null +++ b/vnpy/gateway/huobi/huobi_gateway.py @@ -0,0 +1,747 @@ +# encoding: UTF-8 + +""" +火币交易接口 +""" + +import re +import urllib +import base64 +import json +import zlib +import hashlib +import hmac +from copy import copy +from datetime import datetime + +from vnpy.event import Event +from vnpy.api.rest import RestClient, Request +from vnpy.api.websocket import WebsocketClient +from vnpy.trader.constant import ( + Direction, + Exchange, + Product, + Status, + OrderType +) +from vnpy.trader.gateway import BaseGateway, LocalOrderManager +from vnpy.trader.object import ( + TickData, + OrderData, + TradeData, + AccountData, + ContractData, + OrderRequest, + CancelRequest, + SubscribeRequest +) +from vnpy.trader.event import EVENT_TIMER + + +REST_HOST = "https://api.huobipro.com" +WEBSOCKET_DATA_HOST = "wss://api.huobi.pro/ws" # Market Data +WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # Account and Order + +STATUS_HUOBI2VT = { + "submitted": Status.NOTTRADED, + "partial-filled": Status.PARTTRADED, + "filled": Status.ALLTRADED, + "cancelling": Status.CANCELLED, + "partial-canceled": Status.CANCELLED, + "canceled": Status.CANCELLED, +} + +ORDERTYPE_VT2HUOBI = { + (Direction.LONG, OrderType.MARKET): "buy-market", + (Direction.SHORT, OrderType.MARKET): "sell-market", + (Direction.LONG, OrderType.LIMIT): "buy-limit", + (Direction.SHORT, OrderType.LIMIT): "sell-limit", +} +ORDERTYPE_HUOBI2VT = {v: k for k, v in ORDERTYPE_VT2HUOBI.items()} + + +huobi_symbols = set() +symbol_name_map = {} + + +class HuobiGateway(BaseGateway): + """ + VN Trader Gateway for Huobi connection. + """ + + default_setting = { + "API Key": "", + "Secret Key": "", + "会话数": 3, + "代理地址": "", + "代理端口": "", + } + + def __init__(self, event_engine): + """Constructor""" + super(HuobiGateway, self).__init__(event_engine, "HUOBI") + + self.order_manager = LocalOrderManager(self) + + self.rest_api = HuobiRestApi(self) + self.trade_ws_api = HuobiTradeWebsocketApi(self) + self.market_ws_api = HuobiDataWebsocketApi(self) + + def connect(self, setting: dict): + """""" + key = setting["API Key"] + secret = setting["Secret Key"] + session_number = setting["会话数"] + proxy_host = setting["代理地址"] + proxy_port = setting["代理端口"] + + if proxy_port.isdigit(): + proxy_port = int(proxy_port) + else: + proxy_port = 0 + + self.rest_api.connect(key, secret, session_number, + proxy_host, proxy_port) + self.trade_ws_api.connect(key, secret, proxy_host, proxy_port) + self.market_ws_api.connect(key, secret, proxy_host, proxy_port) + + self.init_query() + + def subscribe(self, req: SubscribeRequest): + """""" + self.market_ws_api.subscribe(req) + self.trade_ws_api.subscribe(req) + + def send_order(self, req: OrderRequest): + """""" + return self.rest_api.send_order(req) + + def cancel_order(self, req: CancelRequest): + """""" + self.rest_api.cancel_order(req) + + def query_account(self): + """""" + self.rest_api.query_account_balance() + + def query_position(self): + """""" + pass + + def close(self): + """""" + self.rest_api.stop() + self.trade_ws_api.stop() + self.market_ws_api.stop() + + def process_timer_event(self, event: Event): + """""" + self.count += 1 + if self.count < 3: + return + + self.query_account() + + def init_query(self): + """""" + self.count = 0 + self.event_engine.register(EVENT_TIMER, self.process_timer_event) + + +class HuobiRestApi(RestClient): + """ + HUOBI REST API + """ + + def __init__(self, gateway: BaseGateway): + """""" + super(HuobiRestApi, self).__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + self.order_manager = gateway.order_manager + + self.host = "" + self.key = "" + self.secret = "" + self.account_id = "" + + self.cancel_requests = {} + self.orders = {} + + def sign(self, request): + """ + Generate HUOBI signature. + """ + request.headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.71 Safari/537.36" + } + params_with_signature = create_signature( + self.key, + request.method, + self.host, + request.path, + self.secret, + request.params + ) + request.params = params_with_signature + + if request.method == "POST": + request.headers["Content-Type"] = "application/json" + + if request.data: + request.data = json.dumps(request.data) + + return request + + def connect( + self, + key: str, + secret: str, + session_number: int, + proxy_host: str, + proxy_port: int + ): + """ + Initialize connection to REST server. + """ + self.key = key + self.secret = secret + + self.host, _ = _split_url(REST_HOST) + + self.init(REST_HOST, proxy_host, proxy_port) + self.start(session_number) + + self.gateway.write_log("REST API启动成功") + + self.query_contract() + self.query_account() + self.query_order() + + def query_account(self): + """""" + self.add_request( + method="GET", + path="/v1/account/accounts", + callback=self.on_query_account + ) + + def query_account_balance(self): + """""" + path = f"/v1/account/accounts/{self.account_id}/balance" + self.add_request( + method="GET", + path=path, + callback=self.on_query_account_balance + ) + + def query_order(self): + """""" + self.add_request( + method="GET", + path="/v1/order/openOrders", + callback=self.on_query_order + ) + + def query_contract(self): + """""" + self.add_request( + method="GET", + path="/v1/common/symbols", + callback=self.on_query_contract + ) + + def send_order(self, req: OrderRequest): + """""" + huobi_type = ORDERTYPE_VT2HUOBI.get( + (req.direction, req.type), "" + ) + + local_orderid = self.order_manager.new_local_orderid() + order = req.create_order_data( + local_orderid, + self.gateway_name + ) + order.time = datetime.now().strftime("%H:%M:%S") + + data = { + "account-id": self.account_id, + "amount": str(req.volume), + "symbol": req.symbol, + "type": huobi_type, + "price": str(req.price), + "source": "api" + } + + self.add_request( + method="POST", + path="/v1/order/orders/place", + callback=self.on_send_order, + data=data, + extra=order, + on_error=self.on_send_order_error, + on_failed=self.on_send_order_failed + ) + + self.order_manager.on_order(order) + return order.vt_orderid + + def cancel_order(self, req: CancelRequest): + """""" + sys_orderid = self.order_manager.get_sys_orderid(req.orderid) + + path = f"/v1/order/orders/{sys_orderid}/submitcancel" + self.add_request( + method="POST", + path=path, + callback=self.on_cancel_order, + extra=req + ) + + def on_query_account(self, data, request): + """""" + if self.check_error(data, "查询账户"): + return + + for d in data["data"]: + if d["type"] == "spot": + self.account_id = d["id"] + self.gateway.write_log(f"账户代码{self.account_id}查询成功") + + self.query_account_balance() + + def on_query_account_balance(self, data, request): + """""" + if self.check_error(data, "查询账户资金"): + return + + buf = {} + for d in data["data"]["list"]: + currency = d["currency"] + currency_data = buf.setdefault(currency, {}) + currency_data[d["type"]] = float(d["balance"]) + + for currency, currency_data in buf.items(): + account = AccountData( + accountid=currency, + balance=currency_data["trade"] + currency_data["frozen"], + frozen=currency_data["frozen"], + gateway_name=self.gateway_name, + ) + + if account.balance: + self.gateway.on_account(account) + + def on_query_order(self, data, request): + """""" + if self.check_error(data, "查询委托"): + return + + for d in data["data"]: + sys_orderid = d["id"] + local_orderid = self.order_manager.get_local_orderid(sys_orderid) + + direction, order_type = ORDERTYPE_HUOBI2VT[d["type"]] + dt = datetime.fromtimestamp(d["created-at"] / 1000) + time = dt.strftime("%H:%M:%S") + + order = OrderData( + orderid=local_orderid, + symbol=d["symbol"], + exchange=Exchange.HUOBI, + price=float(d["price"]), + volume=float(d["amount"]), + type=order_type, + direction=direction, + traded=float(d["filled-amount"]), + status=STATUS_HUOBI2VT.get(d["state"], None), + time=time, + gateway_name=self.gateway_name, + ) + + self.order_manager.on_order(order) + + self.gateway.write_log("委托信息查询成功") + + def on_query_contract(self, data, request): # type: (dict, Request)->None + """""" + if self.check_error(data, "查询合约"): + return + + for d in data["data"]: + base_currency = d["base-currency"] + quote_currency = d["quote-currency"] + name = f"{base_currency.upper()}/{quote_currency.upper()}" + pricetick = 1 / pow(10, d["price-precision"]) + size = 1 / pow(10, d["amount-precision"]) + + contract = ContractData( + symbol=d["symbol"], + exchange=Exchange.HUOBI, + name=name, + pricetick=pricetick, + size=size, + product=Product.SPOT, + gateway_name=self.gateway_name, + ) + self.gateway.on_contract(contract) + + huobi_symbols.add(contract.symbol) + symbol_name_map[contract.symbol] = contract.name + + self.gateway.write_log("合约信息查询成功") + + def on_send_order(self, data, request): + """""" + order = request.extra + + if self.check_error(data, "委托"): + order.status = Status.REJECTED + self.order_manager.on_order(order) + return + + sys_orderid = data["data"] + self.order_manager.update_orderid_map(order.orderid, sys_orderid) + + def on_send_order_failed(self, status_code: str, request: Request): + """ + Callback when sending order failed on server. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + msg = f"委托失败,状态码:{status_code},信息:{request.response.text}" + self.gateway.write_log(msg) + + def on_send_order_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback when sending order caused exception. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + # Record exception if not ConnectionError + if not issubclass(exception_type, ConnectionError): + self.on_error(exception_type, exception_value, tb, request) + + def on_cancel_order(self, data, request): + """""" + cancel_request = request.extra + local_orderid = cancel_request.orderid + order = self.order_manager.get_order_with_local_orderid(local_orderid) + + if self.check_error(data, "撤单"): + order.status = Status.REJECTED + else: + order.status = Status.CANCELLED + self.gateway.write_log(f"委托撤单成功:{order.orderid}") + + self.order_manager.on_order(order) + + def check_error(self, data: dict, func: str = ""): + """""" + if data["status"] != "error": + return False + + error_code = data["err-code"] + error_msg = data["err-msg"] + + self.gateway.write_log(f"{func}请求出错,代码:{error_code},信息:{error_msg}") + return True + + +class HuobiWebsocketApiBase(WebsocketClient): + """""" + + def __init__(self, gateway): + """""" + super(HuobiWebsocketApiBase, self).__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + self.sign_host = "" + self.path = "" + + def connect( + self, + key: str, + secret: str, + url: str, + proxy_host: str, + proxy_port: int + ): + """""" + self.key = key + self.secret = secret + + host, path = _split_url(url) + self.sign_host = host + self.path = path + + self.init(url, proxy_host, proxy_port) + self.start() + + def login(self): + """""" + params = {"op": "auth"} + params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret)) + return self.send_packet(params) + + def on_login(self, packet): + """""" + pass + + @staticmethod + def unpack_data(data): + """""" + return json.loads(zlib.decompress(data, 31)) + + def on_packet(self, packet): + """""" + if "ping" in packet: + self.send_packet({"pong": packet["ping"]}) + elif "err-msg" in packet: + return self.on_error_msg(packet) + elif "op" in packet and packet["op"] == "auth": + return self.on_login() + else: + self.on_data(packet) + + def on_data(self, packet): + """""" + print("data : {}".format(packet)) + + def on_error_msg(self, packet): + """""" + msg = packet["err-msg"] + if msg == "invalid pong": + return + + self.gateway.write_log(packet["err-msg"]) + + +class HuobiTradeWebsocketApi(HuobiWebsocketApiBase): + """""" + def __init__(self, gateway): + """""" + super().__init__(gateway) + + self.order_manager = gateway.order_manager + self.order_manager.push_data_callback = self.on_data + + self.req_id = 0 + + def connect(self, key, secret, proxy_host, proxy_port): + """""" + super().connect(key, secret, WEBSOCKET_TRADE_HOST, proxy_host, proxy_port) + + def subscribe(self, req: SubscribeRequest): + """""" + self.req_id += 1 + req = { + "op": "sub", + "cid": str(self.req_id), + "topic": f"orders.{req.symbol}" + } + self.send_packet(req) + + def on_connected(self): + """""" + self.gateway.write_log("交易Websocket API连接成功") + self.login() + + def on_login(self): + """""" + self.gateway.write_log("交易Websocket API登录成功") + + def on_data(self, packet): # type: (dict)->None + """""" + op = packet.get("op", None) + if op != "notify": + return + + topic = packet["topic"] + if "orders" in topic: + self.on_order(packet["data"]) + + def on_order(self, data: dict): + """""" + sys_orderid = str(data["order-id"]) + + order = self.order_manager.get_order_with_sys_orderid(sys_orderid) + if not order: + self.order_manager.add_push_data(sys_orderid, data) + return + + traded_volume = float(data["filled-amount"]) + + # Push order event + order.traded += traded_volume + order.status = STATUS_HUOBI2VT.get(data["order-state"], None) + self.order_manager.on_order(order) + + # Push trade event + if not traded_volume: + return + + trade = TradeData( + symbol=order.symbol, + exchange=Exchange.HUOBI, + orderid=order.orderid, + tradeid=str(data["seq-id"]), + direction=order.direction, + price=float(data["price"]), + volume=float(data["filled-amount"]), + time=datetime.now().strftime("%H:%M:%S"), + gateway_name=self.gateway_name, + ) + self.gateway.on_trade(trade) + + +class HuobiDataWebsocketApi(HuobiWebsocketApiBase): + """""" + + def __init__(self, gateway): + """""" + super().__init__(gateway) + + self.req_id = 0 + self.ticks = {} + + def connect(self, key: str, secret: str, proxy_host: str, proxy_port: int): + """""" + super().connect(key, secret, WEBSOCKET_DATA_HOST, proxy_host, proxy_port) + + def on_connected(self): + """""" + self.gateway.write_log("行情Websocket API连接成功") + + def subscribe(self, req: SubscribeRequest): + """""" + symbol = req.symbol + + # Create tick data buffer + tick = TickData( + symbol=symbol, + name=symbol_name_map.get(symbol, ""), + exchange=Exchange.HUOBI, + datetime=datetime.now(), + gateway_name=self.gateway_name, + ) + self.ticks[symbol] = tick + + # Subscribe to market depth update + self.req_id += 1 + req = { + "sub": f"market.{symbol}.depth.step0", + "id": str(self.req_id) + } + self.send_packet(req) + + # Subscribe to market detail update + self.req_id += 1 + req = { + "sub": f"market.{symbol}.detail", + "id": str(self.req_id) + } + self.send_packet(req) + + def on_data(self, packet): # type: (dict)->None + """""" + channel = packet.get("ch", None) + if channel: + if "depth.step" in channel: + self.on_market_depth(packet) + elif "detail" in channel: + self.on_market_detail(packet) + elif "err-code" in packet: + code = packet["err-code"] + msg = packet["err-msg"] + self.gateway.write_log(f"错误代码:{code}, 错误信息:{msg}") + + def on_market_depth(self, data): + """行情深度推送 """ + symbol = data["ch"].split(".")[1] + tick = self.ticks[symbol] + tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) + + bids = data["tick"]["bids"] + for n in range(5): + price, volume = bids[n] + tick.__setattr__("bid_price_" + str(n + 1), float(price)) + tick.__setattr__("bid_volume_" + str(n + 1), float(volume)) + + asks = data["tick"]["asks"] + for n in range(5): + price, volume = asks[n] + tick.__setattr__("ask_price_" + str(n + 1), float(price)) + tick.__setattr__("ask_volume_" + str(n + 1), float(volume)) + + if tick.last_price: + self.gateway.on_tick(copy(tick)) + + def on_market_detail(self, data): + """市场细节推送""" + symbol = data["ch"].split(".")[1] + tick = self.ticks[symbol] + tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) + + tick_data = data["tick"] + tick.open_price = float(tick_data["open"]) + tick.high_price = float(tick_data["high"]) + tick.low_price = float(tick_data["low"]) + tick.last_price = float(tick_data["close"]) + tick.volume = float(tick_data["vol"]) + + if tick.bid_price_1: + self.gateway.on_tick(copy(tick)) + + +def _split_url(url): + """ + 将url拆分为host和path + :return: host, path + """ + result = re.match("\w+://([^/]*)(.*)", url) # noqa + if result: + return result.group(1), result.group(2) + + +def create_signature(api_key, method, host, path, secret_key, get_params=None): + """ + 创建签名 + :param get_params: dict 使用GET方法时附带的额外参数(urlparams) + :return: + """ + sorted_params = [ + ("AccessKeyId", api_key), + ("SignatureMethod", "HmacSHA256"), + ("SignatureVersion", "2"), + ("Timestamp", datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S")) + ] + + if get_params: + sorted_params.extend(list(get_params.items())) + sorted_params = list(sorted(sorted_params)) + encode_params = urllib.parse.urlencode(sorted_params) + + payload = [method, host, path, encode_params] + payload = "\n".join(payload) + payload = payload.encode(encoding="UTF8") + + secret_key = secret_key.encode(encoding="UTF8") + + digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest() + signature = base64.b64encode(digest) + + params = dict(sorted_params) + params["Signature"] = signature.decode("UTF8") + return params diff --git a/vnpy/gateway/okex/__init__.py b/vnpy/gateway/okex/__init__.py new file mode 100644 index 00000000..8cb9fb3b --- /dev/null +++ b/vnpy/gateway/okex/__init__.py @@ -0,0 +1 @@ +from .okex_gateway import OkexGateway diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py new file mode 100644 index 00000000..e959430e --- /dev/null +++ b/vnpy/gateway/okex/okex_gateway.py @@ -0,0 +1,725 @@ +# encoding: UTF-8 +""" +""" + +import hashlib +import hmac +import sys +import time +import json +import base64 +import zlib +from copy import copy +from datetime import datetime +from threading import Lock +from urllib.parse import urlencode + +from requests import ConnectionError + +from vnpy.api.rest import Request, RestClient +from vnpy.api.websocket import WebsocketClient +from vnpy.trader.constant import ( + Direction, + Exchange, + OrderType, + Product, + Status +) +from vnpy.trader.gateway import BaseGateway +from vnpy.trader.object import ( + TickData, + OrderData, + TradeData, + AccountData, + ContractData, + OrderRequest, + CancelRequest, + SubscribeRequest, +) + +REST_HOST = "https://www.okex.com" +WEBSOCKET_HOST = "wss://real.okex.com:10442/ws/v3" + +STATUS_OKEX2VT = { + "ordering": Status.SUBMITTING, + "open": Status.NOTTRADED, + "part_filled": Status.PARTTRADED, + "filled": Status.ALLTRADED, + "cancelled": Status.CANCELLED, + "cancelling": Status.CANCELLED, + "failure": Status.REJECTED, +} + +DIRECTION_VT2OKEX = {Direction.LONG: "buy", Direction.SHORT: "sell"} +DIRECTION_OKEX2VT = {v: k for k, v in DIRECTION_VT2OKEX.items()} + +ORDERTYPE_VT2OKEX = { + OrderType.LIMIT: "limit", + OrderType.MARKET: "market" +} +ORDERTYPE_OKEX2VT = {v: k for k, v in ORDERTYPE_VT2OKEX.items()} + + +instruments = set() +currencies = set() + + +class OkexGateway(BaseGateway): + """ + VN Trader Gateway for OKEX connection. + """ + + default_setting = { + "API Key": "", + "Secret Key": "", + "Passphrase": "", + "会话数": 3, + "代理地址": "", + "代理端口": "", + } + + def __init__(self, event_engine): + """Constructor""" + super(OkexGateway, self).__init__(event_engine, "OKEX") + + self.rest_api = OkexRestApi(self) + self.ws_api = OkexWebsocketApi(self) + + self.orders = {} + + def connect(self, setting: dict): + """""" + key = setting["API Key"] + secret = setting["Secret Key"] + passphrase = setting["Passphrase"] + session_number = setting["会话数"] + proxy_host = setting["代理地址"] + proxy_port = setting["代理端口"] + + if proxy_port.isdigit(): + proxy_port = int(proxy_port) + else: + proxy_port = 0 + + self.rest_api.connect(key, secret, passphrase, + session_number, proxy_host, proxy_port) + self.ws_api.connect(key, secret, passphrase, proxy_host, proxy_port) + + def subscribe(self, req: SubscribeRequest): + """""" + self.ws_api.subscribe(req) + + def send_order(self, req: OrderRequest): + """""" + return self.rest_api.send_order(req) + + def cancel_order(self, req: CancelRequest): + """""" + self.rest_api.cancel_order(req) + + def query_account(self): + """""" + pass + + def query_position(self): + """""" + pass + + def close(self): + """""" + self.rest_api.stop() + self.ws_api.stop() + + def on_order(self, order: OrderData): + """""" + self.orders[order.orderid] = order + super().on_order(order) + + def get_order(self, orderid: str): + """""" + return self.orders.get(orderid, None) + + +class OkexRestApi(RestClient): + """ + OKEX REST API + """ + + def __init__(self, gateway: BaseGateway): + """""" + super(OkexRestApi, self).__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + self.passphrase = "" + + self.order_count = 10000 + self.order_count_lock = Lock() + + self.connect_time = 0 + + def sign(self, request): + """ + Generate OKEX signature. + """ + # Sign + # timestamp = str(time.time()) + timestamp = get_timestamp() + request.data = json.dumps(request.data) + + if request.params: + path = request.path + '?' + urlencode(request.params) + else: + path = request.path + + msg = timestamp + request.method + path + request.data + signature = generate_signature(msg, self.secret) + + # Add headers + request.headers = { + 'OK-ACCESS-KEY': self.key, + 'OK-ACCESS-SIGN': signature, + 'OK-ACCESS-TIMESTAMP': timestamp, + 'OK-ACCESS-PASSPHRASE': self.passphrase, + 'Content-Type': 'application/json' + } + return request + + def connect( + self, + key: str, + secret: str, + passphrase: str, + session_number: int, + proxy_host: str, + proxy_port: int, + ): + """ + Initialize connection to REST server. + """ + self.key = key + self.secret = secret.encode() + self.passphrase = passphrase + + self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S")) + + self.init(REST_HOST, proxy_host, proxy_port) + self.start(session_number) + self.gateway.write_log("REST API启动成功") + + self.query_time() + self.query_contract() + self.query_account() + self.query_order() + + def _new_order_id(self): + with self.order_count_lock: + self.order_count += 1 + return self.order_count + + def send_order(self, req: OrderRequest): + """""" + orderid = f"a{self.connect_time}{self._new_order_id()}" + + data = { + "client_oid": orderid, + "type": ORDERTYPE_VT2OKEX[req.type], + "side": DIRECTION_VT2OKEX[req.direction], + "instrument_id": req.symbol + } + + if req.type == OrderType.MARKET: + if req.direction == Direction.LONG: + data["notional"] = req.volume + else: + data["size"] = req.volume + else: + data["price"] = req.price + data["size"] = req.volume + + order = req.create_order_data(orderid, self.gateway_name) + + self.add_request( + "POST", + "/api/spot/v3/orders", + callback=self.on_send_order, + data=data, + extra=order, + on_failed=self.on_send_order_failed, + on_error=self.on_send_order_error, + ) + + self.gateway.on_order(order) + return order.vt_orderid + + def cancel_order(self, req: CancelRequest): + """""" + data = { + "instrument_id": req.symbol, + "client_oid": req.orderid + } + + path = "/api/spot/v3/cancel_orders/" + req.orderid + self.add_request( + "POST", + path, + callback=self.on_cancel_order, + data=data, + on_error=self.on_cancel_order_error, + on_failed=self.on_cancel_order_failed, + extra=req + ) + + def query_contract(self): + """""" + self.add_request( + "GET", + "/api/spot/v3/instruments", + callback=self.on_query_contract + ) + + def query_account(self): + """""" + self.add_request( + "GET", + "/api/spot/v3/accounts", + callback=self.on_query_account + ) + + def query_order(self): + """""" + self.add_request( + "GET", + "/api/spot/v3/orders_pending", + callback=self.on_query_order + ) + + def query_time(self): + """""" + self.add_request( + "GET", + "/api/general/v3/time", + callback=self.on_query_time + ) + + def on_query_contract(self, data, request): + """""" + for instrument_data in data: + symbol = instrument_data["instrument_id"] + contract = ContractData( + symbol=symbol, + exchange=Exchange.OKEX, + name=symbol, + product=Product.SPOT, + size=1, + pricetick=float(instrument_data["tick_size"]), + gateway_name=self.gateway_name + ) + self.gateway.on_contract(contract) + + instruments.add(instrument_data["instrument_id"]) + currencies.add(instrument_data["base_currency"]) + currencies.add(instrument_data["quote_currency"]) + + self.gateway.write_log("合约信息查询成功") + + # Start websocket api after instruments data collected + self.gateway.ws_api.start() + + def on_query_account(self, data, request): + """""" + for account_data in data: + account = AccountData( + accountid=account_data["currency"], + balance=float(account_data["balance"]), + frozen=float(account_data["hold"]), + gateway_name=self.gateway_name + ) + self.gateway.on_account(account) + + self.gateway.write_log("账户资金查询成功") + + def on_query_order(self, data, request): + """""" + for order_data in data: + order = OrderData( + symbol=order_data["instrument_id"], + exchange=Exchange.OKEX, + type=ORDERTYPE_OKEX2VT[order_data["type"]], + orderid=order_data["client_oid"], + direction=DIRECTION_OKEX2VT[order_data["side"]], + price=float(order_data["price"]), + volume=float(order_data["size"]), + traded=float(order_data["filled_size"]), + time=order_data["timestamp"][11:19], + status=STATUS_OKEX2VT[order_data["status"]], + gateway_name=self.gateway_name, + ) + self.gateway.on_order(order) + + self.gateway.write_log("委托信息查询成功") + + def on_query_time(self, data, request): + """""" + server_time = data["iso"] + local_time = datetime.utcnow().isoformat() + msg = f"服务器时间:{server_time},本机时间:{local_time}" + self.gateway.write_log(msg) + + def on_send_order_failed(self, status_code: str, request: Request): + """ + Callback when sending order failed on server. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + msg = f"委托失败,状态码:{status_code},信息:{request.response.text}" + self.gateway.write_log(msg) + + def on_send_order_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback when sending order caused exception. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + # Record exception if not ConnectionError + if not issubclass(exception_type, ConnectionError): + self.on_error(exception_type, exception_value, tb, request) + + def on_send_order(self, data, request): + """Websocket will push a new order status""" + order = request.extra + + error_msg = data["error_message"] + if error_msg: + order.status = Status.REJECTED + self.gateway.on_order(order) + + self.gateway.write_log(f"委托失败:{error_msg}") + + def on_cancel_order_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback when cancelling order failed on server. + """ + # Record exception if not ConnectionError + if not issubclass(exception_type, ConnectionError): + self.on_error(exception_type, exception_value, tb, request) + + def on_cancel_order(self, data, request): + """Websocket will push a new order status""" + pass + + def on_cancel_order_failed(self, status_code: int, request: Request): + """If cancel failed, mark order status to be rejected.""" + req = request.extra + order = self.gateway.get_order(req.orderid) + if order: + order.status = Status.REJECTED + self.gateway.on_order(order) + + def on_failed(self, status_code: int, request: Request): + """ + Callback to handle request failed. + """ + msg = f"请求失败,状态码:{status_code},信息:{request.response.text}" + self.gateway.write_log(msg) + + def on_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback to handler request exception. + """ + msg = f"触发异常,状态码:{exception_type},信息:{exception_value}" + self.gateway.write_log(msg) + + sys.stderr.write( + self.exception_detail(exception_type, exception_value, tb, request) + ) + + +class OkexWebsocketApi(WebsocketClient): + """""" + + def __init__(self, gateway): + """""" + super(OkexWebsocketApi, self).__init__() + self.ping_interval = 20 # OKEX use 30 seconds for ping + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + self.passphrase = "" + + self.trade_count = 10000 + self.connect_time = 0 + + self.callbacks = {} + self.ticks = {} + + def connect( + self, + key: str, + secret: str, + passphrase: str, + proxy_host: str, + proxy_port: int + ): + """""" + self.key = key + self.secret = secret.encode() + self.passphrase = passphrase + + self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S")) + + self.init(WEBSOCKET_HOST, proxy_host, proxy_port) + # self.start() + + def unpack_data(self, data): + """""" + return json.loads(zlib.decompress(data, -zlib.MAX_WBITS)) + + def subscribe(self, req: SubscribeRequest): + """ + Subscribe to tick data upate. + """ + tick = TickData( + symbol=req.symbol, + exchange=req.exchange, + name=req.symbol, + datetime=datetime.now(), + gateway_name=self.gateway_name, + ) + self.ticks[req.symbol] = tick + + channel_ticker = f"spot/ticker:{req.symbol}" + channel_depth = f"spot/depth5:{req.symbol}" + + self.callbacks[channel_ticker] = self.on_ticker + self.callbacks[channel_depth] = self.on_depth + + req = { + "op": "subscribe", + "args": [channel_ticker, channel_depth] + } + self.send_packet(req) + + def on_connected(self): + """""" + self.gateway.write_log("Websocket API连接成功") + self.login() + + def on_disconnected(self): + """""" + self.gateway.write_log("Websocket API连接断开") + + def on_packet(self, packet: dict): + """""" + if "event" in packet: + event = packet["event"] + if event == "subscribe": + return + elif event == "error": + msg = packet["message"] + self.gateway.write_log(f"Websocket API请求异常:{msg}") + elif event == "login": + self.on_login(packet) + else: + channel = packet["table"] + data = packet["data"] + callback = self.callbacks.get(channel, None) + + if callback: + for d in data: + callback(d) + + def on_error(self, exception_type: type, exception_value: Exception, tb): + """""" + msg = f"触发异常,状态码:{exception_type},信息:{exception_value}" + self.gateway.write_log(msg) + + sys.stderr.write(self.exception_detail( + exception_type, exception_value, tb)) + + def login(self): + """ + Need to login befores subscribe to websocket topic. + """ + timestamp = str(time.time()) + + msg = timestamp + 'GET' + '/users/self/verify' + signature = generate_signature(msg, self.secret) + + req = { + "op": "login", + "args": [ + self.key, + self.passphrase, + timestamp, + signature.decode("utf-8") + ] + } + self.send_packet(req) + self.callbacks['login'] = self.on_login + + def subscribe_topic(self): + """ + Subscribe to all private topics. + """ + self.callbacks["spot/ticker"] = self.on_ticker + self.callbacks["spot/depth5"] = self.on_depth + self.callbacks["spot/account"] = self.on_account + self.callbacks["spot/order"] = self.on_order + + # Subscribe to order update + channels = [] + for instrument_id in instruments: + channel = f"spot/order:{instrument_id}" + channels.append(channel) + + req = { + "op": "subscribe", + "args": channels + } + self.send_packet(req) + + # Subscribe to account update + channels = [] + for currency in currencies: + channel = f"spot/account:{currency}" + channels.append(channel) + + req = { + "op": "subscribe", + "args": channels + } + self.send_packet(req) + + # Subscribe to BTC/USDT trade for keep connection alive + req = { + "op": "subscribe", + "args": ["spot/trade:BTC-USDT"] + } + self.send_packet(req) + + def on_login(self, data: dict): + """""" + success = data.get("success", False) + + if success: + self.gateway.write_log("Websocket API登录成功") + self.subscribe_topic() + else: + self.gateway.write_log("Websocket API登录失败") + + def on_ticker(self, d): + """""" + symbol = d["instrument_id"] + tick = self.ticks.get(symbol, None) + if not tick: + return + + tick.last_price = float(d["last"]) + tick.open = float(d["open_24h"]) + tick.high = float(d["high_24h"]) + tick.low = float(d["low_24h"]) + tick.volume = float(d["base_volume_24h"]) + tick.datetime = datetime.strptime( + d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") + self.gateway.on_tick(copy(tick)) + + def on_depth(self, d): + """""" + for tick_data in d: + symbol = d["instrument_id"] + tick = self.ticks.get(symbol, None) + if not tick: + return + + bids = d["bids"] + asks = d["asks"] + for n, buf in enumerate(bids): + price, volume, _ = buf + tick.__setattr__("bid_price_%s" % (n + 1), float(price)) + tick.__setattr__("bid_volume_%s" % (n + 1), float(volume)) + + for n, buf in enumerate(asks): + price, volume, _ = buf + tick.__setattr__("ask_price_%s" % (n + 1), float(price)) + tick.__setattr__("ask_volume_%s" % (n + 1), float(volume)) + + tick.datetime = datetime.strptime( + d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") + self.gateway.on_tick(copy(tick)) + + def on_order(self, d): + """""" + order = OrderData( + symbol=d["instrument_id"], + exchange=Exchange.OKEX, + type=ORDERTYPE_OKEX2VT[d["type"]], + orderid=d["client_oid"], + direction=DIRECTION_OKEX2VT[d["side"]], + price=float(d["price"]), + volume=float(d["size"]), + traded=float(d["filled_size"]), + time=d["timestamp"][11:19], + status=STATUS_OKEX2VT[d["status"]], + gateway_name=self.gateway_name, + ) + self.gateway.on_order(copy(order)) + + trade_volume = d.get("last_fill_qty", 0) + if not trade_volume or float(trade_volume) == 0: + return + + self.trade_count += 1 + tradeid = f"{self.connect_time}{self.trade_count}" + + trade = TradeData( + symbol=order.symbol, + exchange=order.exchange, + orderid=order.orderid, + tradeid=tradeid, + direction=order.direction, + price=float(d["last_fill_px"]), + volume=float(trade_volume), + time=d["last_fill_time"][11:19], + gateway_name=self.gateway_name + ) + self.gateway.on_trade(trade) + + def on_account(self, d): + """""" + account = AccountData( + accountid=d["currency"], + balance=float(d["balance"]), + frozen=float(d["hold"]), + gateway_name=self.gateway_name + ) + + self.gateway.on_account(copy(account)) + + +def generate_signature(msg: str, secret_key: str): + """OKEX V3 signature""" + return base64.b64encode(hmac.new(secret_key, msg.encode(), hashlib.sha256).digest()) + + +def get_timestamp(): + """""" + now = datetime.utcnow() + timestamp = now.isoformat("T", "milliseconds") + return timestamp + "Z" diff --git a/vnpy/gateway/tiger/tiger_gateway.py b/vnpy/gateway/tiger/tiger_gateway.py index eecd9305..997d3f02 100644 --- a/vnpy/gateway/tiger/tiger_gateway.py +++ b/vnpy/gateway/tiger/tiger_gateway.py @@ -1,5 +1,6 @@ # encoding: UTF-8 """ +Author: KeKe Please install tiger-api before use. pip install tigeropen """ @@ -123,6 +124,9 @@ class TigerGateway(BaseGateway): self.contracts = {} self.symbol_names = {} + self.push_connected = False + self.subscribed_symbols = set() + def run(self): """""" while self.active: @@ -203,23 +207,33 @@ class TigerGateway(BaseGateway): """ protocol, host, port = self.client_config.socket_host_port self.push_client = PushClient(host, port, (protocol == 'ssl')) - self.push_client.connect( - self.client_config.tiger_id, self.client_config.private_key) self.push_client.quote_changed = self.on_quote_change self.push_client.asset_changed = self.on_asset_change self.push_client.position_changed = self.on_position_change self.push_client.order_changed = self.on_order_change - self.write_log("推送接口连接成功") + self.push_client.connect( + self.client_config.tiger_id, self.client_config.private_key) def subscribe(self, req: SubscribeRequest): """""" - self.push_client.subscribe_quote([req.symbol]) + self.subscribed_symbols.add(req.symbol) + + if self.push_connected: + self.push_client.subscribe_quote([req.symbol]) + + def on_push_connected(self): + """""" + self.push_connected = True + self.write_log("推送接口连接成功") + self.push_client.subscribe_asset() self.push_client.subscribe_position() self.push_client.subscribe_order() + self.push_client.subscribe_quote(list(self.subscribed_symbols)) + def on_quote_change(self, tiger_symbol: str, data: list, trading: bool): """""" data = dict(data) diff --git a/vnpy/rpc/__init__.py b/vnpy/rpc/__init__.py new file mode 100644 index 00000000..b903d2bb --- /dev/null +++ b/vnpy/rpc/__init__.py @@ -0,0 +1 @@ +from .vnrpc import RpcServer, RpcClient, RemoteException \ No newline at end of file diff --git a/vnpy/rpc/test_client.py b/vnpy/rpc/test_client.py new file mode 100644 index 00000000..7a693a52 --- /dev/null +++ b/vnpy/rpc/test_client.py @@ -0,0 +1,36 @@ +from __future__ import print_function +from __future__ import absolute_import +from time import sleep + +from vnpy.rpc import RpcClient + + +class TestClient(RpcClient): + """ + Test RpcClient + """ + + def __init__(self, req_address, sub_address): + """ + Constructor + """ + super(TestClient, self).__init__(req_address, sub_address) + + def callback(self, topic, data): + """ + Realize callable function + """ + print('client received topic:', topic, ', data:', data) + + +if __name__ == '__main__': + req_address = 'tcp://localhost:2014' + sub_address = 'tcp://localhost:0602' + + tc = TestClient(req_address, sub_address) + tc.subscribeTopic('') + tc.start() + + while 1: + print(tc.add(1, 3)) + sleep(2) diff --git a/vnpy/rpc/test_server.py b/vnpy/rpc/test_server.py new file mode 100644 index 00000000..660168fc --- /dev/null +++ b/vnpy/rpc/test_server.py @@ -0,0 +1,40 @@ +from __future__ import print_function +from __future__ import absolute_import +from time import sleep, time + +from vnpy.rpc import RpcServer + + +class TestServer(RpcServer): + """ + Test RpcServer + """ + + def __init__(self, rep_address, pub_address): + """ + Constructor + """ + super(TestServer, self).__init__(rep_address, pub_address) + + self.register(self.add) + + def add(self, a, b): + """ + Test function + """ + print('receiving: %s, %s' % (a, b)) + return a + b + + +if __name__ == '__main__': + rep_address = 'tcp://*:2014' + pub_address = 'tcp://*:0602' + + ts = TestServer(rep_address, pub_address) + ts.start() + + while 1: + content = 'current server time is %s' % time() + print(content) + ts.publish('test', content) + sleep(2) diff --git a/vnpy/rpc/vnrpc.py b/vnpy/rpc/vnrpc.py new file mode 100644 index 00000000..1cdd835a --- /dev/null +++ b/vnpy/rpc/vnrpc.py @@ -0,0 +1,329 @@ +import threading +import traceback +import signal + +import zmq +from msgpack import packb, unpackb +from json import dumps, loads + +import pickle +p_dumps = pickle.dumps +p_loads = pickle.loads + + +# Achieve Ctrl-c interrupt recv +signal.signal(signal.SIGINT, signal.SIG_DFL) + + +class RpcObject(object): + """ + Referred to serialization of packing and unpacking, we offer 3 tools: + 1) maspack: higher performance, but usually requires the installation of msgpack related tools; + 2) jason: Slightly lower performance but versatility is better, most programming languages have built-in libraries; + 3) cPickle: Lower performance and only can be used in Python, but it is very convenient to transfer Python objects directly. + + Therefore, it is recommended to use msgpack. + Use json, if you want to communicate with some languages without providing msgpack. + Use cPickle, when the data being transferred contains many custom Python objects. + """ + + def __init__(self): + """ + Constructor + Use msgpack as default serialization tool + """ + self.use_msgpack() + + def pack(self, data): + """""" + pass + + def unpack(self, data): + """""" + pass + + def __json_pack(self, data): + """ + Pack with json + """ + return dumps(data) + + def __json_unpack(self, data): + """ + Unpack with json + """ + return loads(data) + + def __msgpack_pack(self, data): + """ + Pack with msgpack + """ + return packb(data) + + def __msgpack_unpack(self, data): + """ + Unpack with msgpack + """ + return unpackb(data) + + def __pickle_pack(self, data): + """ + Pack with cPickle + """ + return p_dumps(data) + + def __pickle_unpack(self, data): + """ + Unpack with cPickle + """ + return p_loads(data) + + def use_json(self): + """ + Use json as serialization tool + """ + self.pack = self.__json_pack + self.unpack = self.__json_unpack + + def use_msgpack(self): + """ + Use msgpack as serialization tool + """ + self.pack = self.__msgpack_pack + self.unpack = self.__msgpack_unpack + + def use_pickle(self): + """ + Use cPickle as serialization tool + """ + self.pack = self.__pickle_pack + self.unpack = self.__pickle_unpack + + +class RpcServer(RpcObject): + """""" + + def __init__(self, rep_address, pub_address): + """ + Constructor + """ + super(RpcServer, self).__init__() + + # Save functions dict: key is fuction name, value is fuction object + self.__functions = {} + + # Zmq port related + self.__context = zmq.Context() + + self.__socket_rep = self.__context.socket( + zmq.REP) # Reply socket (Request–reply pattern) + self.__socket_rep.bind(rep_address) + + # Publish socket (Publish–subscribe pattern) + self.__socket_pub = self.__context.socket(zmq.PUB) + self.__socket_pub.bind(pub_address) + + # Woker thread related + self.__active = False # RpcServer status + self.__thread = threading.Thread(target=self.run) # RpcServer thread + + def start(self): + """ + Start RpcServer + """ + # Start RpcServer status + self.__active = True + + # Start RpcServer thread + if not self.__thread.isAlive(): + self.__thread.start() + + def stop(self, join=False): + """ + Stop RpcServer + """ + # Stop RpcServer status + self.__active = False + + # Wait for RpcServer thread to exit + if join and self.__thread.isAlive(): + self.__thread.join() + + def run(self): + """ + Run RpcServer functions + """ + while self.__active: + # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds) + if not self.__socket_rep.poll(1000): + continue + + # Receive request data from Reply socket + reqb = self.__socket_rep.recv() + + # Unpack request by deserialization + req = self.unpack(reqb) + + # Get function name and parameters + name, args, kwargs = req + + # Try to get and execute callable function object; capture exception information if it fails + name = name.decode("UTF-8") + + try: + func = self.__functions[name] + r = func(*args, **kwargs) + rep = [True, r] + except Exception as e: # noqa + rep = [False, traceback.format_exc()] + + # Pack response by serialization + repb = self.pack(rep) + + # send callable response by Reply socket + self.__socket_rep.send(repb) + + def publish(self, topic, data): + """ + Publish data + """ + # Serialized data + topic = bytes(topic, "UTF-8") + datab = self.pack(data) + + # Send data by Publish socket + # topci must be ascii encoding + self.__socket_pub.send_multipart([topic, datab]) + + def register(self, func): + """ + Register function + """ + self.__functions[func.__name__] = func + + +class RpcClient(RpcObject): + """""" + + def __init__(self, req_address, sub_address): + """Constructor""" + super(RpcClient, self).__init__() + + # zmq port related + self.__req_address = req_address + self.__sub_address = sub_address + + self.__context = zmq.Context() + # Request socket (Request–reply pattern) + self.__socket_req = self.__context.socket(zmq.REQ) + # Subscribe socket (Publish–subscribe pattern) + self.__socket_sub = self.__context.socket(zmq.SUB) + + # Woker thread relate, used to process data pushed from server + self.__active = False # RpcClient status + self.__thread = threading.Thread( + target=self.run) # RpcClient thread + + def __getattr__(self, name): + """ + Realize remote call function + """ + # Perform remote call task + def dorpc(*args, **kwargs): + # Generate request + req = [name, args, kwargs] + + # Pack request by serialization + reqb = self.pack(req) + + # Send request and wait for response + self.__socket_req.send(reqb) + repb = self.__socket_req.recv() + + # Unpack response by deserialization + rep = self.unpack(repb) + + # Return response if successed; Trigger exception if failed + if rep[0]: + return rep[1] + else: + raise RemoteException(rep[1].decode("UTF-8")) + + return dorpc + + def start(self): + """ + Start RpcClient + """ + # Connect zmq port + self.__socket_req.connect(self.__req_address) + self.__socket_sub.connect(self.__sub_address) + + # Start RpcClient status + self.__active = True + + # Start RpcClient thread + if not self.__thread.isAlive(): + self.__thread.start() + + def stop(self): + """ + Stop RpcClient + """ + # Stop RpcClient status + self.__active = False + + # Wait for RpcClient thread to exit + if self.__thread.isAlive(): + self.__thread.join() + + def run(self): + """ + Run RpcClient function + """ + while self.__active: + # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds) + if not self.__socket_sub.poll(1000): + continue + + # Receive data from subscribe socket + topic, datab = self.__socket_sub.recv_multipart() + + # Unpack data by deserialization + data = self.unpack(datab) + + # Process data by callable function + topic = topic.decode("UTF-8") + + self.callback(topic, data) + + def callback(self, topic, data): + """ + Callable function + """ + raise NotImplementedError + + def subscribeTopic(self, topic): + """ + Subscribe data + """ + topic = bytes(topic, "UTF-8") + self.__socket_sub.setsockopt(zmq.SUBSCRIBE, topic) + + +class RemoteException(Exception): + """ + RPC remote exception + """ + + def __init__(self, value): + """ + Constructor + """ + self.__value = value + + def __str__(self): + """ + Output error message + """ + return self.__value diff --git a/vnpy/trader/constant.py b/vnpy/trader/constant.py index 92492f5b..f4727ba3 100644 --- a/vnpy/trader/constant.py +++ b/vnpy/trader/constant.py @@ -99,6 +99,8 @@ class Exchange(Enum): # CryptoCurrency BITMEX = "BITMEX" + OKEX = "OKEX" + HUOBI = "HUOBI" class Currency(Enum): diff --git a/vnpy/trader/database.py b/vnpy/trader/database.py index 87fb3eae..1e90dcbc 100644 --- a/vnpy/trader/database.py +++ b/vnpy/trader/database.py @@ -1,13 +1,47 @@ """""" -from peewee import SqliteDatabase, Model, CharField, DateTimeField, FloatField +from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \ + SqliteDatabase from .constant import Exchange, Interval from .object import BarData, TickData -from .utility import get_file_path +from .setting import SETTINGS +from .utility import resolve_path -DB_NAME = "database.db" -DB = SqliteDatabase(str(get_file_path(DB_NAME))) + +def init(): + db_settings = SETTINGS['database'] + driver = db_settings["driver"] + + init_funcs = { + "sqlite": init_sqlite, + "mysql": init_mysql, + "postgresql": init_postgresql, + } + + assert driver in init_funcs + del db_settings['driver'] + return init_funcs[driver](db_settings) + + +def init_sqlite(settings: dict): + global DB + database = settings['database'] + + DB = SqliteDatabase(str(resolve_path(database))) + + +def init_mysql(settings: dict): + global DB + DB = MySQLDatabase(**settings) + + +def init_postgresql(settings: dict): + global DB + DB = PostgresqlDatabase(**settings) + + +init() class DbBarData(Model): diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index 251d2853..2d4a98e0 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -24,7 +24,7 @@ from .event import ( from .gateway import BaseGateway from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest from .setting import SETTINGS -from .utility import Singleton, get_folder_path +from .utility import get_folder_path class MainEngine: @@ -52,6 +52,7 @@ class MainEngine: """ engine = engine_class(self, self.event_engine) self.engines[engine.engine_name] = engine + return engine def add_gateway(self, gateway_class: BaseGateway): """ @@ -59,6 +60,7 @@ class MainEngine: """ gateway = gateway_class(self.event_engine) self.gateways[gateway.gateway_name] = gateway + return gateway def add_app(self, app_class: BaseApp): """ @@ -67,7 +69,8 @@ class MainEngine: app = app_class() self.apps[app.app_name] = app - self.add_engine(app.engine_class) + engine = self.add_engine(app.engine_class) + return engine def init_engines(self): """ @@ -199,8 +202,6 @@ class LogEngine(BaseEngine): Processes log event and output with logging module. """ - __metaclass__ = Singleton - def __init__(self, main_engine: MainEngine, event_engine: EventEngine): """""" super(LogEngine, self).__init__(main_engine, event_engine, "log") diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index d1a6248f..ce7aa591 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any +from copy import copy from vnpy.event import Event, EventEngine from .event import ( @@ -227,3 +228,124 @@ class BaseGateway(ABC): Return default setting dict. """ return self.default_setting + + +class LocalOrderManager: + """ + Management tool to support use local order id for trading. + """ + + def __init__(self, gateway: BaseGateway): + """""" + self.gateway = gateway + + # For generating local orderid + self.order_prefix = "" + self.order_count = 0 + self.orders = {} # local_orderid:order + + # Map between local and system orderid + self.local_sys_orderid_map = {} + self.sys_local_orderid_map = {} + + # Push order data buf + self.push_data_buf = {} # sys_orderid:data + + # Callback for processing push order data + self.push_data_callback = None + + # Cancel request buf + self.cancel_request_buf = {} # local_orderid:req + + def new_local_orderid(self): + """ + Generate a new local orderid. + """ + self.order_count += 1 + local_orderid = str(self.order_count).rjust(8, "0") + return local_orderid + + def get_local_orderid(self, sys_orderid: str): + """ + Get local orderid with sys orderid. + """ + local_orderid = self.sys_local_orderid_map.get(sys_orderid, "") + + if not local_orderid: + local_orderid = self.new_local_orderid() + self.update_orderid_map(local_orderid, sys_orderid) + + return local_orderid + + def get_sys_orderid(self, local_orderid: str): + """ + Get sys orderid with local orderid. + """ + sys_orderid = self.local_sys_orderid_map.get(local_orderid, "") + return sys_orderid + + def update_orderid_map(self, local_orderid: str, sys_orderid: str): + """ + Update orderid map. + """ + self.sys_local_orderid_map[sys_orderid] = local_orderid + self.local_sys_orderid_map[local_orderid] = sys_orderid + + self.check_cancel_request(local_orderid) + self.check_push_data(sys_orderid) + + def check_push_data(self, sys_orderid: str): + """ + Check if any order push data waiting. + """ + if sys_orderid not in self.push_data_buf: + return + + data = self.push_data_buf.pop(sys_orderid) + if self.push_data_callback: + self.push_data_callback(data) + + def add_push_data(self, sys_orderid: str, data: dict): + """ + Add push data into buf. + """ + self.push_data_buf[sys_orderid] = data + + def get_order_with_sys_orderid(self, sys_orderid: str): + """""" + local_orderid = self.sys_local_orderid_map.get(sys_orderid, None) + if not local_orderid: + return None + else: + return self.get_order_with_local_orderid(local_orderid) + + def get_order_with_local_orderid(self, local_orderid: str): + """""" + order = self.orders[local_orderid] + return copy(order) + + def on_order(self, order: OrderData): + """ + Keep an order buf before pushing it to gateway. + """ + self.orders[order.orderid] = copy(order) + self.gateway.on_order(order) + + def cancel_order(self, req: CancelRequest): + """ + """ + sys_orderid = self.get_sys_orderid(req.orderid) + if not sys_orderid: + self.cancel_request_buf[req.orderid] = req + return + + self.gateway.cancel_order(req) + + def check_cancel_request(self, local_orderid: str): + """ + """ + if local_orderid not in self.cancel_request_buf: + return + + req = self.cancel_request_buf.pop(local_orderid) + self.gateway.cancel_order(req) diff --git a/vnpy/trader/setting.py b/vnpy/trader/setting.py index ae2011f9..b27dbf96 100644 --- a/vnpy/trader/setting.py +++ b/vnpy/trader/setting.py @@ -23,10 +23,17 @@ SETTINGS = { "email.receiver": "", "rqdata.username": "", - "rqdata.password": "" + "rqdata.password": "", + "database": { + "driver": "sqlite", # sqlite, mysql, postgresql + "database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath + "host": "localhost", + "port": 3306, + "user": "root", + "password": "" + } } - # Load global setting from json file. SETTING_FILENAME = "vt_setting.json" SETTINGS.update(load_json(SETTING_FILENAME)) diff --git a/vnpy/trader/ui/mainwindow.py b/vnpy/trader/ui/mainwindow.py index d8c6280b..b1972f42 100644 --- a/vnpy/trader/ui/mainwindow.py +++ b/vnpy/trader/ui/mainwindow.py @@ -24,7 +24,7 @@ from .widget import ( AboutDialog, ) from ..engine import MainEngine -from ..utility import get_icon_path, TRADER_PATH +from ..utility import get_icon_path, TRADER_DIR class MainWindow(QtWidgets.QMainWindow): @@ -38,7 +38,7 @@ class MainWindow(QtWidgets.QMainWindow): self.main_engine = main_engine self.event_engine = event_engine - self.window_title = f"VN Trader [{TRADER_PATH}]" + self.window_title = f"VN Trader [{TRADER_DIR}]" self.connect_dialogs = {} self.widgets = {} diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index ca788227..216dcbae 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -3,6 +3,7 @@ General utility functions. """ import json +import os from pathlib import Path from typing import Callable @@ -12,25 +13,13 @@ import talib from .object import BarData, TickData -class Singleton(type): - """ - Singleton metaclass, - - class A: - __metaclass__ = Singleton - """ - _instances = {} - - def __call__(cls, *args, **kwargs): - """""" - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__( - *args, **kwargs - ) - return cls._instances[cls] +def resolve_path(pattern: str): + env = dict(os.environ) + env.update({"VNPY_TEMP": str(TEMP_DIR)}) + return pattern.format(**env) -def get_path(temp_name: str): +def _get_trader_dir(temp_name: str): """ Get path where trader is running in. """ @@ -53,21 +42,21 @@ def get_path(temp_name: str): return home_path, temp_path -TRADER_PATH, TEMP_PATH = get_path(".vntrader") +TRADER_DIR, TEMP_DIR = _get_trader_dir(".vntrader") def get_file_path(filename: str): """ Get path for temp file with filename. """ - return TEMP_PATH.joinpath(filename) + return TEMP_DIR.joinpath(filename) def get_folder_path(folder_name: str): """ Get path for temp folder with folder name. """ - folder_path = TEMP_PATH.joinpath(folder_name) + folder_path = TEMP_DIR.joinpath(folder_name) if not folder_path.exists(): folder_path.mkdir() return folder_path @@ -385,3 +374,12 @@ class ArrayManager(object): if array: return up, down return up[-1], down[-1] + + +def virtual(func: "callable"): + """ + mark a function as "virtual", which means that this function can be override. + any base class should use this or @abstractmethod to decorate all functions + that can be (re)implemented by subclasses. + """ + return func