Merge branch 'DEV' into master
This commit is contained in:
commit
fbb0249187
@ -7,6 +7,7 @@
|
||||
|
||||
### 使用VNConda
|
||||
|
||||
|
||||
#### 1.下载VNConda (Python 3.7 64位)
|
||||
|
||||
下载地址如下:[VNConda-2.0.1-Windows-x86_64](https://conda.vnpy.com/VNConda-2.0.1-Windows-x86_64.exe)
|
||||
@ -44,6 +45,7 @@
|
||||
|
||||
|
||||
|
||||
|
||||
### 手动安装
|
||||
|
||||
#### 1.下载并安装最新版Anaconda3.7 64位
|
||||
|
@ -1,4 +1,5 @@
|
||||
PyQt5<5.12
|
||||
pyqtgraph
|
||||
dataclasses; python_version<="3.6"
|
||||
qdarkstyle
|
||||
requests
|
||||
|
189
tests/backtesting/genetic_algorithm.ipynb
Normal file
189
tests/backtesting/genetic_algorithm.ipynb
Normal file
@ -0,0 +1,189 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"import multiprocessing\n",
|
||||
"import numpy as np\n",
|
||||
"from deap import creator, base, tools, algorithms\n",
|
||||
"from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n",
|
||||
"from boll_channel_strategy import BollChannelStrategy\n",
|
||||
"from datetime import datetime\n",
|
||||
"import multiprocessing #多进程\n",
|
||||
"from scoop import futures #多进程"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def parameter_generate():\n",
|
||||
" '''\n",
|
||||
" 根据设置的起始值,终止值和步进,随机生成待优化的策略参数\n",
|
||||
" '''\n",
|
||||
" parameter_list = []\n",
|
||||
" p1 = random.randrange(4,50,2) #布林带窗口\n",
|
||||
" p2 = random.randrange(4,50,2) #布林带通道阈值\n",
|
||||
" p3 = random.randrange(4,50,2) #CCI窗口\n",
|
||||
" p4 = random.randrange(18,40,2) #ATR窗口 \n",
|
||||
"\n",
|
||||
" parameter_list.append(p1)\n",
|
||||
" parameter_list.append(p2)\n",
|
||||
" parameter_list.append(p3)\n",
|
||||
" parameter_list.append(p4)\n",
|
||||
"\n",
|
||||
" return parameter_list"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def object_func(strategy_avg):\n",
|
||||
" \"\"\"\n",
|
||||
" 本函数为优化目标函数,根据随机生成的策略参数,运行回测后自动返回2个结果指标:收益回撤比和夏普比率\n",
|
||||
" \"\"\"\n",
|
||||
" # 创建回测引擎对象\n",
|
||||
" engine = BacktestingEngine()\n",
|
||||
" engine.set_parameters(\n",
|
||||
" vt_symbol=\"IF88.CFFEX\",\n",
|
||||
" interval=\"1m\",\n",
|
||||
" start=datetime(2018, 9, 1),\n",
|
||||
" end=datetime(2019, 1,1),\n",
|
||||
" rate=0,\n",
|
||||
" slippage=0,\n",
|
||||
" size=300,\n",
|
||||
" pricetick=0.2,\n",
|
||||
" capital=1_000_000,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" setting = {'boll_window': strategy_avg[0], #布林带窗口\n",
|
||||
" 'boll_dev': strategy_avg[1], #布林带通道阈值\n",
|
||||
" 'cci_window': strategy_avg[2], #CCI窗口\n",
|
||||
" 'atr_window': strategy_avg[3],} #ATR窗口 \n",
|
||||
"\n",
|
||||
" #加载策略 \n",
|
||||
" #engine.initStrategy(TurtleTradingStrategy, setting)\n",
|
||||
" engine.add_strategy(BollChannelStrategy, setting)\n",
|
||||
" engine.load_data()\n",
|
||||
" engine.run_backtesting()\n",
|
||||
" engine.calculate_result()\n",
|
||||
" result = engine.calculate_statistics(Output=False)\n",
|
||||
"\n",
|
||||
" return_drawdown_ratio = round(result['return_drawdown_ratio'],2) #收益回撤比\n",
|
||||
" sharpe_ratio= round(result['sharpe_ratio'],2) #夏普比率\n",
|
||||
" return return_drawdown_ratio , sharpe_ratio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"#设置优化方向:最大化收益回撤比,最大化夏普比率\n",
|
||||
"creator.create(\"FitnessMulti\", base.Fitness, weights=(1.0, 1.0)) # 1.0 求最大值;-1.0 求最小值\n",
|
||||
"creator.create(\"Individual\", list, fitness=creator.FitnessMulti)\n",
|
||||
"\n",
|
||||
"def optimize():\n",
|
||||
" \"\"\"\"\"\" \n",
|
||||
" toolbox = base.Toolbox() #Toolbox是deap库内置的工具箱,里面包含遗传算法中所用到的各种函数\n",
|
||||
"\n",
|
||||
" # 初始化 \n",
|
||||
" toolbox.register(\"individual\", tools.initIterate, creator.Individual,parameter_generate) # 注册个体:随机生成的策略参数parameter_generate() \n",
|
||||
" toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual) #注册种群:个体形成种群 \n",
|
||||
" toolbox.register(\"mate\", tools.cxTwoPoint) #注册交叉:两点交叉 \n",
|
||||
" toolbox.register(\"mutate\", tools.mutUniformInt,low = 4,up = 40,indpb=0.6) #注册变异:随机生成一定区间内的整数\n",
|
||||
" toolbox.register(\"evaluate\", object_func) #注册评估:优化目标函数object_func() \n",
|
||||
" toolbox.register(\"select\", tools.selNSGA2) #注册选择:NSGA-II(带精英策略的非支配排序的遗传算法)\n",
|
||||
" #pool = multiprocessing.Pool()\n",
|
||||
" #toolbox.register(\"map\", pool.map)\n",
|
||||
" #toolbox.register(\"map\", futures.map)\n",
|
||||
"\n",
|
||||
" #遗传算法参数设置\n",
|
||||
" MU = 40 #设置每一代选择的个体数\n",
|
||||
" LAMBDA = 160 #设置每一代产生的子女数\n",
|
||||
" pop = toolbox.population(400) #设置族群里面的个体数量\n",
|
||||
" CXPB, MUTPB, NGEN = 0.5, 0.35,40 #分别为种群内部个体的交叉概率、变异概率、产生种群代数\n",
|
||||
" hof = tools.ParetoFront() #解的集合:帕累托前沿(非占优最优集)\n",
|
||||
"\n",
|
||||
" #解的集合的描述统计信息\n",
|
||||
" #集合内平均值,标准差,最小值,最大值可以体现集合的收敛程度\n",
|
||||
" #收敛程度低可以增加算法的迭代次数\n",
|
||||
" stats = tools.Statistics(lambda ind: ind.fitness.values)\n",
|
||||
" np.set_printoptions(suppress=True) #对numpy默认输出的科学计数法转换\n",
|
||||
" stats.register(\"mean\", np.mean, axis=0) #统计目标优化函数结果的平均值\n",
|
||||
" stats.register(\"std\", np.std, axis=0) #统计目标优化函数结果的标准差\n",
|
||||
" stats.register(\"min\", np.min, axis=0) #统计目标优化函数结果的最小值\n",
|
||||
" stats.register(\"max\", np.max, axis=0) #统计目标优化函数结果的最大值\n",
|
||||
"\n",
|
||||
" #运行算法\n",
|
||||
" algorithms.eaMuPlusLambda(pop, toolbox, MU, LAMBDA, CXPB, MUTPB, NGEN, stats,\n",
|
||||
" halloffame=hof) #esMuPlusLambda是一种基于(μ+λ)选择策略的多目标优化分段遗传算法\n",
|
||||
"\n",
|
||||
" return pop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"optimize()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.1"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -10,9 +10,13 @@ from vnpy.gateway.ib import IbGateway
|
||||
from vnpy.gateway.ctp import CtpGateway
|
||||
from vnpy.gateway.tiger import TigerGateway
|
||||
from vnpy.gateway.oes import OesGateway
|
||||
from vnpy.gateway.okex import OkexGateway
|
||||
from vnpy.gateway.huobi import HuobiGateway
|
||||
|
||||
from vnpy.app.cta_strategy import CtaStrategyApp
|
||||
from vnpy.app.csv_loader import CsvLoaderApp
|
||||
from vnpy.app.algo_trading import AlgoTradingApp
|
||||
from vnpy.app.cta_backtester import CtaBacktesterApp
|
||||
|
||||
|
||||
def main():
|
||||
@ -28,9 +32,13 @@ def main():
|
||||
main_engine.add_gateway(BitmexGateway)
|
||||
main_engine.add_gateway(TigerGateway)
|
||||
main_engine.add_gateway(OesGateway)
|
||||
main_engine.add_gateway(OkexGateway)
|
||||
main_engine.add_gateway(HuobiGateway)
|
||||
|
||||
main_engine.add_app(CtaStrategyApp)
|
||||
main_engine.add_app(CtaBacktesterApp)
|
||||
main_engine.add_app(CsvLoaderApp)
|
||||
main_engine.add_app(AlgoTradingApp)
|
||||
|
||||
main_window = MainWindow(main_engine, event_engine)
|
||||
main_window.showMaximized()
|
||||
|
@ -48,14 +48,18 @@ class WebsocketClient(object):
|
||||
|
||||
self.proxy_host = None
|
||||
self.proxy_port = None
|
||||
self.ping_interval = 60 # seconds
|
||||
|
||||
# For debugging
|
||||
self._last_sent_text = None
|
||||
self._last_received_text = None
|
||||
|
||||
def init(self, host: str, proxy_host: str = "", proxy_port: int = 0):
|
||||
""""""
|
||||
def init(self, host: str, proxy_host: str = "", proxy_port: int = 0, ping_interval: int = 60):
|
||||
"""
|
||||
:param ping_interval: unit: seconds, type: int
|
||||
"""
|
||||
self.host = host
|
||||
self.ping_interval = ping_interval # seconds
|
||||
|
||||
if proxy_host and proxy_port:
|
||||
self.proxy_host = proxy_host
|
||||
@ -206,7 +210,7 @@ class WebsocketClient(object):
|
||||
et, ev, tb = sys.exc_info()
|
||||
self.on_error(et, ev, tb)
|
||||
self._reconnect()
|
||||
for i in range(60):
|
||||
for i in range(self.ping_interval):
|
||||
if not self._active:
|
||||
break
|
||||
sleep(1)
|
||||
|
18
vnpy/app/algo_trading/__init__.py
Normal file
18
vnpy/app/algo_trading/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
from pathlib import Path
|
||||
|
||||
from vnpy.trader.app import BaseApp
|
||||
|
||||
from .engine import AlgoEngine, APP_NAME
|
||||
from .template import AlgoTemplate
|
||||
|
||||
|
||||
class AlgoTradingApp(BaseApp):
|
||||
""""""
|
||||
|
||||
app_name = APP_NAME
|
||||
app_module = __module__
|
||||
app_path = Path(__file__).parent
|
||||
display_name = "算法交易"
|
||||
engine_class = AlgoEngine
|
||||
widget_name = "AlgoManager"
|
||||
icon_name = "algo.ico"
|
0
vnpy/app/algo_trading/algos/__init__.py
Normal file
0
vnpy/app/algo_trading/algos/__init__.py
Normal file
137
vnpy/app/algo_trading/algos/iceberg_algo.py
Normal file
137
vnpy/app/algo_trading/algos/iceberg_algo.py
Normal file
@ -0,0 +1,137 @@
|
||||
from vnpy.trader.constant import Offset, Direction
|
||||
from vnpy.trader.object import TradeData, OrderData, TickData
|
||||
from vnpy.trader.engine import BaseEngine
|
||||
|
||||
from vnpy.app.algo_trading import AlgoTemplate
|
||||
|
||||
|
||||
class IcebergAlgo(AlgoTemplate):
|
||||
""""""
|
||||
|
||||
display_name = "Iceberg 冰山"
|
||||
|
||||
default_setting = {
|
||||
"vt_symbol": "",
|
||||
"direction": [Direction.LONG.value, Direction.SHORT.value],
|
||||
"price": 0.0,
|
||||
"volume": 0.0,
|
||||
"display_volume": 0.0,
|
||||
"interval": 0,
|
||||
"offset": [
|
||||
Offset.NONE.value,
|
||||
Offset.OPEN.value,
|
||||
Offset.CLOSE.value,
|
||||
Offset.CLOSETODAY.value,
|
||||
Offset.CLOSEYESTERDAY.value
|
||||
]
|
||||
}
|
||||
|
||||
variables = [
|
||||
"traded",
|
||||
"timer_count",
|
||||
"vt_orderid"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algo_engine: BaseEngine,
|
||||
algo_name: str,
|
||||
setting: dict
|
||||
):
|
||||
""""""
|
||||
super().__init__(algo_engine, algo_name, setting)
|
||||
|
||||
# Parameters
|
||||
self.vt_symbol = setting["vt_symbol"]
|
||||
self.direction = Direction(setting["direction"])
|
||||
self.price = setting["price"]
|
||||
self.volume = setting["volume"]
|
||||
self.display_volume = setting["display_volume"]
|
||||
self.interval = setting["interval"]
|
||||
self.offset = Offset(setting["offset"])
|
||||
|
||||
# Variables
|
||||
self.timer_count = 0
|
||||
self.vt_orderid = ""
|
||||
self.traded = 0
|
||||
|
||||
self.last_tick = None
|
||||
|
||||
self.subscribe(self.vt_symbol)
|
||||
self.put_parameters_event()
|
||||
self.put_variables_event()
|
||||
|
||||
def on_stop(self):
|
||||
""""""
|
||||
self.write_log("停止算法")
|
||||
|
||||
def on_tick(self, tick: TickData):
|
||||
""""""
|
||||
self.last_tick = tick
|
||||
|
||||
def on_order(self, order: OrderData):
|
||||
""""""
|
||||
msg = f"委托号:{order.vt_orderid},委托状态:{order.status.value}"
|
||||
self.write_log(msg)
|
||||
|
||||
if not order.is_active():
|
||||
self.vt_orderid = ""
|
||||
self.put_variables_event()
|
||||
|
||||
def on_trade(self, trade: TradeData):
|
||||
""""""
|
||||
self.traded += trade.volume
|
||||
|
||||
if self.traded >= self.volume:
|
||||
self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}")
|
||||
self.stop()
|
||||
else:
|
||||
self.put_variables_event()
|
||||
|
||||
def on_timer(self):
|
||||
""""""
|
||||
self.timer_count += 1
|
||||
|
||||
if self.timer_count < self.interval:
|
||||
self.put_variables_event()
|
||||
return
|
||||
|
||||
self.timer_count = 0
|
||||
|
||||
contract = self.get_contract(self.vt_symbol)
|
||||
if not contract:
|
||||
return
|
||||
|
||||
# If order already finished, just send new order
|
||||
if not self.vt_orderid:
|
||||
order_volume = self.volume - self.traded
|
||||
order_volume = min(order_volume, self.display_volume)
|
||||
|
||||
if self.direction == Direction.LONG:
|
||||
self.vt_orderid = self.buy(
|
||||
self.vt_symbol,
|
||||
self.price,
|
||||
order_volume,
|
||||
offset=self.offset
|
||||
)
|
||||
else:
|
||||
self.vt_orderid = self.sell(
|
||||
self.vt_symbol,
|
||||
self.price,
|
||||
order_volume,
|
||||
offset=self.offset
|
||||
)
|
||||
# Otherwise check for cancel
|
||||
else:
|
||||
if self.direction == Direction.LONG:
|
||||
if self.last_tick.ask_price_1 <= self.price:
|
||||
self.cancel_order(self.vt_orderid)
|
||||
self.vt_orderid = ""
|
||||
self.write_log(u"最新Tick卖一价,低于买入委托价格,之前委托可能丢失,强制撤单")
|
||||
else:
|
||||
if self.last_tick.bid_price_1 >= self.price:
|
||||
self.cancel_order(self.vt_orderid)
|
||||
self.vt_orderid = ""
|
||||
self.write_log(u"最新Tick买一价,高于卖出委托价格,之前委托可能丢失,强制撤单")
|
||||
|
||||
self.put_variables_event()
|
101
vnpy/app/algo_trading/algos/sniper_algo.py
Normal file
101
vnpy/app/algo_trading/algos/sniper_algo.py
Normal file
@ -0,0 +1,101 @@
|
||||
from vnpy.trader.constant import Offset, Direction
|
||||
from vnpy.trader.object import TradeData, OrderData, TickData
|
||||
from vnpy.trader.engine import BaseEngine
|
||||
|
||||
from vnpy.app.algo_trading import AlgoTemplate
|
||||
|
||||
|
||||
class SniperAlgo(AlgoTemplate):
|
||||
""""""
|
||||
|
||||
display_name = "Sniper 狙击手"
|
||||
|
||||
default_setting = {
|
||||
"vt_symbol": "",
|
||||
"direction": [Direction.LONG.value, Direction.SHORT.value],
|
||||
"price": 0.0,
|
||||
"volume": 0.0,
|
||||
"offset": [
|
||||
Offset.NONE.value,
|
||||
Offset.OPEN.value,
|
||||
Offset.CLOSE.value,
|
||||
Offset.CLOSETODAY.value,
|
||||
Offset.CLOSEYESTERDAY.value
|
||||
]
|
||||
}
|
||||
|
||||
variables = [
|
||||
"traded",
|
||||
"vt_orderid"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algo_engine: BaseEngine,
|
||||
algo_name: str,
|
||||
setting: dict
|
||||
):
|
||||
""""""
|
||||
super().__init__(algo_engine, algo_name, setting)
|
||||
|
||||
# Parameters
|
||||
self.vt_symbol = setting["vt_symbol"]
|
||||
self.direction = Direction(setting["direction"])
|
||||
self.price = setting["price"]
|
||||
self.volume = setting["volume"]
|
||||
self.offset = Offset(setting["offset"])
|
||||
|
||||
# Variables
|
||||
self.vt_orderid = ""
|
||||
self.traded = 0
|
||||
|
||||
self.subscribe(self.vt_symbol)
|
||||
self.put_parameters_event()
|
||||
self.put_variables_event()
|
||||
|
||||
def on_tick(self, tick: TickData):
|
||||
""""""
|
||||
if self.vt_orderid:
|
||||
self.cancel_all()
|
||||
return
|
||||
|
||||
if self.direction == Direction.LONG:
|
||||
if tick.ask_price_1 <= self.price:
|
||||
order_volume = self.volume - self.traded
|
||||
order_volume = min(order_volume, tick.ask_volume_1)
|
||||
|
||||
self.vt_orderid = self.buy(
|
||||
self.vt_symbol,
|
||||
self.price,
|
||||
order_volume,
|
||||
offset=self.offset
|
||||
)
|
||||
else:
|
||||
if tick.bid_price_1 >= self.price:
|
||||
order_volume = self.volume - self.traded
|
||||
order_volume = min(order_volume, tick.bid_volume_1)
|
||||
|
||||
self.vt_orderid = self.sell(
|
||||
self.vt_symbol,
|
||||
self.price,
|
||||
order_volume,
|
||||
offset=self.offset
|
||||
)
|
||||
|
||||
self.put_variables_event()
|
||||
|
||||
def on_order(self, order: OrderData):
|
||||
""""""
|
||||
if not order.is_active():
|
||||
self.vt_orderid = ""
|
||||
self.put_variables_event()
|
||||
|
||||
def on_trade(self, trade: TradeData):
|
||||
""""""
|
||||
self.traded += trade.volume
|
||||
|
||||
if self.traded >= self.volume:
|
||||
self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}")
|
||||
self.stop()
|
||||
else:
|
||||
self.put_variables_event()
|
105
vnpy/app/algo_trading/algos/twap_algo.py
Normal file
105
vnpy/app/algo_trading/algos/twap_algo.py
Normal file
@ -0,0 +1,105 @@
|
||||
from vnpy.trader.constant import Offset, Direction
|
||||
from vnpy.trader.object import TradeData
|
||||
from vnpy.trader.engine import BaseEngine
|
||||
|
||||
from vnpy.app.algo_trading import AlgoTemplate
|
||||
|
||||
|
||||
class TwapAlgo(AlgoTemplate):
|
||||
""""""
|
||||
|
||||
display_name = "TWAP 时间加权平均"
|
||||
|
||||
default_setting = {
|
||||
"vt_symbol": "",
|
||||
"direction": [Direction.LONG.value, Direction.SHORT.value],
|
||||
"price": 0.0,
|
||||
"volume": 0.0,
|
||||
"time": 600,
|
||||
"interval": 60,
|
||||
"offset": [
|
||||
Offset.NONE.value,
|
||||
Offset.OPEN.value,
|
||||
Offset.CLOSE.value,
|
||||
Offset.CLOSETODAY.value,
|
||||
Offset.CLOSEYESTERDAY.value
|
||||
]
|
||||
}
|
||||
|
||||
variables = [
|
||||
"traded",
|
||||
"order_volume",
|
||||
"timer_count",
|
||||
"total_count"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algo_engine: BaseEngine,
|
||||
algo_name: str,
|
||||
setting: dict
|
||||
):
|
||||
""""""
|
||||
super().__init__(algo_engine, algo_name, setting)
|
||||
|
||||
# Parameters
|
||||
self.vt_symbol = setting["vt_symbol"]
|
||||
self.direction = Direction(setting["direction"])
|
||||
self.price = setting["price"]
|
||||
self.volume = setting["volume"]
|
||||
self.time = setting["time"]
|
||||
self.interval = setting["interval"]
|
||||
self.offset = Offset(setting["offset"])
|
||||
|
||||
# Variables
|
||||
self.order_volume = self.volume / (self.time / self.interval)
|
||||
self.timer_count = 0
|
||||
self.total_count = 0
|
||||
self.traded = 0
|
||||
|
||||
self.subscribe(self.vt_symbol)
|
||||
self.put_parameters_event()
|
||||
self.put_variables_event()
|
||||
|
||||
def on_trade(self, trade: TradeData):
|
||||
""""""
|
||||
self.traded += trade.volume
|
||||
|
||||
if self.traded >= self.volume:
|
||||
self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}")
|
||||
self.stop()
|
||||
else:
|
||||
self.put_variables_event()
|
||||
|
||||
def on_timer(self):
|
||||
""""""
|
||||
self.timer_count += 1
|
||||
self.total_count += 1
|
||||
self.put_variables_event()
|
||||
|
||||
if self.total_count >= self.time:
|
||||
self.write_log("执行时间已结束,停止算法")
|
||||
self.stop()
|
||||
return
|
||||
|
||||
if self.timer_count < self.interval:
|
||||
return
|
||||
self.timer_count = 0
|
||||
|
||||
tick = self.get_tick(self.vt_symbol)
|
||||
if not tick:
|
||||
return
|
||||
|
||||
self.cancel_all()
|
||||
|
||||
left_volume = self.volume - self.traded
|
||||
order_volume = min(self.order_volume, left_volume)
|
||||
|
||||
if self.direction == Direction.LONG:
|
||||
if tick.ask_price_1 <= self.price:
|
||||
self.buy(self.vt_symbol, self.price,
|
||||
order_volume, offset=self.offset)
|
||||
else:
|
||||
if tick.bid_price_1 >= self.price:
|
||||
self.sell(self.vt_symbol, self.price,
|
||||
order_volume, offset=self.offset)
|
265
vnpy/app/algo_trading/engine.py
Normal file
265
vnpy/app/algo_trading/engine.py
Normal file
@ -0,0 +1,265 @@
|
||||
|
||||
from vnpy.event import EventEngine, Event
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
from vnpy.trader.event import (
|
||||
EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE)
|
||||
from vnpy.trader.constant import (Direction, Offset, OrderType)
|
||||
from vnpy.trader.object import (SubscribeRequest, OrderRequest)
|
||||
from vnpy.trader.utility import load_json, save_json
|
||||
|
||||
from .template import AlgoTemplate
|
||||
|
||||
|
||||
APP_NAME = "AlgoTrading"
|
||||
|
||||
EVENT_ALGO_LOG = "eAlgoLog"
|
||||
EVENT_ALGO_SETTING = "eAlgoSetting"
|
||||
EVENT_ALGO_VARIABLES = "eAlgoVariables"
|
||||
EVENT_ALGO_PARAMETERS = "eAlgoParameters"
|
||||
|
||||
|
||||
class AlgoEngine(BaseEngine):
|
||||
""""""
|
||||
setting_filename = "algo_trading_setting.json"
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
"""Constructor"""
|
||||
super().__init__(main_engine, event_engine, APP_NAME)
|
||||
|
||||
self.algos = {}
|
||||
self.symbol_algo_map = {}
|
||||
self.orderid_algo_map = {}
|
||||
|
||||
self.algo_templates = {}
|
||||
self.algo_settings = {}
|
||||
|
||||
self.load_algo_template()
|
||||
self.register_event()
|
||||
|
||||
def init_engine(self):
|
||||
""""""
|
||||
self.write_log("算法交易引擎启动")
|
||||
self.load_algo_setting()
|
||||
|
||||
def load_algo_template(self):
|
||||
""""""
|
||||
from .algos.twap_algo import TwapAlgo
|
||||
from .algos.iceberg_algo import IcebergAlgo
|
||||
from .algos.sniper_algo import SniperAlgo
|
||||
|
||||
self.add_algo_template(TwapAlgo)
|
||||
self.add_algo_template(IcebergAlgo)
|
||||
self.add_algo_template(SniperAlgo)
|
||||
|
||||
def add_algo_template(self, template: AlgoTemplate):
|
||||
""""""
|
||||
self.algo_templates[template.__name__] = template
|
||||
|
||||
def load_algo_setting(self):
|
||||
""""""
|
||||
self.algo_settings = load_json(self.setting_filename)
|
||||
|
||||
for setting_name, setting in self.algo_settings.items():
|
||||
self.put_setting_event(setting_name, setting)
|
||||
|
||||
self.write_log("算法配置载入成功")
|
||||
|
||||
def save_algo_setting(self):
|
||||
""""""
|
||||
save_json(self.setting_filename, self.algo_settings)
|
||||
|
||||
def register_event(self):
|
||||
""""""
|
||||
self.event_engine.register(EVENT_TICK, self.process_tick_event)
|
||||
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
|
||||
self.event_engine.register(EVENT_ORDER, self.process_order_event)
|
||||
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
|
||||
|
||||
def process_tick_event(self, event: Event):
|
||||
""""""
|
||||
tick = event.data
|
||||
|
||||
algos = self.symbol_algo_map.get(tick.vt_symbol, None)
|
||||
if algos:
|
||||
for algo in algos:
|
||||
algo.update_tick(tick)
|
||||
|
||||
def process_timer_event(self, event: Event):
|
||||
""""""
|
||||
for algo in self.algos.values():
|
||||
algo.update_timer()
|
||||
|
||||
def process_trade_event(self, event: Event):
|
||||
""""""
|
||||
trade = event.data
|
||||
|
||||
algo = self.orderid_algo_map.get(trade.vt_orderid, None)
|
||||
if algo:
|
||||
algo.update_trade(trade)
|
||||
|
||||
def process_order_event(self, event: Event):
|
||||
""""""
|
||||
order = event.data
|
||||
|
||||
algo = self.orderid_algo_map.get(order.vt_orderid, None)
|
||||
if algo:
|
||||
algo.update_order(order)
|
||||
|
||||
def start_algo(self, setting: dict):
|
||||
""""""
|
||||
template_name = setting["template_name"]
|
||||
algo_template = self.algo_templates[template_name]
|
||||
|
||||
algo = algo_template.new(self, setting)
|
||||
algo.start()
|
||||
|
||||
self.algos[algo.algo_name] = algo
|
||||
return algo.algo_name
|
||||
|
||||
def stop_algo(self, algo_name: str):
|
||||
""""""
|
||||
algo = self.algos.get(algo_name, None)
|
||||
if algo:
|
||||
algo.stop()
|
||||
self.algos.pop(algo_name)
|
||||
|
||||
def stop_all(self):
|
||||
""""""
|
||||
for algo_name in list(self.algos.keys()):
|
||||
self.stop_algo(algo_name)
|
||||
|
||||
def subscribe(self, algo: AlgoTemplate, vt_symbol: str):
|
||||
""""""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if not contract:
|
||||
self.write_log(f'订阅行情失败,找不到合约:{vt_symbol}', algo)
|
||||
return
|
||||
|
||||
algos = self.symbol_algo_map.setdefault(vt_symbol, set())
|
||||
|
||||
if not algos:
|
||||
req = SubscribeRequest(
|
||||
symbol=contract.symbol,
|
||||
exchange=contract.exchange
|
||||
)
|
||||
self.main_engine.subscribe(req, contract.gateway_name)
|
||||
|
||||
algos.add(algo)
|
||||
|
||||
def send_order(
|
||||
self,
|
||||
algo: AlgoTemplate,
|
||||
vt_symbol: str,
|
||||
direction: Direction,
|
||||
price: float,
|
||||
volume: float,
|
||||
order_type: OrderType,
|
||||
offset: Offset
|
||||
):
|
||||
""""""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if not contract:
|
||||
self.write_log(f'委托下单失败,找不到合约:{vt_symbol}', algo)
|
||||
return
|
||||
|
||||
req = OrderRequest(
|
||||
symbol=contract.symbol,
|
||||
exchange=contract.exchange,
|
||||
direction=direction,
|
||||
type=order_type,
|
||||
volume=volume,
|
||||
price=price,
|
||||
offset=offset
|
||||
)
|
||||
vt_orderid = self.main_engine.send_order(req, contract.gateway_name)
|
||||
|
||||
self.orderid_algo_map[vt_orderid] = algo
|
||||
return vt_orderid
|
||||
|
||||
def cancel_order(self, algo: AlgoTemplate, vt_orderid: str):
|
||||
""""""
|
||||
order = self.main_engine.get_order(vt_orderid)
|
||||
|
||||
if not order:
|
||||
self.write_log(f"委托撤单失败,找不到委托:{vt_orderid}", algo)
|
||||
return
|
||||
|
||||
req = order.create_cancel_request()
|
||||
self.main_engine.cancel_order(req, order.gateway_name)
|
||||
|
||||
def get_tick(self, algo: AlgoTemplate, vt_symbol: str):
|
||||
""""""
|
||||
tick = self.main_engine.get_tick(vt_symbol)
|
||||
|
||||
if not tick:
|
||||
self.write_log(f"查询行情失败,找不到行情:{vt_symbol}", algo)
|
||||
|
||||
return tick
|
||||
|
||||
def get_contract(self, algo: AlgoTemplate, vt_symbol: str):
|
||||
""""""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
|
||||
if not contract:
|
||||
self.write_log(f"查询合约失败,找不到合约:{vt_symbol}", algo)
|
||||
|
||||
return contract
|
||||
|
||||
def write_log(self, msg: str, algo: AlgoTemplate = None):
|
||||
""""""
|
||||
if algo:
|
||||
msg = f"{algo.algo_name}:{msg}"
|
||||
|
||||
event = Event(EVENT_ALGO_LOG)
|
||||
event.data = msg
|
||||
self.event_engine.put(event)
|
||||
|
||||
def put_setting_event(self, setting_name: str, setting: dict):
|
||||
""""""
|
||||
event = Event(EVENT_ALGO_SETTING)
|
||||
event.data = {
|
||||
"setting_name": setting_name,
|
||||
"setting": setting
|
||||
}
|
||||
self.event_engine.put(event)
|
||||
|
||||
def update_algo_setting(self, setting_name: str, setting: dict):
|
||||
""""""
|
||||
self.algo_settings[setting_name] = setting
|
||||
|
||||
self.save_algo_setting()
|
||||
|
||||
self.put_setting_event(setting_name, setting)
|
||||
|
||||
def remove_algo_setting(self, setting_name: str):
|
||||
""""""
|
||||
if setting_name not in self.algo_settings:
|
||||
return
|
||||
self.algo_settings.pop(setting_name)
|
||||
|
||||
event = Event(EVENT_ALGO_SETTING)
|
||||
event.data = {
|
||||
"setting_name": setting_name,
|
||||
"setting": None
|
||||
}
|
||||
self.event_engine.put(event)
|
||||
|
||||
self.save_algo_setting()
|
||||
|
||||
def put_parameters_event(self, algo: AlgoTemplate, parameters: dict):
|
||||
""""""
|
||||
event = Event(EVENT_ALGO_PARAMETERS)
|
||||
event.data = {
|
||||
"algo_name": algo.algo_name,
|
||||
"parameters": parameters
|
||||
}
|
||||
self.event_engine.put(event)
|
||||
|
||||
def put_variables_event(self, algo: AlgoTemplate, variables: dict):
|
||||
""""""
|
||||
event = Event(EVENT_ALGO_VARIABLES)
|
||||
event.data = {
|
||||
"algo_name": algo.algo_name,
|
||||
"variables": variables
|
||||
}
|
||||
self.event_engine.put(event)
|
193
vnpy/app/algo_trading/template.py
Normal file
193
vnpy/app/algo_trading/template.py
Normal file
@ -0,0 +1,193 @@
|
||||
from vnpy.trader.engine import BaseEngine
|
||||
from vnpy.trader.object import TickData, OrderData, TradeData
|
||||
from vnpy.trader.constant import OrderType, Offset, Direction
|
||||
from vnpy.trader.utility import virtual
|
||||
|
||||
|
||||
class AlgoTemplate:
|
||||
""""""
|
||||
|
||||
_count = 0
|
||||
display_name = ""
|
||||
default_setting = {}
|
||||
variables = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algo_engine: BaseEngine,
|
||||
algo_name: str,
|
||||
setting: dict
|
||||
):
|
||||
"""Constructor"""
|
||||
self.algo_engine = algo_engine
|
||||
self.algo_name = algo_name
|
||||
|
||||
self.active = False
|
||||
self.active_orders = {} # vt_orderid:order
|
||||
|
||||
self.variables.insert(0, "active")
|
||||
|
||||
@classmethod
|
||||
def new(cls, algo_engine: BaseEngine, setting: dict):
|
||||
"""Create new algo instance"""
|
||||
cls._count += 1
|
||||
algo_name = f"{cls.__name__}_{cls._count}"
|
||||
algo = cls(algo_engine, algo_name, setting)
|
||||
return algo
|
||||
|
||||
def update_tick(self, tick: TickData):
|
||||
""""""
|
||||
if self.active:
|
||||
self.on_tick(tick)
|
||||
|
||||
def update_order(self, order: OrderData):
|
||||
""""""
|
||||
if self.active:
|
||||
if order.is_active():
|
||||
self.active_orders[order.vt_orderid] = order
|
||||
elif order.vt_orderid in self.active_orders:
|
||||
self.active_orders.pop(order.vt_orderid)
|
||||
|
||||
self.on_order(order)
|
||||
|
||||
def update_trade(self, trade: TradeData):
|
||||
""""""
|
||||
if self.active:
|
||||
self.on_trade(trade)
|
||||
|
||||
def update_timer(self):
|
||||
""""""
|
||||
if self.active:
|
||||
self.on_timer()
|
||||
|
||||
def on_start(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_stop(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_tick(self, tick: TickData):
|
||||
""""""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_order(self, order: OrderData):
|
||||
""""""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_trade(self, trade: TradeData):
|
||||
""""""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_timer(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
""""""
|
||||
self.active = True
|
||||
self.on_start()
|
||||
self.put_variables_event()
|
||||
|
||||
def stop(self):
|
||||
""""""
|
||||
self.active = False
|
||||
self.cancel_all()
|
||||
self.on_stop()
|
||||
self.put_variables_event()
|
||||
|
||||
self.write_log("停止算法")
|
||||
|
||||
def subscribe(self, vt_symbol):
|
||||
""""""
|
||||
self.algo_engine.subscribe(self, vt_symbol)
|
||||
|
||||
def buy(
|
||||
self,
|
||||
vt_symbol,
|
||||
price,
|
||||
volume,
|
||||
order_type: OrderType = OrderType.LIMIT,
|
||||
offset: Offset = Offset.NONE
|
||||
):
|
||||
""""""
|
||||
msg = f"委托买入{vt_symbol}:{volume}@{price}"
|
||||
self.write_log(msg)
|
||||
|
||||
return self.algo_engine.send_order(
|
||||
self,
|
||||
vt_symbol,
|
||||
Direction.LONG,
|
||||
price,
|
||||
volume,
|
||||
order_type,
|
||||
offset
|
||||
)
|
||||
|
||||
def sell(
|
||||
self,
|
||||
vt_symbol,
|
||||
price,
|
||||
volume,
|
||||
order_type: OrderType = OrderType.LIMIT,
|
||||
offset: Offset = Offset.NONE
|
||||
):
|
||||
""""""
|
||||
msg = f"委托卖出{vt_symbol}:{volume}@{price}"
|
||||
self.write_log(msg)
|
||||
|
||||
return self.algo_engine.send_order(
|
||||
self,
|
||||
vt_symbol,
|
||||
Direction.SHORT,
|
||||
price,
|
||||
volume,
|
||||
order_type,
|
||||
offset
|
||||
)
|
||||
|
||||
def cancel_order(self, vt_orderid: str):
|
||||
""""""
|
||||
self.algo_engine.cancel_order(self, vt_orderid)
|
||||
|
||||
def cancel_all(self):
|
||||
""""""
|
||||
if not self.active_orders:
|
||||
return
|
||||
|
||||
for vt_orderid in self.active_orders.keys():
|
||||
self.cancel_order(vt_orderid)
|
||||
|
||||
def get_tick(self, vt_symbol: str):
|
||||
""""""
|
||||
return self.algo_engine.get_tick(self, vt_symbol)
|
||||
|
||||
def get_contract(self, vt_symbol: str):
|
||||
""""""
|
||||
return self.algo_engine.get_contract(self, vt_symbol)
|
||||
|
||||
def write_log(self, msg: str):
|
||||
""""""
|
||||
self.algo_engine.write_log(msg, self)
|
||||
|
||||
def put_parameters_event(self):
|
||||
""""""
|
||||
parameters = {}
|
||||
for name in self.default_setting.keys():
|
||||
parameters[name] = getattr(self, name)
|
||||
|
||||
self.algo_engine.put_parameters_event(self, parameters)
|
||||
|
||||
def put_variables_event(self):
|
||||
""""""
|
||||
variables = {}
|
||||
for name in self.variables:
|
||||
variables[name] = getattr(self, name)
|
||||
|
||||
self.algo_engine.put_variables_event(self, variables)
|
1
vnpy/app/algo_trading/ui/__init__.py
Normal file
1
vnpy/app/algo_trading/ui/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .widget import AlgoManager
|
BIN
vnpy/app/algo_trading/ui/algo.ico
Normal file
BIN
vnpy/app/algo_trading/ui/algo.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
16
vnpy/app/algo_trading/ui/display.py
Normal file
16
vnpy/app/algo_trading/ui/display.py
Normal file
@ -0,0 +1,16 @@
|
||||
NAME_DISPLAY_MAP = {
|
||||
"vt_symbol": "本地代码",
|
||||
"direction": "方向",
|
||||
"price": "价格",
|
||||
"volume": "数量",
|
||||
"time": "执行时间(秒)",
|
||||
"interval": "每轮间隔(秒)",
|
||||
"offset": "开平",
|
||||
"active": "算法状态",
|
||||
"traded": "成交数量",
|
||||
"order_volume": "单笔委托",
|
||||
"timer_count": "本轮读秒",
|
||||
"total_count": "累计读秒",
|
||||
"template_name": "算法模板",
|
||||
"display_volume": "挂出数量"
|
||||
}
|
575
vnpy/app/algo_trading/ui/widget.py
Normal file
575
vnpy/app/algo_trading/ui/widget.py
Normal file
@ -0,0 +1,575 @@
|
||||
"""
|
||||
Widget for algo trading.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from datetime import datetime
|
||||
|
||||
from vnpy.event import EventEngine, Event
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from vnpy.trader.ui import QtWidgets, QtCore
|
||||
|
||||
from ..engine import (
|
||||
AlgoEngine,
|
||||
AlgoTemplate,
|
||||
APP_NAME,
|
||||
EVENT_ALGO_LOG,
|
||||
EVENT_ALGO_PARAMETERS,
|
||||
EVENT_ALGO_VARIABLES,
|
||||
EVENT_ALGO_SETTING
|
||||
)
|
||||
from .display import NAME_DISPLAY_MAP
|
||||
|
||||
|
||||
class AlgoWidget(QtWidgets.QWidget):
|
||||
"""
|
||||
Start connection of a certain gateway.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algo_engine: AlgoEngine,
|
||||
algo_template: AlgoTemplate
|
||||
):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.algo_engine = algo_engine
|
||||
self.template_name = algo_template.__name__
|
||||
self.default_setting = algo_template.default_setting
|
||||
|
||||
self.widgets = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
"""
|
||||
Initialize line edits and form layout based on setting.
|
||||
"""
|
||||
self.setMaximumWidth(400)
|
||||
|
||||
form = QtWidgets.QFormLayout()
|
||||
|
||||
for field_name, field_value in self.default_setting.items():
|
||||
field_type = type(field_value)
|
||||
|
||||
if field_type == list:
|
||||
widget = QtWidgets.QComboBox()
|
||||
widget.addItems(field_value)
|
||||
else:
|
||||
widget = QtWidgets.QLineEdit()
|
||||
|
||||
display_name = NAME_DISPLAY_MAP.get(field_name, field_name)
|
||||
|
||||
form.addRow(display_name, widget)
|
||||
self.widgets[field_name] = (widget, field_type)
|
||||
|
||||
start_algo_button = QtWidgets.QPushButton("启动算法")
|
||||
start_algo_button.clicked.connect(self.start_algo)
|
||||
form.addRow(start_algo_button)
|
||||
|
||||
form.addRow(QtWidgets.QLabel(""))
|
||||
|
||||
self.setting_name_line = QtWidgets.QLineEdit()
|
||||
form.addRow("配置名称", self.setting_name_line)
|
||||
|
||||
save_setting_button = QtWidgets.QPushButton("保存配置")
|
||||
save_setting_button.clicked.connect(self.save_setting)
|
||||
form.addRow(save_setting_button)
|
||||
|
||||
self.setLayout(form)
|
||||
|
||||
def get_setting(self):
|
||||
"""
|
||||
Get setting value from line edits.
|
||||
"""
|
||||
setting = {"template_name": self.template_name}
|
||||
|
||||
for field_name, tp in self.widgets.items():
|
||||
widget, field_type = tp
|
||||
if field_type == list:
|
||||
field_value = str(widget.currentText())
|
||||
else:
|
||||
try:
|
||||
field_value = field_type(widget.text())
|
||||
except ValueError:
|
||||
display_name = NAME_DISPLAY_MAP.get(field_name, field_name)
|
||||
QtWidgets.QMessageBox.warning(
|
||||
self,
|
||||
"参数错误",
|
||||
f"{display_name}参数类型应为{field_type},请检查!"
|
||||
)
|
||||
return None
|
||||
|
||||
setting[field_name] = field_value
|
||||
|
||||
return setting
|
||||
|
||||
def start_algo(self):
|
||||
"""
|
||||
Start algo trading.
|
||||
"""
|
||||
setting = self.get_setting()
|
||||
if setting:
|
||||
self.algo_engine.start_algo(setting)
|
||||
|
||||
def update_setting(self, setting_name: str, setting: dict):
|
||||
"""
|
||||
Update setting into widgets.
|
||||
"""
|
||||
self.setting_name_line.setText(setting_name)
|
||||
|
||||
for name, tp in self.widgets.items():
|
||||
widget, _ = tp
|
||||
value = setting[name]
|
||||
|
||||
if isinstance(widget, QtWidgets.QLineEdit):
|
||||
widget.setText(str(value))
|
||||
elif isinstance(widget, QtWidgets.QComboBox):
|
||||
ix = widget.findText(value)
|
||||
widget.setCurrentIndex(ix)
|
||||
|
||||
def save_setting(self):
|
||||
"""
|
||||
Save algo setting
|
||||
"""
|
||||
setting_name = self.setting_name_line.text()
|
||||
if not setting_name:
|
||||
return
|
||||
|
||||
setting = self.get_setting()
|
||||
if setting:
|
||||
self.algo_engine.update_algo_setting(setting_name, setting)
|
||||
|
||||
|
||||
class AlgoMonitor(QtWidgets.QTableWidget):
|
||||
""""""
|
||||
parameters_signal = QtCore.pyqtSignal(Event)
|
||||
variables_signal = QtCore.pyqtSignal(Event)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algo_engine: AlgoEngine,
|
||||
event_engine: EventEngine,
|
||||
mode_active: bool
|
||||
):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.algo_engine = algo_engine
|
||||
self.event_engine = event_engine
|
||||
self.mode_active = mode_active
|
||||
|
||||
self.algo_cells = {}
|
||||
|
||||
self.init_ui()
|
||||
self.register_event()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
labels = [
|
||||
"",
|
||||
"算法",
|
||||
"参数",
|
||||
"状态"
|
||||
]
|
||||
self.setColumnCount(len(labels))
|
||||
self.setHorizontalHeaderLabels(labels)
|
||||
self.verticalHeader().setVisible(False)
|
||||
self.setEditTriggers(self.NoEditTriggers)
|
||||
|
||||
self.verticalHeader().setSectionResizeMode(
|
||||
QtWidgets.QHeaderView.ResizeToContents
|
||||
)
|
||||
|
||||
for column in range(2, 4):
|
||||
self.horizontalHeader().setSectionResizeMode(
|
||||
column,
|
||||
QtWidgets.QHeaderView.Stretch
|
||||
)
|
||||
self.setWordWrap(True)
|
||||
|
||||
if not self.mode_active:
|
||||
self.hideColumn(0)
|
||||
|
||||
def register_event(self):
|
||||
""""""
|
||||
self.parameters_signal.connect(self.process_parameters_event)
|
||||
self.variables_signal.connect(self.process_variables_event)
|
||||
|
||||
self.event_engine.register(
|
||||
EVENT_ALGO_PARAMETERS, self.parameters_signal.emit)
|
||||
self.event_engine.register(
|
||||
EVENT_ALGO_VARIABLES, self.variables_signal.emit)
|
||||
|
||||
def process_parameters_event(self, event):
|
||||
""""""
|
||||
data = event.data
|
||||
algo_name = data["algo_name"]
|
||||
parameters = data["parameters"]
|
||||
|
||||
cells = self.get_algo_cells(algo_name)
|
||||
text = to_text(parameters)
|
||||
cells["parameters"].setText(text)
|
||||
|
||||
def process_variables_event(self, event):
|
||||
""""""
|
||||
data = event.data
|
||||
algo_name = data["algo_name"]
|
||||
variables = data["variables"]
|
||||
|
||||
cells = self.get_algo_cells(algo_name)
|
||||
variables_cell = cells["variables"]
|
||||
text = to_text(variables)
|
||||
variables_cell.setText(text)
|
||||
|
||||
row = self.row(variables_cell)
|
||||
active = variables["active"]
|
||||
|
||||
if self.mode_active:
|
||||
if active:
|
||||
self.showRow(row)
|
||||
else:
|
||||
self.hideRow(row)
|
||||
else:
|
||||
if active:
|
||||
self.hideRow(row)
|
||||
else:
|
||||
self.showRow(row)
|
||||
|
||||
def stop_algo(self, algo_name: str):
|
||||
""""""
|
||||
self.algo_engine.stop_algo(algo_name)
|
||||
|
||||
def get_algo_cells(self, algo_name: str):
|
||||
""""""
|
||||
cells = self.algo_cells.get(algo_name, None)
|
||||
|
||||
if not cells:
|
||||
stop_func = partial(self.stop_algo, algo_name=algo_name)
|
||||
stop_button = QtWidgets.QPushButton("停止")
|
||||
stop_button.clicked.connect(stop_func)
|
||||
|
||||
name_cell = QtWidgets.QTableWidgetItem(algo_name)
|
||||
parameters_cell = QtWidgets.QTableWidgetItem()
|
||||
variables_cell = QtWidgets.QTableWidgetItem()
|
||||
|
||||
self.insertRow(0)
|
||||
self.setCellWidget(0, 0, stop_button)
|
||||
self.setItem(0, 1, name_cell)
|
||||
self.setItem(0, 2, parameters_cell)
|
||||
self.setItem(0, 3, variables_cell)
|
||||
|
||||
cells = {
|
||||
"name": name_cell,
|
||||
"parameters": parameters_cell,
|
||||
"variables": variables_cell
|
||||
}
|
||||
self.algo_cells[algo_name] = cells
|
||||
|
||||
return cells
|
||||
|
||||
|
||||
class ActiveAlgoMonitor(AlgoMonitor):
|
||||
"""
|
||||
Monitor for active algos.
|
||||
"""
|
||||
|
||||
def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__(algo_engine, event_engine, True)
|
||||
|
||||
|
||||
class InactiveAlgoMonitor(AlgoMonitor):
|
||||
"""
|
||||
Monitor for inactive algos.
|
||||
"""
|
||||
|
||||
def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__(algo_engine, event_engine, False)
|
||||
|
||||
|
||||
class SettingMonitor(QtWidgets.QTableWidget):
|
||||
""""""
|
||||
setting_signal = QtCore.pyqtSignal(Event)
|
||||
use_signal = QtCore.pyqtSignal(dict)
|
||||
|
||||
def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.algo_engine = algo_engine
|
||||
self.event_engine = event_engine
|
||||
|
||||
self.settings = {}
|
||||
self.setting_cells = {}
|
||||
|
||||
self.init_ui()
|
||||
self.register_event()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
labels = [
|
||||
"",
|
||||
"",
|
||||
"名称",
|
||||
"配置"
|
||||
]
|
||||
self.setColumnCount(len(labels))
|
||||
self.setHorizontalHeaderLabels(labels)
|
||||
self.verticalHeader().setVisible(False)
|
||||
self.setEditTriggers(self.NoEditTriggers)
|
||||
|
||||
self.verticalHeader().setSectionResizeMode(
|
||||
QtWidgets.QHeaderView.ResizeToContents
|
||||
)
|
||||
|
||||
self.horizontalHeader().setSectionResizeMode(
|
||||
3,
|
||||
QtWidgets.QHeaderView.Stretch
|
||||
)
|
||||
self.setWordWrap(True)
|
||||
|
||||
def register_event(self):
|
||||
""""""
|
||||
self.setting_signal.connect(self.process_setting_event)
|
||||
|
||||
self.event_engine.register(
|
||||
EVENT_ALGO_SETTING, self.setting_signal.emit)
|
||||
|
||||
def process_setting_event(self, event):
|
||||
""""""
|
||||
data = event.data
|
||||
setting_name = data["setting_name"]
|
||||
setting = data["setting"]
|
||||
cells = self.get_setting_cells(setting_name)
|
||||
|
||||
if setting:
|
||||
self.settings[setting_name] = setting
|
||||
|
||||
cells["setting"].setText(to_text(setting))
|
||||
else:
|
||||
if setting_name in self.settings:
|
||||
self.settings.pop(setting_name)
|
||||
|
||||
row = self.row(cells["setting"])
|
||||
self.removeRow(row)
|
||||
|
||||
self.setting_cells.pop(setting_name)
|
||||
|
||||
def get_setting_cells(self, setting_name: str):
|
||||
""""""
|
||||
cells = self.setting_cells.get(setting_name, None)
|
||||
|
||||
if not cells:
|
||||
use_func = partial(self.use_setting, setting_name=setting_name)
|
||||
use_button = QtWidgets.QPushButton("使用")
|
||||
use_button.clicked.connect(use_func)
|
||||
|
||||
remove_func = partial(self.remove_setting,
|
||||
setting_name=setting_name)
|
||||
remove_button = QtWidgets.QPushButton("移除")
|
||||
remove_button.clicked.connect(remove_func)
|
||||
|
||||
name_cell = QtWidgets.QTableWidgetItem(setting_name)
|
||||
setting_cell = QtWidgets.QTableWidgetItem()
|
||||
|
||||
self.insertRow(0)
|
||||
self.setCellWidget(0, 0, use_button)
|
||||
self.setCellWidget(0, 1, remove_button)
|
||||
self.setItem(0, 2, name_cell)
|
||||
self.setItem(0, 3, setting_cell)
|
||||
|
||||
cells = {
|
||||
"name": name_cell,
|
||||
"setting": setting_cell
|
||||
}
|
||||
self.setting_cells[setting_name] = cells
|
||||
|
||||
return cells
|
||||
|
||||
def use_setting(self, setting_name: str):
|
||||
""""""
|
||||
setting = self.settings[setting_name]
|
||||
setting["setting_name"] = setting_name
|
||||
self.use_signal.emit(setting)
|
||||
|
||||
def remove_setting(self, setting_name: str):
|
||||
""""""
|
||||
self.algo_engine.remove_algo_setting(setting_name)
|
||||
|
||||
|
||||
class LogMonitor(QtWidgets.QTableWidget):
|
||||
""""""
|
||||
signal = QtCore.pyqtSignal(Event)
|
||||
|
||||
def __init__(self, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.event_engine = event_engine
|
||||
|
||||
self.init_ui()
|
||||
self.register_event()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
labels = [
|
||||
"时间",
|
||||
"信息"
|
||||
]
|
||||
self.setColumnCount(len(labels))
|
||||
self.setHorizontalHeaderLabels(labels)
|
||||
self.verticalHeader().setVisible(False)
|
||||
self.setEditTriggers(self.NoEditTriggers)
|
||||
|
||||
self.verticalHeader().setSectionResizeMode(
|
||||
QtWidgets.QHeaderView.ResizeToContents
|
||||
)
|
||||
|
||||
self.horizontalHeader().setSectionResizeMode(
|
||||
1,
|
||||
QtWidgets.QHeaderView.Stretch
|
||||
)
|
||||
self.setWordWrap(True)
|
||||
|
||||
def register_event(self):
|
||||
""""""
|
||||
self.signal.connect(self.process_log_event)
|
||||
|
||||
self.event_engine.register(EVENT_ALGO_LOG, self.signal.emit)
|
||||
|
||||
def process_log_event(self, event):
|
||||
""""""
|
||||
msg = event.data
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
timestamp_cell = QtWidgets.QTableWidgetItem(timestamp)
|
||||
msg_cell = QtWidgets.QTableWidgetItem(msg)
|
||||
|
||||
self.insertRow(0)
|
||||
self.setItem(0, 0, timestamp_cell)
|
||||
self.setItem(0, 1, msg_cell)
|
||||
|
||||
|
||||
class AlgoManager(QtWidgets.QWidget):
|
||||
""""""
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.main_engine = main_engine
|
||||
self.event_engine = event_engine
|
||||
self.algo_engine = main_engine.get_engine(APP_NAME)
|
||||
|
||||
self.algo_widgets = {}
|
||||
|
||||
self.init_ui()
|
||||
self.algo_engine.init_engine()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
self.setWindowTitle("算法交易")
|
||||
|
||||
# Left side control widgets
|
||||
self.template_combo = QtWidgets.QComboBox()
|
||||
self.template_combo.currentIndexChanged.connect(self.show_algo_widget)
|
||||
|
||||
form = QtWidgets.QFormLayout()
|
||||
form.addRow("算法", self.template_combo)
|
||||
widget = QtWidgets.QWidget()
|
||||
widget.setLayout(form)
|
||||
|
||||
vbox = QtWidgets.QVBoxLayout()
|
||||
vbox.addWidget(widget)
|
||||
|
||||
for algo_template in self.algo_engine.algo_templates.values():
|
||||
widget = AlgoWidget(self.algo_engine, algo_template)
|
||||
vbox.addWidget(widget)
|
||||
|
||||
template_name = algo_template.__name__
|
||||
display_name = algo_template.display_name
|
||||
|
||||
self.algo_widgets[template_name] = widget
|
||||
self.template_combo.addItem(display_name, template_name)
|
||||
|
||||
vbox.addStretch()
|
||||
|
||||
stop_all_button = QtWidgets.QPushButton("全部停止")
|
||||
stop_all_button.setFixedHeight(stop_all_button.sizeHint().height() * 2)
|
||||
stop_all_button.clicked.connect(self.algo_engine.stop_all)
|
||||
|
||||
vbox.addWidget(stop_all_button)
|
||||
|
||||
# Right side monitor widgets
|
||||
active_algo_monitor = ActiveAlgoMonitor(
|
||||
self.algo_engine, self.event_engine
|
||||
)
|
||||
inactive_algo_monitor = InactiveAlgoMonitor(
|
||||
self.algo_engine, self.event_engine
|
||||
)
|
||||
tab1 = QtWidgets.QTabWidget()
|
||||
tab1.addTab(active_algo_monitor, "执行中")
|
||||
tab1.addTab(inactive_algo_monitor, "已结束")
|
||||
|
||||
log_monitor = LogMonitor(self.event_engine)
|
||||
tab2 = QtWidgets.QTabWidget()
|
||||
tab2.addTab(log_monitor, "日志")
|
||||
|
||||
setting_monitor = SettingMonitor(self.algo_engine, self.event_engine)
|
||||
setting_monitor.use_signal.connect(self.use_setting)
|
||||
tab3 = QtWidgets.QTabWidget()
|
||||
tab3.addTab(setting_monitor, "配置")
|
||||
|
||||
grid = QtWidgets.QGridLayout()
|
||||
grid.addWidget(tab1, 0, 0, 1, 2)
|
||||
grid.addWidget(tab2, 1, 0)
|
||||
grid.addWidget(tab3, 1, 1)
|
||||
|
||||
hbox2 = QtWidgets.QHBoxLayout()
|
||||
hbox2.addLayout(vbox)
|
||||
hbox2.addLayout(grid)
|
||||
self.setLayout(hbox2)
|
||||
|
||||
self.show_algo_widget()
|
||||
|
||||
def show_algo_widget(self):
|
||||
""""""
|
||||
ix = self.template_combo.currentIndex()
|
||||
current_name = self.template_combo.itemData(ix)
|
||||
|
||||
for template_name, widget in self.algo_widgets.items():
|
||||
if template_name == current_name:
|
||||
widget.show()
|
||||
else:
|
||||
widget.hide()
|
||||
|
||||
def use_setting(self, setting: dict):
|
||||
""""""
|
||||
setting_name = setting["setting_name"]
|
||||
template_name = setting["template_name"]
|
||||
|
||||
widget = self.algo_widgets[template_name]
|
||||
widget.update_setting(setting_name, setting)
|
||||
|
||||
ix = self.template_combo.findData(template_name)
|
||||
self.template_combo.setCurrentIndex(ix)
|
||||
self.show_algo_widget()
|
||||
|
||||
def show(self):
|
||||
""""""
|
||||
self.showMaximized()
|
||||
|
||||
|
||||
def to_text(data: dict):
|
||||
"""
|
||||
Convert dict data into string.
|
||||
"""
|
||||
buf = []
|
||||
for key, value in data.items():
|
||||
key = NAME_DISPLAY_MAP.get(key, key)
|
||||
buf.append(f"{key}:{value}")
|
||||
text = ",".join(buf)
|
||||
return text
|
@ -23,9 +23,11 @@ Sample csv file:
|
||||
import csv
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import chunked
|
||||
|
||||
from vnpy.event import EventEngine
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.database import DbBarData
|
||||
from vnpy.trader.database import DbBarData, DB
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
|
||||
|
||||
@ -39,17 +41,17 @@ class CsvLoaderEngine(BaseEngine):
|
||||
""""""
|
||||
super().__init__(main_engine, event_engine, APP_NAME)
|
||||
|
||||
self.file_path: str = ''
|
||||
self.file_path: str = ""
|
||||
|
||||
self.symbol: str = ""
|
||||
self.exchange: Exchange = Exchange.SSE
|
||||
self.interval: Interval = Interval.MINUTE
|
||||
self.datetime_head: str = ''
|
||||
self.open_head: str = ''
|
||||
self.close_head: str = ''
|
||||
self.low_head: str = ''
|
||||
self.high_head: str = ''
|
||||
self.volume_head: str = ''
|
||||
self.datetime_head: str = ""
|
||||
self.open_head: str = ""
|
||||
self.close_head: str = ""
|
||||
self.low_head: str = ""
|
||||
self.high_head: str = ""
|
||||
self.volume_head: str = ""
|
||||
|
||||
def load(
|
||||
self,
|
||||
@ -72,33 +74,40 @@ class CsvLoaderEngine(BaseEngine):
|
||||
end = None
|
||||
count = 0
|
||||
|
||||
with open(file_path, 'rt') as f:
|
||||
with open(file_path, "rt") as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
db_bars = []
|
||||
|
||||
for item in reader:
|
||||
db_bar = DbBarData()
|
||||
dt = datetime.strptime(item[datetime_head], datetime_format)
|
||||
|
||||
db_bar.symbol = symbol
|
||||
db_bar.exchange = exchange.value
|
||||
db_bar.datetime = datetime.strptime(
|
||||
item[datetime_head], datetime_format
|
||||
)
|
||||
db_bar.interval = interval.value
|
||||
db_bar.volume = item[volume_head]
|
||||
db_bar.open_price = item[open_head]
|
||||
db_bar.high_price = item[high_head]
|
||||
db_bar.low_price = item[low_head]
|
||||
db_bar.close_price = item[close_head]
|
||||
db_bar.vt_symbol = vt_symbol
|
||||
db_bar.gateway_name = "DB"
|
||||
db_bar = {
|
||||
"symbol": symbol,
|
||||
"exchange": exchange.value,
|
||||
"datetime": dt,
|
||||
"interval": interval.value,
|
||||
"volume": item[volume_head],
|
||||
"open_price": item[open_head],
|
||||
"high_price": item[high_head],
|
||||
"low_price": item[low_head],
|
||||
"close_price": item[close_head],
|
||||
"vt_symbol": vt_symbol,
|
||||
"gateway_name": "DB"
|
||||
}
|
||||
|
||||
db_bar.replace()
|
||||
db_bars.append(db_bar)
|
||||
|
||||
# do some statistics
|
||||
count += 1
|
||||
if not start:
|
||||
start = db_bar.datetime
|
||||
start = db_bar["datetime"]
|
||||
|
||||
end = db_bar.datetime
|
||||
end = db_bar["datetime"]
|
||||
|
||||
# Insert into DB
|
||||
with DB.atomic():
|
||||
for batch in chunked(db_bars, 50):
|
||||
DbBarData.insert_many(batch).on_conflict_replace().execute()
|
||||
|
||||
return start, end, count
|
||||
|
@ -90,7 +90,8 @@ class CsvLoaderWidget(QtWidgets.QWidget):
|
||||
|
||||
def select_file(self):
|
||||
""""""
|
||||
result: str = QtWidgets.QFileDialog.getOpenFileName(self)
|
||||
result: str = QtWidgets.QFileDialog.getOpenFileName(
|
||||
self, filter="CSV (*.csv)")
|
||||
filename = result[0]
|
||||
if filename:
|
||||
self.file_edit.setText(filename)
|
||||
|
17
vnpy/app/cta_backtester/__init__.py
Normal file
17
vnpy/app/cta_backtester/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from pathlib import Path
|
||||
|
||||
from vnpy.trader.app import BaseApp
|
||||
|
||||
from .engine import BacktesterEngine, APP_NAME
|
||||
|
||||
|
||||
class CtaBacktesterApp(BaseApp):
|
||||
""""""
|
||||
|
||||
app_name = APP_NAME
|
||||
app_module = __module__
|
||||
app_path = Path(__file__).parent
|
||||
display_name = "CTA回测"
|
||||
engine_class = BacktesterEngine
|
||||
widget_name = "BacktesterManager"
|
||||
icon_name = "backtester.ico"
|
200
vnpy/app/cta_backtester/engine.py
Normal file
200
vnpy/app/cta_backtester/engine.py
Normal file
@ -0,0 +1,200 @@
|
||||
import os
|
||||
import importlib
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
from pathlib import Path
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
from vnpy.app.cta_strategy import (
|
||||
CtaTemplate,
|
||||
BacktestingEngine
|
||||
)
|
||||
|
||||
|
||||
APP_NAME = "CtaBacktester"
|
||||
|
||||
EVENT_BACKTESTER_LOG = "eBacktesterLog"
|
||||
EVENT_BACKTESTER_BACKTESTING_FINISHED = "eBacktesterBacktestingFinished"
|
||||
EVENT_BACKTESTER_OPTIMIZATION_FINISHED = "eBacktesterOptimizationFinished"
|
||||
|
||||
|
||||
class BacktesterEngine(BaseEngine):
|
||||
"""
|
||||
For running CTA strategy backtesting.
|
||||
"""
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__(main_engine, event_engine, APP_NAME)
|
||||
|
||||
self.classes = {}
|
||||
self.backtesting_engine = None
|
||||
self.thread = None
|
||||
|
||||
self.result_df = None
|
||||
self.result_statistics = None
|
||||
|
||||
self.load_strategy_class()
|
||||
|
||||
def init_engine(self):
|
||||
""""""
|
||||
self.write_log("初始化CTA回测引擎")
|
||||
|
||||
self.backtesting_engine = BacktestingEngine()
|
||||
# Redirect log from backtesting engine outside.
|
||||
self.backtesting_engine.output = self.write_log
|
||||
|
||||
self.write_log("策略文件加载完成")
|
||||
|
||||
def write_log(self, msg: str):
|
||||
""""""
|
||||
event = Event(EVENT_BACKTESTER_LOG)
|
||||
event.data = msg
|
||||
self.event_engine.put(event)
|
||||
|
||||
def load_strategy_class(self):
|
||||
"""
|
||||
Load strategy class from source code.
|
||||
"""
|
||||
app_path = Path(__file__).parent.parent
|
||||
path1 = app_path.joinpath("cta_strategy", "strategies")
|
||||
self.load_strategy_class_from_folder(
|
||||
path1, "vnpy.app.cta_strategy.strategies")
|
||||
|
||||
path2 = Path.cwd().joinpath("strategies")
|
||||
self.load_strategy_class_from_folder(path2, "strategies")
|
||||
|
||||
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
|
||||
"""
|
||||
Load strategy class from certain folder.
|
||||
"""
|
||||
for dirpath, dirnames, filenames in os.walk(path):
|
||||
for filename in filenames:
|
||||
if filename.endswith(".py"):
|
||||
strategy_module_name = ".".join(
|
||||
[module_name, filename.replace(".py", "")])
|
||||
self.load_strategy_class_from_module(strategy_module_name)
|
||||
|
||||
def load_strategy_class_from_module(self, module_name: str):
|
||||
"""
|
||||
Load strategy class from module file.
|
||||
"""
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
for name in dir(module):
|
||||
value = getattr(module, name)
|
||||
if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate):
|
||||
self.classes[value.__name__] = value
|
||||
except: # noqa
|
||||
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
|
||||
self.write_log(msg)
|
||||
|
||||
def get_strategy_class_names(self):
|
||||
""""""
|
||||
return list(self.classes.keys())
|
||||
|
||||
def run_backtesting(
|
||||
self,
|
||||
class_name: str,
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
rate: float,
|
||||
slippage: float,
|
||||
size: int,
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
setting: dict
|
||||
):
|
||||
""""""
|
||||
self.result_df = None
|
||||
self.result_statistics = None
|
||||
|
||||
engine = self.backtesting_engine
|
||||
engine.clear_data()
|
||||
|
||||
engine.set_parameters(
|
||||
vt_symbol=vt_symbol,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end,
|
||||
rate=rate,
|
||||
slippage=slippage,
|
||||
size=size,
|
||||
pricetick=pricetick,
|
||||
capital=capital
|
||||
)
|
||||
|
||||
strategy_class = self.classes[class_name]
|
||||
engine.add_strategy(
|
||||
strategy_class,
|
||||
setting
|
||||
)
|
||||
|
||||
engine.load_data()
|
||||
engine.run_backtesting()
|
||||
self.result_df = engine.calculate_result()
|
||||
self.result_statistics = engine.calculate_statistics(output=False)
|
||||
|
||||
# Clear thread object handler.
|
||||
self.thread = None
|
||||
|
||||
# Put backtesting done event
|
||||
event = Event(EVENT_BACKTESTER_BACKTESTING_FINISHED)
|
||||
self.event_engine.put(event)
|
||||
|
||||
def start_backtesting(
|
||||
self,
|
||||
class_name: str,
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
rate: float,
|
||||
slippage: float,
|
||||
size: int,
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
setting: dict
|
||||
):
|
||||
if self.thread:
|
||||
self.write_log("已有回测在运行中,请等待完成")
|
||||
return False
|
||||
|
||||
self.write_log("-" * 40)
|
||||
self.thread = Thread(
|
||||
target=self.run_backtesting,
|
||||
args=(
|
||||
class_name,
|
||||
vt_symbol,
|
||||
interval,
|
||||
start,
|
||||
end,
|
||||
rate,
|
||||
slippage,
|
||||
size,
|
||||
pricetick,
|
||||
capital,
|
||||
setting
|
||||
)
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
return True
|
||||
|
||||
def get_result_df(self):
|
||||
""""""
|
||||
return self.result_df
|
||||
|
||||
def get_result_statistics(self):
|
||||
""""""
|
||||
return self.result_statistics
|
||||
|
||||
def get_default_setting(self, class_name: str):
|
||||
""""""
|
||||
strategy_class = self.classes[class_name]
|
||||
return strategy_class.get_class_parameters()
|
1
vnpy/app/cta_backtester/ui/__init__.py
Normal file
1
vnpy/app/cta_backtester/ui/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .widget import BacktesterManager
|
BIN
vnpy/app/cta_backtester/ui/backtester.ico
Normal file
BIN
vnpy/app/cta_backtester/ui/backtester.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 63 KiB |
476
vnpy/app/cta_backtester/ui/widget.py
Normal file
476
vnpy/app/cta_backtester/ui/widget.py
Normal file
@ -0,0 +1,476 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pyqtgraph as pg
|
||||
import numpy as np
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from vnpy.trader.constant import Interval
|
||||
|
||||
from ..engine import (
|
||||
APP_NAME,
|
||||
EVENT_BACKTESTER_LOG,
|
||||
EVENT_BACKTESTER_BACKTESTING_FINISHED,
|
||||
EVENT_BACKTESTER_OPTIMIZATION_FINISHED
|
||||
)
|
||||
|
||||
|
||||
class BacktesterManager(QtWidgets.QWidget):
|
||||
""""""
|
||||
|
||||
signal_log = QtCore.pyqtSignal(Event)
|
||||
signal_backtesting_finished = QtCore.pyqtSignal(Event)
|
||||
signal_optimization_finished = QtCore.pyqtSignal(Event)
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.main_engine = main_engine
|
||||
self.event_engine = event_engine
|
||||
|
||||
self.backtester_engine = main_engine.get_engine(APP_NAME)
|
||||
self.class_names = []
|
||||
self.settings = {}
|
||||
|
||||
self.init_strategy_settings()
|
||||
self.init_ui()
|
||||
self.register_event()
|
||||
self.backtester_engine.init_engine()
|
||||
|
||||
def init_strategy_settings(self):
|
||||
""""""
|
||||
self.class_names = self.backtester_engine.get_strategy_class_names()
|
||||
|
||||
for class_name in self.class_names:
|
||||
setting = self.backtester_engine.get_default_setting(class_name)
|
||||
self.settings[class_name] = setting
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
self.setWindowTitle("CTA回测")
|
||||
|
||||
# Setting Part
|
||||
self.class_combo = QtWidgets.QComboBox()
|
||||
self.class_combo.addItems(self.class_names)
|
||||
|
||||
self.symbol_line = QtWidgets.QLineEdit("IF88.CFFEX")
|
||||
|
||||
self.interval_combo = QtWidgets.QComboBox()
|
||||
for inteval in Interval:
|
||||
self.interval_combo.addItem(inteval.value)
|
||||
|
||||
end_dt = datetime.now()
|
||||
start_dt = end_dt - timedelta(days=3 * 365)
|
||||
|
||||
self.start_date_edit = QtWidgets.QDateEdit(
|
||||
QtCore.QDate(
|
||||
start_dt.year,
|
||||
start_dt.month,
|
||||
start_dt.day
|
||||
)
|
||||
)
|
||||
self.end_date_edit = QtWidgets.QDateEdit(
|
||||
QtCore.QDate.currentDate()
|
||||
)
|
||||
|
||||
self.rate_line = QtWidgets.QLineEdit("0.000025")
|
||||
self.slippage_line = QtWidgets.QLineEdit("0.2")
|
||||
self.size_line = QtWidgets.QLineEdit("300")
|
||||
self.pricetick_line = QtWidgets.QLineEdit("0.2")
|
||||
self.capital_line = QtWidgets.QLineEdit("1000000")
|
||||
|
||||
start_button = QtWidgets.QPushButton("开始回测")
|
||||
start_button.clicked.connect(self.start_backtesting)
|
||||
|
||||
form = QtWidgets.QFormLayout()
|
||||
form.addRow("交易策略", self.class_combo)
|
||||
form.addRow("本地代码", self.symbol_line)
|
||||
form.addRow("K线周期", self.interval_combo)
|
||||
form.addRow("开始日期", self.start_date_edit)
|
||||
form.addRow("结束日期", self.end_date_edit)
|
||||
form.addRow("手续费率", self.rate_line)
|
||||
form.addRow("交易滑点", self.slippage_line)
|
||||
form.addRow("合约乘数", self.size_line)
|
||||
form.addRow("价格跳动", self.pricetick_line)
|
||||
form.addRow("回测资金", self.capital_line)
|
||||
form.addRow(start_button)
|
||||
|
||||
# Result part
|
||||
self.statistics_monitor = StatisticsMonitor()
|
||||
|
||||
self.log_monitor = QtWidgets.QTextEdit()
|
||||
self.log_monitor.setMaximumHeight(400)
|
||||
|
||||
self.chart = BacktesterChart()
|
||||
self.chart.setMinimumWidth(1000)
|
||||
|
||||
# Layout
|
||||
vbox = QtWidgets.QVBoxLayout()
|
||||
vbox.addWidget(self.statistics_monitor)
|
||||
vbox.addWidget(self.log_monitor)
|
||||
|
||||
hbox = QtWidgets.QHBoxLayout()
|
||||
hbox.addLayout(form)
|
||||
hbox.addLayout(vbox)
|
||||
hbox.addWidget(self.chart)
|
||||
self.setLayout(hbox)
|
||||
|
||||
def register_event(self):
|
||||
""""""
|
||||
self.signal_log.connect(self.process_log_event)
|
||||
self.signal_backtesting_finished.connect(
|
||||
self.process_backtesting_finished_event)
|
||||
self.signal_optimization_finished.connect(
|
||||
self.process_optimization_finished_event)
|
||||
|
||||
self.event_engine.register(EVENT_BACKTESTER_LOG, self.signal_log.emit)
|
||||
self.event_engine.register(
|
||||
EVENT_BACKTESTER_BACKTESTING_FINISHED, self.signal_backtesting_finished.emit)
|
||||
self.event_engine.register(
|
||||
EVENT_BACKTESTER_OPTIMIZATION_FINISHED, self.signal_optimization_finished.emit)
|
||||
|
||||
def process_log_event(self, event: Event):
|
||||
""""""
|
||||
msg = event.data
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
msg = f"{timestamp}\t{msg}"
|
||||
self.log_monitor.append(msg)
|
||||
|
||||
def process_backtesting_finished_event(self, event: Event):
|
||||
""""""
|
||||
statistics = self.backtester_engine.get_result_statistics()
|
||||
self.statistics_monitor.set_data(statistics)
|
||||
|
||||
df = self.backtester_engine.get_result_df()
|
||||
self.chart.set_data(df)
|
||||
|
||||
def process_optimization_finished_event(self, event: Event):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def start_backtesting(self):
|
||||
""""""
|
||||
class_name = self.class_combo.currentText()
|
||||
vt_symbol = self.symbol_line.text()
|
||||
interval = self.interval_combo.currentText()
|
||||
start = self.start_date_edit.date().toPyDate()
|
||||
end = self.end_date_edit.date().toPyDate()
|
||||
rate = float(self.rate_line.text())
|
||||
slippage = float(self.slippage_line.text())
|
||||
size = float(self.size_line.text())
|
||||
pricetick = float(self.pricetick_line.text())
|
||||
capital = float(self.capital_line.text())
|
||||
|
||||
old_setting = self.settings[class_name]
|
||||
dialog = SettingEditor(class_name, old_setting)
|
||||
i = dialog.exec()
|
||||
if i != dialog.Accepted:
|
||||
return
|
||||
|
||||
new_setting = dialog.get_setting()
|
||||
self.settings[class_name] = new_setting
|
||||
|
||||
result = self.backtester_engine.start_backtesting(
|
||||
class_name,
|
||||
vt_symbol,
|
||||
interval,
|
||||
start,
|
||||
end,
|
||||
rate,
|
||||
slippage,
|
||||
size,
|
||||
pricetick,
|
||||
capital,
|
||||
new_setting
|
||||
)
|
||||
|
||||
if result:
|
||||
self.statistics_monitor.clear_data()
|
||||
self.chart.clear_data()
|
||||
|
||||
def show(self):
|
||||
""""""
|
||||
self.showMaximized()
|
||||
|
||||
|
||||
class StatisticsMonitor(QtWidgets.QTableWidget):
|
||||
""""""
|
||||
KEY_NAME_MAP = {
|
||||
"start_date": "首个交易日",
|
||||
"end_date": "最后交易日",
|
||||
|
||||
"total_days": "总交易日",
|
||||
"profit_days": "盈利交易日",
|
||||
"loss_days": "亏损交易日",
|
||||
|
||||
"capital": "起始资金",
|
||||
"end_balance": "结束资金",
|
||||
|
||||
"total_return": "总收益率",
|
||||
"annual_return": "年化收益",
|
||||
"max_drawdown": "最大回撤",
|
||||
"max_ddpercent": "百分比最大回撤",
|
||||
|
||||
"total_net_pnl": "总盈亏",
|
||||
"total_commission": "总手续费",
|
||||
"total_slippage": "总滑点",
|
||||
"total_turnover": "总成交额",
|
||||
"total_trade_count": "总成交笔数",
|
||||
|
||||
"daily_net_pnl": "日均盈亏",
|
||||
"daily_commission": "日均手续费",
|
||||
"daily_slippage": "日均滑点",
|
||||
"daily_turnover": "日均成交额",
|
||||
"daily_trade_count": "日均成交笔数",
|
||||
|
||||
"daily_return": "日均收益率",
|
||||
"return_std": "收益标准差",
|
||||
"sharpe_ratio": "夏普比率",
|
||||
"return_drawdown_ratio": "收益回撤比"
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.cells = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
self.setRowCount(len(self.KEY_NAME_MAP))
|
||||
self.setVerticalHeaderLabels(list(self.KEY_NAME_MAP.values()))
|
||||
|
||||
self.setColumnCount(1)
|
||||
self.horizontalHeader().setVisible(False)
|
||||
self.horizontalHeader().setSectionResizeMode(
|
||||
QtWidgets.QHeaderView.Stretch
|
||||
)
|
||||
|
||||
for row, key in enumerate(self.KEY_NAME_MAP.keys()):
|
||||
cell = QtWidgets.QTableWidgetItem()
|
||||
self.setItem(row, 0, cell)
|
||||
self.cells[key] = cell
|
||||
|
||||
def clear_data(self):
|
||||
""""""
|
||||
for cell in self.cells.values():
|
||||
cell.setText("")
|
||||
|
||||
def set_data(self, data: dict):
|
||||
""""""
|
||||
data["capital"] = f"{data['capital']:,.2f}"
|
||||
data["end_balance"] = f"{data['end_balance']:,.2f}"
|
||||
data["total_return"] = f"{data['total_return']:,.2f}%"
|
||||
data["annual_return"] = f"{data['annual_return']:,.2f}%"
|
||||
data["max_drawdown"] = f"{data['max_drawdown']:,.2f}"
|
||||
data["max_ddpercent"] = f"{data['max_ddpercent']:,.2f}%"
|
||||
data["total_net_pnl"] = f"{data['total_net_pnl']:,.2f}"
|
||||
data["total_commission"] = f"{data['total_commission']:,.2f}"
|
||||
data["total_slippage"] = f"{data['total_slippage']:,.2f}"
|
||||
data["total_turnover"] = f"{data['total_turnover']:,.2f}"
|
||||
data["daily_net_pnl"] = f"{data['daily_net_pnl']:,.2f}"
|
||||
data["daily_commission"] = f"{data['daily_commission']:,.2f}"
|
||||
data["daily_slippage"] = f"{data['daily_slippage']:,.2f}"
|
||||
data["daily_turnover"] = f"{data['daily_turnover']:,.2f}"
|
||||
data["daily_return"] = f"{data['daily_return']:,.2f}%"
|
||||
data["return_std"] = f"{data['return_std']:,.2f}%"
|
||||
data["sharpe_ratio"] = f"{data['sharpe_ratio']:,.2f}"
|
||||
data["return_drawdown_ratio"] = f"{data['return_drawdown_ratio']:,.2f}"
|
||||
|
||||
for key, cell in self.cells.items():
|
||||
value = data.get(key, "")
|
||||
cell.setText(str(value))
|
||||
|
||||
|
||||
class SettingEditor(QtWidgets.QDialog):
|
||||
"""
|
||||
For creating new strategy and editing strategy parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, class_name: str, parameters: dict
|
||||
):
|
||||
""""""
|
||||
super(SettingEditor, self).__init__()
|
||||
|
||||
self.class_name = class_name
|
||||
self.parameters = parameters
|
||||
self.edits = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
form = QtWidgets.QFormLayout()
|
||||
|
||||
# Add vt_symbol and name edit if add new strategy
|
||||
self.setWindowTitle(f"策略参数配置:{self.class_name}")
|
||||
button_text = "确定"
|
||||
parameters = self.parameters
|
||||
|
||||
for name, value in parameters.items():
|
||||
type_ = type(value)
|
||||
|
||||
edit = QtWidgets.QLineEdit(str(value))
|
||||
if type_ is int:
|
||||
validator = QtGui.QIntValidator()
|
||||
edit.setValidator(validator)
|
||||
elif type_ is float:
|
||||
validator = QtGui.QDoubleValidator()
|
||||
edit.setValidator(validator)
|
||||
|
||||
form.addRow(f"{name} {type_}", edit)
|
||||
|
||||
self.edits[name] = (edit, type_)
|
||||
|
||||
button = QtWidgets.QPushButton(button_text)
|
||||
button.clicked.connect(self.accept)
|
||||
form.addRow(button)
|
||||
|
||||
self.setLayout(form)
|
||||
|
||||
def get_setting(self):
|
||||
""""""
|
||||
setting = {}
|
||||
|
||||
for name, tp in self.edits.items():
|
||||
edit, type_ = tp
|
||||
value_text = edit.text()
|
||||
|
||||
if type_ == bool:
|
||||
if value_text == "True":
|
||||
value = True
|
||||
else:
|
||||
value = False
|
||||
else:
|
||||
value = type_(value_text)
|
||||
|
||||
setting[name] = value
|
||||
|
||||
return setting
|
||||
|
||||
|
||||
class BacktesterChart(pg.GraphicsWindow):
|
||||
""""""
|
||||
|
||||
def __init__(self):
|
||||
""""""
|
||||
super().__init__(title="Backtester Chart")
|
||||
|
||||
self.dates = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
pg.setConfigOptions(antialias=True)
|
||||
|
||||
# Create plot widgets
|
||||
self.balance_plot = self.addPlot(
|
||||
title="账户净值",
|
||||
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
|
||||
)
|
||||
self.nextRow()
|
||||
|
||||
self.drawdown_plot = self.addPlot(
|
||||
title="净值回撤",
|
||||
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
|
||||
)
|
||||
self.nextRow()
|
||||
|
||||
self.pnl_plot = self.addPlot(
|
||||
title="每日盈亏",
|
||||
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
|
||||
)
|
||||
self.nextRow()
|
||||
|
||||
self.distribution_plot = self.addPlot(title="盈亏分布")
|
||||
|
||||
# Add curves and bars on plot widgets
|
||||
self.balance_curve = self.balance_plot.plot(
|
||||
pen=pg.mkPen("#ffc107", width=3)
|
||||
)
|
||||
|
||||
dd_color = "#303f9f"
|
||||
self.drawdown_curve = self.drawdown_plot.plot(
|
||||
fillLevel=-0.3, brush=dd_color, pen=dd_color
|
||||
)
|
||||
|
||||
profit_color = 'r'
|
||||
loss_color = 'g'
|
||||
self.profit_pnl_bar = pg.BarGraphItem(
|
||||
x=[], height=[], width=0.3, brush=profit_color, pen=profit_color
|
||||
)
|
||||
self.loss_pnl_bar = pg.BarGraphItem(
|
||||
x=[], height=[], width=0.3, brush=loss_color, pen=loss_color
|
||||
)
|
||||
self.pnl_plot.addItem(self.profit_pnl_bar)
|
||||
self.pnl_plot.addItem(self.loss_pnl_bar)
|
||||
|
||||
distribution_color = "#6d4c41"
|
||||
self.distribution_curve = self.distribution_plot.plot(
|
||||
fillLevel=-0.3, brush=distribution_color, pen=distribution_color
|
||||
)
|
||||
|
||||
def clear_data(self):
|
||||
""""""
|
||||
self.balance_curve.setData([], [])
|
||||
self.drawdown_curve.setData([], [])
|
||||
self.profit_pnl_bar.setOpts(x=[], height=[])
|
||||
self.loss_pnl_bar.setOpts(x=[], height=[])
|
||||
self.distribution_curve.setData([], [])
|
||||
|
||||
def set_data(self, df):
|
||||
""""""
|
||||
count = len(df)
|
||||
|
||||
self.dates.clear()
|
||||
for n, date in enumerate(df.index):
|
||||
self.dates[n] = date
|
||||
|
||||
# Set data for curve of balance and drawdown
|
||||
self.balance_curve.setData(df["balance"])
|
||||
self.drawdown_curve.setData(df["drawdown"])
|
||||
|
||||
# Set data for daily pnl bar
|
||||
profit_pnl_x = []
|
||||
profit_pnl_height = []
|
||||
loss_pnl_x = []
|
||||
loss_pnl_height = []
|
||||
|
||||
for count, pnl in enumerate(df["net_pnl"]):
|
||||
if pnl >= 0:
|
||||
profit_pnl_height.append(pnl)
|
||||
profit_pnl_x.append(count)
|
||||
else:
|
||||
loss_pnl_height.append(pnl)
|
||||
loss_pnl_x.append(count)
|
||||
|
||||
self.profit_pnl_bar.setOpts(x=profit_pnl_x, height=profit_pnl_height)
|
||||
self.loss_pnl_bar.setOpts(x=loss_pnl_x, height=loss_pnl_height)
|
||||
|
||||
# Set data for pnl distribution
|
||||
hist, x = np.histogram(df["net_pnl"], bins="auto")
|
||||
x = x[:-1]
|
||||
self.distribution_curve.setData(x, hist)
|
||||
|
||||
|
||||
class DateAxis(pg.AxisItem):
|
||||
"""Axis for showing date data"""
|
||||
|
||||
def __init__(self, dates: dict, *args, **kwargs):
|
||||
""""""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dates = dates
|
||||
|
||||
def tickStrings(self, values, scale, spacing):
|
||||
""""""
|
||||
strings = []
|
||||
for v in values:
|
||||
dt = self.dates.get(v, "")
|
||||
strings.append(str(dt))
|
||||
return strings
|
@ -7,6 +7,7 @@ from vnpy.trader.utility import BarGenerator, ArrayManager
|
||||
|
||||
from .base import APP_NAME, StopOrder
|
||||
from .engine import CtaEngine
|
||||
from .backtesting import BacktestingEngine, OptimizationSetting
|
||||
from .template import CtaTemplate, CtaSignal, TargetPosTemplate
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@ from collections import defaultdict
|
||||
from datetime import date, datetime
|
||||
from typing import Callable
|
||||
from itertools import product
|
||||
from functools import lru_cache
|
||||
import multiprocessing
|
||||
|
||||
import numpy as np
|
||||
@ -197,28 +198,18 @@ class BacktestingEngine:
|
||||
self.output("开始加载历史数据")
|
||||
|
||||
if self.mode == BacktestingMode.BAR:
|
||||
s = (
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == self.vt_symbol)
|
||||
& (DbBarData.interval == self.interval)
|
||||
& (DbBarData.datetime >= self.start)
|
||||
& (DbBarData.datetime <= self.end)
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
self.history_data = load_bar_data(
|
||||
self.vt_symbol,
|
||||
self.interval,
|
||||
self.start,
|
||||
self.end
|
||||
)
|
||||
self.history_data = [db_bar.to_bar() for db_bar in s]
|
||||
else:
|
||||
s = (
|
||||
DbTickData.select()
|
||||
.where(
|
||||
(DbTickData.vt_symbol == self.vt_symbol)
|
||||
& (DbTickData.datetime >= self.start)
|
||||
& (DbTickData.datetime <= self.end)
|
||||
)
|
||||
.order_by(DbTickData.datetime)
|
||||
self.history_data = load_tick_data(
|
||||
self.vt_symbol,
|
||||
self.start,
|
||||
self.end
|
||||
)
|
||||
self.history_data = [db_tick.to_tick() for db_tick in s]
|
||||
|
||||
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
|
||||
|
||||
@ -293,7 +284,7 @@ class BacktestingEngine:
|
||||
self.output("逐日盯市盈亏计算完成")
|
||||
return self.daily_df
|
||||
|
||||
def calculate_statistics(self, df: DataFrame = None):
|
||||
def calculate_statistics(self, df: DataFrame = None, output=True):
|
||||
""""""
|
||||
self.output("开始计算策略统计指标")
|
||||
|
||||
@ -325,6 +316,7 @@ class BacktestingEngine:
|
||||
daily_return = 0
|
||||
return_std = 0
|
||||
sharpe_ratio = 0
|
||||
return_drawdown_ratio = 0
|
||||
else:
|
||||
# Calculate balance related time series data
|
||||
df["balance"] = df["net_pnl"].cumsum() + self.capital
|
||||
@ -373,38 +365,42 @@ class BacktestingEngine:
|
||||
else:
|
||||
sharpe_ratio = 0
|
||||
|
||||
return_drawdown_ratio = -total_return / max_ddpercent
|
||||
|
||||
# Output
|
||||
self.output("-" * 30)
|
||||
self.output(f"首个交易日:\t{start_date}")
|
||||
self.output(f"最后交易日:\t{end_date}")
|
||||
if output:
|
||||
self.output("-" * 30)
|
||||
self.output(f"首个交易日:\t{start_date}")
|
||||
self.output(f"最后交易日:\t{end_date}")
|
||||
|
||||
self.output(f"总交易日:\t{total_days}")
|
||||
self.output(f"盈利交易日:\t{profit_days}")
|
||||
self.output(f"亏损交易日:\t{loss_days}")
|
||||
self.output(f"总交易日:\t{total_days}")
|
||||
self.output(f"盈利交易日:\t{profit_days}")
|
||||
self.output(f"亏损交易日:\t{loss_days}")
|
||||
|
||||
self.output(f"起始资金:\t{self.capital:,.2f}")
|
||||
self.output(f"结束资金:\t{end_balance:,.2f}")
|
||||
self.output(f"起始资金:\t{self.capital:,.2f}")
|
||||
self.output(f"结束资金:\t{end_balance:,.2f}")
|
||||
|
||||
self.output(f"总收益率:\t{total_return:,.2f}%")
|
||||
self.output(f"年化收益:\t{annual_return:,.2f}%")
|
||||
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
|
||||
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
|
||||
self.output(f"总收益率:\t{total_return:,.2f}%")
|
||||
self.output(f"年化收益:\t{annual_return:,.2f}%")
|
||||
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
|
||||
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
|
||||
|
||||
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
|
||||
self.output(f"总手续费:\t{total_commission:,.2f}")
|
||||
self.output(f"总滑点:\t{total_slippage:,.2f}")
|
||||
self.output(f"总成交金额:\t{total_turnover:,.2f}")
|
||||
self.output(f"总成交笔数:\t{total_trade_count}")
|
||||
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
|
||||
self.output(f"总手续费:\t{total_commission:,.2f}")
|
||||
self.output(f"总滑点:\t{total_slippage:,.2f}")
|
||||
self.output(f"总成交金额:\t{total_turnover:,.2f}")
|
||||
self.output(f"总成交笔数:\t{total_trade_count}")
|
||||
|
||||
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
|
||||
self.output(f"日均手续费:\t{daily_commission:,.2f}")
|
||||
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
|
||||
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
|
||||
self.output(f"日均成交笔数:\t{daily_trade_count}")
|
||||
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
|
||||
self.output(f"日均手续费:\t{daily_commission:,.2f}")
|
||||
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
|
||||
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
|
||||
self.output(f"日均成交笔数:\t{daily_trade_count}")
|
||||
|
||||
self.output(f"日均收益率:\t{daily_return:,.2f}%")
|
||||
self.output(f"收益标准差:\t{return_std:,.2f}%")
|
||||
self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}")
|
||||
self.output(f"日均收益率:\t{daily_return:,.2f}%")
|
||||
self.output(f"收益标准差:\t{return_std:,.2f}%")
|
||||
self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}")
|
||||
self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}")
|
||||
|
||||
statistics = {
|
||||
"start_date": start_date,
|
||||
@ -412,6 +408,7 @@ class BacktestingEngine:
|
||||
"total_days": total_days,
|
||||
"profit_days": profit_days,
|
||||
"loss_days": loss_days,
|
||||
"capital": self.capital,
|
||||
"end_balance": end_balance,
|
||||
"max_drawdown": max_drawdown,
|
||||
"max_ddpercent": max_ddpercent,
|
||||
@ -430,6 +427,7 @@ class BacktestingEngine:
|
||||
"daily_return": daily_return,
|
||||
"return_std": return_std,
|
||||
"sharpe_ratio": sharpe_ratio,
|
||||
"return_drawdown_ratio": return_drawdown_ratio,
|
||||
}
|
||||
|
||||
return statistics
|
||||
@ -963,3 +961,45 @@ def optimize(
|
||||
|
||||
target_value = statistics[target_name]
|
||||
return (str(setting), target_value, statistics)
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def load_bar_data(
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
""""""
|
||||
s = (
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol)
|
||||
& (DbBarData.interval == interval)
|
||||
& (DbBarData.datetime >= start)
|
||||
& (DbBarData.datetime <= end)
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
return data
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def load_tick_data(
|
||||
vt_symbol: str,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
""""""
|
||||
s = (
|
||||
DbTickData.select()
|
||||
.where(
|
||||
(DbTickData.vt_symbol == vt_symbol)
|
||||
& (DbTickData.datetime >= start)
|
||||
& (DbTickData.datetime <= end)
|
||||
)
|
||||
.order_by(DbTickData.datetime)
|
||||
)
|
||||
data = [db_tick.db_tick() for db_tick in s]
|
||||
return data
|
@ -40,6 +40,7 @@ from vnpy.trader.database import DbTickData, DbBarData
|
||||
from vnpy.trader.setting import SETTINGS
|
||||
|
||||
from .base import (
|
||||
APP_NAME,
|
||||
EVENT_CTA_LOG,
|
||||
EVENT_CTA_STRATEGY,
|
||||
EVENT_CTA_STOPORDER,
|
||||
@ -73,7 +74,7 @@ class CtaEngine(BaseEngine):
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super(CtaEngine, self).__init__(
|
||||
main_engine, event_engine, "CtaStrategy")
|
||||
main_engine, event_engine, APP_NAME)
|
||||
|
||||
self.strategy_setting = {} # strategy_name: dict
|
||||
self.strategy_data = {} # strategy_name: dict
|
||||
@ -541,7 +542,7 @@ class CtaEngine(BaseEngine):
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol)
|
||||
& (DbBarData.interval == interval)
|
||||
& (DbBarData.interval == interval.value)
|
||||
& (DbBarData.datetime >= start)
|
||||
& (DbBarData.datetime <= end)
|
||||
)
|
||||
|
@ -4,6 +4,7 @@ from typing import Any, Callable
|
||||
|
||||
from vnpy.trader.constant import Interval, Direction, Offset
|
||||
from vnpy.trader.object import BarData, TickData, OrderData, TradeData
|
||||
from vnpy.trader.utility import virtual
|
||||
|
||||
from .base import StopOrder, EngineType
|
||||
|
||||
@ -87,48 +88,56 @@ class CtaTemplate(ABC):
|
||||
}
|
||||
return strategy_data
|
||||
|
||||
@virtual
|
||||
def on_init(self):
|
||||
"""
|
||||
Callback when strategy is inited.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_start(self):
|
||||
"""
|
||||
Callback when strategy is started.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_stop(self):
|
||||
"""
|
||||
Callback when strategy is stopped.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_tick(self, tick: TickData):
|
||||
"""
|
||||
Callback of new tick data update.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_bar(self, bar: BarData):
|
||||
"""
|
||||
Callback of new bar data update.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_trade(self, trade: TradeData):
|
||||
"""
|
||||
Callback of new trade data update.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_order(self, order: OrderData):
|
||||
"""
|
||||
Callback of new order data update.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_stop_order(self, stop_order: StopOrder):
|
||||
"""
|
||||
Callback of stop order update.
|
||||
@ -255,12 +264,14 @@ class CtaSignal(ABC):
|
||||
""""""
|
||||
self.signal_pos = 0
|
||||
|
||||
@virtual
|
||||
def on_tick(self, tick: TickData):
|
||||
"""
|
||||
Callback of new tick data update.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_bar(self, bar: BarData):
|
||||
"""
|
||||
Callback of new bar data update.
|
||||
@ -292,6 +303,7 @@ class TargetPosTemplate(CtaTemplate):
|
||||
)
|
||||
self.variables.append("target_pos")
|
||||
|
||||
@virtual
|
||||
def on_tick(self, tick: TickData):
|
||||
"""
|
||||
Callback of new tick data update.
|
||||
@ -301,12 +313,14 @@ class TargetPosTemplate(CtaTemplate):
|
||||
if self.trading:
|
||||
self.trade()
|
||||
|
||||
@virtual
|
||||
def on_bar(self, bar: BarData):
|
||||
"""
|
||||
Callback of new bar data update.
|
||||
"""
|
||||
self.last_bar = bar
|
||||
|
||||
@virtual
|
||||
def on_order(self, order: OrderData):
|
||||
"""
|
||||
Callback of new order data update.
|
||||
|
@ -71,8 +71,8 @@ class BitmexGateway(BaseGateway):
|
||||
"Secret": "",
|
||||
"会话数": 3,
|
||||
"服务器": ["REAL", "TESTNET"],
|
||||
"代理地址": "127.0.0.1",
|
||||
"代理端口": 1080,
|
||||
"代理地址": "",
|
||||
"代理端口": "",
|
||||
}
|
||||
|
||||
def __init__(self, event_engine):
|
||||
@ -91,6 +91,11 @@ class BitmexGateway(BaseGateway):
|
||||
proxy_host = setting["代理地址"]
|
||||
proxy_port = setting["代理端口"]
|
||||
|
||||
if proxy_port.isdigit():
|
||||
proxy_port = int(proxy_port)
|
||||
else:
|
||||
proxy_port = 0
|
||||
|
||||
self.rest_api.connect(key, secret, session_number,
|
||||
server, proxy_host, proxy_port)
|
||||
|
||||
|
@ -405,7 +405,7 @@ class CtpTdApi(TdApi):
|
||||
""""""
|
||||
if not error['ErrorID']:
|
||||
self.authStatus = True
|
||||
self.writeLog("交易授权验证成功")
|
||||
self.gateway.write_log("交易授权验证成功")
|
||||
self.login()
|
||||
else:
|
||||
self.gateway.write_error("交易授权验证失败", error)
|
||||
@ -418,7 +418,7 @@ class CtpTdApi(TdApi):
|
||||
self.login_status = True
|
||||
self.gateway.write_log("交易登录成功")
|
||||
|
||||
# Confirm settelment
|
||||
# Confirm settlement
|
||||
req = {
|
||||
"BrokerID": self.brokerid,
|
||||
"InvestorID": self.userid
|
||||
@ -487,11 +487,6 @@ class CtpTdApi(TdApi):
|
||||
)
|
||||
self.positions[key] = position
|
||||
|
||||
# Get contract size, return if size value not collected
|
||||
size = symbol_size_map.get(position.symbol, None)
|
||||
if not size:
|
||||
return
|
||||
|
||||
# For SHFE position data update
|
||||
if position.exchange == Exchange.SHFE:
|
||||
if data["YdPosition"] and not data["TodayPosition"]:
|
||||
@ -500,6 +495,9 @@ class CtpTdApi(TdApi):
|
||||
else:
|
||||
position.yd_volume = data["Position"] - data["TodayPosition"]
|
||||
|
||||
# Get contract size (spread contract has no size value)
|
||||
size = symbol_size_map.get(position.symbol, 0)
|
||||
|
||||
# Calculate previous position cost
|
||||
cost = position.price * position.volume * size
|
||||
|
||||
@ -508,7 +506,7 @@ class CtpTdApi(TdApi):
|
||||
position.pnl += data["PositionProfit"]
|
||||
|
||||
# Calculate average position price
|
||||
if position.volume:
|
||||
if position.volume and size:
|
||||
cost += data["PositionCost"]
|
||||
position.price = cost / (position.volume * size)
|
||||
|
||||
@ -549,13 +547,16 @@ class CtpTdApi(TdApi):
|
||||
product=product,
|
||||
size=data["VolumeMultiple"],
|
||||
pricetick=data["PriceTick"],
|
||||
option_underlying=data["UnderlyingInstrID"],
|
||||
option_type=OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
|
||||
option_strike=data["StrikePrice"],
|
||||
option_expiry=datetime.strptime(data["ExpireDate"], "%Y%m%d"),
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
|
||||
# For option only
|
||||
if contract.product == Product.OPTION:
|
||||
contract.option_underlying = data["UnderlyingInstrID"],
|
||||
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
|
||||
contract.option_strike = data["StrikePrice"],
|
||||
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"),
|
||||
|
||||
self.gateway.on_contract(contract)
|
||||
|
||||
symbol_exchange_map[contract.symbol] = contract.exchange
|
||||
@ -662,7 +663,7 @@ class CtpTdApi(TdApi):
|
||||
"UserID": self.userid,
|
||||
"BrokerID": self.brokerid,
|
||||
"AuthCode": self.auth_code,
|
||||
"ProductInfo": self.product_info
|
||||
"UserProductInfo": self.product_info
|
||||
}
|
||||
|
||||
self.reqid += 1
|
||||
@ -678,7 +679,8 @@ class CtpTdApi(TdApi):
|
||||
req = {
|
||||
"UserID": self.userid,
|
||||
"Password": self.password,
|
||||
"BrokerID": self.brokerid
|
||||
"BrokerID": self.brokerid,
|
||||
"UserProductInfo": self.product_info
|
||||
}
|
||||
|
||||
self.reqid += 1
|
||||
|
4
vnpy/gateway/huobi/__init__.py
Normal file
4
vnpy/gateway/huobi/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .huobi_gateway import HuobiGateway
|
||||
|
||||
|
||||
|
747
vnpy/gateway/huobi/huobi_gateway.py
Normal file
747
vnpy/gateway/huobi/huobi_gateway.py
Normal file
@ -0,0 +1,747 @@
|
||||
# encoding: UTF-8
|
||||
|
||||
"""
|
||||
火币交易接口
|
||||
"""
|
||||
|
||||
import re
|
||||
import urllib
|
||||
import base64
|
||||
import json
|
||||
import zlib
|
||||
import hashlib
|
||||
import hmac
|
||||
from copy import copy
|
||||
from datetime import datetime
|
||||
|
||||
from vnpy.event import Event
|
||||
from vnpy.api.rest import RestClient, Request
|
||||
from vnpy.api.websocket import WebsocketClient
|
||||
from vnpy.trader.constant import (
|
||||
Direction,
|
||||
Exchange,
|
||||
Product,
|
||||
Status,
|
||||
OrderType
|
||||
)
|
||||
from vnpy.trader.gateway import BaseGateway, LocalOrderManager
|
||||
from vnpy.trader.object import (
|
||||
TickData,
|
||||
OrderData,
|
||||
TradeData,
|
||||
AccountData,
|
||||
ContractData,
|
||||
OrderRequest,
|
||||
CancelRequest,
|
||||
SubscribeRequest
|
||||
)
|
||||
from vnpy.trader.event import EVENT_TIMER
|
||||
|
||||
|
||||
REST_HOST = "https://api.huobipro.com"
|
||||
WEBSOCKET_DATA_HOST = "wss://api.huobi.pro/ws" # Market Data
|
||||
WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # Account and Order
|
||||
|
||||
STATUS_HUOBI2VT = {
|
||||
"submitted": Status.NOTTRADED,
|
||||
"partial-filled": Status.PARTTRADED,
|
||||
"filled": Status.ALLTRADED,
|
||||
"cancelling": Status.CANCELLED,
|
||||
"partial-canceled": Status.CANCELLED,
|
||||
"canceled": Status.CANCELLED,
|
||||
}
|
||||
|
||||
ORDERTYPE_VT2HUOBI = {
|
||||
(Direction.LONG, OrderType.MARKET): "buy-market",
|
||||
(Direction.SHORT, OrderType.MARKET): "sell-market",
|
||||
(Direction.LONG, OrderType.LIMIT): "buy-limit",
|
||||
(Direction.SHORT, OrderType.LIMIT): "sell-limit",
|
||||
}
|
||||
ORDERTYPE_HUOBI2VT = {v: k for k, v in ORDERTYPE_VT2HUOBI.items()}
|
||||
|
||||
|
||||
huobi_symbols = set()
|
||||
symbol_name_map = {}
|
||||
|
||||
|
||||
class HuobiGateway(BaseGateway):
|
||||
"""
|
||||
VN Trader Gateway for Huobi connection.
|
||||
"""
|
||||
|
||||
default_setting = {
|
||||
"API Key": "",
|
||||
"Secret Key": "",
|
||||
"会话数": 3,
|
||||
"代理地址": "",
|
||||
"代理端口": "",
|
||||
}
|
||||
|
||||
def __init__(self, event_engine):
|
||||
"""Constructor"""
|
||||
super(HuobiGateway, self).__init__(event_engine, "HUOBI")
|
||||
|
||||
self.order_manager = LocalOrderManager(self)
|
||||
|
||||
self.rest_api = HuobiRestApi(self)
|
||||
self.trade_ws_api = HuobiTradeWebsocketApi(self)
|
||||
self.market_ws_api = HuobiDataWebsocketApi(self)
|
||||
|
||||
def connect(self, setting: dict):
|
||||
""""""
|
||||
key = setting["API Key"]
|
||||
secret = setting["Secret Key"]
|
||||
session_number = setting["会话数"]
|
||||
proxy_host = setting["代理地址"]
|
||||
proxy_port = setting["代理端口"]
|
||||
|
||||
if proxy_port.isdigit():
|
||||
proxy_port = int(proxy_port)
|
||||
else:
|
||||
proxy_port = 0
|
||||
|
||||
self.rest_api.connect(key, secret, session_number,
|
||||
proxy_host, proxy_port)
|
||||
self.trade_ws_api.connect(key, secret, proxy_host, proxy_port)
|
||||
self.market_ws_api.connect(key, secret, proxy_host, proxy_port)
|
||||
|
||||
self.init_query()
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
self.market_ws_api.subscribe(req)
|
||||
self.trade_ws_api.subscribe(req)
|
||||
|
||||
def send_order(self, req: OrderRequest):
|
||||
""""""
|
||||
return self.rest_api.send_order(req)
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
""""""
|
||||
self.rest_api.cancel_order(req)
|
||||
|
||||
def query_account(self):
|
||||
""""""
|
||||
self.rest_api.query_account_balance()
|
||||
|
||||
def query_position(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
""""""
|
||||
self.rest_api.stop()
|
||||
self.trade_ws_api.stop()
|
||||
self.market_ws_api.stop()
|
||||
|
||||
def process_timer_event(self, event: Event):
|
||||
""""""
|
||||
self.count += 1
|
||||
if self.count < 3:
|
||||
return
|
||||
|
||||
self.query_account()
|
||||
|
||||
def init_query(self):
|
||||
""""""
|
||||
self.count = 0
|
||||
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
|
||||
|
||||
|
||||
class HuobiRestApi(RestClient):
|
||||
"""
|
||||
HUOBI REST API
|
||||
"""
|
||||
|
||||
def __init__(self, gateway: BaseGateway):
|
||||
""""""
|
||||
super(HuobiRestApi, self).__init__()
|
||||
|
||||
self.gateway = gateway
|
||||
self.gateway_name = gateway.gateway_name
|
||||
self.order_manager = gateway.order_manager
|
||||
|
||||
self.host = ""
|
||||
self.key = ""
|
||||
self.secret = ""
|
||||
self.account_id = ""
|
||||
|
||||
self.cancel_requests = {}
|
||||
self.orders = {}
|
||||
|
||||
def sign(self, request):
|
||||
"""
|
||||
Generate HUOBI signature.
|
||||
"""
|
||||
request.headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.71 Safari/537.36"
|
||||
}
|
||||
params_with_signature = create_signature(
|
||||
self.key,
|
||||
request.method,
|
||||
self.host,
|
||||
request.path,
|
||||
self.secret,
|
||||
request.params
|
||||
)
|
||||
request.params = params_with_signature
|
||||
|
||||
if request.method == "POST":
|
||||
request.headers["Content-Type"] = "application/json"
|
||||
|
||||
if request.data:
|
||||
request.data = json.dumps(request.data)
|
||||
|
||||
return request
|
||||
|
||||
def connect(
|
||||
self,
|
||||
key: str,
|
||||
secret: str,
|
||||
session_number: int,
|
||||
proxy_host: str,
|
||||
proxy_port: int
|
||||
):
|
||||
"""
|
||||
Initialize connection to REST server.
|
||||
"""
|
||||
self.key = key
|
||||
self.secret = secret
|
||||
|
||||
self.host, _ = _split_url(REST_HOST)
|
||||
|
||||
self.init(REST_HOST, proxy_host, proxy_port)
|
||||
self.start(session_number)
|
||||
|
||||
self.gateway.write_log("REST API启动成功")
|
||||
|
||||
self.query_contract()
|
||||
self.query_account()
|
||||
self.query_order()
|
||||
|
||||
def query_account(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
method="GET",
|
||||
path="/v1/account/accounts",
|
||||
callback=self.on_query_account
|
||||
)
|
||||
|
||||
def query_account_balance(self):
|
||||
""""""
|
||||
path = f"/v1/account/accounts/{self.account_id}/balance"
|
||||
self.add_request(
|
||||
method="GET",
|
||||
path=path,
|
||||
callback=self.on_query_account_balance
|
||||
)
|
||||
|
||||
def query_order(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
method="GET",
|
||||
path="/v1/order/openOrders",
|
||||
callback=self.on_query_order
|
||||
)
|
||||
|
||||
def query_contract(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
method="GET",
|
||||
path="/v1/common/symbols",
|
||||
callback=self.on_query_contract
|
||||
)
|
||||
|
||||
def send_order(self, req: OrderRequest):
|
||||
""""""
|
||||
huobi_type = ORDERTYPE_VT2HUOBI.get(
|
||||
(req.direction, req.type), ""
|
||||
)
|
||||
|
||||
local_orderid = self.order_manager.new_local_orderid()
|
||||
order = req.create_order_data(
|
||||
local_orderid,
|
||||
self.gateway_name
|
||||
)
|
||||
order.time = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
data = {
|
||||
"account-id": self.account_id,
|
||||
"amount": str(req.volume),
|
||||
"symbol": req.symbol,
|
||||
"type": huobi_type,
|
||||
"price": str(req.price),
|
||||
"source": "api"
|
||||
}
|
||||
|
||||
self.add_request(
|
||||
method="POST",
|
||||
path="/v1/order/orders/place",
|
||||
callback=self.on_send_order,
|
||||
data=data,
|
||||
extra=order,
|
||||
on_error=self.on_send_order_error,
|
||||
on_failed=self.on_send_order_failed
|
||||
)
|
||||
|
||||
self.order_manager.on_order(order)
|
||||
return order.vt_orderid
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
""""""
|
||||
sys_orderid = self.order_manager.get_sys_orderid(req.orderid)
|
||||
|
||||
path = f"/v1/order/orders/{sys_orderid}/submitcancel"
|
||||
self.add_request(
|
||||
method="POST",
|
||||
path=path,
|
||||
callback=self.on_cancel_order,
|
||||
extra=req
|
||||
)
|
||||
|
||||
def on_query_account(self, data, request):
|
||||
""""""
|
||||
if self.check_error(data, "查询账户"):
|
||||
return
|
||||
|
||||
for d in data["data"]:
|
||||
if d["type"] == "spot":
|
||||
self.account_id = d["id"]
|
||||
self.gateway.write_log(f"账户代码{self.account_id}查询成功")
|
||||
|
||||
self.query_account_balance()
|
||||
|
||||
def on_query_account_balance(self, data, request):
|
||||
""""""
|
||||
if self.check_error(data, "查询账户资金"):
|
||||
return
|
||||
|
||||
buf = {}
|
||||
for d in data["data"]["list"]:
|
||||
currency = d["currency"]
|
||||
currency_data = buf.setdefault(currency, {})
|
||||
currency_data[d["type"]] = float(d["balance"])
|
||||
|
||||
for currency, currency_data in buf.items():
|
||||
account = AccountData(
|
||||
accountid=currency,
|
||||
balance=currency_data["trade"] + currency_data["frozen"],
|
||||
frozen=currency_data["frozen"],
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
|
||||
if account.balance:
|
||||
self.gateway.on_account(account)
|
||||
|
||||
def on_query_order(self, data, request):
|
||||
""""""
|
||||
if self.check_error(data, "查询委托"):
|
||||
return
|
||||
|
||||
for d in data["data"]:
|
||||
sys_orderid = d["id"]
|
||||
local_orderid = self.order_manager.get_local_orderid(sys_orderid)
|
||||
|
||||
direction, order_type = ORDERTYPE_HUOBI2VT[d["type"]]
|
||||
dt = datetime.fromtimestamp(d["created-at"] / 1000)
|
||||
time = dt.strftime("%H:%M:%S")
|
||||
|
||||
order = OrderData(
|
||||
orderid=local_orderid,
|
||||
symbol=d["symbol"],
|
||||
exchange=Exchange.HUOBI,
|
||||
price=float(d["price"]),
|
||||
volume=float(d["amount"]),
|
||||
type=order_type,
|
||||
direction=direction,
|
||||
traded=float(d["filled-amount"]),
|
||||
status=STATUS_HUOBI2VT.get(d["state"], None),
|
||||
time=time,
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
|
||||
self.order_manager.on_order(order)
|
||||
|
||||
self.gateway.write_log("委托信息查询成功")
|
||||
|
||||
def on_query_contract(self, data, request): # type: (dict, Request)->None
|
||||
""""""
|
||||
if self.check_error(data, "查询合约"):
|
||||
return
|
||||
|
||||
for d in data["data"]:
|
||||
base_currency = d["base-currency"]
|
||||
quote_currency = d["quote-currency"]
|
||||
name = f"{base_currency.upper()}/{quote_currency.upper()}"
|
||||
pricetick = 1 / pow(10, d["price-precision"])
|
||||
size = 1 / pow(10, d["amount-precision"])
|
||||
|
||||
contract = ContractData(
|
||||
symbol=d["symbol"],
|
||||
exchange=Exchange.HUOBI,
|
||||
name=name,
|
||||
pricetick=pricetick,
|
||||
size=size,
|
||||
product=Product.SPOT,
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.gateway.on_contract(contract)
|
||||
|
||||
huobi_symbols.add(contract.symbol)
|
||||
symbol_name_map[contract.symbol] = contract.name
|
||||
|
||||
self.gateway.write_log("合约信息查询成功")
|
||||
|
||||
def on_send_order(self, data, request):
|
||||
""""""
|
||||
order = request.extra
|
||||
|
||||
if self.check_error(data, "委托"):
|
||||
order.status = Status.REJECTED
|
||||
self.order_manager.on_order(order)
|
||||
return
|
||||
|
||||
sys_orderid = data["data"]
|
||||
self.order_manager.update_orderid_map(order.orderid, sys_orderid)
|
||||
|
||||
def on_send_order_failed(self, status_code: str, request: Request):
|
||||
"""
|
||||
Callback when sending order failed on server.
|
||||
"""
|
||||
order = request.extra
|
||||
order.status = Status.REJECTED
|
||||
self.gateway.on_order(order)
|
||||
|
||||
msg = f"委托失败,状态码:{status_code},信息:{request.response.text}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
def on_send_order_error(
|
||||
self, exception_type: type, exception_value: Exception, tb, request: Request
|
||||
):
|
||||
"""
|
||||
Callback when sending order caused exception.
|
||||
"""
|
||||
order = request.extra
|
||||
order.status = Status.REJECTED
|
||||
self.gateway.on_order(order)
|
||||
|
||||
# Record exception if not ConnectionError
|
||||
if not issubclass(exception_type, ConnectionError):
|
||||
self.on_error(exception_type, exception_value, tb, request)
|
||||
|
||||
def on_cancel_order(self, data, request):
|
||||
""""""
|
||||
cancel_request = request.extra
|
||||
local_orderid = cancel_request.orderid
|
||||
order = self.order_manager.get_order_with_local_orderid(local_orderid)
|
||||
|
||||
if self.check_error(data, "撤单"):
|
||||
order.status = Status.REJECTED
|
||||
else:
|
||||
order.status = Status.CANCELLED
|
||||
self.gateway.write_log(f"委托撤单成功:{order.orderid}")
|
||||
|
||||
self.order_manager.on_order(order)
|
||||
|
||||
def check_error(self, data: dict, func: str = ""):
|
||||
""""""
|
||||
if data["status"] != "error":
|
||||
return False
|
||||
|
||||
error_code = data["err-code"]
|
||||
error_msg = data["err-msg"]
|
||||
|
||||
self.gateway.write_log(f"{func}请求出错,代码:{error_code},信息:{error_msg}")
|
||||
return True
|
||||
|
||||
|
||||
class HuobiWebsocketApiBase(WebsocketClient):
|
||||
""""""
|
||||
|
||||
def __init__(self, gateway):
|
||||
""""""
|
||||
super(HuobiWebsocketApiBase, self).__init__()
|
||||
|
||||
self.gateway = gateway
|
||||
self.gateway_name = gateway.gateway_name
|
||||
|
||||
self.key = ""
|
||||
self.secret = ""
|
||||
self.sign_host = ""
|
||||
self.path = ""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
key: str,
|
||||
secret: str,
|
||||
url: str,
|
||||
proxy_host: str,
|
||||
proxy_port: int
|
||||
):
|
||||
""""""
|
||||
self.key = key
|
||||
self.secret = secret
|
||||
|
||||
host, path = _split_url(url)
|
||||
self.sign_host = host
|
||||
self.path = path
|
||||
|
||||
self.init(url, proxy_host, proxy_port)
|
||||
self.start()
|
||||
|
||||
def login(self):
|
||||
""""""
|
||||
params = {"op": "auth"}
|
||||
params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret))
|
||||
return self.send_packet(params)
|
||||
|
||||
def on_login(self, packet):
|
||||
""""""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def unpack_data(data):
|
||||
""""""
|
||||
return json.loads(zlib.decompress(data, 31))
|
||||
|
||||
def on_packet(self, packet):
|
||||
""""""
|
||||
if "ping" in packet:
|
||||
self.send_packet({"pong": packet["ping"]})
|
||||
elif "err-msg" in packet:
|
||||
return self.on_error_msg(packet)
|
||||
elif "op" in packet and packet["op"] == "auth":
|
||||
return self.on_login()
|
||||
else:
|
||||
self.on_data(packet)
|
||||
|
||||
def on_data(self, packet):
|
||||
""""""
|
||||
print("data : {}".format(packet))
|
||||
|
||||
def on_error_msg(self, packet):
|
||||
""""""
|
||||
msg = packet["err-msg"]
|
||||
if msg == "invalid pong":
|
||||
return
|
||||
|
||||
self.gateway.write_log(packet["err-msg"])
|
||||
|
||||
|
||||
class HuobiTradeWebsocketApi(HuobiWebsocketApiBase):
|
||||
""""""
|
||||
def __init__(self, gateway):
|
||||
""""""
|
||||
super().__init__(gateway)
|
||||
|
||||
self.order_manager = gateway.order_manager
|
||||
self.order_manager.push_data_callback = self.on_data
|
||||
|
||||
self.req_id = 0
|
||||
|
||||
def connect(self, key, secret, proxy_host, proxy_port):
|
||||
""""""
|
||||
super().connect(key, secret, WEBSOCKET_TRADE_HOST, proxy_host, proxy_port)
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
self.req_id += 1
|
||||
req = {
|
||||
"op": "sub",
|
||||
"cid": str(self.req_id),
|
||||
"topic": f"orders.{req.symbol}"
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
def on_connected(self):
|
||||
""""""
|
||||
self.gateway.write_log("交易Websocket API连接成功")
|
||||
self.login()
|
||||
|
||||
def on_login(self):
|
||||
""""""
|
||||
self.gateway.write_log("交易Websocket API登录成功")
|
||||
|
||||
def on_data(self, packet): # type: (dict)->None
|
||||
""""""
|
||||
op = packet.get("op", None)
|
||||
if op != "notify":
|
||||
return
|
||||
|
||||
topic = packet["topic"]
|
||||
if "orders" in topic:
|
||||
self.on_order(packet["data"])
|
||||
|
||||
def on_order(self, data: dict):
|
||||
""""""
|
||||
sys_orderid = str(data["order-id"])
|
||||
|
||||
order = self.order_manager.get_order_with_sys_orderid(sys_orderid)
|
||||
if not order:
|
||||
self.order_manager.add_push_data(sys_orderid, data)
|
||||
return
|
||||
|
||||
traded_volume = float(data["filled-amount"])
|
||||
|
||||
# Push order event
|
||||
order.traded += traded_volume
|
||||
order.status = STATUS_HUOBI2VT.get(data["order-state"], None)
|
||||
self.order_manager.on_order(order)
|
||||
|
||||
# Push trade event
|
||||
if not traded_volume:
|
||||
return
|
||||
|
||||
trade = TradeData(
|
||||
symbol=order.symbol,
|
||||
exchange=Exchange.HUOBI,
|
||||
orderid=order.orderid,
|
||||
tradeid=str(data["seq-id"]),
|
||||
direction=order.direction,
|
||||
price=float(data["price"]),
|
||||
volume=float(data["filled-amount"]),
|
||||
time=datetime.now().strftime("%H:%M:%S"),
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.gateway.on_trade(trade)
|
||||
|
||||
|
||||
class HuobiDataWebsocketApi(HuobiWebsocketApiBase):
|
||||
""""""
|
||||
|
||||
def __init__(self, gateway):
|
||||
""""""
|
||||
super().__init__(gateway)
|
||||
|
||||
self.req_id = 0
|
||||
self.ticks = {}
|
||||
|
||||
def connect(self, key: str, secret: str, proxy_host: str, proxy_port: int):
|
||||
""""""
|
||||
super().connect(key, secret, WEBSOCKET_DATA_HOST, proxy_host, proxy_port)
|
||||
|
||||
def on_connected(self):
|
||||
""""""
|
||||
self.gateway.write_log("行情Websocket API连接成功")
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
symbol = req.symbol
|
||||
|
||||
# Create tick data buffer
|
||||
tick = TickData(
|
||||
symbol=symbol,
|
||||
name=symbol_name_map.get(symbol, ""),
|
||||
exchange=Exchange.HUOBI,
|
||||
datetime=datetime.now(),
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.ticks[symbol] = tick
|
||||
|
||||
# Subscribe to market depth update
|
||||
self.req_id += 1
|
||||
req = {
|
||||
"sub": f"market.{symbol}.depth.step0",
|
||||
"id": str(self.req_id)
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
# Subscribe to market detail update
|
||||
self.req_id += 1
|
||||
req = {
|
||||
"sub": f"market.{symbol}.detail",
|
||||
"id": str(self.req_id)
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
def on_data(self, packet): # type: (dict)->None
|
||||
""""""
|
||||
channel = packet.get("ch", None)
|
||||
if channel:
|
||||
if "depth.step" in channel:
|
||||
self.on_market_depth(packet)
|
||||
elif "detail" in channel:
|
||||
self.on_market_detail(packet)
|
||||
elif "err-code" in packet:
|
||||
code = packet["err-code"]
|
||||
msg = packet["err-msg"]
|
||||
self.gateway.write_log(f"错误代码:{code}, 错误信息:{msg}")
|
||||
|
||||
def on_market_depth(self, data):
|
||||
"""行情深度推送 """
|
||||
symbol = data["ch"].split(".")[1]
|
||||
tick = self.ticks[symbol]
|
||||
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
|
||||
|
||||
bids = data["tick"]["bids"]
|
||||
for n in range(5):
|
||||
price, volume = bids[n]
|
||||
tick.__setattr__("bid_price_" + str(n + 1), float(price))
|
||||
tick.__setattr__("bid_volume_" + str(n + 1), float(volume))
|
||||
|
||||
asks = data["tick"]["asks"]
|
||||
for n in range(5):
|
||||
price, volume = asks[n]
|
||||
tick.__setattr__("ask_price_" + str(n + 1), float(price))
|
||||
tick.__setattr__("ask_volume_" + str(n + 1), float(volume))
|
||||
|
||||
if tick.last_price:
|
||||
self.gateway.on_tick(copy(tick))
|
||||
|
||||
def on_market_detail(self, data):
|
||||
"""市场细节推送"""
|
||||
symbol = data["ch"].split(".")[1]
|
||||
tick = self.ticks[symbol]
|
||||
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
|
||||
|
||||
tick_data = data["tick"]
|
||||
tick.open_price = float(tick_data["open"])
|
||||
tick.high_price = float(tick_data["high"])
|
||||
tick.low_price = float(tick_data["low"])
|
||||
tick.last_price = float(tick_data["close"])
|
||||
tick.volume = float(tick_data["vol"])
|
||||
|
||||
if tick.bid_price_1:
|
||||
self.gateway.on_tick(copy(tick))
|
||||
|
||||
|
||||
def _split_url(url):
|
||||
"""
|
||||
将url拆分为host和path
|
||||
:return: host, path
|
||||
"""
|
||||
result = re.match("\w+://([^/]*)(.*)", url) # noqa
|
||||
if result:
|
||||
return result.group(1), result.group(2)
|
||||
|
||||
|
||||
def create_signature(api_key, method, host, path, secret_key, get_params=None):
|
||||
"""
|
||||
创建签名
|
||||
:param get_params: dict 使用GET方法时附带的额外参数(urlparams)
|
||||
:return:
|
||||
"""
|
||||
sorted_params = [
|
||||
("AccessKeyId", api_key),
|
||||
("SignatureMethod", "HmacSHA256"),
|
||||
("SignatureVersion", "2"),
|
||||
("Timestamp", datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S"))
|
||||
]
|
||||
|
||||
if get_params:
|
||||
sorted_params.extend(list(get_params.items()))
|
||||
sorted_params = list(sorted(sorted_params))
|
||||
encode_params = urllib.parse.urlencode(sorted_params)
|
||||
|
||||
payload = [method, host, path, encode_params]
|
||||
payload = "\n".join(payload)
|
||||
payload = payload.encode(encoding="UTF8")
|
||||
|
||||
secret_key = secret_key.encode(encoding="UTF8")
|
||||
|
||||
digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest()
|
||||
signature = base64.b64encode(digest)
|
||||
|
||||
params = dict(sorted_params)
|
||||
params["Signature"] = signature.decode("UTF8")
|
||||
return params
|
1
vnpy/gateway/okex/__init__.py
Normal file
1
vnpy/gateway/okex/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .okex_gateway import OkexGateway
|
725
vnpy/gateway/okex/okex_gateway.py
Normal file
725
vnpy/gateway/okex/okex_gateway.py
Normal file
@ -0,0 +1,725 @@
|
||||
# encoding: UTF-8
|
||||
"""
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import base64
|
||||
import zlib
|
||||
from copy import copy
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from requests import ConnectionError
|
||||
|
||||
from vnpy.api.rest import Request, RestClient
|
||||
from vnpy.api.websocket import WebsocketClient
|
||||
from vnpy.trader.constant import (
|
||||
Direction,
|
||||
Exchange,
|
||||
OrderType,
|
||||
Product,
|
||||
Status
|
||||
)
|
||||
from vnpy.trader.gateway import BaseGateway
|
||||
from vnpy.trader.object import (
|
||||
TickData,
|
||||
OrderData,
|
||||
TradeData,
|
||||
AccountData,
|
||||
ContractData,
|
||||
OrderRequest,
|
||||
CancelRequest,
|
||||
SubscribeRequest,
|
||||
)
|
||||
|
||||
REST_HOST = "https://www.okex.com"
|
||||
WEBSOCKET_HOST = "wss://real.okex.com:10442/ws/v3"
|
||||
|
||||
STATUS_OKEX2VT = {
|
||||
"ordering": Status.SUBMITTING,
|
||||
"open": Status.NOTTRADED,
|
||||
"part_filled": Status.PARTTRADED,
|
||||
"filled": Status.ALLTRADED,
|
||||
"cancelled": Status.CANCELLED,
|
||||
"cancelling": Status.CANCELLED,
|
||||
"failure": Status.REJECTED,
|
||||
}
|
||||
|
||||
DIRECTION_VT2OKEX = {Direction.LONG: "buy", Direction.SHORT: "sell"}
|
||||
DIRECTION_OKEX2VT = {v: k for k, v in DIRECTION_VT2OKEX.items()}
|
||||
|
||||
ORDERTYPE_VT2OKEX = {
|
||||
OrderType.LIMIT: "limit",
|
||||
OrderType.MARKET: "market"
|
||||
}
|
||||
ORDERTYPE_OKEX2VT = {v: k for k, v in ORDERTYPE_VT2OKEX.items()}
|
||||
|
||||
|
||||
instruments = set()
|
||||
currencies = set()
|
||||
|
||||
|
||||
class OkexGateway(BaseGateway):
|
||||
"""
|
||||
VN Trader Gateway for OKEX connection.
|
||||
"""
|
||||
|
||||
default_setting = {
|
||||
"API Key": "",
|
||||
"Secret Key": "",
|
||||
"Passphrase": "",
|
||||
"会话数": 3,
|
||||
"代理地址": "",
|
||||
"代理端口": "",
|
||||
}
|
||||
|
||||
def __init__(self, event_engine):
|
||||
"""Constructor"""
|
||||
super(OkexGateway, self).__init__(event_engine, "OKEX")
|
||||
|
||||
self.rest_api = OkexRestApi(self)
|
||||
self.ws_api = OkexWebsocketApi(self)
|
||||
|
||||
self.orders = {}
|
||||
|
||||
def connect(self, setting: dict):
|
||||
""""""
|
||||
key = setting["API Key"]
|
||||
secret = setting["Secret Key"]
|
||||
passphrase = setting["Passphrase"]
|
||||
session_number = setting["会话数"]
|
||||
proxy_host = setting["代理地址"]
|
||||
proxy_port = setting["代理端口"]
|
||||
|
||||
if proxy_port.isdigit():
|
||||
proxy_port = int(proxy_port)
|
||||
else:
|
||||
proxy_port = 0
|
||||
|
||||
self.rest_api.connect(key, secret, passphrase,
|
||||
session_number, proxy_host, proxy_port)
|
||||
self.ws_api.connect(key, secret, passphrase, proxy_host, proxy_port)
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
self.ws_api.subscribe(req)
|
||||
|
||||
def send_order(self, req: OrderRequest):
|
||||
""""""
|
||||
return self.rest_api.send_order(req)
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
""""""
|
||||
self.rest_api.cancel_order(req)
|
||||
|
||||
def query_account(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def query_position(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
""""""
|
||||
self.rest_api.stop()
|
||||
self.ws_api.stop()
|
||||
|
||||
def on_order(self, order: OrderData):
|
||||
""""""
|
||||
self.orders[order.orderid] = order
|
||||
super().on_order(order)
|
||||
|
||||
def get_order(self, orderid: str):
|
||||
""""""
|
||||
return self.orders.get(orderid, None)
|
||||
|
||||
|
||||
class OkexRestApi(RestClient):
|
||||
"""
|
||||
OKEX REST API
|
||||
"""
|
||||
|
||||
def __init__(self, gateway: BaseGateway):
|
||||
""""""
|
||||
super(OkexRestApi, self).__init__()
|
||||
|
||||
self.gateway = gateway
|
||||
self.gateway_name = gateway.gateway_name
|
||||
|
||||
self.key = ""
|
||||
self.secret = ""
|
||||
self.passphrase = ""
|
||||
|
||||
self.order_count = 10000
|
||||
self.order_count_lock = Lock()
|
||||
|
||||
self.connect_time = 0
|
||||
|
||||
def sign(self, request):
|
||||
"""
|
||||
Generate OKEX signature.
|
||||
"""
|
||||
# Sign
|
||||
# timestamp = str(time.time())
|
||||
timestamp = get_timestamp()
|
||||
request.data = json.dumps(request.data)
|
||||
|
||||
if request.params:
|
||||
path = request.path + '?' + urlencode(request.params)
|
||||
else:
|
||||
path = request.path
|
||||
|
||||
msg = timestamp + request.method + path + request.data
|
||||
signature = generate_signature(msg, self.secret)
|
||||
|
||||
# Add headers
|
||||
request.headers = {
|
||||
'OK-ACCESS-KEY': self.key,
|
||||
'OK-ACCESS-SIGN': signature,
|
||||
'OK-ACCESS-TIMESTAMP': timestamp,
|
||||
'OK-ACCESS-PASSPHRASE': self.passphrase,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
return request
|
||||
|
||||
def connect(
|
||||
self,
|
||||
key: str,
|
||||
secret: str,
|
||||
passphrase: str,
|
||||
session_number: int,
|
||||
proxy_host: str,
|
||||
proxy_port: int,
|
||||
):
|
||||
"""
|
||||
Initialize connection to REST server.
|
||||
"""
|
||||
self.key = key
|
||||
self.secret = secret.encode()
|
||||
self.passphrase = passphrase
|
||||
|
||||
self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
|
||||
|
||||
self.init(REST_HOST, proxy_host, proxy_port)
|
||||
self.start(session_number)
|
||||
self.gateway.write_log("REST API启动成功")
|
||||
|
||||
self.query_time()
|
||||
self.query_contract()
|
||||
self.query_account()
|
||||
self.query_order()
|
||||
|
||||
def _new_order_id(self):
|
||||
with self.order_count_lock:
|
||||
self.order_count += 1
|
||||
return self.order_count
|
||||
|
||||
def send_order(self, req: OrderRequest):
|
||||
""""""
|
||||
orderid = f"a{self.connect_time}{self._new_order_id()}"
|
||||
|
||||
data = {
|
||||
"client_oid": orderid,
|
||||
"type": ORDERTYPE_VT2OKEX[req.type],
|
||||
"side": DIRECTION_VT2OKEX[req.direction],
|
||||
"instrument_id": req.symbol
|
||||
}
|
||||
|
||||
if req.type == OrderType.MARKET:
|
||||
if req.direction == Direction.LONG:
|
||||
data["notional"] = req.volume
|
||||
else:
|
||||
data["size"] = req.volume
|
||||
else:
|
||||
data["price"] = req.price
|
||||
data["size"] = req.volume
|
||||
|
||||
order = req.create_order_data(orderid, self.gateway_name)
|
||||
|
||||
self.add_request(
|
||||
"POST",
|
||||
"/api/spot/v3/orders",
|
||||
callback=self.on_send_order,
|
||||
data=data,
|
||||
extra=order,
|
||||
on_failed=self.on_send_order_failed,
|
||||
on_error=self.on_send_order_error,
|
||||
)
|
||||
|
||||
self.gateway.on_order(order)
|
||||
return order.vt_orderid
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
""""""
|
||||
data = {
|
||||
"instrument_id": req.symbol,
|
||||
"client_oid": req.orderid
|
||||
}
|
||||
|
||||
path = "/api/spot/v3/cancel_orders/" + req.orderid
|
||||
self.add_request(
|
||||
"POST",
|
||||
path,
|
||||
callback=self.on_cancel_order,
|
||||
data=data,
|
||||
on_error=self.on_cancel_order_error,
|
||||
on_failed=self.on_cancel_order_failed,
|
||||
extra=req
|
||||
)
|
||||
|
||||
def query_contract(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
"GET",
|
||||
"/api/spot/v3/instruments",
|
||||
callback=self.on_query_contract
|
||||
)
|
||||
|
||||
def query_account(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
"GET",
|
||||
"/api/spot/v3/accounts",
|
||||
callback=self.on_query_account
|
||||
)
|
||||
|
||||
def query_order(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
"GET",
|
||||
"/api/spot/v3/orders_pending",
|
||||
callback=self.on_query_order
|
||||
)
|
||||
|
||||
def query_time(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
"GET",
|
||||
"/api/general/v3/time",
|
||||
callback=self.on_query_time
|
||||
)
|
||||
|
||||
def on_query_contract(self, data, request):
|
||||
""""""
|
||||
for instrument_data in data:
|
||||
symbol = instrument_data["instrument_id"]
|
||||
contract = ContractData(
|
||||
symbol=symbol,
|
||||
exchange=Exchange.OKEX,
|
||||
name=symbol,
|
||||
product=Product.SPOT,
|
||||
size=1,
|
||||
pricetick=float(instrument_data["tick_size"]),
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
self.gateway.on_contract(contract)
|
||||
|
||||
instruments.add(instrument_data["instrument_id"])
|
||||
currencies.add(instrument_data["base_currency"])
|
||||
currencies.add(instrument_data["quote_currency"])
|
||||
|
||||
self.gateway.write_log("合约信息查询成功")
|
||||
|
||||
# Start websocket api after instruments data collected
|
||||
self.gateway.ws_api.start()
|
||||
|
||||
def on_query_account(self, data, request):
|
||||
""""""
|
||||
for account_data in data:
|
||||
account = AccountData(
|
||||
accountid=account_data["currency"],
|
||||
balance=float(account_data["balance"]),
|
||||
frozen=float(account_data["hold"]),
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
self.gateway.on_account(account)
|
||||
|
||||
self.gateway.write_log("账户资金查询成功")
|
||||
|
||||
def on_query_order(self, data, request):
|
||||
""""""
|
||||
for order_data in data:
|
||||
order = OrderData(
|
||||
symbol=order_data["instrument_id"],
|
||||
exchange=Exchange.OKEX,
|
||||
type=ORDERTYPE_OKEX2VT[order_data["type"]],
|
||||
orderid=order_data["client_oid"],
|
||||
direction=DIRECTION_OKEX2VT[order_data["side"]],
|
||||
price=float(order_data["price"]),
|
||||
volume=float(order_data["size"]),
|
||||
traded=float(order_data["filled_size"]),
|
||||
time=order_data["timestamp"][11:19],
|
||||
status=STATUS_OKEX2VT[order_data["status"]],
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.gateway.on_order(order)
|
||||
|
||||
self.gateway.write_log("委托信息查询成功")
|
||||
|
||||
def on_query_time(self, data, request):
|
||||
""""""
|
||||
server_time = data["iso"]
|
||||
local_time = datetime.utcnow().isoformat()
|
||||
msg = f"服务器时间:{server_time},本机时间:{local_time}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
def on_send_order_failed(self, status_code: str, request: Request):
|
||||
"""
|
||||
Callback when sending order failed on server.
|
||||
"""
|
||||
order = request.extra
|
||||
order.status = Status.REJECTED
|
||||
self.gateway.on_order(order)
|
||||
|
||||
msg = f"委托失败,状态码:{status_code},信息:{request.response.text}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
def on_send_order_error(
|
||||
self, exception_type: type, exception_value: Exception, tb, request: Request
|
||||
):
|
||||
"""
|
||||
Callback when sending order caused exception.
|
||||
"""
|
||||
order = request.extra
|
||||
order.status = Status.REJECTED
|
||||
self.gateway.on_order(order)
|
||||
|
||||
# Record exception if not ConnectionError
|
||||
if not issubclass(exception_type, ConnectionError):
|
||||
self.on_error(exception_type, exception_value, tb, request)
|
||||
|
||||
def on_send_order(self, data, request):
|
||||
"""Websocket will push a new order status"""
|
||||
order = request.extra
|
||||
|
||||
error_msg = data["error_message"]
|
||||
if error_msg:
|
||||
order.status = Status.REJECTED
|
||||
self.gateway.on_order(order)
|
||||
|
||||
self.gateway.write_log(f"委托失败:{error_msg}")
|
||||
|
||||
def on_cancel_order_error(
|
||||
self, exception_type: type, exception_value: Exception, tb, request: Request
|
||||
):
|
||||
"""
|
||||
Callback when cancelling order failed on server.
|
||||
"""
|
||||
# Record exception if not ConnectionError
|
||||
if not issubclass(exception_type, ConnectionError):
|
||||
self.on_error(exception_type, exception_value, tb, request)
|
||||
|
||||
def on_cancel_order(self, data, request):
|
||||
"""Websocket will push a new order status"""
|
||||
pass
|
||||
|
||||
def on_cancel_order_failed(self, status_code: int, request: Request):
|
||||
"""If cancel failed, mark order status to be rejected."""
|
||||
req = request.extra
|
||||
order = self.gateway.get_order(req.orderid)
|
||||
if order:
|
||||
order.status = Status.REJECTED
|
||||
self.gateway.on_order(order)
|
||||
|
||||
def on_failed(self, status_code: int, request: Request):
|
||||
"""
|
||||
Callback to handle request failed.
|
||||
"""
|
||||
msg = f"请求失败,状态码:{status_code},信息:{request.response.text}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
def on_error(
|
||||
self, exception_type: type, exception_value: Exception, tb, request: Request
|
||||
):
|
||||
"""
|
||||
Callback to handler request exception.
|
||||
"""
|
||||
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
sys.stderr.write(
|
||||
self.exception_detail(exception_type, exception_value, tb, request)
|
||||
)
|
||||
|
||||
|
||||
class OkexWebsocketApi(WebsocketClient):
|
||||
""""""
|
||||
|
||||
def __init__(self, gateway):
|
||||
""""""
|
||||
super(OkexWebsocketApi, self).__init__()
|
||||
self.ping_interval = 20 # OKEX use 30 seconds for ping
|
||||
|
||||
self.gateway = gateway
|
||||
self.gateway_name = gateway.gateway_name
|
||||
|
||||
self.key = ""
|
||||
self.secret = ""
|
||||
self.passphrase = ""
|
||||
|
||||
self.trade_count = 10000
|
||||
self.connect_time = 0
|
||||
|
||||
self.callbacks = {}
|
||||
self.ticks = {}
|
||||
|
||||
def connect(
|
||||
self,
|
||||
key: str,
|
||||
secret: str,
|
||||
passphrase: str,
|
||||
proxy_host: str,
|
||||
proxy_port: int
|
||||
):
|
||||
""""""
|
||||
self.key = key
|
||||
self.secret = secret.encode()
|
||||
self.passphrase = passphrase
|
||||
|
||||
self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
|
||||
|
||||
self.init(WEBSOCKET_HOST, proxy_host, proxy_port)
|
||||
# self.start()
|
||||
|
||||
def unpack_data(self, data):
|
||||
""""""
|
||||
return json.loads(zlib.decompress(data, -zlib.MAX_WBITS))
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
"""
|
||||
Subscribe to tick data upate.
|
||||
"""
|
||||
tick = TickData(
|
||||
symbol=req.symbol,
|
||||
exchange=req.exchange,
|
||||
name=req.symbol,
|
||||
datetime=datetime.now(),
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.ticks[req.symbol] = tick
|
||||
|
||||
channel_ticker = f"spot/ticker:{req.symbol}"
|
||||
channel_depth = f"spot/depth5:{req.symbol}"
|
||||
|
||||
self.callbacks[channel_ticker] = self.on_ticker
|
||||
self.callbacks[channel_depth] = self.on_depth
|
||||
|
||||
req = {
|
||||
"op": "subscribe",
|
||||
"args": [channel_ticker, channel_depth]
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
def on_connected(self):
|
||||
""""""
|
||||
self.gateway.write_log("Websocket API连接成功")
|
||||
self.login()
|
||||
|
||||
def on_disconnected(self):
|
||||
""""""
|
||||
self.gateway.write_log("Websocket API连接断开")
|
||||
|
||||
def on_packet(self, packet: dict):
|
||||
""""""
|
||||
if "event" in packet:
|
||||
event = packet["event"]
|
||||
if event == "subscribe":
|
||||
return
|
||||
elif event == "error":
|
||||
msg = packet["message"]
|
||||
self.gateway.write_log(f"Websocket API请求异常:{msg}")
|
||||
elif event == "login":
|
||||
self.on_login(packet)
|
||||
else:
|
||||
channel = packet["table"]
|
||||
data = packet["data"]
|
||||
callback = self.callbacks.get(channel, None)
|
||||
|
||||
if callback:
|
||||
for d in data:
|
||||
callback(d)
|
||||
|
||||
def on_error(self, exception_type: type, exception_value: Exception, tb):
|
||||
""""""
|
||||
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
sys.stderr.write(self.exception_detail(
|
||||
exception_type, exception_value, tb))
|
||||
|
||||
def login(self):
|
||||
"""
|
||||
Need to login befores subscribe to websocket topic.
|
||||
"""
|
||||
timestamp = str(time.time())
|
||||
|
||||
msg = timestamp + 'GET' + '/users/self/verify'
|
||||
signature = generate_signature(msg, self.secret)
|
||||
|
||||
req = {
|
||||
"op": "login",
|
||||
"args": [
|
||||
self.key,
|
||||
self.passphrase,
|
||||
timestamp,
|
||||
signature.decode("utf-8")
|
||||
]
|
||||
}
|
||||
self.send_packet(req)
|
||||
self.callbacks['login'] = self.on_login
|
||||
|
||||
def subscribe_topic(self):
|
||||
"""
|
||||
Subscribe to all private topics.
|
||||
"""
|
||||
self.callbacks["spot/ticker"] = self.on_ticker
|
||||
self.callbacks["spot/depth5"] = self.on_depth
|
||||
self.callbacks["spot/account"] = self.on_account
|
||||
self.callbacks["spot/order"] = self.on_order
|
||||
|
||||
# Subscribe to order update
|
||||
channels = []
|
||||
for instrument_id in instruments:
|
||||
channel = f"spot/order:{instrument_id}"
|
||||
channels.append(channel)
|
||||
|
||||
req = {
|
||||
"op": "subscribe",
|
||||
"args": channels
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
# Subscribe to account update
|
||||
channels = []
|
||||
for currency in currencies:
|
||||
channel = f"spot/account:{currency}"
|
||||
channels.append(channel)
|
||||
|
||||
req = {
|
||||
"op": "subscribe",
|
||||
"args": channels
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
# Subscribe to BTC/USDT trade for keep connection alive
|
||||
req = {
|
||||
"op": "subscribe",
|
||||
"args": ["spot/trade:BTC-USDT"]
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
def on_login(self, data: dict):
|
||||
""""""
|
||||
success = data.get("success", False)
|
||||
|
||||
if success:
|
||||
self.gateway.write_log("Websocket API登录成功")
|
||||
self.subscribe_topic()
|
||||
else:
|
||||
self.gateway.write_log("Websocket API登录失败")
|
||||
|
||||
def on_ticker(self, d):
|
||||
""""""
|
||||
symbol = d["instrument_id"]
|
||||
tick = self.ticks.get(symbol, None)
|
||||
if not tick:
|
||||
return
|
||||
|
||||
tick.last_price = float(d["last"])
|
||||
tick.open = float(d["open_24h"])
|
||||
tick.high = float(d["high_24h"])
|
||||
tick.low = float(d["low_24h"])
|
||||
tick.volume = float(d["base_volume_24h"])
|
||||
tick.datetime = datetime.strptime(
|
||||
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
self.gateway.on_tick(copy(tick))
|
||||
|
||||
def on_depth(self, d):
|
||||
""""""
|
||||
for tick_data in d:
|
||||
symbol = d["instrument_id"]
|
||||
tick = self.ticks.get(symbol, None)
|
||||
if not tick:
|
||||
return
|
||||
|
||||
bids = d["bids"]
|
||||
asks = d["asks"]
|
||||
for n, buf in enumerate(bids):
|
||||
price, volume, _ = buf
|
||||
tick.__setattr__("bid_price_%s" % (n + 1), float(price))
|
||||
tick.__setattr__("bid_volume_%s" % (n + 1), float(volume))
|
||||
|
||||
for n, buf in enumerate(asks):
|
||||
price, volume, _ = buf
|
||||
tick.__setattr__("ask_price_%s" % (n + 1), float(price))
|
||||
tick.__setattr__("ask_volume_%s" % (n + 1), float(volume))
|
||||
|
||||
tick.datetime = datetime.strptime(
|
||||
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
self.gateway.on_tick(copy(tick))
|
||||
|
||||
def on_order(self, d):
|
||||
""""""
|
||||
order = OrderData(
|
||||
symbol=d["instrument_id"],
|
||||
exchange=Exchange.OKEX,
|
||||
type=ORDERTYPE_OKEX2VT[d["type"]],
|
||||
orderid=d["client_oid"],
|
||||
direction=DIRECTION_OKEX2VT[d["side"]],
|
||||
price=float(d["price"]),
|
||||
volume=float(d["size"]),
|
||||
traded=float(d["filled_size"]),
|
||||
time=d["timestamp"][11:19],
|
||||
status=STATUS_OKEX2VT[d["status"]],
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.gateway.on_order(copy(order))
|
||||
|
||||
trade_volume = d.get("last_fill_qty", 0)
|
||||
if not trade_volume or float(trade_volume) == 0:
|
||||
return
|
||||
|
||||
self.trade_count += 1
|
||||
tradeid = f"{self.connect_time}{self.trade_count}"
|
||||
|
||||
trade = TradeData(
|
||||
symbol=order.symbol,
|
||||
exchange=order.exchange,
|
||||
orderid=order.orderid,
|
||||
tradeid=tradeid,
|
||||
direction=order.direction,
|
||||
price=float(d["last_fill_px"]),
|
||||
volume=float(trade_volume),
|
||||
time=d["last_fill_time"][11:19],
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
self.gateway.on_trade(trade)
|
||||
|
||||
def on_account(self, d):
|
||||
""""""
|
||||
account = AccountData(
|
||||
accountid=d["currency"],
|
||||
balance=float(d["balance"]),
|
||||
frozen=float(d["hold"]),
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
|
||||
self.gateway.on_account(copy(account))
|
||||
|
||||
|
||||
def generate_signature(msg: str, secret_key: str):
|
||||
"""OKEX V3 signature"""
|
||||
return base64.b64encode(hmac.new(secret_key, msg.encode(), hashlib.sha256).digest())
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
""""""
|
||||
now = datetime.utcnow()
|
||||
timestamp = now.isoformat("T", "milliseconds")
|
||||
return timestamp + "Z"
|
@ -1,5 +1,6 @@
|
||||
# encoding: UTF-8
|
||||
"""
|
||||
Author: KeKe
|
||||
Please install tiger-api before use.
|
||||
pip install tigeropen
|
||||
"""
|
||||
@ -123,6 +124,9 @@ class TigerGateway(BaseGateway):
|
||||
self.contracts = {}
|
||||
self.symbol_names = {}
|
||||
|
||||
self.push_connected = False
|
||||
self.subscribed_symbols = set()
|
||||
|
||||
def run(self):
|
||||
""""""
|
||||
while self.active:
|
||||
@ -203,23 +207,33 @@ class TigerGateway(BaseGateway):
|
||||
"""
|
||||
protocol, host, port = self.client_config.socket_host_port
|
||||
self.push_client = PushClient(host, port, (protocol == 'ssl'))
|
||||
self.push_client.connect(
|
||||
self.client_config.tiger_id, self.client_config.private_key)
|
||||
|
||||
self.push_client.quote_changed = self.on_quote_change
|
||||
self.push_client.asset_changed = self.on_asset_change
|
||||
self.push_client.position_changed = self.on_position_change
|
||||
self.push_client.order_changed = self.on_order_change
|
||||
|
||||
self.write_log("推送接口连接成功")
|
||||
self.push_client.connect(
|
||||
self.client_config.tiger_id, self.client_config.private_key)
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
self.push_client.subscribe_quote([req.symbol])
|
||||
self.subscribed_symbols.add(req.symbol)
|
||||
|
||||
if self.push_connected:
|
||||
self.push_client.subscribe_quote([req.symbol])
|
||||
|
||||
def on_push_connected(self):
|
||||
""""""
|
||||
self.push_connected = True
|
||||
self.write_log("推送接口连接成功")
|
||||
|
||||
self.push_client.subscribe_asset()
|
||||
self.push_client.subscribe_position()
|
||||
self.push_client.subscribe_order()
|
||||
|
||||
self.push_client.subscribe_quote(list(self.subscribed_symbols))
|
||||
|
||||
def on_quote_change(self, tiger_symbol: str, data: list, trading: bool):
|
||||
""""""
|
||||
data = dict(data)
|
||||
|
1
vnpy/rpc/__init__.py
Normal file
1
vnpy/rpc/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .vnrpc import RpcServer, RpcClient, RemoteException
|
36
vnpy/rpc/test_client.py
Normal file
36
vnpy/rpc/test_client.py
Normal file
@ -0,0 +1,36 @@
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
from time import sleep
|
||||
|
||||
from vnpy.rpc import RpcClient
|
||||
|
||||
|
||||
class TestClient(RpcClient):
|
||||
"""
|
||||
Test RpcClient
|
||||
"""
|
||||
|
||||
def __init__(self, req_address, sub_address):
|
||||
"""
|
||||
Constructor
|
||||
"""
|
||||
super(TestClient, self).__init__(req_address, sub_address)
|
||||
|
||||
def callback(self, topic, data):
|
||||
"""
|
||||
Realize callable function
|
||||
"""
|
||||
print('client received topic:', topic, ', data:', data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
req_address = 'tcp://localhost:2014'
|
||||
sub_address = 'tcp://localhost:0602'
|
||||
|
||||
tc = TestClient(req_address, sub_address)
|
||||
tc.subscribeTopic('')
|
||||
tc.start()
|
||||
|
||||
while 1:
|
||||
print(tc.add(1, 3))
|
||||
sleep(2)
|
40
vnpy/rpc/test_server.py
Normal file
40
vnpy/rpc/test_server.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
from time import sleep, time
|
||||
|
||||
from vnpy.rpc import RpcServer
|
||||
|
||||
|
||||
class TestServer(RpcServer):
|
||||
"""
|
||||
Test RpcServer
|
||||
"""
|
||||
|
||||
def __init__(self, rep_address, pub_address):
|
||||
"""
|
||||
Constructor
|
||||
"""
|
||||
super(TestServer, self).__init__(rep_address, pub_address)
|
||||
|
||||
self.register(self.add)
|
||||
|
||||
def add(self, a, b):
|
||||
"""
|
||||
Test function
|
||||
"""
|
||||
print('receiving: %s, %s' % (a, b))
|
||||
return a + b
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rep_address = 'tcp://*:2014'
|
||||
pub_address = 'tcp://*:0602'
|
||||
|
||||
ts = TestServer(rep_address, pub_address)
|
||||
ts.start()
|
||||
|
||||
while 1:
|
||||
content = 'current server time is %s' % time()
|
||||
print(content)
|
||||
ts.publish('test', content)
|
||||
sleep(2)
|
329
vnpy/rpc/vnrpc.py
Normal file
329
vnpy/rpc/vnrpc.py
Normal file
@ -0,0 +1,329 @@
|
||||
import threading
|
||||
import traceback
|
||||
import signal
|
||||
|
||||
import zmq
|
||||
from msgpack import packb, unpackb
|
||||
from json import dumps, loads
|
||||
|
||||
import pickle
|
||||
p_dumps = pickle.dumps
|
||||
p_loads = pickle.loads
|
||||
|
||||
|
||||
# Achieve Ctrl-c interrupt recv
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
|
||||
|
||||
class RpcObject(object):
|
||||
"""
|
||||
Referred to serialization of packing and unpacking, we offer 3 tools:
|
||||
1) maspack: higher performance, but usually requires the installation of msgpack related tools;
|
||||
2) jason: Slightly lower performance but versatility is better, most programming languages have built-in libraries;
|
||||
3) cPickle: Lower performance and only can be used in Python, but it is very convenient to transfer Python objects directly.
|
||||
|
||||
Therefore, it is recommended to use msgpack.
|
||||
Use json, if you want to communicate with some languages without providing msgpack.
|
||||
Use cPickle, when the data being transferred contains many custom Python objects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Constructor
|
||||
Use msgpack as default serialization tool
|
||||
"""
|
||||
self.use_msgpack()
|
||||
|
||||
def pack(self, data):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def unpack(self, data):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def __json_pack(self, data):
|
||||
"""
|
||||
Pack with json
|
||||
"""
|
||||
return dumps(data)
|
||||
|
||||
def __json_unpack(self, data):
|
||||
"""
|
||||
Unpack with json
|
||||
"""
|
||||
return loads(data)
|
||||
|
||||
def __msgpack_pack(self, data):
|
||||
"""
|
||||
Pack with msgpack
|
||||
"""
|
||||
return packb(data)
|
||||
|
||||
def __msgpack_unpack(self, data):
|
||||
"""
|
||||
Unpack with msgpack
|
||||
"""
|
||||
return unpackb(data)
|
||||
|
||||
def __pickle_pack(self, data):
|
||||
"""
|
||||
Pack with cPickle
|
||||
"""
|
||||
return p_dumps(data)
|
||||
|
||||
def __pickle_unpack(self, data):
|
||||
"""
|
||||
Unpack with cPickle
|
||||
"""
|
||||
return p_loads(data)
|
||||
|
||||
def use_json(self):
|
||||
"""
|
||||
Use json as serialization tool
|
||||
"""
|
||||
self.pack = self.__json_pack
|
||||
self.unpack = self.__json_unpack
|
||||
|
||||
def use_msgpack(self):
|
||||
"""
|
||||
Use msgpack as serialization tool
|
||||
"""
|
||||
self.pack = self.__msgpack_pack
|
||||
self.unpack = self.__msgpack_unpack
|
||||
|
||||
def use_pickle(self):
|
||||
"""
|
||||
Use cPickle as serialization tool
|
||||
"""
|
||||
self.pack = self.__pickle_pack
|
||||
self.unpack = self.__pickle_unpack
|
||||
|
||||
|
||||
class RpcServer(RpcObject):
|
||||
""""""
|
||||
|
||||
def __init__(self, rep_address, pub_address):
|
||||
"""
|
||||
Constructor
|
||||
"""
|
||||
super(RpcServer, self).__init__()
|
||||
|
||||
# Save functions dict: key is fuction name, value is fuction object
|
||||
self.__functions = {}
|
||||
|
||||
# Zmq port related
|
||||
self.__context = zmq.Context()
|
||||
|
||||
self.__socket_rep = self.__context.socket(
|
||||
zmq.REP) # Reply socket (Request–reply pattern)
|
||||
self.__socket_rep.bind(rep_address)
|
||||
|
||||
# Publish socket (Publish–subscribe pattern)
|
||||
self.__socket_pub = self.__context.socket(zmq.PUB)
|
||||
self.__socket_pub.bind(pub_address)
|
||||
|
||||
# Woker thread related
|
||||
self.__active = False # RpcServer status
|
||||
self.__thread = threading.Thread(target=self.run) # RpcServer thread
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start RpcServer
|
||||
"""
|
||||
# Start RpcServer status
|
||||
self.__active = True
|
||||
|
||||
# Start RpcServer thread
|
||||
if not self.__thread.isAlive():
|
||||
self.__thread.start()
|
||||
|
||||
def stop(self, join=False):
|
||||
"""
|
||||
Stop RpcServer
|
||||
"""
|
||||
# Stop RpcServer status
|
||||
self.__active = False
|
||||
|
||||
# Wait for RpcServer thread to exit
|
||||
if join and self.__thread.isAlive():
|
||||
self.__thread.join()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run RpcServer functions
|
||||
"""
|
||||
while self.__active:
|
||||
# Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
|
||||
if not self.__socket_rep.poll(1000):
|
||||
continue
|
||||
|
||||
# Receive request data from Reply socket
|
||||
reqb = self.__socket_rep.recv()
|
||||
|
||||
# Unpack request by deserialization
|
||||
req = self.unpack(reqb)
|
||||
|
||||
# Get function name and parameters
|
||||
name, args, kwargs = req
|
||||
|
||||
# Try to get and execute callable function object; capture exception information if it fails
|
||||
name = name.decode("UTF-8")
|
||||
|
||||
try:
|
||||
func = self.__functions[name]
|
||||
r = func(*args, **kwargs)
|
||||
rep = [True, r]
|
||||
except Exception as e: # noqa
|
||||
rep = [False, traceback.format_exc()]
|
||||
|
||||
# Pack response by serialization
|
||||
repb = self.pack(rep)
|
||||
|
||||
# send callable response by Reply socket
|
||||
self.__socket_rep.send(repb)
|
||||
|
||||
def publish(self, topic, data):
|
||||
"""
|
||||
Publish data
|
||||
"""
|
||||
# Serialized data
|
||||
topic = bytes(topic, "UTF-8")
|
||||
datab = self.pack(data)
|
||||
|
||||
# Send data by Publish socket
|
||||
# topci must be ascii encoding
|
||||
self.__socket_pub.send_multipart([topic, datab])
|
||||
|
||||
def register(self, func):
|
||||
"""
|
||||
Register function
|
||||
"""
|
||||
self.__functions[func.__name__] = func
|
||||
|
||||
|
||||
class RpcClient(RpcObject):
|
||||
""""""
|
||||
|
||||
def __init__(self, req_address, sub_address):
|
||||
"""Constructor"""
|
||||
super(RpcClient, self).__init__()
|
||||
|
||||
# zmq port related
|
||||
self.__req_address = req_address
|
||||
self.__sub_address = sub_address
|
||||
|
||||
self.__context = zmq.Context()
|
||||
# Request socket (Request–reply pattern)
|
||||
self.__socket_req = self.__context.socket(zmq.REQ)
|
||||
# Subscribe socket (Publish–subscribe pattern)
|
||||
self.__socket_sub = self.__context.socket(zmq.SUB)
|
||||
|
||||
# Woker thread relate, used to process data pushed from server
|
||||
self.__active = False # RpcClient status
|
||||
self.__thread = threading.Thread(
|
||||
target=self.run) # RpcClient thread
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
Realize remote call function
|
||||
"""
|
||||
# Perform remote call task
|
||||
def dorpc(*args, **kwargs):
|
||||
# Generate request
|
||||
req = [name, args, kwargs]
|
||||
|
||||
# Pack request by serialization
|
||||
reqb = self.pack(req)
|
||||
|
||||
# Send request and wait for response
|
||||
self.__socket_req.send(reqb)
|
||||
repb = self.__socket_req.recv()
|
||||
|
||||
# Unpack response by deserialization
|
||||
rep = self.unpack(repb)
|
||||
|
||||
# Return response if successed; Trigger exception if failed
|
||||
if rep[0]:
|
||||
return rep[1]
|
||||
else:
|
||||
raise RemoteException(rep[1].decode("UTF-8"))
|
||||
|
||||
return dorpc
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start RpcClient
|
||||
"""
|
||||
# Connect zmq port
|
||||
self.__socket_req.connect(self.__req_address)
|
||||
self.__socket_sub.connect(self.__sub_address)
|
||||
|
||||
# Start RpcClient status
|
||||
self.__active = True
|
||||
|
||||
# Start RpcClient thread
|
||||
if not self.__thread.isAlive():
|
||||
self.__thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop RpcClient
|
||||
"""
|
||||
# Stop RpcClient status
|
||||
self.__active = False
|
||||
|
||||
# Wait for RpcClient thread to exit
|
||||
if self.__thread.isAlive():
|
||||
self.__thread.join()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run RpcClient function
|
||||
"""
|
||||
while self.__active:
|
||||
# Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
|
||||
if not self.__socket_sub.poll(1000):
|
||||
continue
|
||||
|
||||
# Receive data from subscribe socket
|
||||
topic, datab = self.__socket_sub.recv_multipart()
|
||||
|
||||
# Unpack data by deserialization
|
||||
data = self.unpack(datab)
|
||||
|
||||
# Process data by callable function
|
||||
topic = topic.decode("UTF-8")
|
||||
|
||||
self.callback(topic, data)
|
||||
|
||||
def callback(self, topic, data):
|
||||
"""
|
||||
Callable function
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def subscribeTopic(self, topic):
|
||||
"""
|
||||
Subscribe data
|
||||
"""
|
||||
topic = bytes(topic, "UTF-8")
|
||||
self.__socket_sub.setsockopt(zmq.SUBSCRIBE, topic)
|
||||
|
||||
|
||||
class RemoteException(Exception):
|
||||
"""
|
||||
RPC remote exception
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
"""
|
||||
Constructor
|
||||
"""
|
||||
self.__value = value
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Output error message
|
||||
"""
|
||||
return self.__value
|
@ -99,6 +99,8 @@ class Exchange(Enum):
|
||||
|
||||
# CryptoCurrency
|
||||
BITMEX = "BITMEX"
|
||||
OKEX = "OKEX"
|
||||
HUOBI = "HUOBI"
|
||||
|
||||
|
||||
class Currency(Enum):
|
||||
|
@ -1,13 +1,47 @@
|
||||
""""""
|
||||
|
||||
from peewee import SqliteDatabase, Model, CharField, DateTimeField, FloatField
|
||||
from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \
|
||||
SqliteDatabase
|
||||
|
||||
from .constant import Exchange, Interval
|
||||
from .object import BarData, TickData
|
||||
from .utility import get_file_path
|
||||
from .setting import SETTINGS
|
||||
from .utility import resolve_path
|
||||
|
||||
DB_NAME = "database.db"
|
||||
DB = SqliteDatabase(str(get_file_path(DB_NAME)))
|
||||
|
||||
def init():
|
||||
db_settings = SETTINGS['database']
|
||||
driver = db_settings["driver"]
|
||||
|
||||
init_funcs = {
|
||||
"sqlite": init_sqlite,
|
||||
"mysql": init_mysql,
|
||||
"postgresql": init_postgresql,
|
||||
}
|
||||
|
||||
assert driver in init_funcs
|
||||
del db_settings['driver']
|
||||
return init_funcs[driver](db_settings)
|
||||
|
||||
|
||||
def init_sqlite(settings: dict):
|
||||
global DB
|
||||
database = settings['database']
|
||||
|
||||
DB = SqliteDatabase(str(resolve_path(database)))
|
||||
|
||||
|
||||
def init_mysql(settings: dict):
|
||||
global DB
|
||||
DB = MySQLDatabase(**settings)
|
||||
|
||||
|
||||
def init_postgresql(settings: dict):
|
||||
global DB
|
||||
DB = PostgresqlDatabase(**settings)
|
||||
|
||||
|
||||
init()
|
||||
|
||||
|
||||
class DbBarData(Model):
|
||||
|
@ -24,7 +24,7 @@ from .event import (
|
||||
from .gateway import BaseGateway
|
||||
from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest
|
||||
from .setting import SETTINGS
|
||||
from .utility import Singleton, get_folder_path
|
||||
from .utility import get_folder_path
|
||||
|
||||
|
||||
class MainEngine:
|
||||
@ -52,6 +52,7 @@ class MainEngine:
|
||||
"""
|
||||
engine = engine_class(self, self.event_engine)
|
||||
self.engines[engine.engine_name] = engine
|
||||
return engine
|
||||
|
||||
def add_gateway(self, gateway_class: BaseGateway):
|
||||
"""
|
||||
@ -59,6 +60,7 @@ class MainEngine:
|
||||
"""
|
||||
gateway = gateway_class(self.event_engine)
|
||||
self.gateways[gateway.gateway_name] = gateway
|
||||
return gateway
|
||||
|
||||
def add_app(self, app_class: BaseApp):
|
||||
"""
|
||||
@ -67,7 +69,8 @@ class MainEngine:
|
||||
app = app_class()
|
||||
self.apps[app.app_name] = app
|
||||
|
||||
self.add_engine(app.engine_class)
|
||||
engine = self.add_engine(app.engine_class)
|
||||
return engine
|
||||
|
||||
def init_engines(self):
|
||||
"""
|
||||
@ -199,8 +202,6 @@ class LogEngine(BaseEngine):
|
||||
Processes log event and output with logging module.
|
||||
"""
|
||||
|
||||
__metaclass__ = Singleton
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super(LogEngine, self).__init__(main_engine, event_engine, "log")
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from copy import copy
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from .event import (
|
||||
@ -227,3 +228,124 @@ class BaseGateway(ABC):
|
||||
Return default setting dict.
|
||||
"""
|
||||
return self.default_setting
|
||||
|
||||
|
||||
class LocalOrderManager:
|
||||
"""
|
||||
Management tool to support use local order id for trading.
|
||||
"""
|
||||
|
||||
def __init__(self, gateway: BaseGateway):
|
||||
""""""
|
||||
self.gateway = gateway
|
||||
|
||||
# For generating local orderid
|
||||
self.order_prefix = ""
|
||||
self.order_count = 0
|
||||
self.orders = {} # local_orderid:order
|
||||
|
||||
# Map between local and system orderid
|
||||
self.local_sys_orderid_map = {}
|
||||
self.sys_local_orderid_map = {}
|
||||
|
||||
# Push order data buf
|
||||
self.push_data_buf = {} # sys_orderid:data
|
||||
|
||||
# Callback for processing push order data
|
||||
self.push_data_callback = None
|
||||
|
||||
# Cancel request buf
|
||||
self.cancel_request_buf = {} # local_orderid:req
|
||||
|
||||
def new_local_orderid(self):
|
||||
"""
|
||||
Generate a new local orderid.
|
||||
"""
|
||||
self.order_count += 1
|
||||
local_orderid = str(self.order_count).rjust(8, "0")
|
||||
return local_orderid
|
||||
|
||||
def get_local_orderid(self, sys_orderid: str):
|
||||
"""
|
||||
Get local orderid with sys orderid.
|
||||
"""
|
||||
local_orderid = self.sys_local_orderid_map.get(sys_orderid, "")
|
||||
|
||||
if not local_orderid:
|
||||
local_orderid = self.new_local_orderid()
|
||||
self.update_orderid_map(local_orderid, sys_orderid)
|
||||
|
||||
return local_orderid
|
||||
|
||||
def get_sys_orderid(self, local_orderid: str):
|
||||
"""
|
||||
Get sys orderid with local orderid.
|
||||
"""
|
||||
sys_orderid = self.local_sys_orderid_map.get(local_orderid, "")
|
||||
return sys_orderid
|
||||
|
||||
def update_orderid_map(self, local_orderid: str, sys_orderid: str):
|
||||
"""
|
||||
Update orderid map.
|
||||
"""
|
||||
self.sys_local_orderid_map[sys_orderid] = local_orderid
|
||||
self.local_sys_orderid_map[local_orderid] = sys_orderid
|
||||
|
||||
self.check_cancel_request(local_orderid)
|
||||
self.check_push_data(sys_orderid)
|
||||
|
||||
def check_push_data(self, sys_orderid: str):
|
||||
"""
|
||||
Check if any order push data waiting.
|
||||
"""
|
||||
if sys_orderid not in self.push_data_buf:
|
||||
return
|
||||
|
||||
data = self.push_data_buf.pop(sys_orderid)
|
||||
if self.push_data_callback:
|
||||
self.push_data_callback(data)
|
||||
|
||||
def add_push_data(self, sys_orderid: str, data: dict):
|
||||
"""
|
||||
Add push data into buf.
|
||||
"""
|
||||
self.push_data_buf[sys_orderid] = data
|
||||
|
||||
def get_order_with_sys_orderid(self, sys_orderid: str):
|
||||
""""""
|
||||
local_orderid = self.sys_local_orderid_map.get(sys_orderid, None)
|
||||
if not local_orderid:
|
||||
return None
|
||||
else:
|
||||
return self.get_order_with_local_orderid(local_orderid)
|
||||
|
||||
def get_order_with_local_orderid(self, local_orderid: str):
|
||||
""""""
|
||||
order = self.orders[local_orderid]
|
||||
return copy(order)
|
||||
|
||||
def on_order(self, order: OrderData):
|
||||
"""
|
||||
Keep an order buf before pushing it to gateway.
|
||||
"""
|
||||
self.orders[order.orderid] = copy(order)
|
||||
self.gateway.on_order(order)
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
"""
|
||||
"""
|
||||
sys_orderid = self.get_sys_orderid(req.orderid)
|
||||
if not sys_orderid:
|
||||
self.cancel_request_buf[req.orderid] = req
|
||||
return
|
||||
|
||||
self.gateway.cancel_order(req)
|
||||
|
||||
def check_cancel_request(self, local_orderid: str):
|
||||
"""
|
||||
"""
|
||||
if local_orderid not in self.cancel_request_buf:
|
||||
return
|
||||
|
||||
req = self.cancel_request_buf.pop(local_orderid)
|
||||
self.gateway.cancel_order(req)
|
||||
|
@ -23,10 +23,17 @@ SETTINGS = {
|
||||
"email.receiver": "",
|
||||
|
||||
"rqdata.username": "",
|
||||
"rqdata.password": ""
|
||||
"rqdata.password": "",
|
||||
"database": {
|
||||
"driver": "sqlite", # sqlite, mysql, postgresql
|
||||
"database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath
|
||||
"host": "localhost",
|
||||
"port": 3306,
|
||||
"user": "root",
|
||||
"password": ""
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Load global setting from json file.
|
||||
SETTING_FILENAME = "vt_setting.json"
|
||||
SETTINGS.update(load_json(SETTING_FILENAME))
|
||||
|
@ -24,7 +24,7 @@ from .widget import (
|
||||
AboutDialog,
|
||||
)
|
||||
from ..engine import MainEngine
|
||||
from ..utility import get_icon_path, TRADER_PATH
|
||||
from ..utility import get_icon_path, TRADER_DIR
|
||||
|
||||
|
||||
class MainWindow(QtWidgets.QMainWindow):
|
||||
@ -38,7 +38,7 @@ class MainWindow(QtWidgets.QMainWindow):
|
||||
self.main_engine = main_engine
|
||||
self.event_engine = event_engine
|
||||
|
||||
self.window_title = f"VN Trader [{TRADER_PATH}]"
|
||||
self.window_title = f"VN Trader [{TRADER_DIR}]"
|
||||
|
||||
self.connect_dialogs = {}
|
||||
self.widgets = {}
|
||||
|
@ -3,6 +3,7 @@ General utility functions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@ -12,25 +13,13 @@ import talib
|
||||
from .object import BarData, TickData
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
"""
|
||||
Singleton metaclass,
|
||||
|
||||
class A:
|
||||
__metaclass__ = Singleton
|
||||
"""
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
""""""
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cls._instances[cls]
|
||||
def resolve_path(pattern: str):
|
||||
env = dict(os.environ)
|
||||
env.update({"VNPY_TEMP": str(TEMP_DIR)})
|
||||
return pattern.format(**env)
|
||||
|
||||
|
||||
def get_path(temp_name: str):
|
||||
def _get_trader_dir(temp_name: str):
|
||||
"""
|
||||
Get path where trader is running in.
|
||||
"""
|
||||
@ -53,21 +42,21 @@ def get_path(temp_name: str):
|
||||
return home_path, temp_path
|
||||
|
||||
|
||||
TRADER_PATH, TEMP_PATH = get_path(".vntrader")
|
||||
TRADER_DIR, TEMP_DIR = _get_trader_dir(".vntrader")
|
||||
|
||||
|
||||
def get_file_path(filename: str):
|
||||
"""
|
||||
Get path for temp file with filename.
|
||||
"""
|
||||
return TEMP_PATH.joinpath(filename)
|
||||
return TEMP_DIR.joinpath(filename)
|
||||
|
||||
|
||||
def get_folder_path(folder_name: str):
|
||||
"""
|
||||
Get path for temp folder with folder name.
|
||||
"""
|
||||
folder_path = TEMP_PATH.joinpath(folder_name)
|
||||
folder_path = TEMP_DIR.joinpath(folder_name)
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir()
|
||||
return folder_path
|
||||
@ -385,3 +374,12 @@ class ArrayManager(object):
|
||||
if array:
|
||||
return up, down
|
||||
return up[-1], down[-1]
|
||||
|
||||
|
||||
def virtual(func: "callable"):
|
||||
"""
|
||||
mark a function as "virtual", which means that this function can be override.
|
||||
any base class should use this or @abstractmethod to decorate all functions
|
||||
that can be (re)implemented by subclasses.
|
||||
"""
|
||||
return func
|
||||
|
Loading…
Reference in New Issue
Block a user