[新功能] 股票Cta App
This commit is contained in:
parent
8e45b0f4b3
commit
f8b24c9be1
34
vnpy/app/cta_stock/__init__.py
Normal file
34
vnpy/app/cta_stock/__init__.py
Normal file
@ -0,0 +1,34 @@
|
||||
from pathlib import Path
|
||||
|
||||
from vnpy.trader.app import BaseApp
|
||||
from .base import APP_NAME, StopOrder
|
||||
|
||||
from .engine import CtaEngine
|
||||
|
||||
from .template import (
|
||||
Exchange,
|
||||
Direction,
|
||||
Offset,
|
||||
Status,
|
||||
Color,
|
||||
Interval,
|
||||
TickData,
|
||||
BarData,
|
||||
TradeData,
|
||||
OrderData,
|
||||
CtaPolicy,
|
||||
StockPolicy,
|
||||
CtaTemplate, CtaStockTemplate) # noqa
|
||||
|
||||
from vnpy.trader.utility import BarGenerator, ArrayManager # noqa
|
||||
|
||||
|
||||
class CtaStockApp(BaseApp):
|
||||
""""""
|
||||
app_name = APP_NAME
|
||||
app_module = __module__
|
||||
app_path = Path(__file__).parent
|
||||
display_name = "股票CTA策略"
|
||||
engine_class = CtaEngine
|
||||
widget_name = "CtaManager"
|
||||
icon_name = "cta.ico"
|
2070
vnpy/app/cta_stock/back_testing.py
Normal file
2070
vnpy/app/cta_stock/back_testing.py
Normal file
File diff suppressed because it is too large
Load Diff
53
vnpy/app/cta_stock/base.py
Normal file
53
vnpy/app/cta_stock/base.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
Defines constants and objects used in CtaCrypto App.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from datetime import timedelta
|
||||
from vnpy.trader.constant import Direction, Offset, Interval
|
||||
|
||||
APP_NAME = "CtaStock"
|
||||
STOPORDER_PREFIX = "STOP"
|
||||
|
||||
|
||||
class StopOrderStatus(Enum):
|
||||
WAITING = "等待中"
|
||||
CANCELLED = "已撤销"
|
||||
TRIGGERED = "已触发"
|
||||
|
||||
|
||||
class EngineType(Enum):
|
||||
LIVE = "实盘"
|
||||
BACKTESTING = "回测"
|
||||
|
||||
|
||||
class BacktestingMode(Enum):
|
||||
BAR = 1
|
||||
TICK = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopOrder:
|
||||
vt_symbol: str
|
||||
direction: Direction
|
||||
offset: Offset
|
||||
price: float
|
||||
volume: float
|
||||
stop_orderid: str
|
||||
strategy_name: str
|
||||
lock: bool = False
|
||||
vt_orderids: list = field(default_factory=list)
|
||||
status: StopOrderStatus = StopOrderStatus.WAITING
|
||||
gateway_name: str = None
|
||||
|
||||
|
||||
EVENT_CTA_LOG = "eCtaLog"
|
||||
EVENT_CTA_STRATEGY = "eCtaStrategy"
|
||||
EVENT_CTA_STOPORDER = "eCtaStopOrder"
|
||||
|
||||
INTERVAL_DELTA_MAP = {
|
||||
Interval.MINUTE: timedelta(minutes=1),
|
||||
Interval.HOUR: timedelta(hours=1),
|
||||
Interval.DAILY: timedelta(days=1),
|
||||
}
|
1717
vnpy/app/cta_stock/engine.py
Normal file
1717
vnpy/app/cta_stock/engine.py
Normal file
File diff suppressed because it is too large
Load Diff
484
vnpy/app/cta_stock/portfolio_testing.py
Normal file
484
vnpy/app/cta_stock/portfolio_testing.py
Normal file
@ -0,0 +1,484 @@
|
||||
# encoding: UTF-8
|
||||
|
||||
'''
|
||||
本文件中包含的是CTA模块的组合回测引擎,回测引擎的API和CTA引擎一致,
|
||||
可以使用和实盘相同的代码进行回测。
|
||||
华富资产 李来佳
|
||||
'''
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import gc
|
||||
import pandas as pd
|
||||
import traceback
|
||||
import random
|
||||
import bz2
|
||||
import pickle
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from time import sleep
|
||||
|
||||
from vnpy.trader.object import (
|
||||
TickData,
|
||||
BarData,
|
||||
RenkoBarData,
|
||||
)
|
||||
from vnpy.trader.constant import (
|
||||
Exchange,
|
||||
)
|
||||
|
||||
from vnpy.trader.utility import (
|
||||
get_trading_date,
|
||||
extract_vt_symbol,
|
||||
)
|
||||
|
||||
from .back_testing import BackTestingEngine
|
||||
|
||||
|
||||
class PortfolioTestingEngine(BackTestingEngine):
|
||||
"""
|
||||
CTA组合回测引擎, 使用回测引擎作为父类
|
||||
函数接口和策略引擎保持一样,
|
||||
从而实现同一套代码从回测到实盘。
|
||||
针对1分钟bar的回测 或者tick回测
|
||||
导入CTA_Settings
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, event_engine=None):
|
||||
"""Constructor"""
|
||||
super().__init__(event_engine)
|
||||
|
||||
self.bar_csv_file = {}
|
||||
self.bar_df_dict = {} # 历史数据的df,回测用
|
||||
self.bar_df = None # 历史数据的df,时间+symbol作为组合索引
|
||||
self.bar_interval_seconds = 60 # bar csv文件,属于K线类型,K线的周期(秒数),缺省是1分钟
|
||||
|
||||
self.tick_path = None # tick级别回测, 路径
|
||||
|
||||
def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None):
|
||||
"""
|
||||
加载回测bar数据到DataFrame
|
||||
1. 增加前复权/后复权
|
||||
:param vt_symbol:
|
||||
:param bar_file:
|
||||
:param data_start_date:
|
||||
:param data_end_date:
|
||||
:return:
|
||||
"""
|
||||
self.output(u'loading {} from {}'.format(vt_symbol, bar_file))
|
||||
if vt_symbol in self.bar_df_dict:
|
||||
return True
|
||||
|
||||
if bar_file is None or not os.path.exists(bar_file):
|
||||
self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(vt_symbol, bar_file))
|
||||
return False
|
||||
|
||||
try:
|
||||
data_types = {
|
||||
"datetime": str,
|
||||
"open": float,
|
||||
"high": float,
|
||||
"low": float,
|
||||
"close": float,
|
||||
"open_interest": float,
|
||||
"volume": float,
|
||||
"instrument_id": str,
|
||||
"symbol": str,
|
||||
"total_turnover": float,
|
||||
"limit_down": float,
|
||||
"limit_up": float,
|
||||
"trading_day": str,
|
||||
"date": str,
|
||||
"time": str
|
||||
}
|
||||
# 加载csv文件 =》 dateframe
|
||||
symbol_df = pd.read_csv(bar_file, dtype=data_types)
|
||||
# 转换时间,str =》 datetime
|
||||
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
|
||||
# 设置时间为索引
|
||||
symbol_df = symbol_df.set_index("datetime")
|
||||
|
||||
# 裁剪数据
|
||||
symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date]
|
||||
|
||||
# 复权转换
|
||||
adj_list = self.adjust_factors.get(vt_symbol, [])
|
||||
# 按照结束日期,裁剪复权记录
|
||||
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date]
|
||||
|
||||
if adj_list:
|
||||
self.write_log(f'需要对{vt_symbol}进行前复权处理')
|
||||
for row in adj_list:
|
||||
row.update({'dividOperateDate': row.get('dividOperateDate') + ' 09:31:00'})
|
||||
# list -> dataframe, 转换复权日期格式
|
||||
adj_data = pd.DataFrame(adj_list)
|
||||
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S")
|
||||
adj_data = adj_data.set_index("dividOperateDate")
|
||||
# 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
|
||||
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore')
|
||||
|
||||
# 添加到待合并dataframe dict中
|
||||
self.bar_df_dict.update({vt_symbol: symbol_df})
|
||||
|
||||
except Exception as ex:
|
||||
self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex))
|
||||
self.output(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def comine_bar_df(self):
|
||||
"""
|
||||
合并所有回测合约的bar DataFrame =》集中的DataFrame
|
||||
把bar_df_dict =》bar_df
|
||||
:return:
|
||||
"""
|
||||
self.output('comine_df')
|
||||
self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index()
|
||||
self.bar_df_dict.clear()
|
||||
|
||||
def prepare_env(self, test_setting):
|
||||
"""
|
||||
回测环境准备
|
||||
:param test_setting:
|
||||
:return:
|
||||
"""
|
||||
self.output(f'准备组合回测环境')
|
||||
|
||||
# 调用父类回测环境
|
||||
super().prepare_env(test_setting)
|
||||
|
||||
def prepare_data(self, data_dict):
|
||||
"""
|
||||
准备组合数据
|
||||
:param data_dict: 合约得配置参数
|
||||
:return:
|
||||
"""
|
||||
# 调用回测引擎,跟新合约得数据
|
||||
super().prepare_data(data_dict)
|
||||
|
||||
if len(data_dict) == 0:
|
||||
self.write_log(u'请指定回测数据和文件')
|
||||
return
|
||||
|
||||
if self.mode == 'tick':
|
||||
return
|
||||
|
||||
# 检查/更新需要回测的bar文件
|
||||
for vt_symbol, symbol_info in data_dict.items():
|
||||
self.write_log(u'配置{}数据:{}'.format(vt_symbol, symbol_info))
|
||||
|
||||
bar_file = symbol_info.get('bar_file', None)
|
||||
|
||||
if bar_file is None:
|
||||
self.write_error(u'{}没有配置数据文件')
|
||||
continue
|
||||
|
||||
if not os.path.isfile(bar_file):
|
||||
self.write_log(u'{0}文件不存在'.format(bar_file))
|
||||
continue
|
||||
|
||||
self.bar_csv_file.update({vt_symbol: bar_file})
|
||||
|
||||
def run_portfolio_test(self, strategy_setting: dict = {}):
|
||||
"""
|
||||
运行组合回测
|
||||
"""
|
||||
if not self.strategy_start_date:
|
||||
self.write_error(u'回测开始日期未设置。')
|
||||
return
|
||||
|
||||
if len(strategy_setting) == 0:
|
||||
self.write_error('未提供有效配置策略实例')
|
||||
return
|
||||
|
||||
self.cur_capital = self.init_capital # 更新设置期初资金
|
||||
if not self.data_end_date:
|
||||
self.data_end_date = datetime.today()
|
||||
|
||||
# 保存回测设置/策略设置/任务ID至数据库
|
||||
self.save_setting_to_mongo()
|
||||
|
||||
self.write_log(u'开始组合回测')
|
||||
|
||||
for strategy_name, strategy_setting in strategy_setting.items():
|
||||
self.load_strategy(strategy_name, strategy_setting)
|
||||
|
||||
self.write_log(u'策略初始化完成')
|
||||
|
||||
self.write_log(u'开始回放数据')
|
||||
|
||||
self.write_log(u'开始回测:{} ~ {}'.format(self.data_start_date, self.data_end_date))
|
||||
|
||||
if self.mode == 'bar':
|
||||
self.run_bar_test()
|
||||
else:
|
||||
self.run_tick_test()
|
||||
|
||||
def run_bar_test(self):
|
||||
"""使用bar进行组合回测"""
|
||||
testdays = (self.data_end_date - self.data_start_date).days
|
||||
|
||||
if testdays < 1:
|
||||
self.write_log(u'回测时间不足')
|
||||
return
|
||||
|
||||
# 加载数据
|
||||
for vt_symbol in self.symbol_strategy_map.keys():
|
||||
self.load_bar_csv_to_df(vt_symbol, self.bar_csv_file.get(vt_symbol))
|
||||
|
||||
# 合并数据
|
||||
self.comine_bar_df()
|
||||
|
||||
last_trading_day = None
|
||||
bars_dt = None
|
||||
bars_same_dt = []
|
||||
|
||||
gc_collect_days = 0
|
||||
|
||||
try:
|
||||
for (dt, vt_symbol), bar_data in self.bar_df.iterrows():
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
if symbol.startswith('future_renko'):
|
||||
bar_datetime = dt
|
||||
bar = RenkoBarData(
|
||||
gateway_name='backtesting',
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
datetime=bar_datetime
|
||||
)
|
||||
bar.seconds = float(bar_data.get('seconds', 0))
|
||||
bar.high_seconds = float(bar_data.get('high_seconds', 0)) # 当前Bar的上限秒数
|
||||
bar.low_seconds = float(bar_data.get('low_seconds', 0)) # 当前bar的下限秒数
|
||||
bar.height = float(bar_data.get('height', 0)) # 当前Bar的高度限制
|
||||
bar.up_band = float(bar_data.get('up_band', 0)) # 高位区域的基线
|
||||
bar.down_band = float(bar_data.get('down_band', 0)) # 低位区域的基线
|
||||
bar.low_time = bar_data.get('low_time', None) # 最后一次进入低位区域的时间
|
||||
bar.high_time = bar_data.get('high_time', None) # 最后一次进入高位区域的时间
|
||||
else:
|
||||
bar_datetime = dt - timedelta(seconds=self.bar_interval_seconds)
|
||||
|
||||
bar = BarData(
|
||||
gateway_name='backtesting',
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
datetime=bar_datetime
|
||||
)
|
||||
if 'open' in bar_data:
|
||||
bar.open_price = float(bar_data['open'])
|
||||
bar.close_price = float(bar_data['close'])
|
||||
bar.high_price = float(bar_data['high'])
|
||||
bar.low_price = float(bar_data['low'])
|
||||
else:
|
||||
bar.open_price = float(bar_data['open_price'])
|
||||
bar.close_price = float(bar_data['close_price'])
|
||||
bar.high_price = float(bar_data['high_price'])
|
||||
bar.low_price = float(bar_data['low_price'])
|
||||
|
||||
bar.volume = int(bar_data['volume'])
|
||||
bar.date = dt.strftime('%Y-%m-%d')
|
||||
bar.time = dt.strftime('%H:%M:%S')
|
||||
str_td = str(bar_data.get('trading_day', ''))
|
||||
if len(str_td) == 8:
|
||||
bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8]
|
||||
else:
|
||||
bar.trading_day = bar.date
|
||||
|
||||
if last_trading_day != bar.trading_day:
|
||||
self.output(u'回测数据日期:{},资金:{}'.format(bar.trading_day, self.net_capital))
|
||||
if self.strategy_start_date > bar.datetime:
|
||||
last_trading_day = bar.trading_day
|
||||
|
||||
# bar时间与队列时间一致,添加到队列中
|
||||
if dt == bars_dt:
|
||||
bars_same_dt.append(bar)
|
||||
continue
|
||||
else:
|
||||
# bar时间与队列时间不一致,先推送队列的bars
|
||||
random.shuffle(bars_same_dt)
|
||||
for _bar_ in bars_same_dt:
|
||||
self.new_bar(_bar_)
|
||||
|
||||
# 创建新的队列
|
||||
bars_same_dt = [bar]
|
||||
bars_dt = dt
|
||||
|
||||
# 更新每日净值
|
||||
if self.strategy_start_date <= dt <= self.data_end_date:
|
||||
if last_trading_day != bar.trading_day:
|
||||
if last_trading_day is not None:
|
||||
self.saving_daily_data(datetime.strptime(last_trading_day, '%Y-%m-%d'), self.cur_capital,
|
||||
self.max_net_capital, self.total_commission)
|
||||
last_trading_day = bar.trading_day
|
||||
|
||||
# 第二个交易日,撤单
|
||||
self.cancel_orders()
|
||||
# 更新持仓缓存
|
||||
self.update_position_yd()
|
||||
|
||||
gc_collect_days += 1
|
||||
if gc_collect_days >= 10:
|
||||
# 执行内存回收
|
||||
gc.collect()
|
||||
sleep(1)
|
||||
gc_collect_days = 0
|
||||
|
||||
if self.net_capital < 0:
|
||||
self.write_error(u'净值低于0,回测停止')
|
||||
self.output(u'净值低于0,回测停止')
|
||||
return
|
||||
|
||||
self.write_log(u'bar数据回放完成')
|
||||
if last_trading_day is not None:
|
||||
self.saving_daily_data(datetime.strptime(last_trading_day, '%Y-%m-%d'), self.cur_capital,
|
||||
self.max_net_capital, self.total_commission)
|
||||
except Exception as ex:
|
||||
self.write_error(u'回测异常导致停止:{}'.format(str(ex)))
|
||||
self.write_error(u'{},{}'.format(str(ex), traceback.format_exc()))
|
||||
print(str(ex), file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
def load_bz2_cache(self, cache_folder, cache_symbol, cache_date):
|
||||
"""加载缓存bz2数据"""
|
||||
if not os.path.exists(cache_folder):
|
||||
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder))
|
||||
return None
|
||||
cache_folder_year_month = os.path.join(cache_folder, cache_date[:6])
|
||||
if not os.path.exists(cache_folder_year_month):
|
||||
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder_year_month))
|
||||
return None
|
||||
|
||||
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkb2'.format(cache_symbol, cache_date))
|
||||
if not os.path.isfile(cache_file):
|
||||
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkz2'.format(cache_symbol, cache_date))
|
||||
if not os.path.isfile(cache_file):
|
||||
self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file))
|
||||
return None
|
||||
|
||||
with bz2.BZ2File(cache_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
return data
|
||||
|
||||
return None
|
||||
|
||||
def get_day_tick_df(self, test_day):
|
||||
"""获取某一天得所有合约tick"""
|
||||
tick_data_dict = {}
|
||||
|
||||
for vt_symbol in list(self.symbol_strategy_map.keys()):
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
tick_list = self.load_bz2_cache(cache_folder=self.tick_path,
|
||||
cache_symbol=symbol,
|
||||
cache_date=test_day.strftime('%Y%m%d'))
|
||||
if not tick_list or len(tick_list) == 0:
|
||||
continue
|
||||
|
||||
symbol_tick_df = pd.DataFrame(tick_list)
|
||||
# 缓存文件中,datetime字段,已经是datetime格式
|
||||
# 暂时根据时间去重,没有汇总volume
|
||||
symbol_tick_df.drop_duplicates(subset=['datetime'], keep='first', inplace=True)
|
||||
symbol_tick_df.set_index('datetime', inplace=True)
|
||||
|
||||
tick_data_dict.update({vt_symbol: symbol_tick_df})
|
||||
|
||||
if len(tick_data_dict) == 0:
|
||||
return None
|
||||
|
||||
tick_df = pd.concat(tick_data_dict, axis=0).swaplevel(0, 1).sort_index()
|
||||
|
||||
return tick_df
|
||||
|
||||
def run_tick_test(self):
|
||||
"""运行tick级别组合回测"""
|
||||
testdays = (self.data_end_date - self.data_start_date).days
|
||||
|
||||
if testdays < 1:
|
||||
self.write_log(u'回测时间不足')
|
||||
return
|
||||
|
||||
gc_collect_days = 0
|
||||
|
||||
# 循环每一天
|
||||
for i in range(0, testdays):
|
||||
test_day = self.data_start_date + timedelta(days=i)
|
||||
|
||||
combined_df = self.get_day_tick_df(test_day)
|
||||
|
||||
if combined_df is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
for (dt, vt_symbol), tick_data in combined_df.iterrows():
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
tick = TickData(
|
||||
gateway_name='backtesting',
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
datetime=dt,
|
||||
date=dt.strftime('%Y-%m-%d'),
|
||||
time=dt.strftime('%H:%M:%S.%f'),
|
||||
trading_day=test_day.strftime('%Y-%m-%d'),
|
||||
last_price=tick_data['price'],
|
||||
volume=tick_data['volume']
|
||||
)
|
||||
|
||||
self.new_tick(tick)
|
||||
|
||||
# 结束一个交易日后,更新每日净值
|
||||
self.saving_daily_data(test_day,
|
||||
self.cur_capital,
|
||||
self.max_net_capital,
|
||||
self.total_commission)
|
||||
|
||||
self.cancel_orders()
|
||||
# 更新持仓缓存
|
||||
self.update_position_yd()
|
||||
|
||||
gc_collect_days += 1
|
||||
if gc_collect_days >= 10:
|
||||
# 执行内存回收
|
||||
gc.collect()
|
||||
sleep(1)
|
||||
gc_collect_days = 0
|
||||
|
||||
if self.net_capital < 0:
|
||||
self.write_error(u'净值低于0,回测停止')
|
||||
self.output(u'净值低于0,回测停止')
|
||||
return
|
||||
|
||||
except Exception as ex:
|
||||
self.write_error(u'回测异常导致停止:{}'.format(str(ex)))
|
||||
self.write_error(u'{},{}'.format(str(ex), traceback.format_exc()))
|
||||
print(str(ex), file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
self.write_log(u'tick数据回放完成')
|
||||
|
||||
|
||||
def single_test(test_setting: dict, strategy_setting: dict):
|
||||
"""
|
||||
单一回测
|
||||
: test_setting, 组合回测所需的配置,包括合约信息,数据bar信息,回测时间,资金等。
|
||||
:strategy_setting, dict, 一个或多个策略配置
|
||||
"""
|
||||
# 创建组合回测引擎
|
||||
engine = PortfolioTestingEngine()
|
||||
|
||||
engine.prepare_env(test_setting)
|
||||
try:
|
||||
engine.run_portfolio_test(strategy_setting)
|
||||
# 回测结果,保存
|
||||
engine.show_backtesting_result()
|
||||
|
||||
except Exception as ex:
|
||||
print('组合回测异常{}'.format(str(ex)))
|
||||
traceback.print_exc()
|
||||
engine.save_fail_to_mongo(f'回测异常{str(ex)}')
|
||||
return False
|
||||
|
||||
print('测试结束')
|
||||
return True
|
0
vnpy/app/cta_stock/strategies/__init__.py
Normal file
0
vnpy/app/cta_stock/strategies/__init__.py
Normal file
35
vnpy/app/cta_stock/strategies/readme.md
Normal file
35
vnpy/app/cta_stock/strategies/readme.md
Normal file
@ -0,0 +1,35 @@
|
||||
策略加密
|
||||
|
||||
#windows 下加密并运行
|
||||
|
||||
1.安装Visual StudioComunity 2017,下载地址:
|
||||
|
||||
https://visualstudio.microsoft.com/zh-hans/vs/older-downloads/
|
||||
安装时请勾选“使用C++的桌面开发”。
|
||||
|
||||
2. 在Python环境中安装Cython,打开cmd后输入运行pip install cython即可。
|
||||
|
||||
3. 在”管理员”模式的命令行窗口,在策略所在目录,运行:
|
||||
|
||||
cythonize -i demo_strategy.py
|
||||
|
||||
编译完成后,Demo文件夹下会多出2个新的文件,其中就有已加密的策略文件demo_strategy.cp37-win_amd64.pyd
|
||||
|
||||
改名=> demo_strategy.pyd
|
||||
|
||||
放置 demo_strategy.pyd到windows 生产环境的 strateies目录下。
|
||||
|
||||
#centos/ubuntu 下加密并运行
|
||||
|
||||
|
||||
1. 在Python环境中安装Cython,运行pip install cython即可。
|
||||
|
||||
3. 在策略所在目录,运行:
|
||||
|
||||
cythonize -i demo_strategy.py
|
||||
|
||||
编译完成后,Demo文件夹下会多出2个新的文件,其中就有已加密的策略文件demo_strategy.cp37-win_amd64.so
|
||||
|
||||
改名=> demo_strategy.so
|
||||
|
||||
放置 demo_strategy.so 到centos/ubuntu 生产环境的 strateies目录下。
|
1261
vnpy/app/cta_stock/template.py
Normal file
1261
vnpy/app/cta_stock/template.py
Normal file
File diff suppressed because it is too large
Load Diff
1
vnpy/app/cta_stock/ui/__init__.py
Normal file
1
vnpy/app/cta_stock/ui/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .widget import CtaManager
|
BIN
vnpy/app/cta_stock/ui/cta.ico
Normal file
BIN
vnpy/app/cta_stock/ui/cta.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
464
vnpy/app/cta_stock/ui/widget.py
Normal file
464
vnpy/app/cta_stock/ui/widget.py
Normal file
@ -0,0 +1,464 @@
|
||||
from vnpy.event import Event, EventEngine
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
|
||||
from vnpy.trader.ui.widget import (
|
||||
BaseCell,
|
||||
EnumCell,
|
||||
MsgCell,
|
||||
TimeCell,
|
||||
BaseMonitor
|
||||
)
|
||||
from ..base import (
|
||||
APP_NAME,
|
||||
EVENT_CTA_LOG,
|
||||
EVENT_CTA_STOPORDER,
|
||||
EVENT_CTA_STRATEGY
|
||||
)
|
||||
from ..engine import CtaEngine
|
||||
|
||||
|
||||
class CtaManager(QtWidgets.QWidget):
|
||||
""""""
|
||||
|
||||
signal_log = QtCore.pyqtSignal(Event)
|
||||
signal_strategy = QtCore.pyqtSignal(Event)
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
super(CtaManager, self).__init__()
|
||||
|
||||
self.main_engine = main_engine
|
||||
self.event_engine = event_engine
|
||||
self.cta_engine = main_engine.get_engine(APP_NAME)
|
||||
|
||||
self.managers = {}
|
||||
|
||||
self.init_ui()
|
||||
self.register_event()
|
||||
self.cta_engine.init_engine()
|
||||
self.update_class_combo()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
self.setWindowTitle("CTA策略")
|
||||
|
||||
# Create widgets
|
||||
self.class_combo = QtWidgets.QComboBox()
|
||||
|
||||
add_button = QtWidgets.QPushButton("添加策略")
|
||||
add_button.clicked.connect(self.add_strategy)
|
||||
|
||||
init_button = QtWidgets.QPushButton("全部初始化")
|
||||
init_button.clicked.connect(self.cta_engine.init_all_strategies)
|
||||
|
||||
start_button = QtWidgets.QPushButton("全部启动")
|
||||
start_button.clicked.connect(self.cta_engine.start_all_strategies)
|
||||
|
||||
stop_button = QtWidgets.QPushButton("全部停止")
|
||||
stop_button.clicked.connect(self.cta_engine.stop_all_strategies)
|
||||
|
||||
clear_button = QtWidgets.QPushButton("清空日志")
|
||||
clear_button.clicked.connect(self.clear_log)
|
||||
|
||||
self.scroll_layout = QtWidgets.QVBoxLayout()
|
||||
self.scroll_layout.addStretch()
|
||||
|
||||
scroll_widget = QtWidgets.QWidget()
|
||||
scroll_widget.setLayout(self.scroll_layout)
|
||||
|
||||
scroll_area = QtWidgets.QScrollArea()
|
||||
scroll_area.setWidgetResizable(True)
|
||||
scroll_area.setWidget(scroll_widget)
|
||||
|
||||
self.log_monitor = LogMonitor(self.main_engine, self.event_engine)
|
||||
|
||||
self.stop_order_monitor = StopOrderMonitor(
|
||||
self.main_engine, self.event_engine
|
||||
)
|
||||
|
||||
# Set layout
|
||||
hbox1 = QtWidgets.QHBoxLayout()
|
||||
hbox1.addWidget(self.class_combo)
|
||||
hbox1.addWidget(add_button)
|
||||
hbox1.addStretch()
|
||||
hbox1.addWidget(init_button)
|
||||
hbox1.addWidget(start_button)
|
||||
hbox1.addWidget(stop_button)
|
||||
hbox1.addWidget(clear_button)
|
||||
|
||||
grid = QtWidgets.QGridLayout()
|
||||
grid.addWidget(scroll_area, 0, 0, 2, 1)
|
||||
grid.addWidget(self.stop_order_monitor, 0, 1)
|
||||
grid.addWidget(self.log_monitor, 1, 1)
|
||||
|
||||
vbox = QtWidgets.QVBoxLayout()
|
||||
vbox.addLayout(hbox1)
|
||||
vbox.addLayout(grid)
|
||||
|
||||
self.setLayout(vbox)
|
||||
|
||||
def update_class_combo(self):
|
||||
""""""
|
||||
self.class_combo.addItems(
|
||||
self.cta_engine.get_all_strategy_class_names()
|
||||
)
|
||||
|
||||
def register_event(self):
|
||||
""""""
|
||||
self.signal_strategy.connect(self.process_strategy_event)
|
||||
|
||||
self.event_engine.register(
|
||||
EVENT_CTA_STRATEGY, self.signal_strategy.emit
|
||||
)
|
||||
|
||||
def process_strategy_event(self, event):
|
||||
"""
|
||||
Update strategy status onto its monitor.
|
||||
"""
|
||||
data = event.data
|
||||
strategy_name = data["strategy_name"]
|
||||
|
||||
if strategy_name in self.managers:
|
||||
manager = self.managers[strategy_name]
|
||||
manager.update_data(data)
|
||||
else:
|
||||
manager = StrategyManager(self, self.cta_engine, data)
|
||||
self.scroll_layout.insertWidget(0, manager)
|
||||
self.managers[strategy_name] = manager
|
||||
|
||||
def remove_strategy(self, strategy_name):
|
||||
""""""
|
||||
manager = self.managers.pop(strategy_name)
|
||||
manager.deleteLater()
|
||||
|
||||
def add_strategy(self):
|
||||
""""""
|
||||
class_name = str(self.class_combo.currentText())
|
||||
if not class_name:
|
||||
return
|
||||
|
||||
parameters = self.cta_engine.get_strategy_class_parameters(class_name)
|
||||
editor = SettingEditor(parameters, class_name=class_name)
|
||||
n = editor.exec_()
|
||||
|
||||
if n == editor.Accepted:
|
||||
setting = editor.get_setting()
|
||||
vt_symbol = setting.pop("vt_symbol")
|
||||
strategy_name = setting.pop("strategy_name")
|
||||
auto_init = setting.pop("auto_init", False)
|
||||
auto_start = setting.pop("auto_start", False)
|
||||
self.cta_engine.add_strategy(
|
||||
class_name, strategy_name, vt_symbol, setting, auto_init, auto_start
|
||||
)
|
||||
|
||||
def clear_log(self):
|
||||
""""""
|
||||
self.log_monitor.setRowCount(0)
|
||||
|
||||
def show(self):
|
||||
""""""
|
||||
self.showMaximized()
|
||||
|
||||
|
||||
class StrategyManager(QtWidgets.QFrame):
|
||||
"""
|
||||
Manager for a strategy
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict
|
||||
):
|
||||
""""""
|
||||
super(StrategyManager, self).__init__()
|
||||
|
||||
self.cta_manager = cta_manager
|
||||
self.cta_engine = cta_engine
|
||||
|
||||
self.strategy_name = data["strategy_name"]
|
||||
self._data = data
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
self.setFixedHeight(300)
|
||||
self.setFrameShape(self.Box)
|
||||
self.setLineWidth(1)
|
||||
|
||||
init_button = QtWidgets.QPushButton("初始化")
|
||||
init_button.clicked.connect(self.init_strategy)
|
||||
|
||||
start_button = QtWidgets.QPushButton("启动")
|
||||
start_button.clicked.connect(self.start_strategy)
|
||||
|
||||
stop_button = QtWidgets.QPushButton("停止")
|
||||
stop_button.clicked.connect(self.stop_strategy)
|
||||
|
||||
edit_button = QtWidgets.QPushButton("编辑")
|
||||
edit_button.clicked.connect(self.edit_strategy)
|
||||
|
||||
remove_button = QtWidgets.QPushButton("移除")
|
||||
remove_button.clicked.connect(self.remove_strategy)
|
||||
|
||||
reload_button = QtWidgets.QPushButton("重载")
|
||||
reload_button.clicked.connect(self.reload_strategy)
|
||||
|
||||
save_button = QtWidgets.QPushButton("保存")
|
||||
save_button.clicked.connect(self.save_strategy)
|
||||
|
||||
strategy_name = self._data["strategy_name"]
|
||||
vt_symbol = self._data["vt_symbol"]
|
||||
class_name = self._data["class_name"]
|
||||
author = self._data["author"]
|
||||
|
||||
label_text = (
|
||||
f"{strategy_name} - {vt_symbol} ({class_name} by {author})"
|
||||
)
|
||||
label = QtWidgets.QLabel(label_text)
|
||||
label.setAlignment(QtCore.Qt.AlignCenter)
|
||||
|
||||
self.parameters_monitor = DataMonitor(self._data["parameters"])
|
||||
self.variables_monitor = DataMonitor(self._data["variables"])
|
||||
|
||||
hbox = QtWidgets.QHBoxLayout()
|
||||
hbox.addWidget(init_button)
|
||||
hbox.addWidget(start_button)
|
||||
hbox.addWidget(stop_button)
|
||||
hbox.addWidget(edit_button)
|
||||
hbox.addWidget(remove_button)
|
||||
hbox.addWidget(reload_button)
|
||||
hbox.addWidget(save_button)
|
||||
|
||||
vbox = QtWidgets.QVBoxLayout()
|
||||
vbox.addWidget(label)
|
||||
vbox.addLayout(hbox)
|
||||
vbox.addWidget(self.parameters_monitor)
|
||||
vbox.addWidget(self.variables_monitor)
|
||||
self.setLayout(vbox)
|
||||
|
||||
def update_data(self, data: dict):
|
||||
""""""
|
||||
self._data = data
|
||||
|
||||
self.parameters_monitor.update_data(data["parameters"])
|
||||
self.variables_monitor.update_data(data["variables"])
|
||||
|
||||
def init_strategy(self):
|
||||
""""""
|
||||
self.cta_engine.init_strategy(self.strategy_name)
|
||||
|
||||
def start_strategy(self):
|
||||
""""""
|
||||
self.cta_engine.start_strategy(self.strategy_name)
|
||||
|
||||
def stop_strategy(self):
|
||||
""""""
|
||||
self.cta_engine.stop_strategy(self.strategy_name)
|
||||
|
||||
def edit_strategy(self):
|
||||
""""""
|
||||
strategy_name = self._data["strategy_name"]
|
||||
|
||||
parameters = self.cta_engine.get_strategy_parameters(strategy_name)
|
||||
editor = SettingEditor(parameters, strategy_name=strategy_name)
|
||||
n = editor.exec_()
|
||||
|
||||
if n == editor.Accepted:
|
||||
setting = editor.get_setting()
|
||||
self.cta_engine.edit_strategy(strategy_name, setting)
|
||||
|
||||
def remove_strategy(self):
|
||||
""""""
|
||||
result = self.cta_engine.remove_strategy(self.strategy_name)
|
||||
|
||||
# Only remove strategy gui manager if it has been removed from engine
|
||||
if result:
|
||||
self.cta_manager.remove_strategy(self.strategy_name)
|
||||
|
||||
def reload_strategy(self):
|
||||
"""重新加载策略"""
|
||||
self.cta_engine.reload_strategy(self.strategy_name)
|
||||
|
||||
def save_strategy(self):
|
||||
self.cta_engine.save_strategy_data(self.strategy_name)
|
||||
|
||||
|
||||
class DataMonitor(QtWidgets.QTableWidget):
|
||||
"""
|
||||
Table monitor for parameters and variables.
|
||||
"""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
""""""
|
||||
super(DataMonitor, self).__init__()
|
||||
|
||||
self._data = data
|
||||
self.cells = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
labels = list(self._data.keys())
|
||||
self.setColumnCount(len(labels))
|
||||
self.setHorizontalHeaderLabels(labels)
|
||||
|
||||
self.setRowCount(1)
|
||||
self.verticalHeader().setSectionResizeMode(
|
||||
QtWidgets.QHeaderView.Stretch
|
||||
)
|
||||
self.verticalHeader().setVisible(False)
|
||||
self.setEditTriggers(self.NoEditTriggers)
|
||||
|
||||
for column, name in enumerate(self._data.keys()):
|
||||
value = self._data[name]
|
||||
|
||||
cell = QtWidgets.QTableWidgetItem(str(value))
|
||||
cell.setTextAlignment(QtCore.Qt.AlignCenter)
|
||||
|
||||
self.setItem(0, column, cell)
|
||||
self.cells[name] = cell
|
||||
|
||||
def update_data(self, data: dict):
|
||||
""""""
|
||||
for name, value in data.items():
|
||||
cell = self.cells[name]
|
||||
cell.setText(str(value))
|
||||
|
||||
|
||||
class StopOrderMonitor(BaseMonitor):
|
||||
"""
|
||||
Monitor for local stop order.
|
||||
"""
|
||||
|
||||
event_type = EVENT_CTA_STOPORDER
|
||||
data_key = "stop_orderid"
|
||||
sorting = True
|
||||
|
||||
headers = {
|
||||
"stop_orderid": {
|
||||
"display": "停止委托号",
|
||||
"cell": BaseCell,
|
||||
"update": False,
|
||||
},
|
||||
"vt_orderids": {"display": "限价委托号", "cell": BaseCell, "update": True},
|
||||
"vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False},
|
||||
"direction": {"display": "方向", "cell": EnumCell, "update": False},
|
||||
"offset": {"display": "开平", "cell": EnumCell, "update": False},
|
||||
"price": {"display": "价格", "cell": BaseCell, "update": False},
|
||||
"volume": {"display": "数量", "cell": BaseCell, "update": False},
|
||||
"status": {"display": "状态", "cell": EnumCell, "update": True},
|
||||
"lock": {"display": "锁仓", "cell": BaseCell, "update": False},
|
||||
"strategy_name": {"display": "策略名", "cell": BaseCell, "update": False},
|
||||
}
|
||||
|
||||
|
||||
class LogMonitor(BaseMonitor):
|
||||
"""
|
||||
Monitor for log data.
|
||||
"""
|
||||
|
||||
event_type = EVENT_CTA_LOG
|
||||
data_key = ""
|
||||
sorting = False
|
||||
|
||||
headers = {
|
||||
"time": {"display": "时间", "cell": TimeCell, "update": False},
|
||||
"msg": {"display": "信息", "cell": MsgCell, "update": False},
|
||||
}
|
||||
|
||||
def init_ui(self):
|
||||
"""
|
||||
Stretch last column.
|
||||
"""
|
||||
super(LogMonitor, self).init_ui()
|
||||
|
||||
self.horizontalHeader().setSectionResizeMode(
|
||||
1, QtWidgets.QHeaderView.Stretch
|
||||
)
|
||||
|
||||
def insert_new_row(self, data):
|
||||
"""
|
||||
Insert a new row at the top of table.
|
||||
"""
|
||||
super(LogMonitor, self).insert_new_row(data)
|
||||
self.resizeRowToContents(0)
|
||||
|
||||
|
||||
class SettingEditor(QtWidgets.QDialog):
|
||||
"""
|
||||
For creating new strategy and editing strategy parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, parameters: dict, strategy_name: str = "", class_name: str = ""
|
||||
):
|
||||
""""""
|
||||
super(SettingEditor, self).__init__()
|
||||
|
||||
self.parameters = parameters
|
||||
self.strategy_name = strategy_name
|
||||
self.class_name = class_name
|
||||
|
||||
self.edits = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
form = QtWidgets.QFormLayout()
|
||||
|
||||
# Add vt_symbol and name edit if add new strategy
|
||||
if self.class_name:
|
||||
self.setWindowTitle(f"添加策略:{self.class_name}")
|
||||
button_text = "添加"
|
||||
parameters = {"strategy_name": "", "vt_symbol": "", "auto_init": True, "auto_start": True}
|
||||
parameters.update(self.parameters)
|
||||
|
||||
else:
|
||||
self.setWindowTitle(f"参数编辑:{self.strategy_name}")
|
||||
button_text = "确定"
|
||||
parameters = self.parameters
|
||||
|
||||
for name, value in parameters.items():
|
||||
type_ = type(value)
|
||||
|
||||
edit = QtWidgets.QLineEdit(str(value))
|
||||
if type_ is int:
|
||||
validator = QtGui.QIntValidator()
|
||||
edit.setValidator(validator)
|
||||
elif type_ is float:
|
||||
validator = QtGui.QDoubleValidator()
|
||||
edit.setValidator(validator)
|
||||
|
||||
form.addRow(f"{name} {type_}", edit)
|
||||
|
||||
self.edits[name] = (edit, type_)
|
||||
|
||||
button = QtWidgets.QPushButton(button_text)
|
||||
button.clicked.connect(self.accept)
|
||||
form.addRow(button)
|
||||
|
||||
self.setLayout(form)
|
||||
|
||||
def get_setting(self):
|
||||
""""""
|
||||
setting = {}
|
||||
|
||||
if self.class_name:
|
||||
setting["class_name"] = self.class_name
|
||||
|
||||
for name, tp in self.edits.items():
|
||||
edit, type_ = tp
|
||||
value_text = edit.text()
|
||||
|
||||
if type_ == bool:
|
||||
if value_text == "True":
|
||||
value = True
|
||||
else:
|
||||
value = False
|
||||
else:
|
||||
value = type_(value_text)
|
||||
|
||||
setting[name] = value
|
||||
|
||||
return setting
|
Loading…
Reference in New Issue
Block a user