[新功能] 股票Cta App
This commit is contained in:
parent
8e45b0f4b3
commit
f8b24c9be1
34
vnpy/app/cta_stock/__init__.py
Normal file
34
vnpy/app/cta_stock/__init__.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from vnpy.trader.app import BaseApp
|
||||||
|
from .base import APP_NAME, StopOrder
|
||||||
|
|
||||||
|
from .engine import CtaEngine
|
||||||
|
|
||||||
|
from .template import (
|
||||||
|
Exchange,
|
||||||
|
Direction,
|
||||||
|
Offset,
|
||||||
|
Status,
|
||||||
|
Color,
|
||||||
|
Interval,
|
||||||
|
TickData,
|
||||||
|
BarData,
|
||||||
|
TradeData,
|
||||||
|
OrderData,
|
||||||
|
CtaPolicy,
|
||||||
|
StockPolicy,
|
||||||
|
CtaTemplate, CtaStockTemplate) # noqa
|
||||||
|
|
||||||
|
from vnpy.trader.utility import BarGenerator, ArrayManager # noqa
|
||||||
|
|
||||||
|
|
||||||
|
class CtaStockApp(BaseApp):
|
||||||
|
""""""
|
||||||
|
app_name = APP_NAME
|
||||||
|
app_module = __module__
|
||||||
|
app_path = Path(__file__).parent
|
||||||
|
display_name = "股票CTA策略"
|
||||||
|
engine_class = CtaEngine
|
||||||
|
widget_name = "CtaManager"
|
||||||
|
icon_name = "cta.ico"
|
2070
vnpy/app/cta_stock/back_testing.py
Normal file
2070
vnpy/app/cta_stock/back_testing.py
Normal file
File diff suppressed because it is too large
Load Diff
53
vnpy/app/cta_stock/base.py
Normal file
53
vnpy/app/cta_stock/base.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
Defines constants and objects used in CtaCrypto App.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from datetime import timedelta
|
||||||
|
from vnpy.trader.constant import Direction, Offset, Interval
|
||||||
|
|
||||||
|
APP_NAME = "CtaStock"
|
||||||
|
STOPORDER_PREFIX = "STOP"
|
||||||
|
|
||||||
|
|
||||||
|
class StopOrderStatus(Enum):
|
||||||
|
WAITING = "等待中"
|
||||||
|
CANCELLED = "已撤销"
|
||||||
|
TRIGGERED = "已触发"
|
||||||
|
|
||||||
|
|
||||||
|
class EngineType(Enum):
|
||||||
|
LIVE = "实盘"
|
||||||
|
BACKTESTING = "回测"
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestingMode(Enum):
|
||||||
|
BAR = 1
|
||||||
|
TICK = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StopOrder:
|
||||||
|
vt_symbol: str
|
||||||
|
direction: Direction
|
||||||
|
offset: Offset
|
||||||
|
price: float
|
||||||
|
volume: float
|
||||||
|
stop_orderid: str
|
||||||
|
strategy_name: str
|
||||||
|
lock: bool = False
|
||||||
|
vt_orderids: list = field(default_factory=list)
|
||||||
|
status: StopOrderStatus = StopOrderStatus.WAITING
|
||||||
|
gateway_name: str = None
|
||||||
|
|
||||||
|
|
||||||
|
EVENT_CTA_LOG = "eCtaLog"
|
||||||
|
EVENT_CTA_STRATEGY = "eCtaStrategy"
|
||||||
|
EVENT_CTA_STOPORDER = "eCtaStopOrder"
|
||||||
|
|
||||||
|
INTERVAL_DELTA_MAP = {
|
||||||
|
Interval.MINUTE: timedelta(minutes=1),
|
||||||
|
Interval.HOUR: timedelta(hours=1),
|
||||||
|
Interval.DAILY: timedelta(days=1),
|
||||||
|
}
|
1717
vnpy/app/cta_stock/engine.py
Normal file
1717
vnpy/app/cta_stock/engine.py
Normal file
File diff suppressed because it is too large
Load Diff
484
vnpy/app/cta_stock/portfolio_testing.py
Normal file
484
vnpy/app/cta_stock/portfolio_testing.py
Normal file
@ -0,0 +1,484 @@
|
|||||||
|
# encoding: UTF-8
|
||||||
|
|
||||||
|
'''
|
||||||
|
本文件中包含的是CTA模块的组合回测引擎,回测引擎的API和CTA引擎一致,
|
||||||
|
可以使用和实盘相同的代码进行回测。
|
||||||
|
华富资产 李来佳
|
||||||
|
'''
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import gc
|
||||||
|
import pandas as pd
|
||||||
|
import traceback
|
||||||
|
import random
|
||||||
|
import bz2
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
from vnpy.trader.object import (
|
||||||
|
TickData,
|
||||||
|
BarData,
|
||||||
|
RenkoBarData,
|
||||||
|
)
|
||||||
|
from vnpy.trader.constant import (
|
||||||
|
Exchange,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vnpy.trader.utility import (
|
||||||
|
get_trading_date,
|
||||||
|
extract_vt_symbol,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .back_testing import BackTestingEngine
|
||||||
|
|
||||||
|
|
||||||
|
class PortfolioTestingEngine(BackTestingEngine):
|
||||||
|
"""
|
||||||
|
CTA组合回测引擎, 使用回测引擎作为父类
|
||||||
|
函数接口和策略引擎保持一样,
|
||||||
|
从而实现同一套代码从回测到实盘。
|
||||||
|
针对1分钟bar的回测 或者tick回测
|
||||||
|
导入CTA_Settings
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, event_engine=None):
|
||||||
|
"""Constructor"""
|
||||||
|
super().__init__(event_engine)
|
||||||
|
|
||||||
|
self.bar_csv_file = {}
|
||||||
|
self.bar_df_dict = {} # 历史数据的df,回测用
|
||||||
|
self.bar_df = None # 历史数据的df,时间+symbol作为组合索引
|
||||||
|
self.bar_interval_seconds = 60 # bar csv文件,属于K线类型,K线的周期(秒数),缺省是1分钟
|
||||||
|
|
||||||
|
self.tick_path = None # tick级别回测, 路径
|
||||||
|
|
||||||
|
def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None):
|
||||||
|
"""
|
||||||
|
加载回测bar数据到DataFrame
|
||||||
|
1. 增加前复权/后复权
|
||||||
|
:param vt_symbol:
|
||||||
|
:param bar_file:
|
||||||
|
:param data_start_date:
|
||||||
|
:param data_end_date:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.output(u'loading {} from {}'.format(vt_symbol, bar_file))
|
||||||
|
if vt_symbol in self.bar_df_dict:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if bar_file is None or not os.path.exists(bar_file):
|
||||||
|
self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(vt_symbol, bar_file))
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
data_types = {
|
||||||
|
"datetime": str,
|
||||||
|
"open": float,
|
||||||
|
"high": float,
|
||||||
|
"low": float,
|
||||||
|
"close": float,
|
||||||
|
"open_interest": float,
|
||||||
|
"volume": float,
|
||||||
|
"instrument_id": str,
|
||||||
|
"symbol": str,
|
||||||
|
"total_turnover": float,
|
||||||
|
"limit_down": float,
|
||||||
|
"limit_up": float,
|
||||||
|
"trading_day": str,
|
||||||
|
"date": str,
|
||||||
|
"time": str
|
||||||
|
}
|
||||||
|
# 加载csv文件 =》 dateframe
|
||||||
|
symbol_df = pd.read_csv(bar_file, dtype=data_types)
|
||||||
|
# 转换时间,str =》 datetime
|
||||||
|
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
|
||||||
|
# 设置时间为索引
|
||||||
|
symbol_df = symbol_df.set_index("datetime")
|
||||||
|
|
||||||
|
# 裁剪数据
|
||||||
|
symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date]
|
||||||
|
|
||||||
|
# 复权转换
|
||||||
|
adj_list = self.adjust_factors.get(vt_symbol, [])
|
||||||
|
# 按照结束日期,裁剪复权记录
|
||||||
|
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date]
|
||||||
|
|
||||||
|
if adj_list:
|
||||||
|
self.write_log(f'需要对{vt_symbol}进行前复权处理')
|
||||||
|
for row in adj_list:
|
||||||
|
row.update({'dividOperateDate': row.get('dividOperateDate') + ' 09:31:00'})
|
||||||
|
# list -> dataframe, 转换复权日期格式
|
||||||
|
adj_data = pd.DataFrame(adj_list)
|
||||||
|
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S")
|
||||||
|
adj_data = adj_data.set_index("dividOperateDate")
|
||||||
|
# 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
|
||||||
|
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore')
|
||||||
|
|
||||||
|
# 添加到待合并dataframe dict中
|
||||||
|
self.bar_df_dict.update({vt_symbol: symbol_df})
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex))
|
||||||
|
self.output(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex))
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def comine_bar_df(self):
|
||||||
|
"""
|
||||||
|
合并所有回测合约的bar DataFrame =》集中的DataFrame
|
||||||
|
把bar_df_dict =》bar_df
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.output('comine_df')
|
||||||
|
self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index()
|
||||||
|
self.bar_df_dict.clear()
|
||||||
|
|
||||||
|
def prepare_env(self, test_setting):
|
||||||
|
"""
|
||||||
|
回测环境准备
|
||||||
|
:param test_setting:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.output(f'准备组合回测环境')
|
||||||
|
|
||||||
|
# 调用父类回测环境
|
||||||
|
super().prepare_env(test_setting)
|
||||||
|
|
||||||
|
def prepare_data(self, data_dict):
|
||||||
|
"""
|
||||||
|
准备组合数据
|
||||||
|
:param data_dict: 合约得配置参数
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 调用回测引擎,跟新合约得数据
|
||||||
|
super().prepare_data(data_dict)
|
||||||
|
|
||||||
|
if len(data_dict) == 0:
|
||||||
|
self.write_log(u'请指定回测数据和文件')
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.mode == 'tick':
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查/更新需要回测的bar文件
|
||||||
|
for vt_symbol, symbol_info in data_dict.items():
|
||||||
|
self.write_log(u'配置{}数据:{}'.format(vt_symbol, symbol_info))
|
||||||
|
|
||||||
|
bar_file = symbol_info.get('bar_file', None)
|
||||||
|
|
||||||
|
if bar_file is None:
|
||||||
|
self.write_error(u'{}没有配置数据文件')
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not os.path.isfile(bar_file):
|
||||||
|
self.write_log(u'{0}文件不存在'.format(bar_file))
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.bar_csv_file.update({vt_symbol: bar_file})
|
||||||
|
|
||||||
|
def run_portfolio_test(self, strategy_setting: dict = {}):
|
||||||
|
"""
|
||||||
|
运行组合回测
|
||||||
|
"""
|
||||||
|
if not self.strategy_start_date:
|
||||||
|
self.write_error(u'回测开始日期未设置。')
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(strategy_setting) == 0:
|
||||||
|
self.write_error('未提供有效配置策略实例')
|
||||||
|
return
|
||||||
|
|
||||||
|
self.cur_capital = self.init_capital # 更新设置期初资金
|
||||||
|
if not self.data_end_date:
|
||||||
|
self.data_end_date = datetime.today()
|
||||||
|
|
||||||
|
# 保存回测设置/策略设置/任务ID至数据库
|
||||||
|
self.save_setting_to_mongo()
|
||||||
|
|
||||||
|
self.write_log(u'开始组合回测')
|
||||||
|
|
||||||
|
for strategy_name, strategy_setting in strategy_setting.items():
|
||||||
|
self.load_strategy(strategy_name, strategy_setting)
|
||||||
|
|
||||||
|
self.write_log(u'策略初始化完成')
|
||||||
|
|
||||||
|
self.write_log(u'开始回放数据')
|
||||||
|
|
||||||
|
self.write_log(u'开始回测:{} ~ {}'.format(self.data_start_date, self.data_end_date))
|
||||||
|
|
||||||
|
if self.mode == 'bar':
|
||||||
|
self.run_bar_test()
|
||||||
|
else:
|
||||||
|
self.run_tick_test()
|
||||||
|
|
||||||
|
def run_bar_test(self):
|
||||||
|
"""使用bar进行组合回测"""
|
||||||
|
testdays = (self.data_end_date - self.data_start_date).days
|
||||||
|
|
||||||
|
if testdays < 1:
|
||||||
|
self.write_log(u'回测时间不足')
|
||||||
|
return
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
for vt_symbol in self.symbol_strategy_map.keys():
|
||||||
|
self.load_bar_csv_to_df(vt_symbol, self.bar_csv_file.get(vt_symbol))
|
||||||
|
|
||||||
|
# 合并数据
|
||||||
|
self.comine_bar_df()
|
||||||
|
|
||||||
|
last_trading_day = None
|
||||||
|
bars_dt = None
|
||||||
|
bars_same_dt = []
|
||||||
|
|
||||||
|
gc_collect_days = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
for (dt, vt_symbol), bar_data in self.bar_df.iterrows():
|
||||||
|
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||||
|
if symbol.startswith('future_renko'):
|
||||||
|
bar_datetime = dt
|
||||||
|
bar = RenkoBarData(
|
||||||
|
gateway_name='backtesting',
|
||||||
|
symbol=symbol,
|
||||||
|
exchange=exchange,
|
||||||
|
datetime=bar_datetime
|
||||||
|
)
|
||||||
|
bar.seconds = float(bar_data.get('seconds', 0))
|
||||||
|
bar.high_seconds = float(bar_data.get('high_seconds', 0)) # 当前Bar的上限秒数
|
||||||
|
bar.low_seconds = float(bar_data.get('low_seconds', 0)) # 当前bar的下限秒数
|
||||||
|
bar.height = float(bar_data.get('height', 0)) # 当前Bar的高度限制
|
||||||
|
bar.up_band = float(bar_data.get('up_band', 0)) # 高位区域的基线
|
||||||
|
bar.down_band = float(bar_data.get('down_band', 0)) # 低位区域的基线
|
||||||
|
bar.low_time = bar_data.get('low_time', None) # 最后一次进入低位区域的时间
|
||||||
|
bar.high_time = bar_data.get('high_time', None) # 最后一次进入高位区域的时间
|
||||||
|
else:
|
||||||
|
bar_datetime = dt - timedelta(seconds=self.bar_interval_seconds)
|
||||||
|
|
||||||
|
bar = BarData(
|
||||||
|
gateway_name='backtesting',
|
||||||
|
symbol=symbol,
|
||||||
|
exchange=exchange,
|
||||||
|
datetime=bar_datetime
|
||||||
|
)
|
||||||
|
if 'open' in bar_data:
|
||||||
|
bar.open_price = float(bar_data['open'])
|
||||||
|
bar.close_price = float(bar_data['close'])
|
||||||
|
bar.high_price = float(bar_data['high'])
|
||||||
|
bar.low_price = float(bar_data['low'])
|
||||||
|
else:
|
||||||
|
bar.open_price = float(bar_data['open_price'])
|
||||||
|
bar.close_price = float(bar_data['close_price'])
|
||||||
|
bar.high_price = float(bar_data['high_price'])
|
||||||
|
bar.low_price = float(bar_data['low_price'])
|
||||||
|
|
||||||
|
bar.volume = int(bar_data['volume'])
|
||||||
|
bar.date = dt.strftime('%Y-%m-%d')
|
||||||
|
bar.time = dt.strftime('%H:%M:%S')
|
||||||
|
str_td = str(bar_data.get('trading_day', ''))
|
||||||
|
if len(str_td) == 8:
|
||||||
|
bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8]
|
||||||
|
else:
|
||||||
|
bar.trading_day = bar.date
|
||||||
|
|
||||||
|
if last_trading_day != bar.trading_day:
|
||||||
|
self.output(u'回测数据日期:{},资金:{}'.format(bar.trading_day, self.net_capital))
|
||||||
|
if self.strategy_start_date > bar.datetime:
|
||||||
|
last_trading_day = bar.trading_day
|
||||||
|
|
||||||
|
# bar时间与队列时间一致,添加到队列中
|
||||||
|
if dt == bars_dt:
|
||||||
|
bars_same_dt.append(bar)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# bar时间与队列时间不一致,先推送队列的bars
|
||||||
|
random.shuffle(bars_same_dt)
|
||||||
|
for _bar_ in bars_same_dt:
|
||||||
|
self.new_bar(_bar_)
|
||||||
|
|
||||||
|
# 创建新的队列
|
||||||
|
bars_same_dt = [bar]
|
||||||
|
bars_dt = dt
|
||||||
|
|
||||||
|
# 更新每日净值
|
||||||
|
if self.strategy_start_date <= dt <= self.data_end_date:
|
||||||
|
if last_trading_day != bar.trading_day:
|
||||||
|
if last_trading_day is not None:
|
||||||
|
self.saving_daily_data(datetime.strptime(last_trading_day, '%Y-%m-%d'), self.cur_capital,
|
||||||
|
self.max_net_capital, self.total_commission)
|
||||||
|
last_trading_day = bar.trading_day
|
||||||
|
|
||||||
|
# 第二个交易日,撤单
|
||||||
|
self.cancel_orders()
|
||||||
|
# 更新持仓缓存
|
||||||
|
self.update_position_yd()
|
||||||
|
|
||||||
|
gc_collect_days += 1
|
||||||
|
if gc_collect_days >= 10:
|
||||||
|
# 执行内存回收
|
||||||
|
gc.collect()
|
||||||
|
sleep(1)
|
||||||
|
gc_collect_days = 0
|
||||||
|
|
||||||
|
if self.net_capital < 0:
|
||||||
|
self.write_error(u'净值低于0,回测停止')
|
||||||
|
self.output(u'净值低于0,回测停止')
|
||||||
|
return
|
||||||
|
|
||||||
|
self.write_log(u'bar数据回放完成')
|
||||||
|
if last_trading_day is not None:
|
||||||
|
self.saving_daily_data(datetime.strptime(last_trading_day, '%Y-%m-%d'), self.cur_capital,
|
||||||
|
self.max_net_capital, self.total_commission)
|
||||||
|
except Exception as ex:
|
||||||
|
self.write_error(u'回测异常导致停止:{}'.format(str(ex)))
|
||||||
|
self.write_error(u'{},{}'.format(str(ex), traceback.format_exc()))
|
||||||
|
print(str(ex), file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_bz2_cache(self, cache_folder, cache_symbol, cache_date):
|
||||||
|
"""加载缓存bz2数据"""
|
||||||
|
if not os.path.exists(cache_folder):
|
||||||
|
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder))
|
||||||
|
return None
|
||||||
|
cache_folder_year_month = os.path.join(cache_folder, cache_date[:6])
|
||||||
|
if not os.path.exists(cache_folder_year_month):
|
||||||
|
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder_year_month))
|
||||||
|
return None
|
||||||
|
|
||||||
|
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkb2'.format(cache_symbol, cache_date))
|
||||||
|
if not os.path.isfile(cache_file):
|
||||||
|
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkz2'.format(cache_symbol, cache_date))
|
||||||
|
if not os.path.isfile(cache_file):
|
||||||
|
self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file))
|
||||||
|
return None
|
||||||
|
|
||||||
|
with bz2.BZ2File(cache_file, 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_day_tick_df(self, test_day):
|
||||||
|
"""获取某一天得所有合约tick"""
|
||||||
|
tick_data_dict = {}
|
||||||
|
|
||||||
|
for vt_symbol in list(self.symbol_strategy_map.keys()):
|
||||||
|
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||||
|
tick_list = self.load_bz2_cache(cache_folder=self.tick_path,
|
||||||
|
cache_symbol=symbol,
|
||||||
|
cache_date=test_day.strftime('%Y%m%d'))
|
||||||
|
if not tick_list or len(tick_list) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
symbol_tick_df = pd.DataFrame(tick_list)
|
||||||
|
# 缓存文件中,datetime字段,已经是datetime格式
|
||||||
|
# 暂时根据时间去重,没有汇总volume
|
||||||
|
symbol_tick_df.drop_duplicates(subset=['datetime'], keep='first', inplace=True)
|
||||||
|
symbol_tick_df.set_index('datetime', inplace=True)
|
||||||
|
|
||||||
|
tick_data_dict.update({vt_symbol: symbol_tick_df})
|
||||||
|
|
||||||
|
if len(tick_data_dict) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tick_df = pd.concat(tick_data_dict, axis=0).swaplevel(0, 1).sort_index()
|
||||||
|
|
||||||
|
return tick_df
|
||||||
|
|
||||||
|
def run_tick_test(self):
|
||||||
|
"""运行tick级别组合回测"""
|
||||||
|
testdays = (self.data_end_date - self.data_start_date).days
|
||||||
|
|
||||||
|
if testdays < 1:
|
||||||
|
self.write_log(u'回测时间不足')
|
||||||
|
return
|
||||||
|
|
||||||
|
gc_collect_days = 0
|
||||||
|
|
||||||
|
# 循环每一天
|
||||||
|
for i in range(0, testdays):
|
||||||
|
test_day = self.data_start_date + timedelta(days=i)
|
||||||
|
|
||||||
|
combined_df = self.get_day_tick_df(test_day)
|
||||||
|
|
||||||
|
if combined_df is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
for (dt, vt_symbol), tick_data in combined_df.iterrows():
|
||||||
|
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||||
|
tick = TickData(
|
||||||
|
gateway_name='backtesting',
|
||||||
|
symbol=symbol,
|
||||||
|
exchange=exchange,
|
||||||
|
datetime=dt,
|
||||||
|
date=dt.strftime('%Y-%m-%d'),
|
||||||
|
time=dt.strftime('%H:%M:%S.%f'),
|
||||||
|
trading_day=test_day.strftime('%Y-%m-%d'),
|
||||||
|
last_price=tick_data['price'],
|
||||||
|
volume=tick_data['volume']
|
||||||
|
)
|
||||||
|
|
||||||
|
self.new_tick(tick)
|
||||||
|
|
||||||
|
# 结束一个交易日后,更新每日净值
|
||||||
|
self.saving_daily_data(test_day,
|
||||||
|
self.cur_capital,
|
||||||
|
self.max_net_capital,
|
||||||
|
self.total_commission)
|
||||||
|
|
||||||
|
self.cancel_orders()
|
||||||
|
# 更新持仓缓存
|
||||||
|
self.update_position_yd()
|
||||||
|
|
||||||
|
gc_collect_days += 1
|
||||||
|
if gc_collect_days >= 10:
|
||||||
|
# 执行内存回收
|
||||||
|
gc.collect()
|
||||||
|
sleep(1)
|
||||||
|
gc_collect_days = 0
|
||||||
|
|
||||||
|
if self.net_capital < 0:
|
||||||
|
self.write_error(u'净值低于0,回测停止')
|
||||||
|
self.output(u'净值低于0,回测停止')
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
self.write_error(u'回测异常导致停止:{}'.format(str(ex)))
|
||||||
|
self.write_error(u'{},{}'.format(str(ex), traceback.format_exc()))
|
||||||
|
print(str(ex), file=sys.stderr)
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
self.write_log(u'tick数据回放完成')
|
||||||
|
|
||||||
|
|
||||||
|
def single_test(test_setting: dict, strategy_setting: dict):
|
||||||
|
"""
|
||||||
|
单一回测
|
||||||
|
: test_setting, 组合回测所需的配置,包括合约信息,数据bar信息,回测时间,资金等。
|
||||||
|
:strategy_setting, dict, 一个或多个策略配置
|
||||||
|
"""
|
||||||
|
# 创建组合回测引擎
|
||||||
|
engine = PortfolioTestingEngine()
|
||||||
|
|
||||||
|
engine.prepare_env(test_setting)
|
||||||
|
try:
|
||||||
|
engine.run_portfolio_test(strategy_setting)
|
||||||
|
# 回测结果,保存
|
||||||
|
engine.show_backtesting_result()
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
print('组合回测异常{}'.format(str(ex)))
|
||||||
|
traceback.print_exc()
|
||||||
|
engine.save_fail_to_mongo(f'回测异常{str(ex)}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
print('测试结束')
|
||||||
|
return True
|
0
vnpy/app/cta_stock/strategies/__init__.py
Normal file
0
vnpy/app/cta_stock/strategies/__init__.py
Normal file
35
vnpy/app/cta_stock/strategies/readme.md
Normal file
35
vnpy/app/cta_stock/strategies/readme.md
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
策略加密
|
||||||
|
|
||||||
|
#windows 下加密并运行
|
||||||
|
|
||||||
|
1.安装Visual StudioComunity 2017,下载地址:
|
||||||
|
|
||||||
|
https://visualstudio.microsoft.com/zh-hans/vs/older-downloads/
|
||||||
|
安装时请勾选“使用C++的桌面开发”。
|
||||||
|
|
||||||
|
2. 在Python环境中安装Cython,打开cmd后输入运行pip install cython即可。
|
||||||
|
|
||||||
|
3. 在”管理员”模式的命令行窗口,在策略所在目录,运行:
|
||||||
|
|
||||||
|
cythonize -i demo_strategy.py
|
||||||
|
|
||||||
|
编译完成后,Demo文件夹下会多出2个新的文件,其中就有已加密的策略文件demo_strategy.cp37-win_amd64.pyd
|
||||||
|
|
||||||
|
改名=> demo_strategy.pyd
|
||||||
|
|
||||||
|
放置 demo_strategy.pyd到windows 生产环境的 strateies目录下。
|
||||||
|
|
||||||
|
#centos/ubuntu 下加密并运行
|
||||||
|
|
||||||
|
|
||||||
|
1. 在Python环境中安装Cython,运行pip install cython即可。
|
||||||
|
|
||||||
|
3. 在策略所在目录,运行:
|
||||||
|
|
||||||
|
cythonize -i demo_strategy.py
|
||||||
|
|
||||||
|
编译完成后,Demo文件夹下会多出2个新的文件,其中就有已加密的策略文件demo_strategy.cp37-win_amd64.so
|
||||||
|
|
||||||
|
改名=> demo_strategy.so
|
||||||
|
|
||||||
|
放置 demo_strategy.so 到centos/ubuntu 生产环境的 strateies目录下。
|
1261
vnpy/app/cta_stock/template.py
Normal file
1261
vnpy/app/cta_stock/template.py
Normal file
File diff suppressed because it is too large
Load Diff
1
vnpy/app/cta_stock/ui/__init__.py
Normal file
1
vnpy/app/cta_stock/ui/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .widget import CtaManager
|
BIN
vnpy/app/cta_stock/ui/cta.ico
Normal file
BIN
vnpy/app/cta_stock/ui/cta.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
464
vnpy/app/cta_stock/ui/widget.py
Normal file
464
vnpy/app/cta_stock/ui/widget.py
Normal file
@ -0,0 +1,464 @@
|
|||||||
|
from vnpy.event import Event, EventEngine
|
||||||
|
from vnpy.trader.engine import MainEngine
|
||||||
|
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
|
||||||
|
from vnpy.trader.ui.widget import (
|
||||||
|
BaseCell,
|
||||||
|
EnumCell,
|
||||||
|
MsgCell,
|
||||||
|
TimeCell,
|
||||||
|
BaseMonitor
|
||||||
|
)
|
||||||
|
from ..base import (
|
||||||
|
APP_NAME,
|
||||||
|
EVENT_CTA_LOG,
|
||||||
|
EVENT_CTA_STOPORDER,
|
||||||
|
EVENT_CTA_STRATEGY
|
||||||
|
)
|
||||||
|
from ..engine import CtaEngine
|
||||||
|
|
||||||
|
|
||||||
|
class CtaManager(QtWidgets.QWidget):
|
||||||
|
""""""
|
||||||
|
|
||||||
|
signal_log = QtCore.pyqtSignal(Event)
|
||||||
|
signal_strategy = QtCore.pyqtSignal(Event)
|
||||||
|
|
||||||
|
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||||
|
super(CtaManager, self).__init__()
|
||||||
|
|
||||||
|
self.main_engine = main_engine
|
||||||
|
self.event_engine = event_engine
|
||||||
|
self.cta_engine = main_engine.get_engine(APP_NAME)
|
||||||
|
|
||||||
|
self.managers = {}
|
||||||
|
|
||||||
|
self.init_ui()
|
||||||
|
self.register_event()
|
||||||
|
self.cta_engine.init_engine()
|
||||||
|
self.update_class_combo()
|
||||||
|
|
||||||
|
def init_ui(self):
|
||||||
|
""""""
|
||||||
|
self.setWindowTitle("CTA策略")
|
||||||
|
|
||||||
|
# Create widgets
|
||||||
|
self.class_combo = QtWidgets.QComboBox()
|
||||||
|
|
||||||
|
add_button = QtWidgets.QPushButton("添加策略")
|
||||||
|
add_button.clicked.connect(self.add_strategy)
|
||||||
|
|
||||||
|
init_button = QtWidgets.QPushButton("全部初始化")
|
||||||
|
init_button.clicked.connect(self.cta_engine.init_all_strategies)
|
||||||
|
|
||||||
|
start_button = QtWidgets.QPushButton("全部启动")
|
||||||
|
start_button.clicked.connect(self.cta_engine.start_all_strategies)
|
||||||
|
|
||||||
|
stop_button = QtWidgets.QPushButton("全部停止")
|
||||||
|
stop_button.clicked.connect(self.cta_engine.stop_all_strategies)
|
||||||
|
|
||||||
|
clear_button = QtWidgets.QPushButton("清空日志")
|
||||||
|
clear_button.clicked.connect(self.clear_log)
|
||||||
|
|
||||||
|
self.scroll_layout = QtWidgets.QVBoxLayout()
|
||||||
|
self.scroll_layout.addStretch()
|
||||||
|
|
||||||
|
scroll_widget = QtWidgets.QWidget()
|
||||||
|
scroll_widget.setLayout(self.scroll_layout)
|
||||||
|
|
||||||
|
scroll_area = QtWidgets.QScrollArea()
|
||||||
|
scroll_area.setWidgetResizable(True)
|
||||||
|
scroll_area.setWidget(scroll_widget)
|
||||||
|
|
||||||
|
self.log_monitor = LogMonitor(self.main_engine, self.event_engine)
|
||||||
|
|
||||||
|
self.stop_order_monitor = StopOrderMonitor(
|
||||||
|
self.main_engine, self.event_engine
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set layout
|
||||||
|
hbox1 = QtWidgets.QHBoxLayout()
|
||||||
|
hbox1.addWidget(self.class_combo)
|
||||||
|
hbox1.addWidget(add_button)
|
||||||
|
hbox1.addStretch()
|
||||||
|
hbox1.addWidget(init_button)
|
||||||
|
hbox1.addWidget(start_button)
|
||||||
|
hbox1.addWidget(stop_button)
|
||||||
|
hbox1.addWidget(clear_button)
|
||||||
|
|
||||||
|
grid = QtWidgets.QGridLayout()
|
||||||
|
grid.addWidget(scroll_area, 0, 0, 2, 1)
|
||||||
|
grid.addWidget(self.stop_order_monitor, 0, 1)
|
||||||
|
grid.addWidget(self.log_monitor, 1, 1)
|
||||||
|
|
||||||
|
vbox = QtWidgets.QVBoxLayout()
|
||||||
|
vbox.addLayout(hbox1)
|
||||||
|
vbox.addLayout(grid)
|
||||||
|
|
||||||
|
self.setLayout(vbox)
|
||||||
|
|
||||||
|
def update_class_combo(self):
|
||||||
|
""""""
|
||||||
|
self.class_combo.addItems(
|
||||||
|
self.cta_engine.get_all_strategy_class_names()
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_event(self):
|
||||||
|
""""""
|
||||||
|
self.signal_strategy.connect(self.process_strategy_event)
|
||||||
|
|
||||||
|
self.event_engine.register(
|
||||||
|
EVENT_CTA_STRATEGY, self.signal_strategy.emit
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_strategy_event(self, event):
|
||||||
|
"""
|
||||||
|
Update strategy status onto its monitor.
|
||||||
|
"""
|
||||||
|
data = event.data
|
||||||
|
strategy_name = data["strategy_name"]
|
||||||
|
|
||||||
|
if strategy_name in self.managers:
|
||||||
|
manager = self.managers[strategy_name]
|
||||||
|
manager.update_data(data)
|
||||||
|
else:
|
||||||
|
manager = StrategyManager(self, self.cta_engine, data)
|
||||||
|
self.scroll_layout.insertWidget(0, manager)
|
||||||
|
self.managers[strategy_name] = manager
|
||||||
|
|
||||||
|
def remove_strategy(self, strategy_name):
|
||||||
|
""""""
|
||||||
|
manager = self.managers.pop(strategy_name)
|
||||||
|
manager.deleteLater()
|
||||||
|
|
||||||
|
def add_strategy(self):
|
||||||
|
""""""
|
||||||
|
class_name = str(self.class_combo.currentText())
|
||||||
|
if not class_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
parameters = self.cta_engine.get_strategy_class_parameters(class_name)
|
||||||
|
editor = SettingEditor(parameters, class_name=class_name)
|
||||||
|
n = editor.exec_()
|
||||||
|
|
||||||
|
if n == editor.Accepted:
|
||||||
|
setting = editor.get_setting()
|
||||||
|
vt_symbol = setting.pop("vt_symbol")
|
||||||
|
strategy_name = setting.pop("strategy_name")
|
||||||
|
auto_init = setting.pop("auto_init", False)
|
||||||
|
auto_start = setting.pop("auto_start", False)
|
||||||
|
self.cta_engine.add_strategy(
|
||||||
|
class_name, strategy_name, vt_symbol, setting, auto_init, auto_start
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear_log(self):
|
||||||
|
""""""
|
||||||
|
self.log_monitor.setRowCount(0)
|
||||||
|
|
||||||
|
def show(self):
|
||||||
|
""""""
|
||||||
|
self.showMaximized()
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyManager(QtWidgets.QFrame):
|
||||||
|
"""
|
||||||
|
Manager for a strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict
|
||||||
|
):
|
||||||
|
""""""
|
||||||
|
super(StrategyManager, self).__init__()
|
||||||
|
|
||||||
|
self.cta_manager = cta_manager
|
||||||
|
self.cta_engine = cta_engine
|
||||||
|
|
||||||
|
self.strategy_name = data["strategy_name"]
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
self.init_ui()
|
||||||
|
|
||||||
|
def init_ui(self):
|
||||||
|
""""""
|
||||||
|
self.setFixedHeight(300)
|
||||||
|
self.setFrameShape(self.Box)
|
||||||
|
self.setLineWidth(1)
|
||||||
|
|
||||||
|
init_button = QtWidgets.QPushButton("初始化")
|
||||||
|
init_button.clicked.connect(self.init_strategy)
|
||||||
|
|
||||||
|
start_button = QtWidgets.QPushButton("启动")
|
||||||
|
start_button.clicked.connect(self.start_strategy)
|
||||||
|
|
||||||
|
stop_button = QtWidgets.QPushButton("停止")
|
||||||
|
stop_button.clicked.connect(self.stop_strategy)
|
||||||
|
|
||||||
|
edit_button = QtWidgets.QPushButton("编辑")
|
||||||
|
edit_button.clicked.connect(self.edit_strategy)
|
||||||
|
|
||||||
|
remove_button = QtWidgets.QPushButton("移除")
|
||||||
|
remove_button.clicked.connect(self.remove_strategy)
|
||||||
|
|
||||||
|
reload_button = QtWidgets.QPushButton("重载")
|
||||||
|
reload_button.clicked.connect(self.reload_strategy)
|
||||||
|
|
||||||
|
save_button = QtWidgets.QPushButton("保存")
|
||||||
|
save_button.clicked.connect(self.save_strategy)
|
||||||
|
|
||||||
|
strategy_name = self._data["strategy_name"]
|
||||||
|
vt_symbol = self._data["vt_symbol"]
|
||||||
|
class_name = self._data["class_name"]
|
||||||
|
author = self._data["author"]
|
||||||
|
|
||||||
|
label_text = (
|
||||||
|
f"{strategy_name} - {vt_symbol} ({class_name} by {author})"
|
||||||
|
)
|
||||||
|
label = QtWidgets.QLabel(label_text)
|
||||||
|
label.setAlignment(QtCore.Qt.AlignCenter)
|
||||||
|
|
||||||
|
self.parameters_monitor = DataMonitor(self._data["parameters"])
|
||||||
|
self.variables_monitor = DataMonitor(self._data["variables"])
|
||||||
|
|
||||||
|
hbox = QtWidgets.QHBoxLayout()
|
||||||
|
hbox.addWidget(init_button)
|
||||||
|
hbox.addWidget(start_button)
|
||||||
|
hbox.addWidget(stop_button)
|
||||||
|
hbox.addWidget(edit_button)
|
||||||
|
hbox.addWidget(remove_button)
|
||||||
|
hbox.addWidget(reload_button)
|
||||||
|
hbox.addWidget(save_button)
|
||||||
|
|
||||||
|
vbox = QtWidgets.QVBoxLayout()
|
||||||
|
vbox.addWidget(label)
|
||||||
|
vbox.addLayout(hbox)
|
||||||
|
vbox.addWidget(self.parameters_monitor)
|
||||||
|
vbox.addWidget(self.variables_monitor)
|
||||||
|
self.setLayout(vbox)
|
||||||
|
|
||||||
|
def update_data(self, data: dict):
|
||||||
|
""""""
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
self.parameters_monitor.update_data(data["parameters"])
|
||||||
|
self.variables_monitor.update_data(data["variables"])
|
||||||
|
|
||||||
|
def init_strategy(self):
|
||||||
|
""""""
|
||||||
|
self.cta_engine.init_strategy(self.strategy_name)
|
||||||
|
|
||||||
|
def start_strategy(self):
|
||||||
|
""""""
|
||||||
|
self.cta_engine.start_strategy(self.strategy_name)
|
||||||
|
|
||||||
|
def stop_strategy(self):
|
||||||
|
""""""
|
||||||
|
self.cta_engine.stop_strategy(self.strategy_name)
|
||||||
|
|
||||||
|
def edit_strategy(self):
|
||||||
|
""""""
|
||||||
|
strategy_name = self._data["strategy_name"]
|
||||||
|
|
||||||
|
parameters = self.cta_engine.get_strategy_parameters(strategy_name)
|
||||||
|
editor = SettingEditor(parameters, strategy_name=strategy_name)
|
||||||
|
n = editor.exec_()
|
||||||
|
|
||||||
|
if n == editor.Accepted:
|
||||||
|
setting = editor.get_setting()
|
||||||
|
self.cta_engine.edit_strategy(strategy_name, setting)
|
||||||
|
|
||||||
|
def remove_strategy(self):
|
||||||
|
""""""
|
||||||
|
result = self.cta_engine.remove_strategy(self.strategy_name)
|
||||||
|
|
||||||
|
# Only remove strategy gui manager if it has been removed from engine
|
||||||
|
if result:
|
||||||
|
self.cta_manager.remove_strategy(self.strategy_name)
|
||||||
|
|
||||||
|
def reload_strategy(self):
|
||||||
|
"""重新加载策略"""
|
||||||
|
self.cta_engine.reload_strategy(self.strategy_name)
|
||||||
|
|
||||||
|
def save_strategy(self):
|
||||||
|
self.cta_engine.save_strategy_data(self.strategy_name)
|
||||||
|
|
||||||
|
|
||||||
|
class DataMonitor(QtWidgets.QTableWidget):
|
||||||
|
"""
|
||||||
|
Table monitor for parameters and variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data: dict):
|
||||||
|
""""""
|
||||||
|
super(DataMonitor, self).__init__()
|
||||||
|
|
||||||
|
self._data = data
|
||||||
|
self.cells = {}
|
||||||
|
|
||||||
|
self.init_ui()
|
||||||
|
|
||||||
|
def init_ui(self):
|
||||||
|
""""""
|
||||||
|
labels = list(self._data.keys())
|
||||||
|
self.setColumnCount(len(labels))
|
||||||
|
self.setHorizontalHeaderLabels(labels)
|
||||||
|
|
||||||
|
self.setRowCount(1)
|
||||||
|
self.verticalHeader().setSectionResizeMode(
|
||||||
|
QtWidgets.QHeaderView.Stretch
|
||||||
|
)
|
||||||
|
self.verticalHeader().setVisible(False)
|
||||||
|
self.setEditTriggers(self.NoEditTriggers)
|
||||||
|
|
||||||
|
for column, name in enumerate(self._data.keys()):
|
||||||
|
value = self._data[name]
|
||||||
|
|
||||||
|
cell = QtWidgets.QTableWidgetItem(str(value))
|
||||||
|
cell.setTextAlignment(QtCore.Qt.AlignCenter)
|
||||||
|
|
||||||
|
self.setItem(0, column, cell)
|
||||||
|
self.cells[name] = cell
|
||||||
|
|
||||||
|
def update_data(self, data: dict):
|
||||||
|
""""""
|
||||||
|
for name, value in data.items():
|
||||||
|
cell = self.cells[name]
|
||||||
|
cell.setText(str(value))
|
||||||
|
|
||||||
|
|
||||||
|
class StopOrderMonitor(BaseMonitor):
|
||||||
|
"""
|
||||||
|
Monitor for local stop order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_type = EVENT_CTA_STOPORDER
|
||||||
|
data_key = "stop_orderid"
|
||||||
|
sorting = True
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"stop_orderid": {
|
||||||
|
"display": "停止委托号",
|
||||||
|
"cell": BaseCell,
|
||||||
|
"update": False,
|
||||||
|
},
|
||||||
|
"vt_orderids": {"display": "限价委托号", "cell": BaseCell, "update": True},
|
||||||
|
"vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False},
|
||||||
|
"direction": {"display": "方向", "cell": EnumCell, "update": False},
|
||||||
|
"offset": {"display": "开平", "cell": EnumCell, "update": False},
|
||||||
|
"price": {"display": "价格", "cell": BaseCell, "update": False},
|
||||||
|
"volume": {"display": "数量", "cell": BaseCell, "update": False},
|
||||||
|
"status": {"display": "状态", "cell": EnumCell, "update": True},
|
||||||
|
"lock": {"display": "锁仓", "cell": BaseCell, "update": False},
|
||||||
|
"strategy_name": {"display": "策略名", "cell": BaseCell, "update": False},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LogMonitor(BaseMonitor):
|
||||||
|
"""
|
||||||
|
Monitor for log data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_type = EVENT_CTA_LOG
|
||||||
|
data_key = ""
|
||||||
|
sorting = False
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"time": {"display": "时间", "cell": TimeCell, "update": False},
|
||||||
|
"msg": {"display": "信息", "cell": MsgCell, "update": False},
|
||||||
|
}
|
||||||
|
|
||||||
|
def init_ui(self):
|
||||||
|
"""
|
||||||
|
Stretch last column.
|
||||||
|
"""
|
||||||
|
super(LogMonitor, self).init_ui()
|
||||||
|
|
||||||
|
self.horizontalHeader().setSectionResizeMode(
|
||||||
|
1, QtWidgets.QHeaderView.Stretch
|
||||||
|
)
|
||||||
|
|
||||||
|
def insert_new_row(self, data):
|
||||||
|
"""
|
||||||
|
Insert a new row at the top of table.
|
||||||
|
"""
|
||||||
|
super(LogMonitor, self).insert_new_row(data)
|
||||||
|
self.resizeRowToContents(0)
|
||||||
|
|
||||||
|
|
||||||
|
class SettingEditor(QtWidgets.QDialog):
|
||||||
|
"""
|
||||||
|
For creating new strategy and editing strategy parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, parameters: dict, strategy_name: str = "", class_name: str = ""
|
||||||
|
):
|
||||||
|
""""""
|
||||||
|
super(SettingEditor, self).__init__()
|
||||||
|
|
||||||
|
self.parameters = parameters
|
||||||
|
self.strategy_name = strategy_name
|
||||||
|
self.class_name = class_name
|
||||||
|
|
||||||
|
self.edits = {}
|
||||||
|
|
||||||
|
self.init_ui()
|
||||||
|
|
||||||
|
def init_ui(self):
|
||||||
|
""""""
|
||||||
|
form = QtWidgets.QFormLayout()
|
||||||
|
|
||||||
|
# Add vt_symbol and name edit if add new strategy
|
||||||
|
if self.class_name:
|
||||||
|
self.setWindowTitle(f"添加策略:{self.class_name}")
|
||||||
|
button_text = "添加"
|
||||||
|
parameters = {"strategy_name": "", "vt_symbol": "", "auto_init": True, "auto_start": True}
|
||||||
|
parameters.update(self.parameters)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.setWindowTitle(f"参数编辑:{self.strategy_name}")
|
||||||
|
button_text = "确定"
|
||||||
|
parameters = self.parameters
|
||||||
|
|
||||||
|
for name, value in parameters.items():
|
||||||
|
type_ = type(value)
|
||||||
|
|
||||||
|
edit = QtWidgets.QLineEdit(str(value))
|
||||||
|
if type_ is int:
|
||||||
|
validator = QtGui.QIntValidator()
|
||||||
|
edit.setValidator(validator)
|
||||||
|
elif type_ is float:
|
||||||
|
validator = QtGui.QDoubleValidator()
|
||||||
|
edit.setValidator(validator)
|
||||||
|
|
||||||
|
form.addRow(f"{name} {type_}", edit)
|
||||||
|
|
||||||
|
self.edits[name] = (edit, type_)
|
||||||
|
|
||||||
|
button = QtWidgets.QPushButton(button_text)
|
||||||
|
button.clicked.connect(self.accept)
|
||||||
|
form.addRow(button)
|
||||||
|
|
||||||
|
self.setLayout(form)
|
||||||
|
|
||||||
|
def get_setting(self):
|
||||||
|
""""""
|
||||||
|
setting = {}
|
||||||
|
|
||||||
|
if self.class_name:
|
||||||
|
setting["class_name"] = self.class_name
|
||||||
|
|
||||||
|
for name, tp in self.edits.items():
|
||||||
|
edit, type_ = tp
|
||||||
|
value_text = edit.text()
|
||||||
|
|
||||||
|
if type_ == bool:
|
||||||
|
if value_text == "True":
|
||||||
|
value = True
|
||||||
|
else:
|
||||||
|
value = False
|
||||||
|
else:
|
||||||
|
value = type_(value_text)
|
||||||
|
|
||||||
|
setting[name] = value
|
||||||
|
|
||||||
|
return setting
|
Loading…
Reference in New Issue
Block a user