[新功能] 股票Cta App

This commit is contained in:
msincenselee 2020-04-30 00:08:28 +08:00
parent 8e45b0f4b3
commit f8b24c9be1
11 changed files with 6119 additions and 0 deletions

View File

@ -0,0 +1,34 @@
from pathlib import Path
from vnpy.trader.app import BaseApp
from .base import APP_NAME, StopOrder
from .engine import CtaEngine
from .template import (
Exchange,
Direction,
Offset,
Status,
Color,
Interval,
TickData,
BarData,
TradeData,
OrderData,
CtaPolicy,
StockPolicy,
CtaTemplate, CtaStockTemplate) # noqa
from vnpy.trader.utility import BarGenerator, ArrayManager # noqa
class CtaStockApp(BaseApp):
""""""
app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent
display_name = "股票CTA策略"
engine_class = CtaEngine
widget_name = "CtaManager"
icon_name = "cta.ico"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,53 @@
"""
Defines constants and objects used in CtaCrypto App.
"""
from dataclasses import dataclass, field
from enum import Enum
from datetime import timedelta
from vnpy.trader.constant import Direction, Offset, Interval
APP_NAME = "CtaStock"
STOPORDER_PREFIX = "STOP"
class StopOrderStatus(Enum):
WAITING = "等待中"
CANCELLED = "已撤销"
TRIGGERED = "已触发"
class EngineType(Enum):
LIVE = "实盘"
BACKTESTING = "回测"
class BacktestingMode(Enum):
BAR = 1
TICK = 2
@dataclass
class StopOrder:
vt_symbol: str
direction: Direction
offset: Offset
price: float
volume: float
stop_orderid: str
strategy_name: str
lock: bool = False
vt_orderids: list = field(default_factory=list)
status: StopOrderStatus = StopOrderStatus.WAITING
gateway_name: str = None
EVENT_CTA_LOG = "eCtaLog"
EVENT_CTA_STRATEGY = "eCtaStrategy"
EVENT_CTA_STOPORDER = "eCtaStopOrder"
INTERVAL_DELTA_MAP = {
Interval.MINUTE: timedelta(minutes=1),
Interval.HOUR: timedelta(hours=1),
Interval.DAILY: timedelta(days=1),
}

1717
vnpy/app/cta_stock/engine.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,484 @@
# encoding: UTF-8
'''
本文件中包含的是CTA模块的组合回测引擎回测引擎的API和CTA引擎一致
可以使用和实盘相同的代码进行回测
华富资产 李来佳
'''
from __future__ import division
import sys
import os
import gc
import pandas as pd
import traceback
import random
import bz2
import pickle
from datetime import datetime, timedelta
from time import sleep
from vnpy.trader.object import (
TickData,
BarData,
RenkoBarData,
)
from vnpy.trader.constant import (
Exchange,
)
from vnpy.trader.utility import (
get_trading_date,
extract_vt_symbol,
)
from .back_testing import BackTestingEngine
class PortfolioTestingEngine(BackTestingEngine):
"""
CTA组合回测引擎, 使用回测引擎作为父类
函数接口和策略引擎保持一样
从而实现同一套代码从回测到实盘
针对1分钟bar的回测 或者tick回测
导入CTA_Settings
"""
def __init__(self, event_engine=None):
"""Constructor"""
super().__init__(event_engine)
self.bar_csv_file = {}
self.bar_df_dict = {} # 历史数据的df回测用
self.bar_df = None # 历史数据的df时间+symbol作为组合索引
self.bar_interval_seconds = 60 # bar csv文件属于K线类型K线的周期秒数,缺省是1分钟
self.tick_path = None # tick级别回测 路径
def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None):
"""
加载回测bar数据到DataFrame
1. 增加前复权/后复权
:param vt_symbol:
:param bar_file:
:param data_start_date:
:param data_end_date:
:return:
"""
self.output(u'loading {} from {}'.format(vt_symbol, bar_file))
if vt_symbol in self.bar_df_dict:
return True
if bar_file is None or not os.path.exists(bar_file):
self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(vt_symbol, bar_file))
return False
try:
data_types = {
"datetime": str,
"open": float,
"high": float,
"low": float,
"close": float,
"open_interest": float,
"volume": float,
"instrument_id": str,
"symbol": str,
"total_turnover": float,
"limit_down": float,
"limit_up": float,
"trading_day": str,
"date": str,
"time": str
}
# 加载csv文件 =》 dateframe
symbol_df = pd.read_csv(bar_file, dtype=data_types)
# 转换时间str =》 datetime
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
# 设置时间为索引
symbol_df = symbol_df.set_index("datetime")
# 裁剪数据
symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date]
# 复权转换
adj_list = self.adjust_factors.get(vt_symbol, [])
# 按照结束日期,裁剪复权记录
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date]
if adj_list:
self.write_log(f'需要对{vt_symbol}进行前复权处理')
for row in adj_list:
row.update({'dividOperateDate': row.get('dividOperateDate') + ' 09:31:00'})
# list -> dataframe, 转换复权日期格式
adj_data = pd.DataFrame(adj_list)
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S")
adj_data = adj_data.set_index("dividOperateDate")
# 调用转换方法对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore')
# 添加到待合并dataframe dict中
self.bar_df_dict.update({vt_symbol: symbol_df})
except Exception as ex:
self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex))
self.output(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex))
return False
return True
def comine_bar_df(self):
"""
合并所有回测合约的bar DataFrame =集中的DataFrame
把bar_df_dict =bar_df
:return:
"""
self.output('comine_df')
self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index()
self.bar_df_dict.clear()
def prepare_env(self, test_setting):
"""
回测环境准备
:param test_setting:
:return:
"""
self.output(f'准备组合回测环境')
# 调用父类回测环境
super().prepare_env(test_setting)
def prepare_data(self, data_dict):
"""
准备组合数据
:param data_dict: 合约得配置参数
:return:
"""
# 调用回测引擎,跟新合约得数据
super().prepare_data(data_dict)
if len(data_dict) == 0:
self.write_log(u'请指定回测数据和文件')
return
if self.mode == 'tick':
return
# 检查/更新需要回测的bar文件
for vt_symbol, symbol_info in data_dict.items():
self.write_log(u'配置{}数据:{}'.format(vt_symbol, symbol_info))
bar_file = symbol_info.get('bar_file', None)
if bar_file is None:
self.write_error(u'{}没有配置数据文件')
continue
if not os.path.isfile(bar_file):
self.write_log(u'{0}文件不存在'.format(bar_file))
continue
self.bar_csv_file.update({vt_symbol: bar_file})
def run_portfolio_test(self, strategy_setting: dict = {}):
"""
运行组合回测
"""
if not self.strategy_start_date:
self.write_error(u'回测开始日期未设置。')
return
if len(strategy_setting) == 0:
self.write_error('未提供有效配置策略实例')
return
self.cur_capital = self.init_capital # 更新设置期初资金
if not self.data_end_date:
self.data_end_date = datetime.today()
# 保存回测设置/策略设置/任务ID至数据库
self.save_setting_to_mongo()
self.write_log(u'开始组合回测')
for strategy_name, strategy_setting in strategy_setting.items():
self.load_strategy(strategy_name, strategy_setting)
self.write_log(u'策略初始化完成')
self.write_log(u'开始回放数据')
self.write_log(u'开始回测:{} ~ {}'.format(self.data_start_date, self.data_end_date))
if self.mode == 'bar':
self.run_bar_test()
else:
self.run_tick_test()
def run_bar_test(self):
"""使用bar进行组合回测"""
testdays = (self.data_end_date - self.data_start_date).days
if testdays < 1:
self.write_log(u'回测时间不足')
return
# 加载数据
for vt_symbol in self.symbol_strategy_map.keys():
self.load_bar_csv_to_df(vt_symbol, self.bar_csv_file.get(vt_symbol))
# 合并数据
self.comine_bar_df()
last_trading_day = None
bars_dt = None
bars_same_dt = []
gc_collect_days = 0
try:
for (dt, vt_symbol), bar_data in self.bar_df.iterrows():
symbol, exchange = extract_vt_symbol(vt_symbol)
if symbol.startswith('future_renko'):
bar_datetime = dt
bar = RenkoBarData(
gateway_name='backtesting',
symbol=symbol,
exchange=exchange,
datetime=bar_datetime
)
bar.seconds = float(bar_data.get('seconds', 0))
bar.high_seconds = float(bar_data.get('high_seconds', 0)) # 当前Bar的上限秒数
bar.low_seconds = float(bar_data.get('low_seconds', 0)) # 当前bar的下限秒数
bar.height = float(bar_data.get('height', 0)) # 当前Bar的高度限制
bar.up_band = float(bar_data.get('up_band', 0)) # 高位区域的基线
bar.down_band = float(bar_data.get('down_band', 0)) # 低位区域的基线
bar.low_time = bar_data.get('low_time', None) # 最后一次进入低位区域的时间
bar.high_time = bar_data.get('high_time', None) # 最后一次进入高位区域的时间
else:
bar_datetime = dt - timedelta(seconds=self.bar_interval_seconds)
bar = BarData(
gateway_name='backtesting',
symbol=symbol,
exchange=exchange,
datetime=bar_datetime
)
if 'open' in bar_data:
bar.open_price = float(bar_data['open'])
bar.close_price = float(bar_data['close'])
bar.high_price = float(bar_data['high'])
bar.low_price = float(bar_data['low'])
else:
bar.open_price = float(bar_data['open_price'])
bar.close_price = float(bar_data['close_price'])
bar.high_price = float(bar_data['high_price'])
bar.low_price = float(bar_data['low_price'])
bar.volume = int(bar_data['volume'])
bar.date = dt.strftime('%Y-%m-%d')
bar.time = dt.strftime('%H:%M:%S')
str_td = str(bar_data.get('trading_day', ''))
if len(str_td) == 8:
bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8]
else:
bar.trading_day = bar.date
if last_trading_day != bar.trading_day:
self.output(u'回测数据日期:{},资金:{}'.format(bar.trading_day, self.net_capital))
if self.strategy_start_date > bar.datetime:
last_trading_day = bar.trading_day
# bar时间与队列时间一致添加到队列中
if dt == bars_dt:
bars_same_dt.append(bar)
continue
else:
# bar时间与队列时间不一致先推送队列的bars
random.shuffle(bars_same_dt)
for _bar_ in bars_same_dt:
self.new_bar(_bar_)
# 创建新的队列
bars_same_dt = [bar]
bars_dt = dt
# 更新每日净值
if self.strategy_start_date <= dt <= self.data_end_date:
if last_trading_day != bar.trading_day:
if last_trading_day is not None:
self.saving_daily_data(datetime.strptime(last_trading_day, '%Y-%m-%d'), self.cur_capital,
self.max_net_capital, self.total_commission)
last_trading_day = bar.trading_day
# 第二个交易日,撤单
self.cancel_orders()
# 更新持仓缓存
self.update_position_yd()
gc_collect_days += 1
if gc_collect_days >= 10:
# 执行内存回收
gc.collect()
sleep(1)
gc_collect_days = 0
if self.net_capital < 0:
self.write_error(u'净值低于0回测停止')
self.output(u'净值低于0回测停止')
return
self.write_log(u'bar数据回放完成')
if last_trading_day is not None:
self.saving_daily_data(datetime.strptime(last_trading_day, '%Y-%m-%d'), self.cur_capital,
self.max_net_capital, self.total_commission)
except Exception as ex:
self.write_error(u'回测异常导致停止:{}'.format(str(ex)))
self.write_error(u'{},{}'.format(str(ex), traceback.format_exc()))
print(str(ex), file=sys.stderr)
traceback.print_exc()
return
def load_bz2_cache(self, cache_folder, cache_symbol, cache_date):
"""加载缓存bz2数据"""
if not os.path.exists(cache_folder):
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder))
return None
cache_folder_year_month = os.path.join(cache_folder, cache_date[:6])
if not os.path.exists(cache_folder_year_month):
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder_year_month))
return None
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkb2'.format(cache_symbol, cache_date))
if not os.path.isfile(cache_file):
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkz2'.format(cache_symbol, cache_date))
if not os.path.isfile(cache_file):
self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file))
return None
with bz2.BZ2File(cache_file, 'rb') as f:
data = pickle.load(f)
return data
return None
def get_day_tick_df(self, test_day):
"""获取某一天得所有合约tick"""
tick_data_dict = {}
for vt_symbol in list(self.symbol_strategy_map.keys()):
symbol, exchange = extract_vt_symbol(vt_symbol)
tick_list = self.load_bz2_cache(cache_folder=self.tick_path,
cache_symbol=symbol,
cache_date=test_day.strftime('%Y%m%d'))
if not tick_list or len(tick_list) == 0:
continue
symbol_tick_df = pd.DataFrame(tick_list)
# 缓存文件中datetime字段已经是datetime格式
# 暂时根据时间去重没有汇总volume
symbol_tick_df.drop_duplicates(subset=['datetime'], keep='first', inplace=True)
symbol_tick_df.set_index('datetime', inplace=True)
tick_data_dict.update({vt_symbol: symbol_tick_df})
if len(tick_data_dict) == 0:
return None
tick_df = pd.concat(tick_data_dict, axis=0).swaplevel(0, 1).sort_index()
return tick_df
def run_tick_test(self):
"""运行tick级别组合回测"""
testdays = (self.data_end_date - self.data_start_date).days
if testdays < 1:
self.write_log(u'回测时间不足')
return
gc_collect_days = 0
# 循环每一天
for i in range(0, testdays):
test_day = self.data_start_date + timedelta(days=i)
combined_df = self.get_day_tick_df(test_day)
if combined_df is None:
continue
try:
for (dt, vt_symbol), tick_data in combined_df.iterrows():
symbol, exchange = extract_vt_symbol(vt_symbol)
tick = TickData(
gateway_name='backtesting',
symbol=symbol,
exchange=exchange,
datetime=dt,
date=dt.strftime('%Y-%m-%d'),
time=dt.strftime('%H:%M:%S.%f'),
trading_day=test_day.strftime('%Y-%m-%d'),
last_price=tick_data['price'],
volume=tick_data['volume']
)
self.new_tick(tick)
# 结束一个交易日后,更新每日净值
self.saving_daily_data(test_day,
self.cur_capital,
self.max_net_capital,
self.total_commission)
self.cancel_orders()
# 更新持仓缓存
self.update_position_yd()
gc_collect_days += 1
if gc_collect_days >= 10:
# 执行内存回收
gc.collect()
sleep(1)
gc_collect_days = 0
if self.net_capital < 0:
self.write_error(u'净值低于0回测停止')
self.output(u'净值低于0回测停止')
return
except Exception as ex:
self.write_error(u'回测异常导致停止:{}'.format(str(ex)))
self.write_error(u'{},{}'.format(str(ex), traceback.format_exc()))
print(str(ex), file=sys.stderr)
traceback.print_exc()
return
self.write_log(u'tick数据回放完成')
def single_test(test_setting: dict, strategy_setting: dict):
"""
单一回测
: test_setting, 组合回测所需的配置包括合约信息数据bar信息回测时间资金等
strategy_setting, dict, 一个或多个策略配置
"""
# 创建组合回测引擎
engine = PortfolioTestingEngine()
engine.prepare_env(test_setting)
try:
engine.run_portfolio_test(strategy_setting)
# 回测结果,保存
engine.show_backtesting_result()
except Exception as ex:
print('组合回测异常{}'.format(str(ex)))
traceback.print_exc()
engine.save_fail_to_mongo(f'回测异常{str(ex)}')
return False
print('测试结束')
return True

View File

@ -0,0 +1,35 @@
策略加密
#windows 下加密并运行
1.安装Visual StudioComunity 2017下载地址
https://visualstudio.microsoft.com/zh-hans/vs/older-downloads/
安装时请勾选“使用C++的桌面开发”。
2. 在Python环境中安装Cython打开cmd后输入运行pip install cython即可。
3. 在”管理员”模式的命令行窗口,在策略所在目录,运行:
cythonize -i demo_strategy.py
编译完成后Demo文件夹下会多出2个新的文件其中就有已加密的策略文件demo_strategy.cp37-win_amd64.pyd
改名=> demo_strategy.pyd
放置 demo_strategy.pyd到windows 生产环境的 strateies目录下。
#centos/ubuntu 下加密并运行
1. 在Python环境中安装Cython运行pip install cython即可。
3. 在策略所在目录,运行:
cythonize -i demo_strategy.py
编译完成后Demo文件夹下会多出2个新的文件其中就有已加密的策略文件demo_strategy.cp37-win_amd64.so
改名=> demo_strategy.so
放置 demo_strategy.so 到centos/ubuntu 生产环境的 strateies目录下。

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
from .widget import CtaManager

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

@ -0,0 +1,464 @@
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
from vnpy.trader.ui.widget import (
BaseCell,
EnumCell,
MsgCell,
TimeCell,
BaseMonitor
)
from ..base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_STOPORDER,
EVENT_CTA_STRATEGY
)
from ..engine import CtaEngine
class CtaManager(QtWidgets.QWidget):
""""""
signal_log = QtCore.pyqtSignal(Event)
signal_strategy = QtCore.pyqtSignal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
super(CtaManager, self).__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.cta_engine = main_engine.get_engine(APP_NAME)
self.managers = {}
self.init_ui()
self.register_event()
self.cta_engine.init_engine()
self.update_class_combo()
def init_ui(self):
""""""
self.setWindowTitle("CTA策略")
# Create widgets
self.class_combo = QtWidgets.QComboBox()
add_button = QtWidgets.QPushButton("添加策略")
add_button.clicked.connect(self.add_strategy)
init_button = QtWidgets.QPushButton("全部初始化")
init_button.clicked.connect(self.cta_engine.init_all_strategies)
start_button = QtWidgets.QPushButton("全部启动")
start_button.clicked.connect(self.cta_engine.start_all_strategies)
stop_button = QtWidgets.QPushButton("全部停止")
stop_button.clicked.connect(self.cta_engine.stop_all_strategies)
clear_button = QtWidgets.QPushButton("清空日志")
clear_button.clicked.connect(self.clear_log)
self.scroll_layout = QtWidgets.QVBoxLayout()
self.scroll_layout.addStretch()
scroll_widget = QtWidgets.QWidget()
scroll_widget.setLayout(self.scroll_layout)
scroll_area = QtWidgets.QScrollArea()
scroll_area.setWidgetResizable(True)
scroll_area.setWidget(scroll_widget)
self.log_monitor = LogMonitor(self.main_engine, self.event_engine)
self.stop_order_monitor = StopOrderMonitor(
self.main_engine, self.event_engine
)
# Set layout
hbox1 = QtWidgets.QHBoxLayout()
hbox1.addWidget(self.class_combo)
hbox1.addWidget(add_button)
hbox1.addStretch()
hbox1.addWidget(init_button)
hbox1.addWidget(start_button)
hbox1.addWidget(stop_button)
hbox1.addWidget(clear_button)
grid = QtWidgets.QGridLayout()
grid.addWidget(scroll_area, 0, 0, 2, 1)
grid.addWidget(self.stop_order_monitor, 0, 1)
grid.addWidget(self.log_monitor, 1, 1)
vbox = QtWidgets.QVBoxLayout()
vbox.addLayout(hbox1)
vbox.addLayout(grid)
self.setLayout(vbox)
def update_class_combo(self):
""""""
self.class_combo.addItems(
self.cta_engine.get_all_strategy_class_names()
)
def register_event(self):
""""""
self.signal_strategy.connect(self.process_strategy_event)
self.event_engine.register(
EVENT_CTA_STRATEGY, self.signal_strategy.emit
)
def process_strategy_event(self, event):
"""
Update strategy status onto its monitor.
"""
data = event.data
strategy_name = data["strategy_name"]
if strategy_name in self.managers:
manager = self.managers[strategy_name]
manager.update_data(data)
else:
manager = StrategyManager(self, self.cta_engine, data)
self.scroll_layout.insertWidget(0, manager)
self.managers[strategy_name] = manager
def remove_strategy(self, strategy_name):
""""""
manager = self.managers.pop(strategy_name)
manager.deleteLater()
def add_strategy(self):
""""""
class_name = str(self.class_combo.currentText())
if not class_name:
return
parameters = self.cta_engine.get_strategy_class_parameters(class_name)
editor = SettingEditor(parameters, class_name=class_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
vt_symbol = setting.pop("vt_symbol")
strategy_name = setting.pop("strategy_name")
auto_init = setting.pop("auto_init", False)
auto_start = setting.pop("auto_start", False)
self.cta_engine.add_strategy(
class_name, strategy_name, vt_symbol, setting, auto_init, auto_start
)
def clear_log(self):
""""""
self.log_monitor.setRowCount(0)
def show(self):
""""""
self.showMaximized()
class StrategyManager(QtWidgets.QFrame):
"""
Manager for a strategy
"""
def __init__(
self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict
):
""""""
super(StrategyManager, self).__init__()
self.cta_manager = cta_manager
self.cta_engine = cta_engine
self.strategy_name = data["strategy_name"]
self._data = data
self.init_ui()
def init_ui(self):
""""""
self.setFixedHeight(300)
self.setFrameShape(self.Box)
self.setLineWidth(1)
init_button = QtWidgets.QPushButton("初始化")
init_button.clicked.connect(self.init_strategy)
start_button = QtWidgets.QPushButton("启动")
start_button.clicked.connect(self.start_strategy)
stop_button = QtWidgets.QPushButton("停止")
stop_button.clicked.connect(self.stop_strategy)
edit_button = QtWidgets.QPushButton("编辑")
edit_button.clicked.connect(self.edit_strategy)
remove_button = QtWidgets.QPushButton("移除")
remove_button.clicked.connect(self.remove_strategy)
reload_button = QtWidgets.QPushButton("重载")
reload_button.clicked.connect(self.reload_strategy)
save_button = QtWidgets.QPushButton("保存")
save_button.clicked.connect(self.save_strategy)
strategy_name = self._data["strategy_name"]
vt_symbol = self._data["vt_symbol"]
class_name = self._data["class_name"]
author = self._data["author"]
label_text = (
f"{strategy_name} - {vt_symbol} ({class_name} by {author})"
)
label = QtWidgets.QLabel(label_text)
label.setAlignment(QtCore.Qt.AlignCenter)
self.parameters_monitor = DataMonitor(self._data["parameters"])
self.variables_monitor = DataMonitor(self._data["variables"])
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(init_button)
hbox.addWidget(start_button)
hbox.addWidget(stop_button)
hbox.addWidget(edit_button)
hbox.addWidget(remove_button)
hbox.addWidget(reload_button)
hbox.addWidget(save_button)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(label)
vbox.addLayout(hbox)
vbox.addWidget(self.parameters_monitor)
vbox.addWidget(self.variables_monitor)
self.setLayout(vbox)
def update_data(self, data: dict):
""""""
self._data = data
self.parameters_monitor.update_data(data["parameters"])
self.variables_monitor.update_data(data["variables"])
def init_strategy(self):
""""""
self.cta_engine.init_strategy(self.strategy_name)
def start_strategy(self):
""""""
self.cta_engine.start_strategy(self.strategy_name)
def stop_strategy(self):
""""""
self.cta_engine.stop_strategy(self.strategy_name)
def edit_strategy(self):
""""""
strategy_name = self._data["strategy_name"]
parameters = self.cta_engine.get_strategy_parameters(strategy_name)
editor = SettingEditor(parameters, strategy_name=strategy_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
self.cta_engine.edit_strategy(strategy_name, setting)
def remove_strategy(self):
""""""
result = self.cta_engine.remove_strategy(self.strategy_name)
# Only remove strategy gui manager if it has been removed from engine
if result:
self.cta_manager.remove_strategy(self.strategy_name)
def reload_strategy(self):
"""重新加载策略"""
self.cta_engine.reload_strategy(self.strategy_name)
def save_strategy(self):
self.cta_engine.save_strategy_data(self.strategy_name)
class DataMonitor(QtWidgets.QTableWidget):
"""
Table monitor for parameters and variables.
"""
def __init__(self, data: dict):
""""""
super(DataMonitor, self).__init__()
self._data = data
self.cells = {}
self.init_ui()
def init_ui(self):
""""""
labels = list(self._data.keys())
self.setColumnCount(len(labels))
self.setHorizontalHeaderLabels(labels)
self.setRowCount(1)
self.verticalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
self.verticalHeader().setVisible(False)
self.setEditTriggers(self.NoEditTriggers)
for column, name in enumerate(self._data.keys()):
value = self._data[name]
cell = QtWidgets.QTableWidgetItem(str(value))
cell.setTextAlignment(QtCore.Qt.AlignCenter)
self.setItem(0, column, cell)
self.cells[name] = cell
def update_data(self, data: dict):
""""""
for name, value in data.items():
cell = self.cells[name]
cell.setText(str(value))
class StopOrderMonitor(BaseMonitor):
"""
Monitor for local stop order.
"""
event_type = EVENT_CTA_STOPORDER
data_key = "stop_orderid"
sorting = True
headers = {
"stop_orderid": {
"display": "停止委托号",
"cell": BaseCell,
"update": False,
},
"vt_orderids": {"display": "限价委托号", "cell": BaseCell, "update": True},
"vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False},
"direction": {"display": "方向", "cell": EnumCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": False},
"status": {"display": "状态", "cell": EnumCell, "update": True},
"lock": {"display": "锁仓", "cell": BaseCell, "update": False},
"strategy_name": {"display": "策略名", "cell": BaseCell, "update": False},
}
class LogMonitor(BaseMonitor):
"""
Monitor for log data.
"""
event_type = EVENT_CTA_LOG
data_key = ""
sorting = False
headers = {
"time": {"display": "时间", "cell": TimeCell, "update": False},
"msg": {"display": "信息", "cell": MsgCell, "update": False},
}
def init_ui(self):
"""
Stretch last column.
"""
super(LogMonitor, self).init_ui()
self.horizontalHeader().setSectionResizeMode(
1, QtWidgets.QHeaderView.Stretch
)
def insert_new_row(self, data):
"""
Insert a new row at the top of table.
"""
super(LogMonitor, self).insert_new_row(data)
self.resizeRowToContents(0)
class SettingEditor(QtWidgets.QDialog):
"""
For creating new strategy and editing strategy parameters.
"""
def __init__(
self, parameters: dict, strategy_name: str = "", class_name: str = ""
):
""""""
super(SettingEditor, self).__init__()
self.parameters = parameters
self.strategy_name = strategy_name
self.class_name = class_name
self.edits = {}
self.init_ui()
def init_ui(self):
""""""
form = QtWidgets.QFormLayout()
# Add vt_symbol and name edit if add new strategy
if self.class_name:
self.setWindowTitle(f"添加策略:{self.class_name}")
button_text = "添加"
parameters = {"strategy_name": "", "vt_symbol": "", "auto_init": True, "auto_start": True}
parameters.update(self.parameters)
else:
self.setWindowTitle(f"参数编辑:{self.strategy_name}")
button_text = "确定"
parameters = self.parameters
for name, value in parameters.items():
type_ = type(value)
edit = QtWidgets.QLineEdit(str(value))
if type_ is int:
validator = QtGui.QIntValidator()
edit.setValidator(validator)
elif type_ is float:
validator = QtGui.QDoubleValidator()
edit.setValidator(validator)
form.addRow(f"{name} {type_}", edit)
self.edits[name] = (edit, type_)
button = QtWidgets.QPushButton(button_text)
button.clicked.connect(self.accept)
form.addRow(button)
self.setLayout(form)
def get_setting(self):
""""""
setting = {}
if self.class_name:
setting["class_name"] = self.class_name
for name, tp in self.edits.items():
edit, type_ = tp
value_text = edit.text()
if type_ == bool:
if value_text == "True":
value = True
else:
value = False
else:
value = type_(value_text)
setting[name] = value
return setting