Merge branch 'DEV' into master

This commit is contained in:
vn.py 2019-04-15 11:07:28 +08:00 committed by GitHub
commit fbb0249187
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 4572 additions and 129 deletions

View File

@ -7,6 +7,7 @@
### 使用VNConda
#### 1.下载VNConda Python 3.7 64位
下载地址如下:[VNConda-2.0.1-Windows-x86_64](https://conda.vnpy.com/VNConda-2.0.1-Windows-x86_64.exe)
@ -44,6 +45,7 @@
 
 
### 手动安装
#### 1.下载并安装最新版Anaconda3.7 64位

View File

@ -1,4 +1,5 @@
PyQt5<5.12
pyqtgraph
dataclasses; python_version<="3.6"
qdarkstyle
requests

View File

@ -0,0 +1,189 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import multiprocessing\n",
"import numpy as np\n",
"from deap import creator, base, tools, algorithms\n",
"from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n",
"from boll_channel_strategy import BollChannelStrategy\n",
"from datetime import datetime\n",
"import multiprocessing #多进程\n",
"from scoop import futures #多进程"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def parameter_generate():\n",
" '''\n",
" 根据设置的起始值,终止值和步进,随机生成待优化的策略参数\n",
" '''\n",
" parameter_list = []\n",
" p1 = random.randrange(4,50,2) #布林带窗口\n",
" p2 = random.randrange(4,50,2) #布林带通道阈值\n",
" p3 = random.randrange(4,50,2) #CCI窗口\n",
" p4 = random.randrange(18,40,2) #ATR窗口 \n",
"\n",
" parameter_list.append(p1)\n",
" parameter_list.append(p2)\n",
" parameter_list.append(p3)\n",
" parameter_list.append(p4)\n",
"\n",
" return parameter_list"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def object_func(strategy_avg):\n",
" \"\"\"\n",
" 本函数为优化目标函数根据随机生成的策略参数运行回测后自动返回2个结果指标收益回撤比和夏普比率\n",
" \"\"\"\n",
" # 创建回测引擎对象\n",
" engine = BacktestingEngine()\n",
" engine.set_parameters(\n",
" vt_symbol=\"IF88.CFFEX\",\n",
" interval=\"1m\",\n",
" start=datetime(2018, 9, 1),\n",
" end=datetime(2019, 1,1),\n",
" rate=0,\n",
" slippage=0,\n",
" size=300,\n",
" pricetick=0.2,\n",
" capital=1_000_000,\n",
" )\n",
"\n",
" setting = {'boll_window': strategy_avg[0], #布林带窗口\n",
" 'boll_dev': strategy_avg[1], #布林带通道阈值\n",
" 'cci_window': strategy_avg[2], #CCI窗口\n",
" 'atr_window': strategy_avg[3],} #ATR窗口 \n",
"\n",
" #加载策略 \n",
" #engine.initStrategy(TurtleTradingStrategy, setting)\n",
" engine.add_strategy(BollChannelStrategy, setting)\n",
" engine.load_data()\n",
" engine.run_backtesting()\n",
" engine.calculate_result()\n",
" result = engine.calculate_statistics(Output=False)\n",
"\n",
" return_drawdown_ratio = round(result['return_drawdown_ratio'],2) #收益回撤比\n",
" sharpe_ratio= round(result['sharpe_ratio'],2) #夏普比率\n",
" return return_drawdown_ratio , sharpe_ratio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"#设置优化方向:最大化收益回撤比,最大化夏普比率\n",
"creator.create(\"FitnessMulti\", base.Fitness, weights=(1.0, 1.0)) # 1.0 求最大值;-1.0 求最小值\n",
"creator.create(\"Individual\", list, fitness=creator.FitnessMulti)\n",
"\n",
"def optimize():\n",
" \"\"\"\"\"\" \n",
" toolbox = base.Toolbox() #Toolbox是deap库内置的工具箱里面包含遗传算法中所用到的各种函数\n",
"\n",
" # 初始化 \n",
" toolbox.register(\"individual\", tools.initIterate, creator.Individual,parameter_generate) # 注册个体随机生成的策略参数parameter_generate() \n",
" toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual) #注册种群:个体形成种群 \n",
" toolbox.register(\"mate\", tools.cxTwoPoint) #注册交叉:两点交叉 \n",
" toolbox.register(\"mutate\", tools.mutUniformInt,low = 4,up = 40,indpb=0.6) #注册变异:随机生成一定区间内的整数\n",
" toolbox.register(\"evaluate\", object_func) #注册评估优化目标函数object_func() \n",
" toolbox.register(\"select\", tools.selNSGA2) #注册选择:NSGA-II(带精英策略的非支配排序的遗传算法)\n",
" #pool = multiprocessing.Pool()\n",
" #toolbox.register(\"map\", pool.map)\n",
" #toolbox.register(\"map\", futures.map)\n",
"\n",
" #遗传算法参数设置\n",
" MU = 40 #设置每一代选择的个体数\n",
" LAMBDA = 160 #设置每一代产生的子女数\n",
" pop = toolbox.population(400) #设置族群里面的个体数量\n",
" CXPB, MUTPB, NGEN = 0.5, 0.35,40 #分别为种群内部个体的交叉概率、变异概率、产生种群代数\n",
" hof = tools.ParetoFront() #解的集合:帕累托前沿(非占优最优集)\n",
"\n",
" #解的集合的描述统计信息\n",
" #集合内平均值,标准差,最小值,最大值可以体现集合的收敛程度\n",
" #收敛程度低可以增加算法的迭代次数\n",
" stats = tools.Statistics(lambda ind: ind.fitness.values)\n",
" np.set_printoptions(suppress=True) #对numpy默认输出的科学计数法转换\n",
" stats.register(\"mean\", np.mean, axis=0) #统计目标优化函数结果的平均值\n",
" stats.register(\"std\", np.std, axis=0) #统计目标优化函数结果的标准差\n",
" stats.register(\"min\", np.min, axis=0) #统计目标优化函数结果的最小值\n",
" stats.register(\"max\", np.max, axis=0) #统计目标优化函数结果的最大值\n",
"\n",
" #运行算法\n",
" algorithms.eaMuPlusLambda(pop, toolbox, MU, LAMBDA, CXPB, MUTPB, NGEN, stats,\n",
" halloffame=hof) #esMuPlusLambda是一种基于(μ+λ)选择策略的多目标优化分段遗传算法\n",
"\n",
" return pop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"optimize()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -10,9 +10,13 @@ from vnpy.gateway.ib import IbGateway
from vnpy.gateway.ctp import CtpGateway
from vnpy.gateway.tiger import TigerGateway
from vnpy.gateway.oes import OesGateway
from vnpy.gateway.okex import OkexGateway
from vnpy.gateway.huobi import HuobiGateway
from vnpy.app.cta_strategy import CtaStrategyApp
from vnpy.app.csv_loader import CsvLoaderApp
from vnpy.app.algo_trading import AlgoTradingApp
from vnpy.app.cta_backtester import CtaBacktesterApp
def main():
@ -28,9 +32,13 @@ def main():
main_engine.add_gateway(BitmexGateway)
main_engine.add_gateway(TigerGateway)
main_engine.add_gateway(OesGateway)
main_engine.add_gateway(OkexGateway)
main_engine.add_gateway(HuobiGateway)
main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CtaBacktesterApp)
main_engine.add_app(CsvLoaderApp)
main_engine.add_app(AlgoTradingApp)
main_window = MainWindow(main_engine, event_engine)
main_window.showMaximized()

View File

@ -48,14 +48,18 @@ class WebsocketClient(object):
self.proxy_host = None
self.proxy_port = None
self.ping_interval = 60 # seconds
# For debugging
self._last_sent_text = None
self._last_received_text = None
def init(self, host: str, proxy_host: str = "", proxy_port: int = 0):
""""""
def init(self, host: str, proxy_host: str = "", proxy_port: int = 0, ping_interval: int = 60):
"""
:param ping_interval: unit: seconds, type: int
"""
self.host = host
self.ping_interval = ping_interval # seconds
if proxy_host and proxy_port:
self.proxy_host = proxy_host
@ -206,7 +210,7 @@ class WebsocketClient(object):
et, ev, tb = sys.exc_info()
self.on_error(et, ev, tb)
self._reconnect()
for i in range(60):
for i in range(self.ping_interval):
if not self._active:
break
sleep(1)

View File

@ -0,0 +1,18 @@
from pathlib import Path
from vnpy.trader.app import BaseApp
from .engine import AlgoEngine, APP_NAME
from .template import AlgoTemplate
class AlgoTradingApp(BaseApp):
""""""
app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent
display_name = "算法交易"
engine_class = AlgoEngine
widget_name = "AlgoManager"
icon_name = "algo.ico"

View File

View File

@ -0,0 +1,137 @@
from vnpy.trader.constant import Offset, Direction
from vnpy.trader.object import TradeData, OrderData, TickData
from vnpy.trader.engine import BaseEngine
from vnpy.app.algo_trading import AlgoTemplate
class IcebergAlgo(AlgoTemplate):
""""""
display_name = "Iceberg 冰山"
default_setting = {
"vt_symbol": "",
"direction": [Direction.LONG.value, Direction.SHORT.value],
"price": 0.0,
"volume": 0.0,
"display_volume": 0.0,
"interval": 0,
"offset": [
Offset.NONE.value,
Offset.OPEN.value,
Offset.CLOSE.value,
Offset.CLOSETODAY.value,
Offset.CLOSEYESTERDAY.value
]
}
variables = [
"traded",
"timer_count",
"vt_orderid"
]
def __init__(
self,
algo_engine: BaseEngine,
algo_name: str,
setting: dict
):
""""""
super().__init__(algo_engine, algo_name, setting)
# Parameters
self.vt_symbol = setting["vt_symbol"]
self.direction = Direction(setting["direction"])
self.price = setting["price"]
self.volume = setting["volume"]
self.display_volume = setting["display_volume"]
self.interval = setting["interval"]
self.offset = Offset(setting["offset"])
# Variables
self.timer_count = 0
self.vt_orderid = ""
self.traded = 0
self.last_tick = None
self.subscribe(self.vt_symbol)
self.put_parameters_event()
self.put_variables_event()
def on_stop(self):
""""""
self.write_log("停止算法")
def on_tick(self, tick: TickData):
""""""
self.last_tick = tick
def on_order(self, order: OrderData):
""""""
msg = f"委托号:{order.vt_orderid},委托状态:{order.status.value}"
self.write_log(msg)
if not order.is_active():
self.vt_orderid = ""
self.put_variables_event()
def on_trade(self, trade: TradeData):
""""""
self.traded += trade.volume
if self.traded >= self.volume:
self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}")
self.stop()
else:
self.put_variables_event()
def on_timer(self):
""""""
self.timer_count += 1
if self.timer_count < self.interval:
self.put_variables_event()
return
self.timer_count = 0
contract = self.get_contract(self.vt_symbol)
if not contract:
return
# If order already finished, just send new order
if not self.vt_orderid:
order_volume = self.volume - self.traded
order_volume = min(order_volume, self.display_volume)
if self.direction == Direction.LONG:
self.vt_orderid = self.buy(
self.vt_symbol,
self.price,
order_volume,
offset=self.offset
)
else:
self.vt_orderid = self.sell(
self.vt_symbol,
self.price,
order_volume,
offset=self.offset
)
# Otherwise check for cancel
else:
if self.direction == Direction.LONG:
if self.last_tick.ask_price_1 <= self.price:
self.cancel_order(self.vt_orderid)
self.vt_orderid = ""
self.write_log(u"最新Tick卖一价低于买入委托价格之前委托可能丢失强制撤单")
else:
if self.last_tick.bid_price_1 >= self.price:
self.cancel_order(self.vt_orderid)
self.vt_orderid = ""
self.write_log(u"最新Tick买一价高于卖出委托价格之前委托可能丢失强制撤单")
self.put_variables_event()

View File

@ -0,0 +1,101 @@
from vnpy.trader.constant import Offset, Direction
from vnpy.trader.object import TradeData, OrderData, TickData
from vnpy.trader.engine import BaseEngine
from vnpy.app.algo_trading import AlgoTemplate
class SniperAlgo(AlgoTemplate):
""""""
display_name = "Sniper 狙击手"
default_setting = {
"vt_symbol": "",
"direction": [Direction.LONG.value, Direction.SHORT.value],
"price": 0.0,
"volume": 0.0,
"offset": [
Offset.NONE.value,
Offset.OPEN.value,
Offset.CLOSE.value,
Offset.CLOSETODAY.value,
Offset.CLOSEYESTERDAY.value
]
}
variables = [
"traded",
"vt_orderid"
]
def __init__(
self,
algo_engine: BaseEngine,
algo_name: str,
setting: dict
):
""""""
super().__init__(algo_engine, algo_name, setting)
# Parameters
self.vt_symbol = setting["vt_symbol"]
self.direction = Direction(setting["direction"])
self.price = setting["price"]
self.volume = setting["volume"]
self.offset = Offset(setting["offset"])
# Variables
self.vt_orderid = ""
self.traded = 0
self.subscribe(self.vt_symbol)
self.put_parameters_event()
self.put_variables_event()
def on_tick(self, tick: TickData):
""""""
if self.vt_orderid:
self.cancel_all()
return
if self.direction == Direction.LONG:
if tick.ask_price_1 <= self.price:
order_volume = self.volume - self.traded
order_volume = min(order_volume, tick.ask_volume_1)
self.vt_orderid = self.buy(
self.vt_symbol,
self.price,
order_volume,
offset=self.offset
)
else:
if tick.bid_price_1 >= self.price:
order_volume = self.volume - self.traded
order_volume = min(order_volume, tick.bid_volume_1)
self.vt_orderid = self.sell(
self.vt_symbol,
self.price,
order_volume,
offset=self.offset
)
self.put_variables_event()
def on_order(self, order: OrderData):
""""""
if not order.is_active():
self.vt_orderid = ""
self.put_variables_event()
def on_trade(self, trade: TradeData):
""""""
self.traded += trade.volume
if self.traded >= self.volume:
self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}")
self.stop()
else:
self.put_variables_event()

View File

@ -0,0 +1,105 @@
from vnpy.trader.constant import Offset, Direction
from vnpy.trader.object import TradeData
from vnpy.trader.engine import BaseEngine
from vnpy.app.algo_trading import AlgoTemplate
class TwapAlgo(AlgoTemplate):
""""""
display_name = "TWAP 时间加权平均"
default_setting = {
"vt_symbol": "",
"direction": [Direction.LONG.value, Direction.SHORT.value],
"price": 0.0,
"volume": 0.0,
"time": 600,
"interval": 60,
"offset": [
Offset.NONE.value,
Offset.OPEN.value,
Offset.CLOSE.value,
Offset.CLOSETODAY.value,
Offset.CLOSEYESTERDAY.value
]
}
variables = [
"traded",
"order_volume",
"timer_count",
"total_count"
]
def __init__(
self,
algo_engine: BaseEngine,
algo_name: str,
setting: dict
):
""""""
super().__init__(algo_engine, algo_name, setting)
# Parameters
self.vt_symbol = setting["vt_symbol"]
self.direction = Direction(setting["direction"])
self.price = setting["price"]
self.volume = setting["volume"]
self.time = setting["time"]
self.interval = setting["interval"]
self.offset = Offset(setting["offset"])
# Variables
self.order_volume = self.volume / (self.time / self.interval)
self.timer_count = 0
self.total_count = 0
self.traded = 0
self.subscribe(self.vt_symbol)
self.put_parameters_event()
self.put_variables_event()
def on_trade(self, trade: TradeData):
""""""
self.traded += trade.volume
if self.traded >= self.volume:
self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}")
self.stop()
else:
self.put_variables_event()
def on_timer(self):
""""""
self.timer_count += 1
self.total_count += 1
self.put_variables_event()
if self.total_count >= self.time:
self.write_log("执行时间已结束,停止算法")
self.stop()
return
if self.timer_count < self.interval:
return
self.timer_count = 0
tick = self.get_tick(self.vt_symbol)
if not tick:
return
self.cancel_all()
left_volume = self.volume - self.traded
order_volume = min(self.order_volume, left_volume)
if self.direction == Direction.LONG:
if tick.ask_price_1 <= self.price:
self.buy(self.vt_symbol, self.price,
order_volume, offset=self.offset)
else:
if tick.bid_price_1 >= self.price:
self.sell(self.vt_symbol, self.price,
order_volume, offset=self.offset)

View File

@ -0,0 +1,265 @@
from vnpy.event import EventEngine, Event
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.event import (
EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE)
from vnpy.trader.constant import (Direction, Offset, OrderType)
from vnpy.trader.object import (SubscribeRequest, OrderRequest)
from vnpy.trader.utility import load_json, save_json
from .template import AlgoTemplate
APP_NAME = "AlgoTrading"
EVENT_ALGO_LOG = "eAlgoLog"
EVENT_ALGO_SETTING = "eAlgoSetting"
EVENT_ALGO_VARIABLES = "eAlgoVariables"
EVENT_ALGO_PARAMETERS = "eAlgoParameters"
class AlgoEngine(BaseEngine):
""""""
setting_filename = "algo_trading_setting.json"
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
"""Constructor"""
super().__init__(main_engine, event_engine, APP_NAME)
self.algos = {}
self.symbol_algo_map = {}
self.orderid_algo_map = {}
self.algo_templates = {}
self.algo_settings = {}
self.load_algo_template()
self.register_event()
def init_engine(self):
""""""
self.write_log("算法交易引擎启动")
self.load_algo_setting()
def load_algo_template(self):
""""""
from .algos.twap_algo import TwapAlgo
from .algos.iceberg_algo import IcebergAlgo
from .algos.sniper_algo import SniperAlgo
self.add_algo_template(TwapAlgo)
self.add_algo_template(IcebergAlgo)
self.add_algo_template(SniperAlgo)
def add_algo_template(self, template: AlgoTemplate):
""""""
self.algo_templates[template.__name__] = template
def load_algo_setting(self):
""""""
self.algo_settings = load_json(self.setting_filename)
for setting_name, setting in self.algo_settings.items():
self.put_setting_event(setting_name, setting)
self.write_log("算法配置载入成功")
def save_algo_setting(self):
""""""
save_json(self.setting_filename, self.algo_settings)
def register_event(self):
""""""
self.event_engine.register(EVENT_TICK, self.process_tick_event)
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
self.event_engine.register(EVENT_ORDER, self.process_order_event)
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
def process_tick_event(self, event: Event):
""""""
tick = event.data
algos = self.symbol_algo_map.get(tick.vt_symbol, None)
if algos:
for algo in algos:
algo.update_tick(tick)
def process_timer_event(self, event: Event):
""""""
for algo in self.algos.values():
algo.update_timer()
def process_trade_event(self, event: Event):
""""""
trade = event.data
algo = self.orderid_algo_map.get(trade.vt_orderid, None)
if algo:
algo.update_trade(trade)
def process_order_event(self, event: Event):
""""""
order = event.data
algo = self.orderid_algo_map.get(order.vt_orderid, None)
if algo:
algo.update_order(order)
def start_algo(self, setting: dict):
""""""
template_name = setting["template_name"]
algo_template = self.algo_templates[template_name]
algo = algo_template.new(self, setting)
algo.start()
self.algos[algo.algo_name] = algo
return algo.algo_name
def stop_algo(self, algo_name: str):
""""""
algo = self.algos.get(algo_name, None)
if algo:
algo.stop()
self.algos.pop(algo_name)
def stop_all(self):
""""""
for algo_name in list(self.algos.keys()):
self.stop_algo(algo_name)
def subscribe(self, algo: AlgoTemplate, vt_symbol: str):
""""""
contract = self.main_engine.get_contract(vt_symbol)
if not contract:
self.write_log(f'订阅行情失败,找不到合约:{vt_symbol}', algo)
return
algos = self.symbol_algo_map.setdefault(vt_symbol, set())
if not algos:
req = SubscribeRequest(
symbol=contract.symbol,
exchange=contract.exchange
)
self.main_engine.subscribe(req, contract.gateway_name)
algos.add(algo)
def send_order(
self,
algo: AlgoTemplate,
vt_symbol: str,
direction: Direction,
price: float,
volume: float,
order_type: OrderType,
offset: Offset
):
""""""
contract = self.main_engine.get_contract(vt_symbol)
if not contract:
self.write_log(f'委托下单失败,找不到合约:{vt_symbol}', algo)
return
req = OrderRequest(
symbol=contract.symbol,
exchange=contract.exchange,
direction=direction,
type=order_type,
volume=volume,
price=price,
offset=offset
)
vt_orderid = self.main_engine.send_order(req, contract.gateway_name)
self.orderid_algo_map[vt_orderid] = algo
return vt_orderid
def cancel_order(self, algo: AlgoTemplate, vt_orderid: str):
""""""
order = self.main_engine.get_order(vt_orderid)
if not order:
self.write_log(f"委托撤单失败,找不到委托:{vt_orderid}", algo)
return
req = order.create_cancel_request()
self.main_engine.cancel_order(req, order.gateway_name)
def get_tick(self, algo: AlgoTemplate, vt_symbol: str):
""""""
tick = self.main_engine.get_tick(vt_symbol)
if not tick:
self.write_log(f"查询行情失败,找不到行情:{vt_symbol}", algo)
return tick
def get_contract(self, algo: AlgoTemplate, vt_symbol: str):
""""""
contract = self.main_engine.get_contract(vt_symbol)
if not contract:
self.write_log(f"查询合约失败,找不到合约:{vt_symbol}", algo)
return contract
def write_log(self, msg: str, algo: AlgoTemplate = None):
""""""
if algo:
msg = f"{algo.algo_name}{msg}"
event = Event(EVENT_ALGO_LOG)
event.data = msg
self.event_engine.put(event)
def put_setting_event(self, setting_name: str, setting: dict):
""""""
event = Event(EVENT_ALGO_SETTING)
event.data = {
"setting_name": setting_name,
"setting": setting
}
self.event_engine.put(event)
def update_algo_setting(self, setting_name: str, setting: dict):
""""""
self.algo_settings[setting_name] = setting
self.save_algo_setting()
self.put_setting_event(setting_name, setting)
def remove_algo_setting(self, setting_name: str):
""""""
if setting_name not in self.algo_settings:
return
self.algo_settings.pop(setting_name)
event = Event(EVENT_ALGO_SETTING)
event.data = {
"setting_name": setting_name,
"setting": None
}
self.event_engine.put(event)
self.save_algo_setting()
def put_parameters_event(self, algo: AlgoTemplate, parameters: dict):
""""""
event = Event(EVENT_ALGO_PARAMETERS)
event.data = {
"algo_name": algo.algo_name,
"parameters": parameters
}
self.event_engine.put(event)
def put_variables_event(self, algo: AlgoTemplate, variables: dict):
""""""
event = Event(EVENT_ALGO_VARIABLES)
event.data = {
"algo_name": algo.algo_name,
"variables": variables
}
self.event_engine.put(event)

View File

@ -0,0 +1,193 @@
from vnpy.trader.engine import BaseEngine
from vnpy.trader.object import TickData, OrderData, TradeData
from vnpy.trader.constant import OrderType, Offset, Direction
from vnpy.trader.utility import virtual
class AlgoTemplate:
""""""
_count = 0
display_name = ""
default_setting = {}
variables = []
def __init__(
self,
algo_engine: BaseEngine,
algo_name: str,
setting: dict
):
"""Constructor"""
self.algo_engine = algo_engine
self.algo_name = algo_name
self.active = False
self.active_orders = {} # vt_orderid:order
self.variables.insert(0, "active")
@classmethod
def new(cls, algo_engine: BaseEngine, setting: dict):
"""Create new algo instance"""
cls._count += 1
algo_name = f"{cls.__name__}_{cls._count}"
algo = cls(algo_engine, algo_name, setting)
return algo
def update_tick(self, tick: TickData):
""""""
if self.active:
self.on_tick(tick)
def update_order(self, order: OrderData):
""""""
if self.active:
if order.is_active():
self.active_orders[order.vt_orderid] = order
elif order.vt_orderid in self.active_orders:
self.active_orders.pop(order.vt_orderid)
self.on_order(order)
def update_trade(self, trade: TradeData):
""""""
if self.active:
self.on_trade(trade)
def update_timer(self):
""""""
if self.active:
self.on_timer()
def on_start(self):
""""""
pass
@virtual
def on_stop(self):
""""""
pass
@virtual
def on_tick(self, tick: TickData):
""""""
pass
@virtual
def on_order(self, order: OrderData):
""""""
pass
@virtual
def on_trade(self, trade: TradeData):
""""""
pass
@virtual
def on_timer(self):
""""""
pass
def start(self):
""""""
self.active = True
self.on_start()
self.put_variables_event()
def stop(self):
""""""
self.active = False
self.cancel_all()
self.on_stop()
self.put_variables_event()
self.write_log("停止算法")
def subscribe(self, vt_symbol):
""""""
self.algo_engine.subscribe(self, vt_symbol)
def buy(
self,
vt_symbol,
price,
volume,
order_type: OrderType = OrderType.LIMIT,
offset: Offset = Offset.NONE
):
""""""
msg = f"委托买入{vt_symbol}{volume}@{price}"
self.write_log(msg)
return self.algo_engine.send_order(
self,
vt_symbol,
Direction.LONG,
price,
volume,
order_type,
offset
)
def sell(
self,
vt_symbol,
price,
volume,
order_type: OrderType = OrderType.LIMIT,
offset: Offset = Offset.NONE
):
""""""
msg = f"委托卖出{vt_symbol}{volume}@{price}"
self.write_log(msg)
return self.algo_engine.send_order(
self,
vt_symbol,
Direction.SHORT,
price,
volume,
order_type,
offset
)
def cancel_order(self, vt_orderid: str):
""""""
self.algo_engine.cancel_order(self, vt_orderid)
def cancel_all(self):
""""""
if not self.active_orders:
return
for vt_orderid in self.active_orders.keys():
self.cancel_order(vt_orderid)
def get_tick(self, vt_symbol: str):
""""""
return self.algo_engine.get_tick(self, vt_symbol)
def get_contract(self, vt_symbol: str):
""""""
return self.algo_engine.get_contract(self, vt_symbol)
def write_log(self, msg: str):
""""""
self.algo_engine.write_log(msg, self)
def put_parameters_event(self):
""""""
parameters = {}
for name in self.default_setting.keys():
parameters[name] = getattr(self, name)
self.algo_engine.put_parameters_event(self, parameters)
def put_variables_event(self):
""""""
variables = {}
for name in self.variables:
variables[name] = getattr(self, name)
self.algo_engine.put_variables_event(self, variables)

View File

@ -0,0 +1 @@
from .widget import AlgoManager

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

@ -0,0 +1,16 @@
NAME_DISPLAY_MAP = {
"vt_symbol": "本地代码",
"direction": "方向",
"price": "价格",
"volume": "数量",
"time": "执行时间(秒)",
"interval": "每轮间隔(秒)",
"offset": "开平",
"active": "算法状态",
"traded": "成交数量",
"order_volume": "单笔委托",
"timer_count": "本轮读秒",
"total_count": "累计读秒",
"template_name": "算法模板",
"display_volume": "挂出数量"
}

View File

@ -0,0 +1,575 @@
"""
Widget for algo trading.
"""
from functools import partial
from datetime import datetime
from vnpy.event import EventEngine, Event
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtWidgets, QtCore
from ..engine import (
AlgoEngine,
AlgoTemplate,
APP_NAME,
EVENT_ALGO_LOG,
EVENT_ALGO_PARAMETERS,
EVENT_ALGO_VARIABLES,
EVENT_ALGO_SETTING
)
from .display import NAME_DISPLAY_MAP
class AlgoWidget(QtWidgets.QWidget):
"""
Start connection of a certain gateway.
"""
def __init__(
self,
algo_engine: AlgoEngine,
algo_template: AlgoTemplate
):
""""""
super().__init__()
self.algo_engine = algo_engine
self.template_name = algo_template.__name__
self.default_setting = algo_template.default_setting
self.widgets = {}
self.init_ui()
def init_ui(self):
"""
Initialize line edits and form layout based on setting.
"""
self.setMaximumWidth(400)
form = QtWidgets.QFormLayout()
for field_name, field_value in self.default_setting.items():
field_type = type(field_value)
if field_type == list:
widget = QtWidgets.QComboBox()
widget.addItems(field_value)
else:
widget = QtWidgets.QLineEdit()
display_name = NAME_DISPLAY_MAP.get(field_name, field_name)
form.addRow(display_name, widget)
self.widgets[field_name] = (widget, field_type)
start_algo_button = QtWidgets.QPushButton("启动算法")
start_algo_button.clicked.connect(self.start_algo)
form.addRow(start_algo_button)
form.addRow(QtWidgets.QLabel(""))
self.setting_name_line = QtWidgets.QLineEdit()
form.addRow("配置名称", self.setting_name_line)
save_setting_button = QtWidgets.QPushButton("保存配置")
save_setting_button.clicked.connect(self.save_setting)
form.addRow(save_setting_button)
self.setLayout(form)
def get_setting(self):
"""
Get setting value from line edits.
"""
setting = {"template_name": self.template_name}
for field_name, tp in self.widgets.items():
widget, field_type = tp
if field_type == list:
field_value = str(widget.currentText())
else:
try:
field_value = field_type(widget.text())
except ValueError:
display_name = NAME_DISPLAY_MAP.get(field_name, field_name)
QtWidgets.QMessageBox.warning(
self,
"参数错误",
f"{display_name}参数类型应为{field_type},请检查!"
)
return None
setting[field_name] = field_value
return setting
def start_algo(self):
"""
Start algo trading.
"""
setting = self.get_setting()
if setting:
self.algo_engine.start_algo(setting)
def update_setting(self, setting_name: str, setting: dict):
"""
Update setting into widgets.
"""
self.setting_name_line.setText(setting_name)
for name, tp in self.widgets.items():
widget, _ = tp
value = setting[name]
if isinstance(widget, QtWidgets.QLineEdit):
widget.setText(str(value))
elif isinstance(widget, QtWidgets.QComboBox):
ix = widget.findText(value)
widget.setCurrentIndex(ix)
def save_setting(self):
"""
Save algo setting
"""
setting_name = self.setting_name_line.text()
if not setting_name:
return
setting = self.get_setting()
if setting:
self.algo_engine.update_algo_setting(setting_name, setting)
class AlgoMonitor(QtWidgets.QTableWidget):
""""""
parameters_signal = QtCore.pyqtSignal(Event)
variables_signal = QtCore.pyqtSignal(Event)
def __init__(
self,
algo_engine: AlgoEngine,
event_engine: EventEngine,
mode_active: bool
):
""""""
super().__init__()
self.algo_engine = algo_engine
self.event_engine = event_engine
self.mode_active = mode_active
self.algo_cells = {}
self.init_ui()
self.register_event()
def init_ui(self):
""""""
labels = [
"",
"算法",
"参数",
"状态"
]
self.setColumnCount(len(labels))
self.setHorizontalHeaderLabels(labels)
self.verticalHeader().setVisible(False)
self.setEditTriggers(self.NoEditTriggers)
self.verticalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.ResizeToContents
)
for column in range(2, 4):
self.horizontalHeader().setSectionResizeMode(
column,
QtWidgets.QHeaderView.Stretch
)
self.setWordWrap(True)
if not self.mode_active:
self.hideColumn(0)
def register_event(self):
""""""
self.parameters_signal.connect(self.process_parameters_event)
self.variables_signal.connect(self.process_variables_event)
self.event_engine.register(
EVENT_ALGO_PARAMETERS, self.parameters_signal.emit)
self.event_engine.register(
EVENT_ALGO_VARIABLES, self.variables_signal.emit)
def process_parameters_event(self, event):
""""""
data = event.data
algo_name = data["algo_name"]
parameters = data["parameters"]
cells = self.get_algo_cells(algo_name)
text = to_text(parameters)
cells["parameters"].setText(text)
def process_variables_event(self, event):
""""""
data = event.data
algo_name = data["algo_name"]
variables = data["variables"]
cells = self.get_algo_cells(algo_name)
variables_cell = cells["variables"]
text = to_text(variables)
variables_cell.setText(text)
row = self.row(variables_cell)
active = variables["active"]
if self.mode_active:
if active:
self.showRow(row)
else:
self.hideRow(row)
else:
if active:
self.hideRow(row)
else:
self.showRow(row)
def stop_algo(self, algo_name: str):
""""""
self.algo_engine.stop_algo(algo_name)
def get_algo_cells(self, algo_name: str):
""""""
cells = self.algo_cells.get(algo_name, None)
if not cells:
stop_func = partial(self.stop_algo, algo_name=algo_name)
stop_button = QtWidgets.QPushButton("停止")
stop_button.clicked.connect(stop_func)
name_cell = QtWidgets.QTableWidgetItem(algo_name)
parameters_cell = QtWidgets.QTableWidgetItem()
variables_cell = QtWidgets.QTableWidgetItem()
self.insertRow(0)
self.setCellWidget(0, 0, stop_button)
self.setItem(0, 1, name_cell)
self.setItem(0, 2, parameters_cell)
self.setItem(0, 3, variables_cell)
cells = {
"name": name_cell,
"parameters": parameters_cell,
"variables": variables_cell
}
self.algo_cells[algo_name] = cells
return cells
class ActiveAlgoMonitor(AlgoMonitor):
"""
Monitor for active algos.
"""
def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine):
""""""
super().__init__(algo_engine, event_engine, True)
class InactiveAlgoMonitor(AlgoMonitor):
"""
Monitor for inactive algos.
"""
def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine):
""""""
super().__init__(algo_engine, event_engine, False)
class SettingMonitor(QtWidgets.QTableWidget):
""""""
setting_signal = QtCore.pyqtSignal(Event)
use_signal = QtCore.pyqtSignal(dict)
def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine):
""""""
super().__init__()
self.algo_engine = algo_engine
self.event_engine = event_engine
self.settings = {}
self.setting_cells = {}
self.init_ui()
self.register_event()
def init_ui(self):
""""""
labels = [
"",
"",
"名称",
"配置"
]
self.setColumnCount(len(labels))
self.setHorizontalHeaderLabels(labels)
self.verticalHeader().setVisible(False)
self.setEditTriggers(self.NoEditTriggers)
self.verticalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.ResizeToContents
)
self.horizontalHeader().setSectionResizeMode(
3,
QtWidgets.QHeaderView.Stretch
)
self.setWordWrap(True)
def register_event(self):
""""""
self.setting_signal.connect(self.process_setting_event)
self.event_engine.register(
EVENT_ALGO_SETTING, self.setting_signal.emit)
def process_setting_event(self, event):
""""""
data = event.data
setting_name = data["setting_name"]
setting = data["setting"]
cells = self.get_setting_cells(setting_name)
if setting:
self.settings[setting_name] = setting
cells["setting"].setText(to_text(setting))
else:
if setting_name in self.settings:
self.settings.pop(setting_name)
row = self.row(cells["setting"])
self.removeRow(row)
self.setting_cells.pop(setting_name)
def get_setting_cells(self, setting_name: str):
""""""
cells = self.setting_cells.get(setting_name, None)
if not cells:
use_func = partial(self.use_setting, setting_name=setting_name)
use_button = QtWidgets.QPushButton("使用")
use_button.clicked.connect(use_func)
remove_func = partial(self.remove_setting,
setting_name=setting_name)
remove_button = QtWidgets.QPushButton("移除")
remove_button.clicked.connect(remove_func)
name_cell = QtWidgets.QTableWidgetItem(setting_name)
setting_cell = QtWidgets.QTableWidgetItem()
self.insertRow(0)
self.setCellWidget(0, 0, use_button)
self.setCellWidget(0, 1, remove_button)
self.setItem(0, 2, name_cell)
self.setItem(0, 3, setting_cell)
cells = {
"name": name_cell,
"setting": setting_cell
}
self.setting_cells[setting_name] = cells
return cells
def use_setting(self, setting_name: str):
""""""
setting = self.settings[setting_name]
setting["setting_name"] = setting_name
self.use_signal.emit(setting)
def remove_setting(self, setting_name: str):
""""""
self.algo_engine.remove_algo_setting(setting_name)
class LogMonitor(QtWidgets.QTableWidget):
""""""
signal = QtCore.pyqtSignal(Event)
def __init__(self, event_engine: EventEngine):
""""""
super().__init__()
self.event_engine = event_engine
self.init_ui()
self.register_event()
def init_ui(self):
""""""
labels = [
"时间",
"信息"
]
self.setColumnCount(len(labels))
self.setHorizontalHeaderLabels(labels)
self.verticalHeader().setVisible(False)
self.setEditTriggers(self.NoEditTriggers)
self.verticalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.ResizeToContents
)
self.horizontalHeader().setSectionResizeMode(
1,
QtWidgets.QHeaderView.Stretch
)
self.setWordWrap(True)
def register_event(self):
""""""
self.signal.connect(self.process_log_event)
self.event_engine.register(EVENT_ALGO_LOG, self.signal.emit)
def process_log_event(self, event):
""""""
msg = event.data
timestamp = datetime.now().strftime("%H:%M:%S")
timestamp_cell = QtWidgets.QTableWidgetItem(timestamp)
msg_cell = QtWidgets.QTableWidgetItem(msg)
self.insertRow(0)
self.setItem(0, 0, timestamp_cell)
self.setItem(0, 1, msg_cell)
class AlgoManager(QtWidgets.QWidget):
""""""
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.algo_engine = main_engine.get_engine(APP_NAME)
self.algo_widgets = {}
self.init_ui()
self.algo_engine.init_engine()
def init_ui(self):
""""""
self.setWindowTitle("算法交易")
# Left side control widgets
self.template_combo = QtWidgets.QComboBox()
self.template_combo.currentIndexChanged.connect(self.show_algo_widget)
form = QtWidgets.QFormLayout()
form.addRow("算法", self.template_combo)
widget = QtWidgets.QWidget()
widget.setLayout(form)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(widget)
for algo_template in self.algo_engine.algo_templates.values():
widget = AlgoWidget(self.algo_engine, algo_template)
vbox.addWidget(widget)
template_name = algo_template.__name__
display_name = algo_template.display_name
self.algo_widgets[template_name] = widget
self.template_combo.addItem(display_name, template_name)
vbox.addStretch()
stop_all_button = QtWidgets.QPushButton("全部停止")
stop_all_button.setFixedHeight(stop_all_button.sizeHint().height() * 2)
stop_all_button.clicked.connect(self.algo_engine.stop_all)
vbox.addWidget(stop_all_button)
# Right side monitor widgets
active_algo_monitor = ActiveAlgoMonitor(
self.algo_engine, self.event_engine
)
inactive_algo_monitor = InactiveAlgoMonitor(
self.algo_engine, self.event_engine
)
tab1 = QtWidgets.QTabWidget()
tab1.addTab(active_algo_monitor, "执行中")
tab1.addTab(inactive_algo_monitor, "已结束")
log_monitor = LogMonitor(self.event_engine)
tab2 = QtWidgets.QTabWidget()
tab2.addTab(log_monitor, "日志")
setting_monitor = SettingMonitor(self.algo_engine, self.event_engine)
setting_monitor.use_signal.connect(self.use_setting)
tab3 = QtWidgets.QTabWidget()
tab3.addTab(setting_monitor, "配置")
grid = QtWidgets.QGridLayout()
grid.addWidget(tab1, 0, 0, 1, 2)
grid.addWidget(tab2, 1, 0)
grid.addWidget(tab3, 1, 1)
hbox2 = QtWidgets.QHBoxLayout()
hbox2.addLayout(vbox)
hbox2.addLayout(grid)
self.setLayout(hbox2)
self.show_algo_widget()
def show_algo_widget(self):
""""""
ix = self.template_combo.currentIndex()
current_name = self.template_combo.itemData(ix)
for template_name, widget in self.algo_widgets.items():
if template_name == current_name:
widget.show()
else:
widget.hide()
def use_setting(self, setting: dict):
""""""
setting_name = setting["setting_name"]
template_name = setting["template_name"]
widget = self.algo_widgets[template_name]
widget.update_setting(setting_name, setting)
ix = self.template_combo.findData(template_name)
self.template_combo.setCurrentIndex(ix)
self.show_algo_widget()
def show(self):
""""""
self.showMaximized()
def to_text(data: dict):
"""
Convert dict data into string.
"""
buf = []
for key, value in data.items():
key = NAME_DISPLAY_MAP.get(key, key)
buf.append(f"{key}{value}")
text = "".join(buf)
return text

View File

@ -23,9 +23,11 @@ Sample csv file:
import csv
from datetime import datetime
from peewee import chunked
from vnpy.event import EventEngine
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database import DbBarData
from vnpy.trader.database import DbBarData, DB
from vnpy.trader.engine import BaseEngine, MainEngine
@ -39,17 +41,17 @@ class CsvLoaderEngine(BaseEngine):
""""""
super().__init__(main_engine, event_engine, APP_NAME)
self.file_path: str = ''
self.file_path: str = ""
self.symbol: str = ""
self.exchange: Exchange = Exchange.SSE
self.interval: Interval = Interval.MINUTE
self.datetime_head: str = ''
self.open_head: str = ''
self.close_head: str = ''
self.low_head: str = ''
self.high_head: str = ''
self.volume_head: str = ''
self.datetime_head: str = ""
self.open_head: str = ""
self.close_head: str = ""
self.low_head: str = ""
self.high_head: str = ""
self.volume_head: str = ""
def load(
self,
@ -72,33 +74,40 @@ class CsvLoaderEngine(BaseEngine):
end = None
count = 0
with open(file_path, 'rt') as f:
with open(file_path, "rt") as f:
reader = csv.DictReader(f)
db_bars = []
for item in reader:
db_bar = DbBarData()
dt = datetime.strptime(item[datetime_head], datetime_format)
db_bar.symbol = symbol
db_bar.exchange = exchange.value
db_bar.datetime = datetime.strptime(
item[datetime_head], datetime_format
)
db_bar.interval = interval.value
db_bar.volume = item[volume_head]
db_bar.open_price = item[open_head]
db_bar.high_price = item[high_head]
db_bar.low_price = item[low_head]
db_bar.close_price = item[close_head]
db_bar.vt_symbol = vt_symbol
db_bar.gateway_name = "DB"
db_bar = {
"symbol": symbol,
"exchange": exchange.value,
"datetime": dt,
"interval": interval.value,
"volume": item[volume_head],
"open_price": item[open_head],
"high_price": item[high_head],
"low_price": item[low_head],
"close_price": item[close_head],
"vt_symbol": vt_symbol,
"gateway_name": "DB"
}
db_bar.replace()
db_bars.append(db_bar)
# do some statistics
count += 1
if not start:
start = db_bar.datetime
start = db_bar["datetime"]
end = db_bar.datetime
end = db_bar["datetime"]
# Insert into DB
with DB.atomic():
for batch in chunked(db_bars, 50):
DbBarData.insert_many(batch).on_conflict_replace().execute()
return start, end, count

View File

@ -90,7 +90,8 @@ class CsvLoaderWidget(QtWidgets.QWidget):
def select_file(self):
""""""
result: str = QtWidgets.QFileDialog.getOpenFileName(self)
result: str = QtWidgets.QFileDialog.getOpenFileName(
self, filter="CSV (*.csv)")
filename = result[0]
if filename:
self.file_edit.setText(filename)

View File

@ -0,0 +1,17 @@
from pathlib import Path
from vnpy.trader.app import BaseApp
from .engine import BacktesterEngine, APP_NAME
class CtaBacktesterApp(BaseApp):
""""""
app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent
display_name = "CTA回测"
engine_class = BacktesterEngine
widget_name = "BacktesterManager"
icon_name = "backtester.ico"

View File

@ -0,0 +1,200 @@
import os
import importlib
import traceback
from datetime import datetime
from threading import Thread
from pathlib import Path
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.app.cta_strategy import (
CtaTemplate,
BacktestingEngine
)
APP_NAME = "CtaBacktester"
EVENT_BACKTESTER_LOG = "eBacktesterLog"
EVENT_BACKTESTER_BACKTESTING_FINISHED = "eBacktesterBacktestingFinished"
EVENT_BACKTESTER_OPTIMIZATION_FINISHED = "eBacktesterOptimizationFinished"
class BacktesterEngine(BaseEngine):
"""
For running CTA strategy backtesting.
"""
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__(main_engine, event_engine, APP_NAME)
self.classes = {}
self.backtesting_engine = None
self.thread = None
self.result_df = None
self.result_statistics = None
self.load_strategy_class()
def init_engine(self):
""""""
self.write_log("初始化CTA回测引擎")
self.backtesting_engine = BacktestingEngine()
# Redirect log from backtesting engine outside.
self.backtesting_engine.output = self.write_log
self.write_log("策略文件加载完成")
def write_log(self, msg: str):
""""""
event = Event(EVENT_BACKTESTER_LOG)
event.data = msg
self.event_engine.put(event)
def load_strategy_class(self):
"""
Load strategy class from source code.
"""
app_path = Path(__file__).parent.parent
path1 = app_path.joinpath("cta_strategy", "strategies")
self.load_strategy_class_from_folder(
path1, "vnpy.app.cta_strategy.strategies")
path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2, "strategies")
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
if filename.endswith(".py"):
strategy_module_name = ".".join(
[module_name, filename.replace(".py", "")])
self.load_strategy_class_from_module(strategy_module_name)
def load_strategy_class_from_module(self, module_name: str):
"""
Load strategy class from module file.
"""
try:
module = importlib.import_module(module_name)
for name in dir(module):
value = getattr(module, name)
if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate):
self.classes[value.__name__] = value
except: # noqa
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
def get_strategy_class_names(self):
""""""
return list(self.classes.keys())
def run_backtesting(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
setting: dict
):
""""""
self.result_df = None
self.result_statistics = None
engine = self.backtesting_engine
engine.clear_data()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start,
end=end,
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital
)
strategy_class = self.classes[class_name]
engine.add_strategy(
strategy_class,
setting
)
engine.load_data()
engine.run_backtesting()
self.result_df = engine.calculate_result()
self.result_statistics = engine.calculate_statistics(output=False)
# Clear thread object handler.
self.thread = None
# Put backtesting done event
event = Event(EVENT_BACKTESTER_BACKTESTING_FINISHED)
self.event_engine.put(event)
def start_backtesting(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
setting: dict
):
if self.thread:
self.write_log("已有回测在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_backtesting,
args=(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
setting
)
)
self.thread.start()
return True
def get_result_df(self):
""""""
return self.result_df
def get_result_statistics(self):
""""""
return self.result_statistics
def get_default_setting(self, class_name: str):
""""""
strategy_class = self.classes[class_name]
return strategy_class.get_class_parameters()

View File

@ -0,0 +1 @@
from .widget import BacktesterManager

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

View File

@ -0,0 +1,476 @@
from datetime import datetime, timedelta
import pyqtgraph as pg
import numpy as np
from vnpy.event import Event, EventEngine
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
from vnpy.trader.engine import MainEngine
from vnpy.trader.constant import Interval
from ..engine import (
APP_NAME,
EVENT_BACKTESTER_LOG,
EVENT_BACKTESTER_BACKTESTING_FINISHED,
EVENT_BACKTESTER_OPTIMIZATION_FINISHED
)
class BacktesterManager(QtWidgets.QWidget):
""""""
signal_log = QtCore.pyqtSignal(Event)
signal_backtesting_finished = QtCore.pyqtSignal(Event)
signal_optimization_finished = QtCore.pyqtSignal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.backtester_engine = main_engine.get_engine(APP_NAME)
self.class_names = []
self.settings = {}
self.init_strategy_settings()
self.init_ui()
self.register_event()
self.backtester_engine.init_engine()
def init_strategy_settings(self):
""""""
self.class_names = self.backtester_engine.get_strategy_class_names()
for class_name in self.class_names:
setting = self.backtester_engine.get_default_setting(class_name)
self.settings[class_name] = setting
def init_ui(self):
""""""
self.setWindowTitle("CTA回测")
# Setting Part
self.class_combo = QtWidgets.QComboBox()
self.class_combo.addItems(self.class_names)
self.symbol_line = QtWidgets.QLineEdit("IF88.CFFEX")
self.interval_combo = QtWidgets.QComboBox()
for inteval in Interval:
self.interval_combo.addItem(inteval.value)
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=3 * 365)
self.start_date_edit = QtWidgets.QDateEdit(
QtCore.QDate(
start_dt.year,
start_dt.month,
start_dt.day
)
)
self.end_date_edit = QtWidgets.QDateEdit(
QtCore.QDate.currentDate()
)
self.rate_line = QtWidgets.QLineEdit("0.000025")
self.slippage_line = QtWidgets.QLineEdit("0.2")
self.size_line = QtWidgets.QLineEdit("300")
self.pricetick_line = QtWidgets.QLineEdit("0.2")
self.capital_line = QtWidgets.QLineEdit("1000000")
start_button = QtWidgets.QPushButton("开始回测")
start_button.clicked.connect(self.start_backtesting)
form = QtWidgets.QFormLayout()
form.addRow("交易策略", self.class_combo)
form.addRow("本地代码", self.symbol_line)
form.addRow("K线周期", self.interval_combo)
form.addRow("开始日期", self.start_date_edit)
form.addRow("结束日期", self.end_date_edit)
form.addRow("手续费率", self.rate_line)
form.addRow("交易滑点", self.slippage_line)
form.addRow("合约乘数", self.size_line)
form.addRow("价格跳动", self.pricetick_line)
form.addRow("回测资金", self.capital_line)
form.addRow(start_button)
# Result part
self.statistics_monitor = StatisticsMonitor()
self.log_monitor = QtWidgets.QTextEdit()
self.log_monitor.setMaximumHeight(400)
self.chart = BacktesterChart()
self.chart.setMinimumWidth(1000)
# Layout
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(self.statistics_monitor)
vbox.addWidget(self.log_monitor)
hbox = QtWidgets.QHBoxLayout()
hbox.addLayout(form)
hbox.addLayout(vbox)
hbox.addWidget(self.chart)
self.setLayout(hbox)
def register_event(self):
""""""
self.signal_log.connect(self.process_log_event)
self.signal_backtesting_finished.connect(
self.process_backtesting_finished_event)
self.signal_optimization_finished.connect(
self.process_optimization_finished_event)
self.event_engine.register(EVENT_BACKTESTER_LOG, self.signal_log.emit)
self.event_engine.register(
EVENT_BACKTESTER_BACKTESTING_FINISHED, self.signal_backtesting_finished.emit)
self.event_engine.register(
EVENT_BACKTESTER_OPTIMIZATION_FINISHED, self.signal_optimization_finished.emit)
def process_log_event(self, event: Event):
""""""
msg = event.data
timestamp = datetime.now().strftime("%H:%M:%S")
msg = f"{timestamp}\t{msg}"
self.log_monitor.append(msg)
def process_backtesting_finished_event(self, event: Event):
""""""
statistics = self.backtester_engine.get_result_statistics()
self.statistics_monitor.set_data(statistics)
df = self.backtester_engine.get_result_df()
self.chart.set_data(df)
def process_optimization_finished_event(self, event: Event):
""""""
pass
def start_backtesting(self):
""""""
class_name = self.class_combo.currentText()
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start = self.start_date_edit.date().toPyDate()
end = self.end_date_edit.date().toPyDate()
rate = float(self.rate_line.text())
slippage = float(self.slippage_line.text())
size = float(self.size_line.text())
pricetick = float(self.pricetick_line.text())
capital = float(self.capital_line.text())
old_setting = self.settings[class_name]
dialog = SettingEditor(class_name, old_setting)
i = dialog.exec()
if i != dialog.Accepted:
return
new_setting = dialog.get_setting()
self.settings[class_name] = new_setting
result = self.backtester_engine.start_backtesting(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
new_setting
)
if result:
self.statistics_monitor.clear_data()
self.chart.clear_data()
def show(self):
""""""
self.showMaximized()
class StatisticsMonitor(QtWidgets.QTableWidget):
""""""
KEY_NAME_MAP = {
"start_date": "首个交易日",
"end_date": "最后交易日",
"total_days": "总交易日",
"profit_days": "盈利交易日",
"loss_days": "亏损交易日",
"capital": "起始资金",
"end_balance": "结束资金",
"total_return": "总收益率",
"annual_return": "年化收益",
"max_drawdown": "最大回撤",
"max_ddpercent": "百分比最大回撤",
"total_net_pnl": "总盈亏",
"total_commission": "总手续费",
"total_slippage": "总滑点",
"total_turnover": "总成交额",
"total_trade_count": "总成交笔数",
"daily_net_pnl": "日均盈亏",
"daily_commission": "日均手续费",
"daily_slippage": "日均滑点",
"daily_turnover": "日均成交额",
"daily_trade_count": "日均成交笔数",
"daily_return": "日均收益率",
"return_std": "收益标准差",
"sharpe_ratio": "夏普比率",
"return_drawdown_ratio": "收益回撤比"
}
def __init__(self):
""""""
super().__init__()
self.cells = {}
self.init_ui()
def init_ui(self):
""""""
self.setRowCount(len(self.KEY_NAME_MAP))
self.setVerticalHeaderLabels(list(self.KEY_NAME_MAP.values()))
self.setColumnCount(1)
self.horizontalHeader().setVisible(False)
self.horizontalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
for row, key in enumerate(self.KEY_NAME_MAP.keys()):
cell = QtWidgets.QTableWidgetItem()
self.setItem(row, 0, cell)
self.cells[key] = cell
def clear_data(self):
""""""
for cell in self.cells.values():
cell.setText("")
def set_data(self, data: dict):
""""""
data["capital"] = f"{data['capital']:,.2f}"
data["end_balance"] = f"{data['end_balance']:,.2f}"
data["total_return"] = f"{data['total_return']:,.2f}%"
data["annual_return"] = f"{data['annual_return']:,.2f}%"
data["max_drawdown"] = f"{data['max_drawdown']:,.2f}"
data["max_ddpercent"] = f"{data['max_ddpercent']:,.2f}%"
data["total_net_pnl"] = f"{data['total_net_pnl']:,.2f}"
data["total_commission"] = f"{data['total_commission']:,.2f}"
data["total_slippage"] = f"{data['total_slippage']:,.2f}"
data["total_turnover"] = f"{data['total_turnover']:,.2f}"
data["daily_net_pnl"] = f"{data['daily_net_pnl']:,.2f}"
data["daily_commission"] = f"{data['daily_commission']:,.2f}"
data["daily_slippage"] = f"{data['daily_slippage']:,.2f}"
data["daily_turnover"] = f"{data['daily_turnover']:,.2f}"
data["daily_return"] = f"{data['daily_return']:,.2f}%"
data["return_std"] = f"{data['return_std']:,.2f}%"
data["sharpe_ratio"] = f"{data['sharpe_ratio']:,.2f}"
data["return_drawdown_ratio"] = f"{data['return_drawdown_ratio']:,.2f}"
for key, cell in self.cells.items():
value = data.get(key, "")
cell.setText(str(value))
class SettingEditor(QtWidgets.QDialog):
"""
For creating new strategy and editing strategy parameters.
"""
def __init__(
self, class_name: str, parameters: dict
):
""""""
super(SettingEditor, self).__init__()
self.class_name = class_name
self.parameters = parameters
self.edits = {}
self.init_ui()
def init_ui(self):
""""""
form = QtWidgets.QFormLayout()
# Add vt_symbol and name edit if add new strategy
self.setWindowTitle(f"策略参数配置:{self.class_name}")
button_text = "确定"
parameters = self.parameters
for name, value in parameters.items():
type_ = type(value)
edit = QtWidgets.QLineEdit(str(value))
if type_ is int:
validator = QtGui.QIntValidator()
edit.setValidator(validator)
elif type_ is float:
validator = QtGui.QDoubleValidator()
edit.setValidator(validator)
form.addRow(f"{name} {type_}", edit)
self.edits[name] = (edit, type_)
button = QtWidgets.QPushButton(button_text)
button.clicked.connect(self.accept)
form.addRow(button)
self.setLayout(form)
def get_setting(self):
""""""
setting = {}
for name, tp in self.edits.items():
edit, type_ = tp
value_text = edit.text()
if type_ == bool:
if value_text == "True":
value = True
else:
value = False
else:
value = type_(value_text)
setting[name] = value
return setting
class BacktesterChart(pg.GraphicsWindow):
""""""
def __init__(self):
""""""
super().__init__(title="Backtester Chart")
self.dates = {}
self.init_ui()
def init_ui(self):
""""""
pg.setConfigOptions(antialias=True)
# Create plot widgets
self.balance_plot = self.addPlot(
title="账户净值",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.drawdown_plot = self.addPlot(
title="净值回撤",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.pnl_plot = self.addPlot(
title="每日盈亏",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.distribution_plot = self.addPlot(title="盈亏分布")
# Add curves and bars on plot widgets
self.balance_curve = self.balance_plot.plot(
pen=pg.mkPen("#ffc107", width=3)
)
dd_color = "#303f9f"
self.drawdown_curve = self.drawdown_plot.plot(
fillLevel=-0.3, brush=dd_color, pen=dd_color
)
profit_color = 'r'
loss_color = 'g'
self.profit_pnl_bar = pg.BarGraphItem(
x=[], height=[], width=0.3, brush=profit_color, pen=profit_color
)
self.loss_pnl_bar = pg.BarGraphItem(
x=[], height=[], width=0.3, brush=loss_color, pen=loss_color
)
self.pnl_plot.addItem(self.profit_pnl_bar)
self.pnl_plot.addItem(self.loss_pnl_bar)
distribution_color = "#6d4c41"
self.distribution_curve = self.distribution_plot.plot(
fillLevel=-0.3, brush=distribution_color, pen=distribution_color
)
def clear_data(self):
""""""
self.balance_curve.setData([], [])
self.drawdown_curve.setData([], [])
self.profit_pnl_bar.setOpts(x=[], height=[])
self.loss_pnl_bar.setOpts(x=[], height=[])
self.distribution_curve.setData([], [])
def set_data(self, df):
""""""
count = len(df)
self.dates.clear()
for n, date in enumerate(df.index):
self.dates[n] = date
# Set data for curve of balance and drawdown
self.balance_curve.setData(df["balance"])
self.drawdown_curve.setData(df["drawdown"])
# Set data for daily pnl bar
profit_pnl_x = []
profit_pnl_height = []
loss_pnl_x = []
loss_pnl_height = []
for count, pnl in enumerate(df["net_pnl"]):
if pnl >= 0:
profit_pnl_height.append(pnl)
profit_pnl_x.append(count)
else:
loss_pnl_height.append(pnl)
loss_pnl_x.append(count)
self.profit_pnl_bar.setOpts(x=profit_pnl_x, height=profit_pnl_height)
self.loss_pnl_bar.setOpts(x=loss_pnl_x, height=loss_pnl_height)
# Set data for pnl distribution
hist, x = np.histogram(df["net_pnl"], bins="auto")
x = x[:-1]
self.distribution_curve.setData(x, hist)
class DateAxis(pg.AxisItem):
"""Axis for showing date data"""
def __init__(self, dates: dict, *args, **kwargs):
""""""
super().__init__(*args, **kwargs)
self.dates = dates
def tickStrings(self, values, scale, spacing):
""""""
strings = []
for v in values:
dt = self.dates.get(v, "")
strings.append(str(dt))
return strings

View File

@ -7,6 +7,7 @@ from vnpy.trader.utility import BarGenerator, ArrayManager
from .base import APP_NAME, StopOrder
from .engine import CtaEngine
from .backtesting import BacktestingEngine, OptimizationSetting
from .template import CtaTemplate, CtaSignal, TargetPosTemplate

View File

@ -2,6 +2,7 @@ from collections import defaultdict
from datetime import date, datetime
from typing import Callable
from itertools import product
from functools import lru_cache
import multiprocessing
import numpy as np
@ -197,28 +198,18 @@ class BacktestingEngine:
self.output("开始加载历史数据")
if self.mode == BacktestingMode.BAR:
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == self.vt_symbol)
& (DbBarData.interval == self.interval)
& (DbBarData.datetime >= self.start)
& (DbBarData.datetime <= self.end)
)
.order_by(DbBarData.datetime)
self.history_data = load_bar_data(
self.vt_symbol,
self.interval,
self.start,
self.end
)
self.history_data = [db_bar.to_bar() for db_bar in s]
else:
s = (
DbTickData.select()
.where(
(DbTickData.vt_symbol == self.vt_symbol)
& (DbTickData.datetime >= self.start)
& (DbTickData.datetime <= self.end)
)
.order_by(DbTickData.datetime)
self.history_data = load_tick_data(
self.vt_symbol,
self.start,
self.end
)
self.history_data = [db_tick.to_tick() for db_tick in s]
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
@ -293,7 +284,7 @@ class BacktestingEngine:
self.output("逐日盯市盈亏计算完成")
return self.daily_df
def calculate_statistics(self, df: DataFrame = None):
def calculate_statistics(self, df: DataFrame = None, output=True):
""""""
self.output("开始计算策略统计指标")
@ -325,6 +316,7 @@ class BacktestingEngine:
daily_return = 0
return_std = 0
sharpe_ratio = 0
return_drawdown_ratio = 0
else:
# Calculate balance related time series data
df["balance"] = df["net_pnl"].cumsum() + self.capital
@ -373,38 +365,42 @@ class BacktestingEngine:
else:
sharpe_ratio = 0
return_drawdown_ratio = -total_return / max_ddpercent
# Output
self.output("-" * 30)
self.output(f"首个交易日:\t{start_date}")
self.output(f"最后交易日:\t{end_date}")
if output:
self.output("-" * 30)
self.output(f"首个交易日:\t{start_date}")
self.output(f"最后交易日:\t{end_date}")
self.output(f"总交易日:\t{total_days}")
self.output(f"盈利交易日:\t{profit_days}")
self.output(f"亏损交易日:\t{loss_days}")
self.output(f"总交易日:\t{total_days}")
self.output(f"盈利交易日:\t{profit_days}")
self.output(f"亏损交易日:\t{loss_days}")
self.output(f"起始资金:\t{self.capital:,.2f}")
self.output(f"结束资金:\t{end_balance:,.2f}")
self.output(f"起始资金:\t{self.capital:,.2f}")
self.output(f"结束资金:\t{end_balance:,.2f}")
self.output(f"总收益率:\t{total_return:,.2f}%")
self.output(f"年化收益:\t{annual_return:,.2f}%")
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
self.output(f"总收益率:\t{total_return:,.2f}%")
self.output(f"年化收益:\t{annual_return:,.2f}%")
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
self.output(f"总手续费:\t{total_commission:,.2f}")
self.output(f"总滑点:\t{total_slippage:,.2f}")
self.output(f"总成交金额:\t{total_turnover:,.2f}")
self.output(f"总成交笔数:\t{total_trade_count}")
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
self.output(f"总手续费:\t{total_commission:,.2f}")
self.output(f"总滑点:\t{total_slippage:,.2f}")
self.output(f"总成交金额:\t{total_turnover:,.2f}")
self.output(f"总成交笔数:\t{total_trade_count}")
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
self.output(f"日均手续费:\t{daily_commission:,.2f}")
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
self.output(f"日均成交笔数:\t{daily_trade_count}")
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
self.output(f"日均手续费:\t{daily_commission:,.2f}")
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
self.output(f"日均成交笔数:\t{daily_trade_count}")
self.output(f"日均收益率:\t{daily_return:,.2f}%")
self.output(f"收益标准差:\t{return_std:,.2f}%")
self.output(f"Sharpe Ratio\t{sharpe_ratio:,.2f}")
self.output(f"日均收益率:\t{daily_return:,.2f}%")
self.output(f"收益标准差:\t{return_std:,.2f}%")
self.output(f"Sharpe Ratio\t{sharpe_ratio:,.2f}")
self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}")
statistics = {
"start_date": start_date,
@ -412,6 +408,7 @@ class BacktestingEngine:
"total_days": total_days,
"profit_days": profit_days,
"loss_days": loss_days,
"capital": self.capital,
"end_balance": end_balance,
"max_drawdown": max_drawdown,
"max_ddpercent": max_ddpercent,
@ -430,6 +427,7 @@ class BacktestingEngine:
"daily_return": daily_return,
"return_std": return_std,
"sharpe_ratio": sharpe_ratio,
"return_drawdown_ratio": return_drawdown_ratio,
}
return statistics
@ -963,3 +961,45 @@ def optimize(
target_value = statistics[target_name]
return (str(setting), target_value, statistics)
@lru_cache(maxsize=10)
def load_bar_data(
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
""""""
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
)
.order_by(DbBarData.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
@lru_cache(maxsize=10)
def load_tick_data(
vt_symbol: str,
start: datetime,
end: datetime
):
""""""
s = (
DbTickData.select()
.where(
(DbTickData.vt_symbol == vt_symbol)
& (DbTickData.datetime >= start)
& (DbTickData.datetime <= end)
)
.order_by(DbTickData.datetime)
)
data = [db_tick.db_tick() for db_tick in s]
return data

View File

@ -40,6 +40,7 @@ from vnpy.trader.database import DbTickData, DbBarData
from vnpy.trader.setting import SETTINGS
from .base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_STRATEGY,
EVENT_CTA_STOPORDER,
@ -73,7 +74,7 @@ class CtaEngine(BaseEngine):
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super(CtaEngine, self).__init__(
main_engine, event_engine, "CtaStrategy")
main_engine, event_engine, APP_NAME)
self.strategy_setting = {} # strategy_name: dict
self.strategy_data = {} # strategy_name: dict
@ -541,7 +542,7 @@ class CtaEngine(BaseEngine):
DbBarData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval)
& (DbBarData.interval == interval.value)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
)

View File

@ -4,6 +4,7 @@ from typing import Any, Callable
from vnpy.trader.constant import Interval, Direction, Offset
from vnpy.trader.object import BarData, TickData, OrderData, TradeData
from vnpy.trader.utility import virtual
from .base import StopOrder, EngineType
@ -87,48 +88,56 @@ class CtaTemplate(ABC):
}
return strategy_data
@virtual
def on_init(self):
"""
Callback when strategy is inited.
"""
pass
@virtual
def on_start(self):
"""
Callback when strategy is started.
"""
pass
@virtual
def on_stop(self):
"""
Callback when strategy is stopped.
"""
pass
@virtual
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
pass
@virtual
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
pass
@virtual
def on_trade(self, trade: TradeData):
"""
Callback of new trade data update.
"""
pass
@virtual
def on_order(self, order: OrderData):
"""
Callback of new order data update.
"""
pass
@virtual
def on_stop_order(self, stop_order: StopOrder):
"""
Callback of stop order update.
@ -255,12 +264,14 @@ class CtaSignal(ABC):
""""""
self.signal_pos = 0
@virtual
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
pass
@virtual
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
@ -292,6 +303,7 @@ class TargetPosTemplate(CtaTemplate):
)
self.variables.append("target_pos")
@virtual
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
@ -301,12 +313,14 @@ class TargetPosTemplate(CtaTemplate):
if self.trading:
self.trade()
@virtual
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
self.last_bar = bar
@virtual
def on_order(self, order: OrderData):
"""
Callback of new order data update.

View File

@ -71,8 +71,8 @@ class BitmexGateway(BaseGateway):
"Secret": "",
"会话数": 3,
"服务器": ["REAL", "TESTNET"],
"代理地址": "127.0.0.1",
"代理端口": 1080,
"代理地址": "",
"代理端口": "",
}
def __init__(self, event_engine):
@ -91,6 +91,11 @@ class BitmexGateway(BaseGateway):
proxy_host = setting["代理地址"]
proxy_port = setting["代理端口"]
if proxy_port.isdigit():
proxy_port = int(proxy_port)
else:
proxy_port = 0
self.rest_api.connect(key, secret, session_number,
server, proxy_host, proxy_port)

View File

@ -405,7 +405,7 @@ class CtpTdApi(TdApi):
""""""
if not error['ErrorID']:
self.authStatus = True
self.writeLog("交易授权验证成功")
self.gateway.write_log("交易授权验证成功")
self.login()
else:
self.gateway.write_error("交易授权验证失败", error)
@ -418,7 +418,7 @@ class CtpTdApi(TdApi):
self.login_status = True
self.gateway.write_log("交易登录成功")
# Confirm settelment
# Confirm settlement
req = {
"BrokerID": self.brokerid,
"InvestorID": self.userid
@ -487,11 +487,6 @@ class CtpTdApi(TdApi):
)
self.positions[key] = position
# Get contract size, return if size value not collected
size = symbol_size_map.get(position.symbol, None)
if not size:
return
# For SHFE position data update
if position.exchange == Exchange.SHFE:
if data["YdPosition"] and not data["TodayPosition"]:
@ -500,6 +495,9 @@ class CtpTdApi(TdApi):
else:
position.yd_volume = data["Position"] - data["TodayPosition"]
# Get contract size (spread contract has no size value)
size = symbol_size_map.get(position.symbol, 0)
# Calculate previous position cost
cost = position.price * position.volume * size
@ -508,7 +506,7 @@ class CtpTdApi(TdApi):
position.pnl += data["PositionProfit"]
# Calculate average position price
if position.volume:
if position.volume and size:
cost += data["PositionCost"]
position.price = cost / (position.volume * size)
@ -549,13 +547,16 @@ class CtpTdApi(TdApi):
product=product,
size=data["VolumeMultiple"],
pricetick=data["PriceTick"],
option_underlying=data["UnderlyingInstrID"],
option_type=OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
option_strike=data["StrikePrice"],
option_expiry=datetime.strptime(data["ExpireDate"], "%Y%m%d"),
gateway_name=self.gateway_name
)
# For option only
if contract.product == Product.OPTION:
contract.option_underlying = data["UnderlyingInstrID"],
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
contract.option_strike = data["StrikePrice"],
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"),
self.gateway.on_contract(contract)
symbol_exchange_map[contract.symbol] = contract.exchange
@ -662,7 +663,7 @@ class CtpTdApi(TdApi):
"UserID": self.userid,
"BrokerID": self.brokerid,
"AuthCode": self.auth_code,
"ProductInfo": self.product_info
"UserProductInfo": self.product_info
}
self.reqid += 1
@ -678,7 +679,8 @@ class CtpTdApi(TdApi):
req = {
"UserID": self.userid,
"Password": self.password,
"BrokerID": self.brokerid
"BrokerID": self.brokerid,
"UserProductInfo": self.product_info
}
self.reqid += 1

View File

@ -0,0 +1,4 @@
from .huobi_gateway import HuobiGateway

View File

@ -0,0 +1,747 @@
# encoding: UTF-8
"""
火币交易接口
"""
import re
import urllib
import base64
import json
import zlib
import hashlib
import hmac
from copy import copy
from datetime import datetime
from vnpy.event import Event
from vnpy.api.rest import RestClient, Request
from vnpy.api.websocket import WebsocketClient
from vnpy.trader.constant import (
Direction,
Exchange,
Product,
Status,
OrderType
)
from vnpy.trader.gateway import BaseGateway, LocalOrderManager
from vnpy.trader.object import (
TickData,
OrderData,
TradeData,
AccountData,
ContractData,
OrderRequest,
CancelRequest,
SubscribeRequest
)
from vnpy.trader.event import EVENT_TIMER
REST_HOST = "https://api.huobipro.com"
WEBSOCKET_DATA_HOST = "wss://api.huobi.pro/ws" # Market Data
WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # Account and Order
STATUS_HUOBI2VT = {
"submitted": Status.NOTTRADED,
"partial-filled": Status.PARTTRADED,
"filled": Status.ALLTRADED,
"cancelling": Status.CANCELLED,
"partial-canceled": Status.CANCELLED,
"canceled": Status.CANCELLED,
}
ORDERTYPE_VT2HUOBI = {
(Direction.LONG, OrderType.MARKET): "buy-market",
(Direction.SHORT, OrderType.MARKET): "sell-market",
(Direction.LONG, OrderType.LIMIT): "buy-limit",
(Direction.SHORT, OrderType.LIMIT): "sell-limit",
}
ORDERTYPE_HUOBI2VT = {v: k for k, v in ORDERTYPE_VT2HUOBI.items()}
huobi_symbols = set()
symbol_name_map = {}
class HuobiGateway(BaseGateway):
"""
VN Trader Gateway for Huobi connection.
"""
default_setting = {
"API Key": "",
"Secret Key": "",
"会话数": 3,
"代理地址": "",
"代理端口": "",
}
def __init__(self, event_engine):
"""Constructor"""
super(HuobiGateway, self).__init__(event_engine, "HUOBI")
self.order_manager = LocalOrderManager(self)
self.rest_api = HuobiRestApi(self)
self.trade_ws_api = HuobiTradeWebsocketApi(self)
self.market_ws_api = HuobiDataWebsocketApi(self)
def connect(self, setting: dict):
""""""
key = setting["API Key"]
secret = setting["Secret Key"]
session_number = setting["会话数"]
proxy_host = setting["代理地址"]
proxy_port = setting["代理端口"]
if proxy_port.isdigit():
proxy_port = int(proxy_port)
else:
proxy_port = 0
self.rest_api.connect(key, secret, session_number,
proxy_host, proxy_port)
self.trade_ws_api.connect(key, secret, proxy_host, proxy_port)
self.market_ws_api.connect(key, secret, proxy_host, proxy_port)
self.init_query()
def subscribe(self, req: SubscribeRequest):
""""""
self.market_ws_api.subscribe(req)
self.trade_ws_api.subscribe(req)
def send_order(self, req: OrderRequest):
""""""
return self.rest_api.send_order(req)
def cancel_order(self, req: CancelRequest):
""""""
self.rest_api.cancel_order(req)
def query_account(self):
""""""
self.rest_api.query_account_balance()
def query_position(self):
""""""
pass
def close(self):
""""""
self.rest_api.stop()
self.trade_ws_api.stop()
self.market_ws_api.stop()
def process_timer_event(self, event: Event):
""""""
self.count += 1
if self.count < 3:
return
self.query_account()
def init_query(self):
""""""
self.count = 0
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
class HuobiRestApi(RestClient):
"""
HUOBI REST API
"""
def __init__(self, gateway: BaseGateway):
""""""
super(HuobiRestApi, self).__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.order_manager = gateway.order_manager
self.host = ""
self.key = ""
self.secret = ""
self.account_id = ""
self.cancel_requests = {}
self.orders = {}
def sign(self, request):
"""
Generate HUOBI signature.
"""
request.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.71 Safari/537.36"
}
params_with_signature = create_signature(
self.key,
request.method,
self.host,
request.path,
self.secret,
request.params
)
request.params = params_with_signature
if request.method == "POST":
request.headers["Content-Type"] = "application/json"
if request.data:
request.data = json.dumps(request.data)
return request
def connect(
self,
key: str,
secret: str,
session_number: int,
proxy_host: str,
proxy_port: int
):
"""
Initialize connection to REST server.
"""
self.key = key
self.secret = secret
self.host, _ = _split_url(REST_HOST)
self.init(REST_HOST, proxy_host, proxy_port)
self.start(session_number)
self.gateway.write_log("REST API启动成功")
self.query_contract()
self.query_account()
self.query_order()
def query_account(self):
""""""
self.add_request(
method="GET",
path="/v1/account/accounts",
callback=self.on_query_account
)
def query_account_balance(self):
""""""
path = f"/v1/account/accounts/{self.account_id}/balance"
self.add_request(
method="GET",
path=path,
callback=self.on_query_account_balance
)
def query_order(self):
""""""
self.add_request(
method="GET",
path="/v1/order/openOrders",
callback=self.on_query_order
)
def query_contract(self):
""""""
self.add_request(
method="GET",
path="/v1/common/symbols",
callback=self.on_query_contract
)
def send_order(self, req: OrderRequest):
""""""
huobi_type = ORDERTYPE_VT2HUOBI.get(
(req.direction, req.type), ""
)
local_orderid = self.order_manager.new_local_orderid()
order = req.create_order_data(
local_orderid,
self.gateway_name
)
order.time = datetime.now().strftime("%H:%M:%S")
data = {
"account-id": self.account_id,
"amount": str(req.volume),
"symbol": req.symbol,
"type": huobi_type,
"price": str(req.price),
"source": "api"
}
self.add_request(
method="POST",
path="/v1/order/orders/place",
callback=self.on_send_order,
data=data,
extra=order,
on_error=self.on_send_order_error,
on_failed=self.on_send_order_failed
)
self.order_manager.on_order(order)
return order.vt_orderid
def cancel_order(self, req: CancelRequest):
""""""
sys_orderid = self.order_manager.get_sys_orderid(req.orderid)
path = f"/v1/order/orders/{sys_orderid}/submitcancel"
self.add_request(
method="POST",
path=path,
callback=self.on_cancel_order,
extra=req
)
def on_query_account(self, data, request):
""""""
if self.check_error(data, "查询账户"):
return
for d in data["data"]:
if d["type"] == "spot":
self.account_id = d["id"]
self.gateway.write_log(f"账户代码{self.account_id}查询成功")
self.query_account_balance()
def on_query_account_balance(self, data, request):
""""""
if self.check_error(data, "查询账户资金"):
return
buf = {}
for d in data["data"]["list"]:
currency = d["currency"]
currency_data = buf.setdefault(currency, {})
currency_data[d["type"]] = float(d["balance"])
for currency, currency_data in buf.items():
account = AccountData(
accountid=currency,
balance=currency_data["trade"] + currency_data["frozen"],
frozen=currency_data["frozen"],
gateway_name=self.gateway_name,
)
if account.balance:
self.gateway.on_account(account)
def on_query_order(self, data, request):
""""""
if self.check_error(data, "查询委托"):
return
for d in data["data"]:
sys_orderid = d["id"]
local_orderid = self.order_manager.get_local_orderid(sys_orderid)
direction, order_type = ORDERTYPE_HUOBI2VT[d["type"]]
dt = datetime.fromtimestamp(d["created-at"] / 1000)
time = dt.strftime("%H:%M:%S")
order = OrderData(
orderid=local_orderid,
symbol=d["symbol"],
exchange=Exchange.HUOBI,
price=float(d["price"]),
volume=float(d["amount"]),
type=order_type,
direction=direction,
traded=float(d["filled-amount"]),
status=STATUS_HUOBI2VT.get(d["state"], None),
time=time,
gateway_name=self.gateway_name,
)
self.order_manager.on_order(order)
self.gateway.write_log("委托信息查询成功")
def on_query_contract(self, data, request): # type: (dict, Request)->None
""""""
if self.check_error(data, "查询合约"):
return
for d in data["data"]:
base_currency = d["base-currency"]
quote_currency = d["quote-currency"]
name = f"{base_currency.upper()}/{quote_currency.upper()}"
pricetick = 1 / pow(10, d["price-precision"])
size = 1 / pow(10, d["amount-precision"])
contract = ContractData(
symbol=d["symbol"],
exchange=Exchange.HUOBI,
name=name,
pricetick=pricetick,
size=size,
product=Product.SPOT,
gateway_name=self.gateway_name,
)
self.gateway.on_contract(contract)
huobi_symbols.add(contract.symbol)
symbol_name_map[contract.symbol] = contract.name
self.gateway.write_log("合约信息查询成功")
def on_send_order(self, data, request):
""""""
order = request.extra
if self.check_error(data, "委托"):
order.status = Status.REJECTED
self.order_manager.on_order(order)
return
sys_orderid = data["data"]
self.order_manager.update_orderid_map(order.orderid, sys_orderid)
def on_send_order_failed(self, status_code: str, request: Request):
"""
Callback when sending order failed on server.
"""
order = request.extra
order.status = Status.REJECTED
self.gateway.on_order(order)
msg = f"委托失败,状态码:{status_code},信息:{request.response.text}"
self.gateway.write_log(msg)
def on_send_order_error(
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback when sending order caused exception.
"""
order = request.extra
order.status = Status.REJECTED
self.gateway.on_order(order)
# Record exception if not ConnectionError
if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request)
def on_cancel_order(self, data, request):
""""""
cancel_request = request.extra
local_orderid = cancel_request.orderid
order = self.order_manager.get_order_with_local_orderid(local_orderid)
if self.check_error(data, "撤单"):
order.status = Status.REJECTED
else:
order.status = Status.CANCELLED
self.gateway.write_log(f"委托撤单成功:{order.orderid}")
self.order_manager.on_order(order)
def check_error(self, data: dict, func: str = ""):
""""""
if data["status"] != "error":
return False
error_code = data["err-code"]
error_msg = data["err-msg"]
self.gateway.write_log(f"{func}请求出错,代码:{error_code},信息:{error_msg}")
return True
class HuobiWebsocketApiBase(WebsocketClient):
""""""
def __init__(self, gateway):
""""""
super(HuobiWebsocketApiBase, self).__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.key = ""
self.secret = ""
self.sign_host = ""
self.path = ""
def connect(
self,
key: str,
secret: str,
url: str,
proxy_host: str,
proxy_port: int
):
""""""
self.key = key
self.secret = secret
host, path = _split_url(url)
self.sign_host = host
self.path = path
self.init(url, proxy_host, proxy_port)
self.start()
def login(self):
""""""
params = {"op": "auth"}
params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret))
return self.send_packet(params)
def on_login(self, packet):
""""""
pass
@staticmethod
def unpack_data(data):
""""""
return json.loads(zlib.decompress(data, 31))
def on_packet(self, packet):
""""""
if "ping" in packet:
self.send_packet({"pong": packet["ping"]})
elif "err-msg" in packet:
return self.on_error_msg(packet)
elif "op" in packet and packet["op"] == "auth":
return self.on_login()
else:
self.on_data(packet)
def on_data(self, packet):
""""""
print("data : {}".format(packet))
def on_error_msg(self, packet):
""""""
msg = packet["err-msg"]
if msg == "invalid pong":
return
self.gateway.write_log(packet["err-msg"])
class HuobiTradeWebsocketApi(HuobiWebsocketApiBase):
""""""
def __init__(self, gateway):
""""""
super().__init__(gateway)
self.order_manager = gateway.order_manager
self.order_manager.push_data_callback = self.on_data
self.req_id = 0
def connect(self, key, secret, proxy_host, proxy_port):
""""""
super().connect(key, secret, WEBSOCKET_TRADE_HOST, proxy_host, proxy_port)
def subscribe(self, req: SubscribeRequest):
""""""
self.req_id += 1
req = {
"op": "sub",
"cid": str(self.req_id),
"topic": f"orders.{req.symbol}"
}
self.send_packet(req)
def on_connected(self):
""""""
self.gateway.write_log("交易Websocket API连接成功")
self.login()
def on_login(self):
""""""
self.gateway.write_log("交易Websocket API登录成功")
def on_data(self, packet): # type: (dict)->None
""""""
op = packet.get("op", None)
if op != "notify":
return
topic = packet["topic"]
if "orders" in topic:
self.on_order(packet["data"])
def on_order(self, data: dict):
""""""
sys_orderid = str(data["order-id"])
order = self.order_manager.get_order_with_sys_orderid(sys_orderid)
if not order:
self.order_manager.add_push_data(sys_orderid, data)
return
traded_volume = float(data["filled-amount"])
# Push order event
order.traded += traded_volume
order.status = STATUS_HUOBI2VT.get(data["order-state"], None)
self.order_manager.on_order(order)
# Push trade event
if not traded_volume:
return
trade = TradeData(
symbol=order.symbol,
exchange=Exchange.HUOBI,
orderid=order.orderid,
tradeid=str(data["seq-id"]),
direction=order.direction,
price=float(data["price"]),
volume=float(data["filled-amount"]),
time=datetime.now().strftime("%H:%M:%S"),
gateway_name=self.gateway_name,
)
self.gateway.on_trade(trade)
class HuobiDataWebsocketApi(HuobiWebsocketApiBase):
""""""
def __init__(self, gateway):
""""""
super().__init__(gateway)
self.req_id = 0
self.ticks = {}
def connect(self, key: str, secret: str, proxy_host: str, proxy_port: int):
""""""
super().connect(key, secret, WEBSOCKET_DATA_HOST, proxy_host, proxy_port)
def on_connected(self):
""""""
self.gateway.write_log("行情Websocket API连接成功")
def subscribe(self, req: SubscribeRequest):
""""""
symbol = req.symbol
# Create tick data buffer
tick = TickData(
symbol=symbol,
name=symbol_name_map.get(symbol, ""),
exchange=Exchange.HUOBI,
datetime=datetime.now(),
gateway_name=self.gateway_name,
)
self.ticks[symbol] = tick
# Subscribe to market depth update
self.req_id += 1
req = {
"sub": f"market.{symbol}.depth.step0",
"id": str(self.req_id)
}
self.send_packet(req)
# Subscribe to market detail update
self.req_id += 1
req = {
"sub": f"market.{symbol}.detail",
"id": str(self.req_id)
}
self.send_packet(req)
def on_data(self, packet): # type: (dict)->None
""""""
channel = packet.get("ch", None)
if channel:
if "depth.step" in channel:
self.on_market_depth(packet)
elif "detail" in channel:
self.on_market_detail(packet)
elif "err-code" in packet:
code = packet["err-code"]
msg = packet["err-msg"]
self.gateway.write_log(f"错误代码:{code}, 错误信息:{msg}")
def on_market_depth(self, data):
"""行情深度推送 """
symbol = data["ch"].split(".")[1]
tick = self.ticks[symbol]
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
bids = data["tick"]["bids"]
for n in range(5):
price, volume = bids[n]
tick.__setattr__("bid_price_" + str(n + 1), float(price))
tick.__setattr__("bid_volume_" + str(n + 1), float(volume))
asks = data["tick"]["asks"]
for n in range(5):
price, volume = asks[n]
tick.__setattr__("ask_price_" + str(n + 1), float(price))
tick.__setattr__("ask_volume_" + str(n + 1), float(volume))
if tick.last_price:
self.gateway.on_tick(copy(tick))
def on_market_detail(self, data):
"""市场细节推送"""
symbol = data["ch"].split(".")[1]
tick = self.ticks[symbol]
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
tick_data = data["tick"]
tick.open_price = float(tick_data["open"])
tick.high_price = float(tick_data["high"])
tick.low_price = float(tick_data["low"])
tick.last_price = float(tick_data["close"])
tick.volume = float(tick_data["vol"])
if tick.bid_price_1:
self.gateway.on_tick(copy(tick))
def _split_url(url):
"""
将url拆分为host和path
:return: host, path
"""
result = re.match("\w+://([^/]*)(.*)", url) # noqa
if result:
return result.group(1), result.group(2)
def create_signature(api_key, method, host, path, secret_key, get_params=None):
"""
创建签名
:param get_params: dict 使用GET方法时附带的额外参数(urlparams)
:return:
"""
sorted_params = [
("AccessKeyId", api_key),
("SignatureMethod", "HmacSHA256"),
("SignatureVersion", "2"),
("Timestamp", datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S"))
]
if get_params:
sorted_params.extend(list(get_params.items()))
sorted_params = list(sorted(sorted_params))
encode_params = urllib.parse.urlencode(sorted_params)
payload = [method, host, path, encode_params]
payload = "\n".join(payload)
payload = payload.encode(encoding="UTF8")
secret_key = secret_key.encode(encoding="UTF8")
digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest()
signature = base64.b64encode(digest)
params = dict(sorted_params)
params["Signature"] = signature.decode("UTF8")
return params

View File

@ -0,0 +1 @@
from .okex_gateway import OkexGateway

View File

@ -0,0 +1,725 @@
# encoding: UTF-8
"""
"""
import hashlib
import hmac
import sys
import time
import json
import base64
import zlib
from copy import copy
from datetime import datetime
from threading import Lock
from urllib.parse import urlencode
from requests import ConnectionError
from vnpy.api.rest import Request, RestClient
from vnpy.api.websocket import WebsocketClient
from vnpy.trader.constant import (
Direction,
Exchange,
OrderType,
Product,
Status
)
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import (
TickData,
OrderData,
TradeData,
AccountData,
ContractData,
OrderRequest,
CancelRequest,
SubscribeRequest,
)
REST_HOST = "https://www.okex.com"
WEBSOCKET_HOST = "wss://real.okex.com:10442/ws/v3"
STATUS_OKEX2VT = {
"ordering": Status.SUBMITTING,
"open": Status.NOTTRADED,
"part_filled": Status.PARTTRADED,
"filled": Status.ALLTRADED,
"cancelled": Status.CANCELLED,
"cancelling": Status.CANCELLED,
"failure": Status.REJECTED,
}
DIRECTION_VT2OKEX = {Direction.LONG: "buy", Direction.SHORT: "sell"}
DIRECTION_OKEX2VT = {v: k for k, v in DIRECTION_VT2OKEX.items()}
ORDERTYPE_VT2OKEX = {
OrderType.LIMIT: "limit",
OrderType.MARKET: "market"
}
ORDERTYPE_OKEX2VT = {v: k for k, v in ORDERTYPE_VT2OKEX.items()}
instruments = set()
currencies = set()
class OkexGateway(BaseGateway):
"""
VN Trader Gateway for OKEX connection.
"""
default_setting = {
"API Key": "",
"Secret Key": "",
"Passphrase": "",
"会话数": 3,
"代理地址": "",
"代理端口": "",
}
def __init__(self, event_engine):
"""Constructor"""
super(OkexGateway, self).__init__(event_engine, "OKEX")
self.rest_api = OkexRestApi(self)
self.ws_api = OkexWebsocketApi(self)
self.orders = {}
def connect(self, setting: dict):
""""""
key = setting["API Key"]
secret = setting["Secret Key"]
passphrase = setting["Passphrase"]
session_number = setting["会话数"]
proxy_host = setting["代理地址"]
proxy_port = setting["代理端口"]
if proxy_port.isdigit():
proxy_port = int(proxy_port)
else:
proxy_port = 0
self.rest_api.connect(key, secret, passphrase,
session_number, proxy_host, proxy_port)
self.ws_api.connect(key, secret, passphrase, proxy_host, proxy_port)
def subscribe(self, req: SubscribeRequest):
""""""
self.ws_api.subscribe(req)
def send_order(self, req: OrderRequest):
""""""
return self.rest_api.send_order(req)
def cancel_order(self, req: CancelRequest):
""""""
self.rest_api.cancel_order(req)
def query_account(self):
""""""
pass
def query_position(self):
""""""
pass
def close(self):
""""""
self.rest_api.stop()
self.ws_api.stop()
def on_order(self, order: OrderData):
""""""
self.orders[order.orderid] = order
super().on_order(order)
def get_order(self, orderid: str):
""""""
return self.orders.get(orderid, None)
class OkexRestApi(RestClient):
"""
OKEX REST API
"""
def __init__(self, gateway: BaseGateway):
""""""
super(OkexRestApi, self).__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.key = ""
self.secret = ""
self.passphrase = ""
self.order_count = 10000
self.order_count_lock = Lock()
self.connect_time = 0
def sign(self, request):
"""
Generate OKEX signature.
"""
# Sign
# timestamp = str(time.time())
timestamp = get_timestamp()
request.data = json.dumps(request.data)
if request.params:
path = request.path + '?' + urlencode(request.params)
else:
path = request.path
msg = timestamp + request.method + path + request.data
signature = generate_signature(msg, self.secret)
# Add headers
request.headers = {
'OK-ACCESS-KEY': self.key,
'OK-ACCESS-SIGN': signature,
'OK-ACCESS-TIMESTAMP': timestamp,
'OK-ACCESS-PASSPHRASE': self.passphrase,
'Content-Type': 'application/json'
}
return request
def connect(
self,
key: str,
secret: str,
passphrase: str,
session_number: int,
proxy_host: str,
proxy_port: int,
):
"""
Initialize connection to REST server.
"""
self.key = key
self.secret = secret.encode()
self.passphrase = passphrase
self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
self.init(REST_HOST, proxy_host, proxy_port)
self.start(session_number)
self.gateway.write_log("REST API启动成功")
self.query_time()
self.query_contract()
self.query_account()
self.query_order()
def _new_order_id(self):
with self.order_count_lock:
self.order_count += 1
return self.order_count
def send_order(self, req: OrderRequest):
""""""
orderid = f"a{self.connect_time}{self._new_order_id()}"
data = {
"client_oid": orderid,
"type": ORDERTYPE_VT2OKEX[req.type],
"side": DIRECTION_VT2OKEX[req.direction],
"instrument_id": req.symbol
}
if req.type == OrderType.MARKET:
if req.direction == Direction.LONG:
data["notional"] = req.volume
else:
data["size"] = req.volume
else:
data["price"] = req.price
data["size"] = req.volume
order = req.create_order_data(orderid, self.gateway_name)
self.add_request(
"POST",
"/api/spot/v3/orders",
callback=self.on_send_order,
data=data,
extra=order,
on_failed=self.on_send_order_failed,
on_error=self.on_send_order_error,
)
self.gateway.on_order(order)
return order.vt_orderid
def cancel_order(self, req: CancelRequest):
""""""
data = {
"instrument_id": req.symbol,
"client_oid": req.orderid
}
path = "/api/spot/v3/cancel_orders/" + req.orderid
self.add_request(
"POST",
path,
callback=self.on_cancel_order,
data=data,
on_error=self.on_cancel_order_error,
on_failed=self.on_cancel_order_failed,
extra=req
)
def query_contract(self):
""""""
self.add_request(
"GET",
"/api/spot/v3/instruments",
callback=self.on_query_contract
)
def query_account(self):
""""""
self.add_request(
"GET",
"/api/spot/v3/accounts",
callback=self.on_query_account
)
def query_order(self):
""""""
self.add_request(
"GET",
"/api/spot/v3/orders_pending",
callback=self.on_query_order
)
def query_time(self):
""""""
self.add_request(
"GET",
"/api/general/v3/time",
callback=self.on_query_time
)
def on_query_contract(self, data, request):
""""""
for instrument_data in data:
symbol = instrument_data["instrument_id"]
contract = ContractData(
symbol=symbol,
exchange=Exchange.OKEX,
name=symbol,
product=Product.SPOT,
size=1,
pricetick=float(instrument_data["tick_size"]),
gateway_name=self.gateway_name
)
self.gateway.on_contract(contract)
instruments.add(instrument_data["instrument_id"])
currencies.add(instrument_data["base_currency"])
currencies.add(instrument_data["quote_currency"])
self.gateway.write_log("合约信息查询成功")
# Start websocket api after instruments data collected
self.gateway.ws_api.start()
def on_query_account(self, data, request):
""""""
for account_data in data:
account = AccountData(
accountid=account_data["currency"],
balance=float(account_data["balance"]),
frozen=float(account_data["hold"]),
gateway_name=self.gateway_name
)
self.gateway.on_account(account)
self.gateway.write_log("账户资金查询成功")
def on_query_order(self, data, request):
""""""
for order_data in data:
order = OrderData(
symbol=order_data["instrument_id"],
exchange=Exchange.OKEX,
type=ORDERTYPE_OKEX2VT[order_data["type"]],
orderid=order_data["client_oid"],
direction=DIRECTION_OKEX2VT[order_data["side"]],
price=float(order_data["price"]),
volume=float(order_data["size"]),
traded=float(order_data["filled_size"]),
time=order_data["timestamp"][11:19],
status=STATUS_OKEX2VT[order_data["status"]],
gateway_name=self.gateway_name,
)
self.gateway.on_order(order)
self.gateway.write_log("委托信息查询成功")
def on_query_time(self, data, request):
""""""
server_time = data["iso"]
local_time = datetime.utcnow().isoformat()
msg = f"服务器时间:{server_time},本机时间:{local_time}"
self.gateway.write_log(msg)
def on_send_order_failed(self, status_code: str, request: Request):
"""
Callback when sending order failed on server.
"""
order = request.extra
order.status = Status.REJECTED
self.gateway.on_order(order)
msg = f"委托失败,状态码:{status_code},信息:{request.response.text}"
self.gateway.write_log(msg)
def on_send_order_error(
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback when sending order caused exception.
"""
order = request.extra
order.status = Status.REJECTED
self.gateway.on_order(order)
# Record exception if not ConnectionError
if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request)
def on_send_order(self, data, request):
"""Websocket will push a new order status"""
order = request.extra
error_msg = data["error_message"]
if error_msg:
order.status = Status.REJECTED
self.gateway.on_order(order)
self.gateway.write_log(f"委托失败:{error_msg}")
def on_cancel_order_error(
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback when cancelling order failed on server.
"""
# Record exception if not ConnectionError
if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request)
def on_cancel_order(self, data, request):
"""Websocket will push a new order status"""
pass
def on_cancel_order_failed(self, status_code: int, request: Request):
"""If cancel failed, mark order status to be rejected."""
req = request.extra
order = self.gateway.get_order(req.orderid)
if order:
order.status = Status.REJECTED
self.gateway.on_order(order)
def on_failed(self, status_code: int, request: Request):
"""
Callback to handle request failed.
"""
msg = f"请求失败,状态码:{status_code},信息:{request.response.text}"
self.gateway.write_log(msg)
def on_error(
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback to handler request exception.
"""
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
self.gateway.write_log(msg)
sys.stderr.write(
self.exception_detail(exception_type, exception_value, tb, request)
)
class OkexWebsocketApi(WebsocketClient):
""""""
def __init__(self, gateway):
""""""
super(OkexWebsocketApi, self).__init__()
self.ping_interval = 20 # OKEX use 30 seconds for ping
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.key = ""
self.secret = ""
self.passphrase = ""
self.trade_count = 10000
self.connect_time = 0
self.callbacks = {}
self.ticks = {}
def connect(
self,
key: str,
secret: str,
passphrase: str,
proxy_host: str,
proxy_port: int
):
""""""
self.key = key
self.secret = secret.encode()
self.passphrase = passphrase
self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
self.init(WEBSOCKET_HOST, proxy_host, proxy_port)
# self.start()
def unpack_data(self, data):
""""""
return json.loads(zlib.decompress(data, -zlib.MAX_WBITS))
def subscribe(self, req: SubscribeRequest):
"""
Subscribe to tick data upate.
"""
tick = TickData(
symbol=req.symbol,
exchange=req.exchange,
name=req.symbol,
datetime=datetime.now(),
gateway_name=self.gateway_name,
)
self.ticks[req.symbol] = tick
channel_ticker = f"spot/ticker:{req.symbol}"
channel_depth = f"spot/depth5:{req.symbol}"
self.callbacks[channel_ticker] = self.on_ticker
self.callbacks[channel_depth] = self.on_depth
req = {
"op": "subscribe",
"args": [channel_ticker, channel_depth]
}
self.send_packet(req)
def on_connected(self):
""""""
self.gateway.write_log("Websocket API连接成功")
self.login()
def on_disconnected(self):
""""""
self.gateway.write_log("Websocket API连接断开")
def on_packet(self, packet: dict):
""""""
if "event" in packet:
event = packet["event"]
if event == "subscribe":
return
elif event == "error":
msg = packet["message"]
self.gateway.write_log(f"Websocket API请求异常{msg}")
elif event == "login":
self.on_login(packet)
else:
channel = packet["table"]
data = packet["data"]
callback = self.callbacks.get(channel, None)
if callback:
for d in data:
callback(d)
def on_error(self, exception_type: type, exception_value: Exception, tb):
""""""
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
self.gateway.write_log(msg)
sys.stderr.write(self.exception_detail(
exception_type, exception_value, tb))
def login(self):
"""
Need to login befores subscribe to websocket topic.
"""
timestamp = str(time.time())
msg = timestamp + 'GET' + '/users/self/verify'
signature = generate_signature(msg, self.secret)
req = {
"op": "login",
"args": [
self.key,
self.passphrase,
timestamp,
signature.decode("utf-8")
]
}
self.send_packet(req)
self.callbacks['login'] = self.on_login
def subscribe_topic(self):
"""
Subscribe to all private topics.
"""
self.callbacks["spot/ticker"] = self.on_ticker
self.callbacks["spot/depth5"] = self.on_depth
self.callbacks["spot/account"] = self.on_account
self.callbacks["spot/order"] = self.on_order
# Subscribe to order update
channels = []
for instrument_id in instruments:
channel = f"spot/order:{instrument_id}"
channels.append(channel)
req = {
"op": "subscribe",
"args": channels
}
self.send_packet(req)
# Subscribe to account update
channels = []
for currency in currencies:
channel = f"spot/account:{currency}"
channels.append(channel)
req = {
"op": "subscribe",
"args": channels
}
self.send_packet(req)
# Subscribe to BTC/USDT trade for keep connection alive
req = {
"op": "subscribe",
"args": ["spot/trade:BTC-USDT"]
}
self.send_packet(req)
def on_login(self, data: dict):
""""""
success = data.get("success", False)
if success:
self.gateway.write_log("Websocket API登录成功")
self.subscribe_topic()
else:
self.gateway.write_log("Websocket API登录失败")
def on_ticker(self, d):
""""""
symbol = d["instrument_id"]
tick = self.ticks.get(symbol, None)
if not tick:
return
tick.last_price = float(d["last"])
tick.open = float(d["open_24h"])
tick.high = float(d["high_24h"])
tick.low = float(d["low_24h"])
tick.volume = float(d["base_volume_24h"])
tick.datetime = datetime.strptime(
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
self.gateway.on_tick(copy(tick))
def on_depth(self, d):
""""""
for tick_data in d:
symbol = d["instrument_id"]
tick = self.ticks.get(symbol, None)
if not tick:
return
bids = d["bids"]
asks = d["asks"]
for n, buf in enumerate(bids):
price, volume, _ = buf
tick.__setattr__("bid_price_%s" % (n + 1), float(price))
tick.__setattr__("bid_volume_%s" % (n + 1), float(volume))
for n, buf in enumerate(asks):
price, volume, _ = buf
tick.__setattr__("ask_price_%s" % (n + 1), float(price))
tick.__setattr__("ask_volume_%s" % (n + 1), float(volume))
tick.datetime = datetime.strptime(
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
self.gateway.on_tick(copy(tick))
def on_order(self, d):
""""""
order = OrderData(
symbol=d["instrument_id"],
exchange=Exchange.OKEX,
type=ORDERTYPE_OKEX2VT[d["type"]],
orderid=d["client_oid"],
direction=DIRECTION_OKEX2VT[d["side"]],
price=float(d["price"]),
volume=float(d["size"]),
traded=float(d["filled_size"]),
time=d["timestamp"][11:19],
status=STATUS_OKEX2VT[d["status"]],
gateway_name=self.gateway_name,
)
self.gateway.on_order(copy(order))
trade_volume = d.get("last_fill_qty", 0)
if not trade_volume or float(trade_volume) == 0:
return
self.trade_count += 1
tradeid = f"{self.connect_time}{self.trade_count}"
trade = TradeData(
symbol=order.symbol,
exchange=order.exchange,
orderid=order.orderid,
tradeid=tradeid,
direction=order.direction,
price=float(d["last_fill_px"]),
volume=float(trade_volume),
time=d["last_fill_time"][11:19],
gateway_name=self.gateway_name
)
self.gateway.on_trade(trade)
def on_account(self, d):
""""""
account = AccountData(
accountid=d["currency"],
balance=float(d["balance"]),
frozen=float(d["hold"]),
gateway_name=self.gateway_name
)
self.gateway.on_account(copy(account))
def generate_signature(msg: str, secret_key: str):
"""OKEX V3 signature"""
return base64.b64encode(hmac.new(secret_key, msg.encode(), hashlib.sha256).digest())
def get_timestamp():
""""""
now = datetime.utcnow()
timestamp = now.isoformat("T", "milliseconds")
return timestamp + "Z"

View File

@ -1,5 +1,6 @@
# encoding: UTF-8
"""
Author: KeKe
Please install tiger-api before use.
pip install tigeropen
"""
@ -123,6 +124,9 @@ class TigerGateway(BaseGateway):
self.contracts = {}
self.symbol_names = {}
self.push_connected = False
self.subscribed_symbols = set()
def run(self):
""""""
while self.active:
@ -203,23 +207,33 @@ class TigerGateway(BaseGateway):
"""
protocol, host, port = self.client_config.socket_host_port
self.push_client = PushClient(host, port, (protocol == 'ssl'))
self.push_client.connect(
self.client_config.tiger_id, self.client_config.private_key)
self.push_client.quote_changed = self.on_quote_change
self.push_client.asset_changed = self.on_asset_change
self.push_client.position_changed = self.on_position_change
self.push_client.order_changed = self.on_order_change
self.write_log("推送接口连接成功")
self.push_client.connect(
self.client_config.tiger_id, self.client_config.private_key)
def subscribe(self, req: SubscribeRequest):
""""""
self.push_client.subscribe_quote([req.symbol])
self.subscribed_symbols.add(req.symbol)
if self.push_connected:
self.push_client.subscribe_quote([req.symbol])
def on_push_connected(self):
""""""
self.push_connected = True
self.write_log("推送接口连接成功")
self.push_client.subscribe_asset()
self.push_client.subscribe_position()
self.push_client.subscribe_order()
self.push_client.subscribe_quote(list(self.subscribed_symbols))
def on_quote_change(self, tiger_symbol: str, data: list, trading: bool):
""""""
data = dict(data)

1
vnpy/rpc/__init__.py Normal file
View File

@ -0,0 +1 @@
from .vnrpc import RpcServer, RpcClient, RemoteException

36
vnpy/rpc/test_client.py Normal file
View File

@ -0,0 +1,36 @@
from __future__ import print_function
from __future__ import absolute_import
from time import sleep
from vnpy.rpc import RpcClient
class TestClient(RpcClient):
"""
Test RpcClient
"""
def __init__(self, req_address, sub_address):
"""
Constructor
"""
super(TestClient, self).__init__(req_address, sub_address)
def callback(self, topic, data):
"""
Realize callable function
"""
print('client received topic:', topic, ', data:', data)
if __name__ == '__main__':
req_address = 'tcp://localhost:2014'
sub_address = 'tcp://localhost:0602'
tc = TestClient(req_address, sub_address)
tc.subscribeTopic('')
tc.start()
while 1:
print(tc.add(1, 3))
sleep(2)

40
vnpy/rpc/test_server.py Normal file
View File

@ -0,0 +1,40 @@
from __future__ import print_function
from __future__ import absolute_import
from time import sleep, time
from vnpy.rpc import RpcServer
class TestServer(RpcServer):
"""
Test RpcServer
"""
def __init__(self, rep_address, pub_address):
"""
Constructor
"""
super(TestServer, self).__init__(rep_address, pub_address)
self.register(self.add)
def add(self, a, b):
"""
Test function
"""
print('receiving: %s, %s' % (a, b))
return a + b
if __name__ == '__main__':
rep_address = 'tcp://*:2014'
pub_address = 'tcp://*:0602'
ts = TestServer(rep_address, pub_address)
ts.start()
while 1:
content = 'current server time is %s' % time()
print(content)
ts.publish('test', content)
sleep(2)

329
vnpy/rpc/vnrpc.py Normal file
View File

@ -0,0 +1,329 @@
import threading
import traceback
import signal
import zmq
from msgpack import packb, unpackb
from json import dumps, loads
import pickle
p_dumps = pickle.dumps
p_loads = pickle.loads
# Achieve Ctrl-c interrupt recv
signal.signal(signal.SIGINT, signal.SIG_DFL)
class RpcObject(object):
"""
Referred to serialization of packing and unpacking, we offer 3 tools:
1) maspack: higher performance, but usually requires the installation of msgpack related tools;
2) jason: Slightly lower performance but versatility is better, most programming languages have built-in libraries;
3) cPickle: Lower performance and only can be used in Python, but it is very convenient to transfer Python objects directly.
Therefore, it is recommended to use msgpack.
Use json, if you want to communicate with some languages without providing msgpack.
Use cPickle, when the data being transferred contains many custom Python objects.
"""
def __init__(self):
"""
Constructor
Use msgpack as default serialization tool
"""
self.use_msgpack()
def pack(self, data):
""""""
pass
def unpack(self, data):
""""""
pass
def __json_pack(self, data):
"""
Pack with json
"""
return dumps(data)
def __json_unpack(self, data):
"""
Unpack with json
"""
return loads(data)
def __msgpack_pack(self, data):
"""
Pack with msgpack
"""
return packb(data)
def __msgpack_unpack(self, data):
"""
Unpack with msgpack
"""
return unpackb(data)
def __pickle_pack(self, data):
"""
Pack with cPickle
"""
return p_dumps(data)
def __pickle_unpack(self, data):
"""
Unpack with cPickle
"""
return p_loads(data)
def use_json(self):
"""
Use json as serialization tool
"""
self.pack = self.__json_pack
self.unpack = self.__json_unpack
def use_msgpack(self):
"""
Use msgpack as serialization tool
"""
self.pack = self.__msgpack_pack
self.unpack = self.__msgpack_unpack
def use_pickle(self):
"""
Use cPickle as serialization tool
"""
self.pack = self.__pickle_pack
self.unpack = self.__pickle_unpack
class RpcServer(RpcObject):
""""""
def __init__(self, rep_address, pub_address):
"""
Constructor
"""
super(RpcServer, self).__init__()
# Save functions dict: key is fuction name, value is fuction object
self.__functions = {}
# Zmq port related
self.__context = zmq.Context()
self.__socket_rep = self.__context.socket(
zmq.REP) # Reply socket (Requestreply pattern)
self.__socket_rep.bind(rep_address)
# Publish socket (Publishsubscribe pattern)
self.__socket_pub = self.__context.socket(zmq.PUB)
self.__socket_pub.bind(pub_address)
# Woker thread related
self.__active = False # RpcServer status
self.__thread = threading.Thread(target=self.run) # RpcServer thread
def start(self):
"""
Start RpcServer
"""
# Start RpcServer status
self.__active = True
# Start RpcServer thread
if not self.__thread.isAlive():
self.__thread.start()
def stop(self, join=False):
"""
Stop RpcServer
"""
# Stop RpcServer status
self.__active = False
# Wait for RpcServer thread to exit
if join and self.__thread.isAlive():
self.__thread.join()
def run(self):
"""
Run RpcServer functions
"""
while self.__active:
# Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
if not self.__socket_rep.poll(1000):
continue
# Receive request data from Reply socket
reqb = self.__socket_rep.recv()
# Unpack request by deserialization
req = self.unpack(reqb)
# Get function name and parameters
name, args, kwargs = req
# Try to get and execute callable function object; capture exception information if it fails
name = name.decode("UTF-8")
try:
func = self.__functions[name]
r = func(*args, **kwargs)
rep = [True, r]
except Exception as e: # noqa
rep = [False, traceback.format_exc()]
# Pack response by serialization
repb = self.pack(rep)
# send callable response by Reply socket
self.__socket_rep.send(repb)
def publish(self, topic, data):
"""
Publish data
"""
# Serialized data
topic = bytes(topic, "UTF-8")
datab = self.pack(data)
# Send data by Publish socket
# topci must be ascii encoding
self.__socket_pub.send_multipart([topic, datab])
def register(self, func):
"""
Register function
"""
self.__functions[func.__name__] = func
class RpcClient(RpcObject):
""""""
def __init__(self, req_address, sub_address):
"""Constructor"""
super(RpcClient, self).__init__()
# zmq port related
self.__req_address = req_address
self.__sub_address = sub_address
self.__context = zmq.Context()
# Request socket (Requestreply pattern)
self.__socket_req = self.__context.socket(zmq.REQ)
# Subscribe socket (Publishsubscribe pattern)
self.__socket_sub = self.__context.socket(zmq.SUB)
# Woker thread relate, used to process data pushed from server
self.__active = False # RpcClient status
self.__thread = threading.Thread(
target=self.run) # RpcClient thread
def __getattr__(self, name):
"""
Realize remote call function
"""
# Perform remote call task
def dorpc(*args, **kwargs):
# Generate request
req = [name, args, kwargs]
# Pack request by serialization
reqb = self.pack(req)
# Send request and wait for response
self.__socket_req.send(reqb)
repb = self.__socket_req.recv()
# Unpack response by deserialization
rep = self.unpack(repb)
# Return response if successed; Trigger exception if failed
if rep[0]:
return rep[1]
else:
raise RemoteException(rep[1].decode("UTF-8"))
return dorpc
def start(self):
"""
Start RpcClient
"""
# Connect zmq port
self.__socket_req.connect(self.__req_address)
self.__socket_sub.connect(self.__sub_address)
# Start RpcClient status
self.__active = True
# Start RpcClient thread
if not self.__thread.isAlive():
self.__thread.start()
def stop(self):
"""
Stop RpcClient
"""
# Stop RpcClient status
self.__active = False
# Wait for RpcClient thread to exit
if self.__thread.isAlive():
self.__thread.join()
def run(self):
"""
Run RpcClient function
"""
while self.__active:
# Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
if not self.__socket_sub.poll(1000):
continue
# Receive data from subscribe socket
topic, datab = self.__socket_sub.recv_multipart()
# Unpack data by deserialization
data = self.unpack(datab)
# Process data by callable function
topic = topic.decode("UTF-8")
self.callback(topic, data)
def callback(self, topic, data):
"""
Callable function
"""
raise NotImplementedError
def subscribeTopic(self, topic):
"""
Subscribe data
"""
topic = bytes(topic, "UTF-8")
self.__socket_sub.setsockopt(zmq.SUBSCRIBE, topic)
class RemoteException(Exception):
"""
RPC remote exception
"""
def __init__(self, value):
"""
Constructor
"""
self.__value = value
def __str__(self):
"""
Output error message
"""
return self.__value

View File

@ -99,6 +99,8 @@ class Exchange(Enum):
# CryptoCurrency
BITMEX = "BITMEX"
OKEX = "OKEX"
HUOBI = "HUOBI"
class Currency(Enum):

View File

@ -1,13 +1,47 @@
""""""
from peewee import SqliteDatabase, Model, CharField, DateTimeField, FloatField
from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \
SqliteDatabase
from .constant import Exchange, Interval
from .object import BarData, TickData
from .utility import get_file_path
from .setting import SETTINGS
from .utility import resolve_path
DB_NAME = "database.db"
DB = SqliteDatabase(str(get_file_path(DB_NAME)))
def init():
db_settings = SETTINGS['database']
driver = db_settings["driver"]
init_funcs = {
"sqlite": init_sqlite,
"mysql": init_mysql,
"postgresql": init_postgresql,
}
assert driver in init_funcs
del db_settings['driver']
return init_funcs[driver](db_settings)
def init_sqlite(settings: dict):
global DB
database = settings['database']
DB = SqliteDatabase(str(resolve_path(database)))
def init_mysql(settings: dict):
global DB
DB = MySQLDatabase(**settings)
def init_postgresql(settings: dict):
global DB
DB = PostgresqlDatabase(**settings)
init()
class DbBarData(Model):

View File

@ -24,7 +24,7 @@ from .event import (
from .gateway import BaseGateway
from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest
from .setting import SETTINGS
from .utility import Singleton, get_folder_path
from .utility import get_folder_path
class MainEngine:
@ -52,6 +52,7 @@ class MainEngine:
"""
engine = engine_class(self, self.event_engine)
self.engines[engine.engine_name] = engine
return engine
def add_gateway(self, gateway_class: BaseGateway):
"""
@ -59,6 +60,7 @@ class MainEngine:
"""
gateway = gateway_class(self.event_engine)
self.gateways[gateway.gateway_name] = gateway
return gateway
def add_app(self, app_class: BaseApp):
"""
@ -67,7 +69,8 @@ class MainEngine:
app = app_class()
self.apps[app.app_name] = app
self.add_engine(app.engine_class)
engine = self.add_engine(app.engine_class)
return engine
def init_engines(self):
"""
@ -199,8 +202,6 @@ class LogEngine(BaseEngine):
Processes log event and output with logging module.
"""
__metaclass__ = Singleton
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super(LogEngine, self).__init__(main_engine, event_engine, "log")

View File

@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from typing import Any
from copy import copy
from vnpy.event import Event, EventEngine
from .event import (
@ -227,3 +228,124 @@ class BaseGateway(ABC):
Return default setting dict.
"""
return self.default_setting
class LocalOrderManager:
"""
Management tool to support use local order id for trading.
"""
def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway
# For generating local orderid
self.order_prefix = ""
self.order_count = 0
self.orders = {} # local_orderid:order
# Map between local and system orderid
self.local_sys_orderid_map = {}
self.sys_local_orderid_map = {}
# Push order data buf
self.push_data_buf = {} # sys_orderid:data
# Callback for processing push order data
self.push_data_callback = None
# Cancel request buf
self.cancel_request_buf = {} # local_orderid:req
def new_local_orderid(self):
"""
Generate a new local orderid.
"""
self.order_count += 1
local_orderid = str(self.order_count).rjust(8, "0")
return local_orderid
def get_local_orderid(self, sys_orderid: str):
"""
Get local orderid with sys orderid.
"""
local_orderid = self.sys_local_orderid_map.get(sys_orderid, "")
if not local_orderid:
local_orderid = self.new_local_orderid()
self.update_orderid_map(local_orderid, sys_orderid)
return local_orderid
def get_sys_orderid(self, local_orderid: str):
"""
Get sys orderid with local orderid.
"""
sys_orderid = self.local_sys_orderid_map.get(local_orderid, "")
return sys_orderid
def update_orderid_map(self, local_orderid: str, sys_orderid: str):
"""
Update orderid map.
"""
self.sys_local_orderid_map[sys_orderid] = local_orderid
self.local_sys_orderid_map[local_orderid] = sys_orderid
self.check_cancel_request(local_orderid)
self.check_push_data(sys_orderid)
def check_push_data(self, sys_orderid: str):
"""
Check if any order push data waiting.
"""
if sys_orderid not in self.push_data_buf:
return
data = self.push_data_buf.pop(sys_orderid)
if self.push_data_callback:
self.push_data_callback(data)
def add_push_data(self, sys_orderid: str, data: dict):
"""
Add push data into buf.
"""
self.push_data_buf[sys_orderid] = data
def get_order_with_sys_orderid(self, sys_orderid: str):
""""""
local_orderid = self.sys_local_orderid_map.get(sys_orderid, None)
if not local_orderid:
return None
else:
return self.get_order_with_local_orderid(local_orderid)
def get_order_with_local_orderid(self, local_orderid: str):
""""""
order = self.orders[local_orderid]
return copy(order)
def on_order(self, order: OrderData):
"""
Keep an order buf before pushing it to gateway.
"""
self.orders[order.orderid] = copy(order)
self.gateway.on_order(order)
def cancel_order(self, req: CancelRequest):
"""
"""
sys_orderid = self.get_sys_orderid(req.orderid)
if not sys_orderid:
self.cancel_request_buf[req.orderid] = req
return
self.gateway.cancel_order(req)
def check_cancel_request(self, local_orderid: str):
"""
"""
if local_orderid not in self.cancel_request_buf:
return
req = self.cancel_request_buf.pop(local_orderid)
self.gateway.cancel_order(req)

View File

@ -23,10 +23,17 @@ SETTINGS = {
"email.receiver": "",
"rqdata.username": "",
"rqdata.password": ""
"rqdata.password": "",
"database": {
"driver": "sqlite", # sqlite, mysql, postgresql
"database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath
"host": "localhost",
"port": 3306,
"user": "root",
"password": ""
}
}
# Load global setting from json file.
SETTING_FILENAME = "vt_setting.json"
SETTINGS.update(load_json(SETTING_FILENAME))

View File

@ -24,7 +24,7 @@ from .widget import (
AboutDialog,
)
from ..engine import MainEngine
from ..utility import get_icon_path, TRADER_PATH
from ..utility import get_icon_path, TRADER_DIR
class MainWindow(QtWidgets.QMainWindow):
@ -38,7 +38,7 @@ class MainWindow(QtWidgets.QMainWindow):
self.main_engine = main_engine
self.event_engine = event_engine
self.window_title = f"VN Trader [{TRADER_PATH}]"
self.window_title = f"VN Trader [{TRADER_DIR}]"
self.connect_dialogs = {}
self.widgets = {}

View File

@ -3,6 +3,7 @@ General utility functions.
"""
import json
import os
from pathlib import Path
from typing import Callable
@ -12,25 +13,13 @@ import talib
from .object import BarData, TickData
class Singleton(type):
"""
Singleton metaclass,
class A:
__metaclass__ = Singleton
"""
_instances = {}
def __call__(cls, *args, **kwargs):
""""""
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(
*args, **kwargs
)
return cls._instances[cls]
def resolve_path(pattern: str):
env = dict(os.environ)
env.update({"VNPY_TEMP": str(TEMP_DIR)})
return pattern.format(**env)
def get_path(temp_name: str):
def _get_trader_dir(temp_name: str):
"""
Get path where trader is running in.
"""
@ -53,21 +42,21 @@ def get_path(temp_name: str):
return home_path, temp_path
TRADER_PATH, TEMP_PATH = get_path(".vntrader")
TRADER_DIR, TEMP_DIR = _get_trader_dir(".vntrader")
def get_file_path(filename: str):
"""
Get path for temp file with filename.
"""
return TEMP_PATH.joinpath(filename)
return TEMP_DIR.joinpath(filename)
def get_folder_path(folder_name: str):
"""
Get path for temp folder with folder name.
"""
folder_path = TEMP_PATH.joinpath(folder_name)
folder_path = TEMP_DIR.joinpath(folder_name)
if not folder_path.exists():
folder_path.mkdir()
return folder_path
@ -385,3 +374,12 @@ class ArrayManager(object):
if array:
return up, down
return up[-1], down[-1]
def virtual(func: "callable"):
"""
mark a function as "virtual", which means that this function can be override.
any base class should use this or @abstractmethod to decorate all functions
that can be (re)implemented by subclasses.
"""
return func