[Add] CTA engine and template for creating strategies

This commit is contained in:
vn.py 2019-01-19 13:12:29 +08:00
parent d8dae9e680
commit 499ffd9491
9 changed files with 784 additions and 18 deletions

0
vnpy/app/__init__.py Normal file
View File

View File

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
from vnpy.trader.app import BaseApp from vnpy.trader.app import BaseApp
from .cta_engine import CtaEngine from .engine import CtaEngine
class CtaStrategyApp(BaseApp): class CtaStrategyApp(BaseApp):

View File

@ -0,0 +1,59 @@
"""
Defines constants and objects used in CtaStrategy App.
"""
from enum import Enum
from dataclasses import dataclass
from typing import Any
from vnpy.trader.constant import Direction, Offset
STOPORDER_PREFIX = "STOP."
class CtaOrderType(Enum):
BUY = "买开"
SELL = "买开"
SHORT = "买开"
COVER = "买开"
class StopOrderStatus(Enum):
WAITING = "等待中"
CANCELLED = "已撤销"
TRIGGERED = "已触发"
class EngineType(Enum):
LIVE = "实盘"
BACKTESTING = "回测"
@dataclass
class StopOrder:
vt_symbol: str
order_type: CtaOrderType
direction: Direction
offset: Offset
price: float
volume: float
stop_orderid: str
strategy: Any
status: StopOrderStatus = StopOrderStatus.WAITING
vt_orderid: str = ""
EVENT_CTA_LOG = 'eCtaLog'
EVENT_CTA_STRATEGY = 'eCtaStrategy'
EVENT_CTA_STOPORDER = 'eCtaStopOrder'
ORDER_CTA2VT = {
CtaOrderType.BUY: (Direction.LONG,
Offset.OPEN),
CtaOrderType.SELL: (Direction.SHORT,
Offset.CLOSE),
CtaOrderType.SHORT: (Direction.SHORT,
Offset.OPEN),
CtaOrderType.COVER: (Direction.LONG,
Offset.CLOSE),
}

View File

@ -1,13 +1,584 @@
"""""" """"""
from vnpy.event import EventEngine import os
import importlib
import traceback
import shelve
from typing import Callable, Any
from collections import defaultdict
from pathlib import Path
from vnpy.event import EventEngine, Event
from vnpy.trader.engine import BaseEngine, MainEngine from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.object import (
OrderRequest,
CancelRequest,
SubscribeRequest,
LogData,
TickData
)
from vnpy.trader.event import EVENT_TICK, EVENT_ORDER, EVENT_TRADE
from vnpy.trader.constant import Direction, Offset, Exchange, PriceType
from .template import CtaTemplate
from .base import (
STOPORDER_PREFIX,
CtaOrderType,
EngineType,
StopOrderStatus,
StopOrder,
ORDER_CTA2VT,
EVENT_CTA_LOG,
EVENT_CTA_STRATEGY,
EVENT_CTA_STOPORDER
)
class CtaEngine(BaseEngine): class CtaEngine(BaseEngine):
""""""
filename = "CtaStrategy.vt"
def __init__(self, main_engine: MainEngine, event_engine: EventEngine): def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super(CtaEngine, super(CtaEngine,
self).__init__(main_engine, self).__init__(main_engine,
event_engine, event_engine,
"CtaStrategy") "CtaStrategy")
self._engine_type = EngineType.LIVE # live trading engine
self.setting_file = None # setting file object
self._strategy_classes = {} # class_name: stategy_class
self._strategies = {} # name: strategy
self._symbol_strategy_map = defaultdict(list) # vt_symbol: strategy list
self._orderid_strategy_map = {} # vt_orderid: strategy
self._active_orderids = defaultdict(set) # name: active orderid list
self._stop_order_count = 0 # for generating stop_orderid
self._stop_orders = {} # stop_orderid: stop_order
self.load_strategy_class()
self.load_setting()
self.register_event()
def close(self):
""""""
self.save_setting()
def register_event(self):
""""""
self.event_engine.register(EVENT_TICK, self.process_tick_event)
self.event_engine.register(EVENT_ORDER, self.process_order_event)
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
def process_tick_event(self, event: Event):
""""""
tick = event.data
strategies = self._symbol_strategy_map[tick.vt_symbol]
if not strategies:
return
self.check_stop_order(tick)
for strategy in strategies:
if strategy._inited:
self.call_strategy_func(strategy, strategy.on_tick, tick)
def process_order_event(self, event: Event):
""""""
order = event.data
strategy = self._orderid_strategy_map.get(order.vt_orderid, None)
if not strategy:
return
# Remove vt_orderid if order is no longer active.
vt_orderids = self._active_orderids[strategy.name]
if order.vt_orderid in vt_orderids and not order.is_active():
vt_orderids.remove(order.vt_orderid)
self.call_strategy_func(strategy, strategy.on_order, order)
def process_trade_event(self, event: Event):
""""""
trade = event.data
strategy = self._orderid_strategy_map.get(trade.vt_orderid, None)
if not strategy:
return
if trade.direction == Direction.LONG:
strategy._pos += trade.volume
else:
strategy._pos -= trade.volume
self.call_strategy_func(strategy, strategy.on_trade, trade)
def check_stop_order(self, tick: TickData):
""""""
for stop_order in self._stop_orders.values():
if stop_order.vt_symbol != tick.vt_symbol:
continue
long_triggered = (
so.direction == Direction.LONG
and tick.last_price >= stop_order.price
)
short_triggered = (
so.direction == Direction.SHORT
and tick.last_price <= stop_order.price
)
if long_triggered or short_triggered:
strategy = stop_order.strategy
# To get excuted immediately after stop order is
# triggered, use limit price if available, otherwise
# use ask_price_5 or bid_price_5
if so.direction == Direction.LONG:
if tick.limit_up:
price = tick.limit_up
else:
price = tick.ask_price_5
else:
if tick.limit_down:
price = tick.limit_down
else:
price = tick.bid_price_5
vt_orderid = self.send_limit_order(
strategy,
stop_order.order_type,
price,
stop_order.volume
)
# Update stop order status if placed successfully
if vt_orderid:
# Remove from relation map.
self._stop_orders.pop(stop_order.stop_orderid)
vt_orderids = self._active_orderids[strategy.name]
if stop_orderid in vt_orderids:
vt_orderids.remove(stop_orderid)
# Change stop order status to cancelled and update to strategy.
stop_order.status = StopOrderStatus.TRIGGERED
stop_order.vt_orderid = vt_orderid
self.call_strategy_func(
strategy,
strategy.on_stop_order,
stop_order
)
def send_limit_order(
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float
):
"""
Send a new order.
"""
contract = self.main_engine.get_contract(strategy.vt_symbol)
if not contract:
self.write_log(f"委托失败,找不到合约:{strategy.vt_symbol}", strategy)
return ""
direction, offset = ORDER_CTA2VT[order_type]
# Create request and send order.
req = OrderRequest(
symbol=contract.symbol,
exchange=contract.exchange,
dierction=direction,
offset=offset,
price_type=PriceType.LIMIT,
price=price,
volume=volume
)
vt_orderid = self.main_engine.send_limit_order(
req,
contract.gateway_name
)
# Save relationship between orderid and strategy.
self._orderid_strategy_map[vt_orderid] = strategy
vt_orderids = self._active_orderids[strategy.name]
vt_orderids.add(vt_orderid)
return vt_orderid
def send_stop_order(
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float
):
"""
Send a new order.
"""
self._stop_order_count += 1
direction, offset = ORDER_CTA2VT[order_type]
stop_orderid = f"{STOPORDER_PREFIX}.{self._stop_order_count}"
stop_order = StopOrder(
vt_symbol=strategy.vt_symbol,
direction=direction,
offset=offset,
price=price,
volume=volume,
stop_orderid=stop_orderid,
strategy=strategy
)
self._stop_orders[stop_orderid] = stop_order
vt_orderids = self._active_orderids[strategy.name]
vt_orderids.add(stop_orderid)
self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
return stop_orderid
def cancel_limit_order(self, vt_orderid: str):
"""
Cancel existing order by vt_orderid.
"""
order = self.main_engine.get_order(vt_orderid)
if not order:
self.write_log(f"撤单失败,找不到委托{vt_orderid}", strategy)
return
req = order.create_cancel_request()
self.main_engine.cancel_limit_order(req, order.gateway_name)
def cancel_stop_order(self, stop_orderid: str):
"""
Cancel a local stop order.
"""
stop_order = self._stop_orders.get(stop_orderid, None)
if not stop_order:
return
strategy = stop_order.strategy
# Remove from relation map.
self._stop_orders.pop(stop_orderid)
vt_orderids = self._active_orderids[strategy.name]
if stop_orderid in vt_orderids:
vt_orderids.remove(stop_orderid)
# Change stop order status to cancelled and update to strategy.
stop_order.status = StopOrderStatus.CANCELLED
self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
def send_order(
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float,
stop: bool
):
"""
"""
if stop:
return self.send_stop_order(strategy, order_type, price, volume)
else:
return self.send_limit_order(strategy, order_type, price, volume)
def cancel_order(self, vt_orderid: str):
"""
"""
if vt_orderid.startswith(STOPORDER_PREFIX):
self.cancel_stop_order(vt_orderid)
else:
self.cancel_limit_order(vt_orderid)
def cancel_all(self, strategy: CtaTemplate):
"""
Cancel all active orders of a strategy.
"""
vt_orderids = self._active_orderids[strategy.name]
if not vt_orderids:
return
for vt_orderid in vt_orderids:
self.cancel_limit_order(vt_orderid)
def get_engine_type(self):
""""""
return self._engine_type
def call_strategy_func(
self,
strategy: CtaTemplate,
func: Callable,
params: Any = None
):
"""
Call function of a strategy and catch any exception raised.
"""
try:
if params:
func(params)
else:
func()
except Exception:
strategy._trading = False
strategy._inited = False
msg = f"触发异常已停止\n{traceback.format_exc()}"
self.write_log(msg, strategy)
def add_strategy(self, setting):
"""
Add a new strategy.
"""
name = setting["name"]
if name in self._strategies:
self.write_log(f"创建策略失败,存在重名{name}")
return
class_name = setting["class_name"]
strategy_class = self._strategy_classes[class_name]
strategy = strategy_class(self, setting)
self._strategies[name] = strategy
# Add vt_symbol to strategy map.
strategies = self._symbol_strategy_map[strategy.vt_symbol]
strategies.append(strategy)
# Update to setting file.
self.update_setting(setting)
self.put_strategy_event()
def init_strategy(self, name):
"""
Init a strategy.
"""
strategy = self._strategies[name]
self.call_strategy_func(strategy, strategy.on_init)
strategy._inited = True
# Subscribe market data
contract = self.main_engine.get_contract(strategy.vt_symbol)
if not contract:
self.write_log(f"行情订阅失败,找不到合约{strategy.vt_symbol}", strategy)
self.put_strategy_event()
def start_strategy(self, name):
"""
Start a strategy.
"""
strategy = self._strategies[name]
self.call_strategy_func(strategy, strategy.on_start)
strategy._trading = True
self.put_strategy_event()
def stop_strategy(self, name):
"""
Stop a strategy.
"""
strategy = self._strategies[name]
self.call_strategy_func(strategy, strategy.on_start)
strategy._trading = False
self.put_strategy_event()
def edit_strategy(self, setting):
"""
Edit parameters of a strategy.
"""
name = setting["name"]
strategy = self._strategies[name]
for name in strategy.parameters:
setattr(strategy, name, setting[name])
self.put_strategy_event(strategy)
def remove_strategy(self, name):
"""
Remove a strategy.
"""
# Remove setting
self.remove_setting(name)
# Remove from symbol strategy map
strategy = self._strategies[name]
strategies = self._symbol_strategy_map[strategy.vt_symbol]
strategies.remove(strategy)
# Remove from active orderid map
if name in self._active_orderids:
vt_orderids = self._active_orderids.pop(name)
# Remove vt_orderid strategy map
for vt_orderid in vt_orderids:
self._orderid_strategy_map.pop(vt_orderid)
# Remove from strategies
self._strategies.pop(name)
def load_strategy_class(self):
"""
Load strategy class from source code.
"""
path1 = Path(__file__).parent.joinpath("strategies")
self.load_strategy_class_from_folder(path1, __module__)
path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2)
def load_strategy_class_from_folder(
self,
path: Path,
module_name: str = ""
):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(path):
for name in filenames:
module_name = ".".join([module_name, name.replace(".py", "")])
self.load_strategy_class_from_module(module_name)
def load_strategy_class_from_module(self, module_name: str):
"""
Load strategy class from module file.
"""
try:
module = importlib.import_module(module_name)
for name in dir(module):
value = getattr(module, name)
if isinstance(value, CtaTemplate):
self._strategy_classes[value.__name__] = value
except:
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
def get_all_strategy_class_names(self):
"""
Return names of strategy classes loaded.
"""
return list(self._strategy_classes.keys())
def get_strategy_class_parameters(self, class_name: str):
"""
Get default parameters of a strategy.
"""
strategy_class = self._strategy_classes[class_name]
parameters = {}
for name in strategy_class.parameters:
parameters[name] = getattr(strategy_class, name)
return parameters
def init_all_strategies(self):
"""
"""
for name in self._strategies.keys():
self.init_strategy(name)
def start_all_strategies(self):
"""
"""
for name in self._strategies.keys():
self.start_strategy(name)
def stop_all_strategies(self):
"""
"""
for name in self._strategies.keys():
self.stop_strategy(name)
def load_setting(self):
"""
Load setting file.
"""
self.setting_file = shelve.open(self.filename)
for setting in list(self.setting_file.values()):
self.add_strategy(setting)
def update_setting(self, setting: dict):
"""
Update setting file.
"""
self.setting_file[new_setting["name"]] = new_setting
self.setting_file.sync()
def remove_setting(self, name: str):
"""
Update setting file.
"""
if name not in self.setting_file:
return
self.setting_file.pop(name)
self.setting_file.sync()
def save_setting(self):
"""
Save and close setting file.
"""
self.setting_file.close()
def put_stop_order_event(self, stop_order: StopOrder):
"""
Put an event to update stop order status.
"""
event = Event(EVENT_CTA_STOPORDER, stop_order)
self.event_engine.put(event)
def put_strategy_event(self, strategy: CtaTemplate):
"""
Put an event to update strategy status.
"""
parameters = {}
for name in strategy.parameters:
parameters[name] = getattr(strategy, name)
variables = {}
for name in strategy.variables:
variables[name] = getattr(strategy, name)
data = {
"name": name,
"inited": strategy._inited,
"trading": strategy._trading,
"pos": strategy._pos,
"author": strategy.author,
"vt_symbol": strategy.vt_symbol,
"parameters": parameters,
"variables": variables
}
event = Event(EVENT_CTA_STRATEGY, data)
self.event_engine.put(event)
def write_log(self, msg: str, strategy: CtaTemplate = None):
"""
Create cta engine log event.
"""
if strategy:
msg = f"{strategy.name}: {msg}"
log = LogData(msg=msg, gateway_name="CtaStrategy")
event = Event(type=EVENT_CTA_LOG, data=log)
self.event_engine.put(event)

View File

@ -2,10 +2,138 @@
from abc import ABC from abc import ABC
from vnpy.trader.engine import BaseEngine
from vnpy.trader.object import TickData, OrderData, TradeData, BarData
from .base import CtaOrderType, StopOrder
class CtaTemplate(ABC): class CtaTemplate(ABC):
"""""" """"""
def __init__(self, engine): _inited = False
_trading = False
_pos = 0
author = ""
vt_symbol = ""
parameters = []
variables = []
def __init__(self, engine: BaseEngine, setting: dict):
"""""" """"""
self.engine = engine self.engine = engine
self.vt_symbol = setting["vt_symbol"]
for name in self.parameters:
if name in setting:
setattr(self, name, setting[name])
def on_init(self):
"""
Callback when strategy is inited.
"""
pass
def on_start(self):
"""
Callback when strategy is started.
"""
pass
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
pass
def on_trade(self, trade: TradeData):
"""
Callback of new trade data update.
"""
pass
def on_order(self, order: OrderData):
"""
Callback of new order data update.
"""
pass
def on_stop_order(self, stop_order: StopOrder):
"""
Callback of stop order update.
"""
pass
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
pass
def buy(self, price: float, volume: float, stop: bool = False):
"""
Send buy order to open a long position.
"""
return self.send_order(CtaOrderType.BUY, price, volume, stop)
def sell(self, price: float, volume: float, stop: bool = False):
"""
Send sell order to close a long position.
"""
return self.send_order(CtaOrderType.SELL, price, volume, stop)
def short(self, price: float, volume: float, stop: bool = False):
"""
Send short order to open as short position.
"""
return self.send_order(CtaOrderType.SHORT, price, volume, stop)
def cover(self, price: float, volume: float, stop: bool = False):
"""
Send cover order to close a short position.
"""
return self.send_order(CtaOrderType.COVER, price, volume, stop)
def send_order(
self,
order_type: CtaOrderType,
price: float,
volume: float,
stop: bool = False
):
"""
Send a new order.
"""
return self.engine.send_order(self, order_type, price, volume, stop)
def cancel_order(self, vt_orderid):
"""
Cancel an existing order.
"""
self.engine.cancel_order(vt_orderid)
def cancel_all(self):
"""
Cancel all orders sent by strategy.
"""
self.engine.cancel_all(self)
def write_log(self, msg):
"""
Write a log message.
"""
self.engine.write_log(self, msg)
def get_engine_type(self):
"""
Return whether the engine is backtesting or live trading.
"""
return self.engine.get_engine_type()
def get_pos(self):
"""
Return current net position of the strategy.
"""
return self._pos

View File

@ -253,10 +253,10 @@ class OrderRequest:
Request sending to specific gateway for creating a new order. Request sending to specific gateway for creating a new order.
""" """
symbol: str symbol: str
exchange: Exchange
direction: Direction direction: Direction
price_type: str price_type: str
volume: float volume: float
exchange: Exchange
price: float = 0 price: float = 0
offset: Offset = Offset.NONE offset: Offset = Offset.NONE

View File

@ -679,7 +679,7 @@ class ConnectDialog(QtWidgets.QDialog):
self.main_engine = main_engine self.main_engine = main_engine
self.gateway_name = gateway_name self.gateway_name = gateway_name
self.file_name = f"Connect{gateway_name}.vt" self.filename = f"Connect{gateway_name}.vt"
self.widgets = {} self.widgets = {}
@ -695,7 +695,7 @@ class ConnectDialog(QtWidgets.QDialog):
) )
# Saved setting provides field data used last time. # Saved setting provides field data used last time.
loaded_setting = load_setting(self.file_name) loaded_setting = load_setting(self.filename)
# Initialize line edits and form layout based on setting. # Initialize line edits and form layout based on setting.
form = QtWidgets.QFormLayout() form = QtWidgets.QFormLayout()
@ -742,7 +742,7 @@ class ConnectDialog(QtWidgets.QDialog):
self.main_engine.connect(setting, self.gateway_name) self.main_engine.connect(setting, self.gateway_name)
save_setting(self.file_name, setting) save_setting(self.filename, setting)
self.accept() self.accept()

View File

@ -34,9 +34,9 @@ def get_trader_path():
return home_path return home_path
def get_temp_path(file_name: str): def get_temp_path(filename: str):
""" """
Get path for temp file with file_name. Get path for temp file with filename.
""" """
trader_path = get_trader_path() trader_path = get_trader_path()
temp_path = trader_path.joinpath('.vntrader') temp_path = trader_path.joinpath('.vntrader')
@ -44,35 +44,43 @@ def get_temp_path(file_name: str):
if not temp_path.exists(): if not temp_path.exists():
temp_path.mkdir() temp_path.mkdir()
return temp_path.joinpath(file_name) return temp_path.joinpath(filename)
def get_icon_path(file_path: str, ico_name: str): def get_icon_path(filepath: str, ico_name: str):
""" """
Get path for icon file with ico name. Get path for icon file with ico name.
""" """
ui_path = Path(file_path).parent ui_path = Path(filepath).parent
icon_path = ui_path.joinpath("ico", ico_name) icon_path = ui_path.joinpath("ico", ico_name)
return str(icon_path) return str(icon_path)
def load_setting(file_name: str): def load_setting(filename: str):
""" """
Load setting from shelve file in temp path. Load setting from shelve file in temp path.
""" """
file_path = get_temp_path(file_name) filepath = get_temp_path(filename)
f = shelve.open(str(file_path)) f = shelve.open(str(filepath))
setting = dict(f) setting = dict(f)
f.close() f.close()
return setting return setting
def save_setting(file_name: str, setting: dict): def save_setting(filename: str, setting: dict):
""" """
Save setting into shelve file in temp path. Save setting into shelve file in temp path.
""" """
file_path = get_temp_path(file_name) filepath = get_temp_path(filename)
f = shelve.open(str(file_path)) f = shelve.open(str(filepath))
for k, v in setting.items(): for k, v in setting.items():
f[k] = v f[k] = v
f.close() f.close()
def round_to_pricetick(price: float, pricetick: float):
"""
Round price to price tick value.
"""
rounded = round(price / pricetick, 0) * pricetick
return rounded