This commit is contained in:
msincenselee 2019-01-20 23:31:22 +08:00
parent ec89edebe5
commit 6a0f8476e1
14 changed files with 1131 additions and 34 deletions

View File

@ -1,4 +1,4 @@
# “当你想放弃时,想想你为什么开始。”
# “当你想放弃时,想想你为什么开始。埃隆·马斯克
###Fork版本主要改进如下
- 1、增加CtaLineBarCtaPositionCtaPolicy,UtlSinaClient等基础组件

View File

@ -0,0 +1,564 @@
# encoding: UTF-8
"""
批量测试相关方法
# 华富资产 李来佳
"""
import sys, os, platform, gc, copy,multiprocessing
from datetime import datetime
from time import sleep
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.append(vnpy_root)
import numpy as np
import pandas as pd
import talib as ta # 科学计算库
import statsmodels.api as sm # 统计库
import matplotlib
import matplotlib.pyplot as plt
import math # 数学计算相关
matplotlib.rcParams['figure.figsize'] = (20.0, 10.0)
import traceback
from vnpy.trader.setup_logger import *
from vnpy.trader.app.ctaStrategy.ctaBacktesting import BacktestingEngine, OptimizationSetting, MINUTE_DB_NAME
# 合并回测结果
def combine_results(results_list):
result_df = None
# 判断结果集是否有数据
if len(results_list) < 1:
print('no records')
return None
effected_results = 0
for dict in results_list:
# 测试项目
test_item = dict['test_item']
# 测试csv文件
file_name = dict['result_file']
if not os.path.isfile(file_name):
continue
effected_results += 1
# 读取测试文件
df = pd.read_csv(file_name)
# 修正索引为时间日期索引
df = df.set_index(pd.DatetimeIndex(df['date']))
if result_df is None:
# 首个测试结果,将净值字段设置为周期
result_df = df['rate'].to_frame(name=test_item)
# 汇总净值
result_df['rate'] = result_df[test_item]
else:
# 增加新的测试结果数据
result_df[test_item] = df['rate']
# 汇总净值
result_df['rate'] = result_df['rate'] + result_df[test_item]
# 释放内存
l = [df]
del df
del l
if effected_results > 0:
# 净值平均
result_df['avg_rate'] = result_df['rate'] / effected_results
# 组合净值累加仍然按照1个策略的总资金累加各策略的收益
result_df['group_rate'] = result_df['rate'] - effected_results + 1
# 删除累加的rate
result_df.drop('rate', axis=1, inplace=True)
return result_df
# 计算最大回撤,单日最大回撤
def calculate_result(result_df, rate_column):
max_rate = 0
max_loss = 0
max_rate_date = None
max_loss_date = None
max_loss_info = '-'
for idx in result_df.index:
# 当前日净值
cur_rate = result_df[rate_column].loc[idx]
if cur_rate > max_rate:
max_rate = cur_rate
max_rate_date = idx.strftime('%Y-%m-%d')
cur_loss = max_rate - cur_rate
if cur_loss > max_loss:
max_loss = cur_loss
max_loss_date = idx.strftime('%Y-%m-%d')
max_loss_percent = max_loss / max_rate
max_loss_info = u'{} from {} to {},rate {}=>{},max loss rate {}'.format(rate_column, max_rate_date,
max_loss_date, max_rate, cur_rate,
max_loss_percent)
return max_loss_info
def single_strategy_test(test_settings):
"""
根据设置参数执行回测品种
:param strategyClass: 策略类
:param test_settings: 策略参数设置
test_settings['bar_file']: 回测的bar csv文件路径
test_settings['bar_interval']: 回测的Bar csv周期
test_settings['report_file']: 净值输出报告的保存路径
:return:
"""
from vnpy.trader.vtEvent import EventEngine2
eventEngine = EventEngine2()
eventEngine.start()
# 创建回测引擎
engine = BacktestingEngine(eventEngine=eventEngine)
# 设置回测的策略类
engine.setStrategyName(test_settings['name'])
# 创建日志
if 'debug' in test_settings:
engine.createLogger(debug=test_settings['debug'])
else:
engine.createLogger()
# 设置引擎的回测模式
if test_settings['mode'] == 'tick':
engine.setBacktestingMode(engine.TICK_MODE)
else:
engine.setBacktestingMode(engine.BAR_MODE)
strategy_settings = copy.copy(test_settings)
# 设置回测的策略类
if 'filenamePrefix' in strategy_settings:
engine.setStrategyName(strategy_settings["filenamePrefix"])
else:
engine.setStrategyName(test_settings['name'])
if 'is_7x24' in test_settings:
engine.is_7x24 = test_settings['is_7x24']
# del strategy_settings['size']
#del strategy_settings['margin_rate']
del strategy_settings['initCapital']
if 'report_file' in strategy_settings:
del strategy_settings['report_file']
# 设置回测用的数据起始日期
if 'start_date' in strategy_settings:
engine.setStartDate(test_settings['start_date'], initDays = strategy_settings.get('initDays', 10))
del strategy_settings['start_date']
else:
engine.setStartDate('20110101',initDays = strategy_settings.get('initDays',10))
# 设置回测用的数据结束日期
if 'end_date' in test_settings:
engine.setEndDate(test_settings['end_date'])
del strategy_settings['end_date']
else:
engine.setEndDate('20171201')
# engine.connectMysql()
engine.setDatabase(dbName=MINUTE_DB_NAME, symbol=test_settings['vtSymbol'])
# 设置产品相关参数
if 'slippage' in test_settings and test_settings['slippage'] > 0:
engine.setSlippage(test_settings['slippage'])
else:
engine.setSlippage(0)
engine.setRate(test_settings['rate'] if 'rate' in test_settings else float(0.0001)) # 万1
engine.setSize(test_settings['size']) # 合约大小
engine.setMinDiff(test_settings['minDiff']) # 合约价格最小跳动
engine.setMarginRate(test_settings['margin_rate']) # 合约保证金率
if 'fixCommission' in test_settings:
engine.fixCommission = float(test_settings['fixCommission']) # 固定交易费用(每次开平仓收费)
# 删除本地json文件
data_path = os.path.abspath(os.path.join(os.getcwd(),'data'))
if not os.path.exists(data_path):
os.mkdir(data_path)
logs_path = os.path.abspath(os.path.join(os.getcwd(), 'logs'))
if not os.path.exists(logs_path):
os.mkdir(logs_path)
up_grid_json_file = os.path.abspath(os.path.join(data_path,'{0}_upGrids.json'.format(test_settings['name'])))
dn_grid_json_file = os.path.abspath(os.path.join(data_path,'{0}_dnGrids.json'.format(test_settings['name'])))
grid_json_file = os.path.abspath(os.path.join(data_path,'{0}_Grids.json'.format(test_settings['name'])))
policy_json_file = os.path.abspath(os.path.join(data_path, '{0}_Policy.json'.format(test_settings['name'])))
if os.path.isfile(up_grid_json_file):
print(u'{0} exist,remove it'.format(up_grid_json_file))
try:
os.remove(up_grid_json_file)
except Exception as ex:
print(u'{0}{1}'.format(Exception, ex))
return False
if os.path.isfile(dn_grid_json_file):
print(u'{0}exist,remove it'.format(dn_grid_json_file))
try:
os.remove(dn_grid_json_file)
except Exception as ex:
print(u'{0}{1}'.format(Exception, ex))
return False
if os.path.isfile(grid_json_file):
print(u'{0}exist,remove it'.format(grid_json_file))
try:
os.remove(grid_json_file)
except Exception as ex:
print(u'{0}{1}'.format(Exception, ex))
return False
if os.path.isfile(policy_json_file):
print(u'{0}exist,remove it'.format(policy_json_file))
try:
os.remove(policy_json_file)
except Exception as ex:
print(u'{0}{1}'.format(Exception, ex))
return False
# 在引擎中创建策略对象
print(u'run {} using:{}'.format(test_settings['strategy'],strategy_settings))
engine.initStrategy(test_settings['strategy'], setting=strategy_settings)
# 设置每日净值的报告文档存储路径
daily_report_file = 'DailyList.csv' if 'report_file' not in test_settings else test_settings['report_file']
engine.setDailyReportName(daily_report_file)
# 使用简单复利模式计算
engine.usageCompounding = False # True时只针对FINAL_MODE有效
# 启用实时计算净值模式REALTIME_MODE / FINAL_MODE 回测结束时统一计算模式
engine.calculateMode = engine.REALTIME_MODE
engine.capital = test_settings['initCapital'] # 设置期初资金
engine.initCapital = test_settings['initCapital'] # 设置期初资金
engine.avaliable = test_settings['initCapital'] # 设置期初资金
engine.netCapital = test_settings['initCapital']
engine.maxCapital = test_settings['initCapital'] # 设置期初资金
engine.maxNetCapital = test_settings['initCapital'] # 设置期初资金
engine.percentLimit = test_settings['percentLimit'] # 设置资金使用上限比例(%)
engine.barTimeInterval = 60 * test_settings['bar_interval'] # 回测文件中bar的周期秒数用于csv文件自动减时间
try:
# 前置动作(无参数)
pre_functions = test_settings.get('pre_functions',[])
for fun_name in pre_functions:
try:
if not isinstance(fun_name,str):
continue
if hasattr(engine.strategy,fun_name):
fun = getattr(engine.strategy,fun_name)
if fun is not None:
fun()
except Exception as ex:
print(u'调用前置动作异常:{},{}'.format(str(ex),traceback.format_exc()),file=sys.stderr)
# 开始跑回测
if 'bar_file' in test_settings:
engine.runBackTestingWithBarFile(test_settings['bar_file'])
else:
engine.runBackTestingWithDataSource()
print('{}finished loop bars'.format(test_settings['name']))
# 显示回测结果
engine.showBacktestingResult()
# 保存策略得内部数据
engine.saveStrategyData()
print('{} finished'.format(test_settings['name']))
return True
except Exception as ex:
print(u'single_strategy_test exception:{}'.format(str(ex)))
traceback.print_exc()
return False
def multi_period_test(gid, group_setting):
"""
多周期回测品种组合
1对group_settings进行分解分解出各个运行周期的参数设置
2逐一周期运行测试
3添加回测结果
:param gid: 测试组名
:param group_settingdict包含参数多周期清单
如果多周期则对每一周期执行回测并汇总结果
:return 回测的每日净值统计文件
"""
# 回测的分钟周期
minutes_interval_list = group_setting['minute_list'] # 回测分钟队列3510等
strategyClass = group_setting['strategy']
# 测试批次时间
test_dt = datetime.now().strftime('%Y%m%d_%H%M')
# 回测结果队列,对应测试分钟队列
daily_results = []
return_results = []
# 启动多进程
pool = multiprocessing.Pool(multiprocessing.cpu_count())
l = []
# 逐一分钟级别进行回测
for m_i in minutes_interval_list:
settings = copy.copy(group_setting)
del settings['minute_list']
settings['name'] = '{}_{}_{}_M{}'.format(gid, strategyClass.className, group_setting['symbol'], m_i)
# 资金占用比例,根据组合内周期数量,进行平均分配
settings['percentLimit'] = group_setting['percentLimit']/ len(minutes_interval_list)
settings['vtSymbol'] = group_setting['symbol']
settings['MinInterval'] = m_i
settings['mode'] = 'bar'
settings['backtesting'] = True
settings['bar_interval'] = group_setting['bar_interval'] if 'bar_interval' in group_setting else 1
settings['strategy'] = strategyClass
# 回测报告文件保存路径: 组合,测试实例名称,测试时间
daily_report_file = os.path.abspath(os.path.join(group_setting['report_folder'], u'{}_daily_{}.csv'
.format(settings['name'], test_dt)))
settings['report_file'] = daily_report_file
#if rt:
# 回测报告集登记
daily_results.append({'test_item': 'M{}'.format(m_i), 'result_file': daily_report_file})
l.append(pool.apply_async(single_strategy_test, (settings,)))
#rt = single_strategy_test(test_settings=settings)
# 执行内存回收
gc.collect()
sleep(10)
result_list = [res.get() for res in l]
for idx, rt in enumerate(result_list):
if rt:
return_results.append(daily_results[idx])
pool.close()
pool.join()
return return_results
def run_multiperiod_test(gid, group_settings):
"""
运行多周期的组合测试
:param gid: 组合名称
:param group_settings:
:return:
"""
m = '_'.join(str(e) for e in group_settings['minute_list'])
if 'report_folder' in group_settings:
final_file = os.path.abspath(os.path.join(group_settings['report_folder'], '{}_{}_Report_{}.csv'.format(gid, group_settings['symbol'], m)))
else:
# 报告所在目录
report_folder = os.path.abspath(os.path.join(os.getcwd(), 'logs', gid))
if not os.path.exists(report_folder):
os.mkdir(report_folder)
# 汇总报告文件
final_file = os.path.abspath(
os.path.join(report_folder, '{}_{}_Report_{}.csv'.format(gid, group_settings['symbol'], m)))
group_settings['report_folder'] = report_folder
if not os.path.exists(group_settings['report_folder']):
os.makedirs(group_settings['report_folder'])
# 运行回测方法,统计结果
daily_results = multi_period_test(gid, group_settings)
# 统计结果
backtest_df = combine_results(daily_results)
# 保存汇总记录到文件
backtest_df.to_csv(final_file)
# 显示资金曲线汇总
fig, ax1 = plt.subplots()
period_columns = [item['test_item'] for item in daily_results]
fig.patch.set_facecolor('white')
ax1.plot(backtest_df[period_columns])
ax1.legend()
# 释放内存
l = [backtest_df]
del backtest_df
del l
# 释放内存
gc.collect()
print( 'finished run_multiparameter_test')
def multi_parameter_test(gid, settings_list):
"""
不同参数的回测组合
:param gid: 组合ID
:param settings_list: 参数列表
:return:
"""
# 回测结果队列,对应测试参数队列
daily_results = []
return_results = []
# 每个测试的参数名称
para_list = [i['paraName'] for i in settings_list]
if len(para_list) == 0:
return daily_results
# 启动多进程
pool = multiprocessing.Pool(multiprocessing.cpu_count())
#pool = multiprocessing.Pool(2)
l = []
print('multi_parameter_test,total:{}'.format(len(settings_list)))
for idx, strategy_settings in enumerate(settings_list):
settings = copy.copy(strategy_settings)
# 测试时间
test_dt = datetime.now().strftime('%Y%m%d_%H%M')
del settings['log_file']
if 'minute_list' in settings:
del settings['minute_list']
settings['vtSymbol'] = settings['symbol']
settings['mode'] = 'bar'
settings['backtesting'] = True
if 'bar_interval' not in settings:
settings['bar_interval'] = 1
# 回测报告文件保存路径: 组合,测试实例名称,测试时间
daily_report_file = os.path.abspath(os.path.join(settings['report_folder'], u'{}_daily_{}.csv'.format(settings['name'], test_dt)))
settings['report_file'] = daily_report_file
l.append(pool.apply_async(single_strategy_test, (settings,)))
# 回测报告集登记
daily_results.append({'test_item': settings['paraName'], 'result_file': daily_report_file})
# 执行内存回收
gc.collect()
sleep(10)
result_list = [res.get() for res in l]
# 返回结果是正确的,才添加到返回列表中
for idx, rt in enumerate(result_list):
if rt:
return_results.append(daily_results[idx])
pool.close()
pool.join()
return return_results
def run_multiparameter_test(gid, settings_list):
"""
多策略组测试
:param gid:
:param settings_list:
:return: """
if len(settings_list) == 0:
raise ReferenceError('Zero settings')
first_setting = settings_list[0]
paraNames = '_'.join(i['paraName'] for i in settings_list)
if 'report_folder' in first_setting:
report_folder = first_setting['report_folder']
final_file = os.path.abspath(os.path.join(first_setting['report_folder'],
'{}_{}_Report_{}.csv'.format(gid, first_setting['symbol'],
paraNames)))
else:
# 报告所在目录
report_folder = os.path.abspath(os.path.join(os.getcwd(), 'logs'))
final_file = os.path.abspath(
os.path.join(report_folder, '{}_{}_Report_{}.csv'.format(gid, first_setting['symbol'], paraNames)))
if not os.path.exists(report_folder):
os.makedirs(report_folder)
for settings in settings_list:
settings['report_folder'] = report_folder
# 运行回测方法,统计结果
daily_results = multi_parameter_test(gid, settings_list)
# 统计结果
#backtest_df = combine_results(daily_results)
# 保存汇总记录到文件
#backtest_df.to_csv(final_file)
# 显示资金曲线汇总
##fig, ax = plt.subplots()
#period_columns = [item['test_item'] for item in daily_results]
#fig.patch.set_facecolor('white')
#for column in period_columns:
# ax.plot(backtest_df[column], label=column)
#ax.legend()
#title = u'{} {}'.format(gid, paraNames)
#plt.title(title)
#fig.savefig(u'{}/rate.png'.format(report_folder))
## 释放内存
gc.collect()
print( 'finished run_multiparameter_test')
def single_func(para):
logger=setup_logger('MyLog', name='my{}'.format(para))
if para > 5:
print( u'more than 5')
logger.info('More than 5')
return True
else:
print ('less')
logger.info('Less than 5')
return False
def multi_func():
import logging
# 启动多进程
pool = multiprocessing.Pool(multiprocessing.cpu_count())
logger = setup_logger('MyLog')
logger.info('main process')
l = []
for i in range(0,10):
l.append(pool.apply_async(single_func,(i,)))
results = [res.get() for res in l]
pool.close()
pool.join()

View File

@ -305,7 +305,7 @@ def start():
# 往任务表增加定时计划
operate_crontab("add")
# 执行启动
_start()
#_start()
print(u'启动{}服务执行完毕'.format(base_path))
def _stop():

View File

@ -0,0 +1,31 @@
[
{
"name": "Strategy_TripleMa_01_RB_Min5",
"comment": "螺纹05合约三均线策略",
"auto_init": true,
"auto_start": true,
"className": "Strategy_TripleMa_v01",
"inputSS": 1,
"minDiff": 1,
"mode": "tick",
"shortSymbol": "RB",
"strategy_module": "Strategy_TripleMa_v01",
"symbol": "RB1905",
"vtSymbol": "rb1905"
},
{
"name": "Strategy_TripleMa_02_HC_Min5",
"comment": "热卷05合约三均线策略",
"auto_init": true,
"auto_start": true,
"className": "Strategy_TripleMa_v02",
"inputSS": 1,
"minDiff": 1,
"mode": "tick",
"shortSymbol": "HC",
"strategy_module": "Strategy_TripleMa_v02",
"symbol": "HC1905",
"vtSymbol": "hc1905"
}
]

View File

@ -23,9 +23,9 @@ from vnpy.trader.uiMainWindow import *
# 加载底层接口
from vnpy.trader.gateway import ctpGateway
# 初始化的接口模块,以及其指定的名称,CTP是模块value是该模块下的多个连接配置文件,如 CTP_JR_connect.json 'CTP_Prod', 'CTP_JR', , 'CTP_JK', 'CTP_02'
init_gateway_names = {'CTP': ['CTP','CTP_001','CTP_002']}
init_gateway_names = {'CTP': ['CTP','CTP01','CTP02']}
from vnpy.trader.app import (ctaStrategy, riskManager, spreadTrading) #,dataRecorder
from vnpy.trader.app import (ctaStrategy, riskManager, spreadTrading,dataRecorder) #,
# 文件路径名
path = os.path.abspath(os.path.dirname(__file__))
@ -60,7 +60,7 @@ def main():
mainEngine.addApp(ctaStrategy)
mainEngine.addApp(riskManager)
mainEngine.addApp(spreadTrading)
#mainEngine.addApp(dataRecorder)
mainEngine.addApp(dataRecorder)
mainWindow = MainWindow(mainEngine, ee)
mainWindow.showMaximized()

View File

@ -1432,6 +1432,7 @@ class CtaEngine(object):
return
for expired_pos in expired_pos_list:
self.writeCtaLog(u'执行仓位处理:{}'.format(expired_pos))
if expired_pos['volume'] == 0:
self.writeCtaError(u'clear_dispatch_pospos 为空:{},删除'.format(expired_pos))
flt = {'_id': expired_pos['_id']}
@ -1509,7 +1510,7 @@ class CtaEngine(object):
else:
if curPos.longYd >= expired_pos['volume']:
sell_longYd = expired_pos['volume']
if curPos.longYd == 0:
elif curPos.longYd == 0:
sell_longToday = expired_pos['volume']
else:
sell_longYd = curPos.longYd

View File

@ -253,9 +253,13 @@ class TurtlePolicy(CtaPolicy):
def __init__(self, strategy):
super(TurtlePolicy, self).__init__(strategy)
self.tns_open_price = 0 # 首次开仓价格
self.last_open_price = 0 # 最后一次加仓价格
self.stop_price = 0 # 止损价
self.allow_add_pos = False # 是否加仓
self.add_pos_on_pips = EMPTY_INT # 价格超过开仓价多少点时加仓
self.tns_open_price = 0 # 首次开仓价格
self.last_open_price = 0 # 最后一次加仓价格
self.stop_price = 0 # 止损价
self.exit_on_last_rtn_pips = 0 # 最高价/最低价回撤多少跳动
self.high_price_in_long = 0 # 多趋势时,最高价
self.low_price_in_short = 0 # 空趋势时,最低价
self.last_under_open_price = 0 # 低于首次开仓价的补仓价格
@ -279,12 +283,15 @@ class TurtlePolicy(CtaPolicy):
j['tns_open_date'] = self.tns_open_date
j['tns_open_price'] = self.tns_open_price if self.tns_open_price is not None else 0
j['allow_add_pos'] = self.allow_add_pos
j['add_pos_on_pips'] = self.add_pos_on_pips
j['exit_on_last_rtn_pips'] = self.exit_on_last_rtn_pips
j['last_open_price'] = self.last_open_price if self.last_open_price is not None else 0
j['stop_price'] = self.stop_price if self.stop_price is not None else 0
j['high_price_in_long'] = self.high_price_in_long if self.high_price_in_long is not None else 0
j['low_price_in_short'] = self.low_price_in_short if self.low_price_in_short is not None else 0
j[
'add_pos_count_under_first_price'] = self.add_pos_count_under_first_price if self.add_pos_count_under_first_price is not None else 0
j['add_pos_count_under_first_price'] = self.add_pos_count_under_first_price if self.add_pos_count_under_first_price is not None else 0
j['last_under_open_price'] = self.last_under_open_price if self.last_under_open_price is not None else 0
j['max_pos'] = self.max_pos if self.max_pos is not None else 0
@ -412,6 +419,9 @@ class TurtlePolicy(CtaPolicy):
self.writeCtaError(u'解释last_risk_level异常:{}'.format(str(ex)))
self.last_risk_level = 0
self.allow_add_pos=json_data.get('allow_add_pos',False)
self.add_pos_on_pips = json_data.get('add_pos_on_pips',1)
self.exit_on_last_rtn_pips = json_data.get('exit_on_last_rtn_pips',0)
def clean(self):
"""
@ -432,6 +442,9 @@ class TurtlePolicy(CtaPolicy):
self.tns_has_opened = False
self.last_risk_level = 0
self.tns_count = 0
self.allow_add_pos = False
self.add_pos_on_pips = 1
self.exit_on_last_rtn_pips = 0
class TrendPolicy(CtaPolicy):
"""

View File

@ -14,6 +14,7 @@ SETTING_DB_NAME = 'VnTrader_Setting_Db'
TICK_DB_NAME = 'VnTrader_Tick_Db'
DAILY_DB_NAME = 'VnTrader_Daily_Db'
MINUTE_DB_NAME = 'VnTrader_1Min_Db'
RENKO_DB_NAME = 'Renko_Db'
# CTA引擎中涉及的数据类定义

View File

@ -14,12 +14,15 @@ from datetime import datetime, timedelta
from queue import Queue
from threading import Thread
from vnpy.trader.vtEvent import *
from vnpy.trader.vtObject import VtTickData
from vnpy.trader.vtGateway import VtSubscribeReq, VtLogData
from vnpy.trader.vtFunction import todayDate,getJsonPath
from vnpy.trader.vtFunction import todayDate,getJsonPath,getShortSymbol
from vnpy.trader.app.ctaStrategy.ctaRenkoBar import CtaRenkoBar
from .drBase import *
from vnpy.trader.setup_logger import setup_logger
from vnpy.trader.data_source import DataSource
########################################################################
class DrEngine(object):
"""数据记录引擎"""
@ -44,15 +47,39 @@ class DrEngine(object):
# K线对象字典
self.barDict = {}
# Renko K线对象字典
self.renkoDict = {}
# 负责执行数据库插入的单独线程相关
self.active = False # 工作状态
self.queue = Queue() # 队列
self.thread = Thread(target=self.run) # 线程
self.logger = None
self.createLogger()
# 载入设置,订阅行情
self.loadSetting()
def createLogger(self):
"""
创建日志记录
:return:
"""
currentFolder = os.path.abspath(os.path.join(os.getcwd(), 'logs'))
if os.path.isdir(currentFolder):
# 如果工作目录下存在data子目录就使用data子目录
path = currentFolder
else:
# 否则,使用缺省保存目录 vnpy/trader/app/ctaStrategy/data
path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'logs'))
if self.logger is None:
filename = os.path.abspath(os.path.join(path, 'drEngine'))
print(u'create logger:{}'.format(filename))
self.logger = setup_logger(filename=filename, name='drEngine', debug=True)
#----------------------------------------------------------------------
def loadSetting(self):
"""载入设置"""
@ -111,7 +138,43 @@ class DrEngine(object):
bar = DrBarData()
self.barDict[vtSymbol] = bar
if 'renko' in drSetting:
l = drSetting.get('renko')
req_set = set()
for setting in l:
# 获取合约合约短号renko的高度多少个跳
vtSymbol = setting.get('vtSymbol',None)
if vtSymbol is None:
continue
short_symbol = getShortSymbol(vtSymbol).upper()
height = setting.get('height',5)
minDiff = setting.get('minDiff',1)
# 获取vtSymbol的多个renkobar列表添加新的CtaRenkoBar
renko_list = self.renkoDict.get(vtSymbol,[])
bar_setting = {'name':'{}_{}'.format(vtSymbol,height),
'shortSymbol':short_symbol,
'vtSymbol':vtSymbol,
'minDiff':minDiff,
'height':minDiff*height}
renko_bar = CtaRenkoBar(strategy=None,onBarFunc=self.onRenkoBar,setting=bar_setting)
renko_list.append(renko_bar)
self.renkoDict.update({vtSymbol:renko_list})
req = VtSubscribeReq()
req.symbol = vtSymbol
req_set.add((req,setting.get('gateway',None)))
# 更新合约的历史数据
self.add_gap_ticks()
# 订阅行情
for req,gw in req_set:
self.mainEngine.subscribe(req,gw)
if 'active' in drSetting:
d = drSetting['active']
@ -125,7 +188,7 @@ class DrEngine(object):
# 注册事件监听
self.registerEvent()
#----------------------------------------------------------------------
# ----------------------------------------------------------------------
def procecssTickEvent(self, event):
"""处理行情推送"""
tick = event.dict_['data']
@ -137,7 +200,7 @@ class DrEngine(object):
for key in d.keys():
if key != 'datetime':
d[key] = tick.__getattribute__(key)
drTick.datetime = datetime.strptime(' '.join([tick.date, tick.time]), '%Y%m%d %H:%M:%S.%f')
drTick.datetime = datetime.strptime(' '.join([tick.date, tick.time]), '%Y-%m-%d %H:%M:%S.%f')
# 更新Tick数据
if vtSymbol in self.tickDict:
@ -190,6 +253,111 @@ class DrEngine(object):
bar.close = drTick.lastPrice
#----------------------------------------------------------------------
# 更新Renko数据
for renko_bar in self.renkoDict.get(vtSymbol,[]):
renko_bar.onTick(copy.copy(tick))
def add_gap_ticks(self):
"""
补充缺失的分时数据
:return:
"""
ds = DataSource()
for vtSymbol in self.renkoDict.keys():
renkobar_list = self.renkoDict.get(vtSymbol,[])
cache_bars = None
cache_start_date = None
cache_end_date = None
for renko_bar in renkobar_list:
# 通过mongo获取最新一个Renko bar的数据日期close时间
last_renko_dt = self.get_last_datetime(renko_bar.name)
# 根据日期+vtSymbol像datasource获取分钟数据以close价格转化为tick推送到renko_bar中
if last_renko_dt is not None:
start_date =last_renko_dt.strftime('%Y-%m-%d')
else:
start_date = (datetime.now() - timedelta(days=90)).strftime('%Y-%m-%d')
end_date = (datetime.now() + timedelta(days=5)).strftime('%Y-%m-%d')
self.writeDrLog(u'从datasource获取{}数据,开始日期:{}'.format(vtSymbol,start_date))
if cache_bars is None or cache_start_date!=start_date or cache_end_date!=end_date:
fields = ['open', 'close', 'high', 'low', 'volume', 'open_interest', 'limit_up', 'limit_down',
'trading_date']
cache_bars = ds.get_price(order_book_id=vtSymbol,start_date=start_date,end_date=end_date,
frequency='1m', fields=fields)
cache_start_date = start_date
cache_end_date = end_date
if cache_bars is not None:
total = len(cache_bars)
self.writeDrLog(u'一共获取{}{} 1分钟数据'.format(total,vtSymbol))
if cache_bars is not None:
self.writeDrLog(u'推送分时数据tick:{}到:{}'.format(vtSymbol,renko_bar.name))
for idx in cache_bars.index:
row = cache_bars.loc[idx]
tick = VtTickData()
tick.vtSymbol = vtSymbol
tick.symbol = vtSymbol
last_bar_dt = datetime.strptime(str(idx), '%Y-%m-%d %H:%M:00')
tick.datetime = last_bar_dt - timedelta(minutes=1)
tick.date = tick.datetime.strftime('%Y-%m-%d')
tick.time = tick.datetime.strftime('%H:%M:00')
if tick.datetime.hour >= 21:
if tick.datetime.isoweekday() == 5:
# 星期五=》星期一
tick.tradingDay = (tick.datetime + timedelta(days=3)).strftime('%Y-%m-%d')
else:
# 第二天
tick.tradingDay = (tick.datetime + timedelta(days=1)).strftime('%Y-%m-%d')
elif tick.datetime.hour < 8 and tick.datetime.isoweekday() == 6:
# 星期六=>星期一
tick.tradingDay = (tick.datetime + timedelta(days=2)).strftime('%Y-%m-%d')
else:
tick.tradingDay = tick.date
tick.upperLimit = float(row['limit_up'])
tick.lowerLimit = float(row['limit_down'])
tick.lastPrice = float(row['close'])
tick.askPrice1 = float(row['close'])
tick.bidPrice1 = float(row['close'])
tick.volume = int(row['volume'])
tick.askVolume1 = tick.volume
tick.bidVolume1 = tick.volume
if last_renko_dt is not None and tick.datetime <= last_renko_dt:
continue
renko_bar.onTick(tick)
def get_last_datetime(self,renko_name):
"""
通过mongo获取最新一个bar的数据日期
:param renko_name:
:return:
"""
qryData = self.mainEngine.dbQueryBySort(dbName=RENKO_DB_NAME,
collectionName=renko_name,
d={},
sortName='datetime',
sortType=-1,
limitNum=1)
last_renko_open_dt =None
last_renko_close_dt=None
for d in qryData:
last_renko_open_dt = d.get('datetime',None)
if last_renko_open_dt is not None:
last_renko_close_dt = last_renko_open_dt + timedelta(seconds=d.get('seconds',0))
break
return last_renko_close_dt
def onRenkoBar(self,bar,bar_name):
newBar = copy.copy(bar)
self.insertData(RENKO_DB_NAME, bar_name, newBar)
self.writeDrLog(u'new Renko Bar:{},dt:{},open:{},close:{},high:{},low:{}'
.format(bar_name,bar.datetime,bar.open,bar.close, bar.high, bar.low))
def registerEvent(self):
"""注册事件监听"""
self.eventEngine.register(EVENT_TICK, self.procecssTickEvent)
@ -206,7 +374,7 @@ class DrEngine(object):
try:
dbName, collectionName, d = self.queue.get(block=True, timeout=1)
self.mainEngine.dbInsert(dbName, collectionName, d)
except Empty:
except Exception as ex:
pass
#----------------------------------------------------------------------
def start(self):
@ -228,5 +396,7 @@ class DrEngine(object):
log.logContent = content
event = Event(type_=EVENT_DATARECORDER_LOG)
event.dict_['data'] = log
self.eventEngine.put(event)
self.eventEngine.put(event)
if self.logger:
self.logger.info(content)

View File

@ -29,7 +29,7 @@ class TableCell(QtWidgets.QTableWidgetItem):
if text == '0' or text == '0.0':
self.setText('')
else:
self.setText(text)
self.setText(str(text))
########################################################################
@ -60,7 +60,7 @@ class DrEngineManager(QtWidgets.QWidget):
self.tickTable.setColumnCount(2)
self.tickTable.verticalHeader().setVisible(False)
self.tickTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
self.tickTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
#self.tickTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
self.tickTable.setAlternatingRowColors(True)
self.tickTable.setHorizontalHeaderLabels([u'合约代码', u'接口'])
@ -69,16 +69,25 @@ class DrEngineManager(QtWidgets.QWidget):
self.barTable.setColumnCount(2)
self.barTable.verticalHeader().setVisible(False)
self.barTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
self.barTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
#self.barTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
self.barTable.setAlternatingRowColors(True)
self.barTable.setHorizontalHeaderLabels([u'合约代码', u'接口'])
renkoLabel = QtWidgets.QLabel(u'RenkoBar')
self.renkoTable = QtWidgets.QTableWidget()
self.renkoTable.setColumnCount(2)
self.renkoTable.verticalHeader().setVisible(False)
self.renkoTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
# self.renkoTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
self.renkoTable.setAlternatingRowColors(True)
self.renkoTable.setHorizontalHeaderLabels([u'合约代码', u'高度'])
activeLabel = QtWidgets.QLabel(u'主力合约')
self.activeTable = QtWidgets.QTableWidget()
self.activeTable.setColumnCount(2)
self.activeTable.verticalHeader().setVisible(False)
self.activeTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
self.activeTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
#self.activeTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
self.activeTable.setAlternatingRowColors(True)
self.activeTable.setHorizontalHeaderLabels([u'主力代码', u'合约代码'])
@ -92,10 +101,13 @@ class DrEngineManager(QtWidgets.QWidget):
grid.addWidget(tickLabel, 0, 0)
grid.addWidget(barLabel, 0, 1)
grid.addWidget(activeLabel, 0, 2)
grid.addWidget(renkoLabel, 0, 2)
grid.addWidget(activeLabel, 0, 3)
grid.addWidget(self.tickTable, 1, 0)
grid.addWidget(self.barTable, 1, 1)
grid.addWidget(self.activeTable, 1, 2)
grid.addWidget(self.renkoTable, 1, 2)
grid.addWidget(self.activeTable, 1, 3)
vbox = QtWidgets.QVBoxLayout()
vbox.addLayout(grid)
@ -118,7 +130,7 @@ class DrEngineManager(QtWidgets.QWidget):
#----------------------------------------------------------------------
def updateSetting(self):
"""显示引擎行情记录配置"""
with open(self.drEngine.settingFileName) as f:
with open(self.drEngine.settingFilePath) as f:
drSetting = json.load(f)
if 'tick' in drSetting:
@ -135,8 +147,15 @@ class DrEngineManager(QtWidgets.QWidget):
for setting in l:
self.barTable.insertRow(0)
self.barTable.setItem(0, 0, TableCell(setting[0]))
self.barTable.setItem(0, 1, TableCell(setting[1]))
self.barTable.setItem(0, 1, TableCell(setting[1]))
if 'renko' in drSetting:
l = drSetting['renko']
for setting in l:
self.renkoTable.insertRow(0)
self.renkoTable.setItem(0, 0, TableCell(setting.get('vtSymbol','')))
self.renkoTable.setItem(0, 1, TableCell(setting.get('height',0)))
if 'active' in drSetting:
d = drSetting['active']

View File

@ -244,7 +244,8 @@ class MultiprocessHandler(logging.FileHandler):
if not self.suffix:
raise ValueError(u"指定的日期间隔单位无效: %s" % self.when)
#拼接文件路径 格式化字符串
self.filefmt = "%s_%s.log" % (self.prefix,self.suffix)
#self.filefmt = "%s_%s.log" % (self.prefix,self.suffix)
self.filefmt = u'{}_{}.log'.format(self.prefix, self.suffix)
#使用当前时间,格式化文件格式化字符串
self.filePath = datetime.now().strftime(self.filefmt)
#获得文件夹路径

View File

@ -38,6 +38,10 @@ EVENT_NOTIFICATION = 'eNotification' # 全局通知
EVENT_SIGNAL = 'eSignal' # 信号通知
EVENT_STATUS = 'eStatus' # 服务状态
# 股票使用
EVENT_BAR = 'eBar' # 1分钟Bar 行情
EVENT_BARDICT = 'eBarDict_' # BarDict事件+策略实例名称
# CTA模块相关
EVENT_CTA_LOG = 'eCtaLog' # CTA相关的日志事件
EVENT_CTA_STRATEGY = 'eCtaStrategy.' # CTA策略状态变化事件

View File

@ -1,6 +1,6 @@
# encoding: UTF-8
import time,os,sys
import time,os,sys,copy
from datetime import datetime
from vnpy.trader.vtEvent import *
@ -23,6 +23,9 @@ class VtGateway(object):
self.accountID = 'AccountID'
self.createLogger()
# 所有订阅onBar的都会添加
self.klines = {}
# ----------------------------------------------------------------------
def onTick(self, tick):
"""市场行情推送"""
@ -35,7 +38,14 @@ class VtGateway(object):
event2 = Event(type_=EVENT_TICK+tick.vtSymbol)
event2.dict_['data'] = tick
self.eventEngine.put(event2)
def onBar(self,bar,type=EVENT_BAR):
"""市场行情推送"""
# bar, 或者 barDict
event = Event(type_=type)
event.dict_['data'] = bar
self.eventEngine.put(event)
# ----------------------------------------------------------------------
def onTrade(self, trade):
"""成交信息推送"""

283
vnpy/trader/vtUtility.py Normal file
View File

@ -0,0 +1,283 @@
# encoding: UTF-8
import numpy as np
import talib
from vnpy.trader.vtObject import VtBarData
########################################################################
class BarGenerator(object):
"""
K线合成器支持
1. 基于Tick合成1分钟K线
2. 基于1分钟K线合成X分钟K线X可以是235101530
"""
#----------------------------------------------------------------------
def __init__(self, onBar, xmin=0, onXminBar=None):
"""Constructor"""
self.bar = None # 1分钟K线对象
self.onBar = onBar # 1分钟K线回调函数
self.xminBar = None # X分钟K线对象
self.xmin = xmin # X的值
self.onXminBar = onXminBar # X分钟K线的回调函数
self.lastTick = None # 上一TICK缓存对象
#----------------------------------------------------------------------
def updateTick(self, tick):
"""TICK更新"""
newMinute = False # 默认不是新的一分钟
# 尚未创建对象
if not self.bar:
self.bar = VtBarData()
newMinute = True
# 新的一分钟
elif self.bar.datetime.minute != tick.datetime.minute:
# 生成上一分钟K线的时间戳
self.bar.datetime = self.bar.datetime.replace(second=0, microsecond=0) # 将秒和微秒设为0
self.bar.date = self.bar.datetime.strftime('%Y%m%d')
self.bar.time = self.bar.datetime.strftime('%H:%M:%S.%f')
# 推送已经结束的上一分钟K线
self.onBar(self.bar)
# 创建新的K线对象
self.bar = VtBarData()
newMinute = True
# 初始化新一分钟的K线数据
if newMinute:
self.bar.vtSymbol = tick.vtSymbol
self.bar.symbol = tick.symbol
self.bar.exchange = tick.exchange
self.bar.open = tick.lastPrice
self.bar.high = tick.lastPrice
self.bar.low = tick.lastPrice
# 累加更新老一分钟的K线数据
else:
self.bar.high = max(self.bar.high, tick.lastPrice)
self.bar.low = min(self.bar.low, tick.lastPrice)
# 通用更新部分
self.bar.close = tick.lastPrice
self.bar.datetime = tick.datetime
self.bar.openInterest = tick.openInterest
if self.lastTick:
volumeChange = tick.volume - self.lastTick.volume # 当前K线内的成交量
self.bar.volume += max(volumeChange, 0) # 避免夜盘开盘lastTick.volume为昨日收盘数据导致成交量变化为负的情况
# 缓存Tick
self.lastTick = tick
#----------------------------------------------------------------------
def updateBar(self, bar):
"""1分钟K线更新"""
# 尚未创建对象
if not self.xminBar:
self.xminBar = VtBarData()
self.xminBar.vtSymbol = bar.vtSymbol
self.xminBar.symbol = bar.symbol
self.xminBar.exchange = bar.exchange
self.xminBar.open = bar.open
self.xminBar.high = bar.high
self.xminBar.low = bar.low
self.xminBar.datetime = bar.datetime # 以第一根分钟K线的开始时间戳作为X分钟线的时间戳
# 累加老K线
else:
self.xminBar.high = max(self.xminBar.high, bar.high)
self.xminBar.low = min(self.xminBar.low, bar.low)
# 通用部分
self.xminBar.close = bar.close
self.xminBar.openInterest = bar.openInterest
self.xminBar.volume += int(bar.volume)
# X分钟已经走完
if not (bar.datetime.minute + 1) % self.xmin: # 可以用X整除
# 生成上一X分钟K线的时间戳
self.xminBar.datetime = self.xminBar.datetime.replace(second=0, microsecond=0) # 将秒和微秒设为0
self.xminBar.date = self.xminBar.datetime.strftime('%Y%m%d')
self.xminBar.time = self.xminBar.datetime.strftime('%H:%M:%S.%f')
# 推送
self.onXminBar(self.xminBar)
# 清空老K线缓存对象
self.xminBar = None
#----------------------------------------------------------------------
def generate(self):
"""手动强制立即完成K线合成"""
self.onBar(self.bar)
self.bar = None
########################################################################
class ArrayManager(object):
"""
K线序列管理工具负责
1. K线时间序列的维护
2. 常用技术指标的计算
"""
#----------------------------------------------------------------------
def __init__(self, size=100):
"""Constructor"""
self.count = 0 # 缓存计数
self.size = size # 缓存大小
self.inited = False # True if count>=size
self.openArray = np.zeros(size) # OHLC
self.highArray = np.zeros(size)
self.lowArray = np.zeros(size)
self.closeArray = np.zeros(size)
self.volumeArray = np.zeros(size)
#----------------------------------------------------------------------
def updateBar(self, bar):
"""更新K线"""
self.count += 1
if not self.inited and self.count >= self.size:
self.inited = True
self.openArray[:-1] = self.openArray[1:]
self.highArray[:-1] = self.highArray[1:]
self.lowArray[:-1] = self.lowArray[1:]
self.closeArray[:-1] = self.closeArray[1:]
self.volumeArray[:-1] = self.volumeArray[1:]
self.openArray[-1] = bar.open
self.highArray[-1] = bar.high
self.lowArray[-1] = bar.low
self.closeArray[-1] = bar.close
self.volumeArray[-1] = bar.volume
#----------------------------------------------------------------------
@property
def open(self):
"""获取开盘价序列"""
return self.openArray
#----------------------------------------------------------------------
@property
def high(self):
"""获取最高价序列"""
return self.highArray
#----------------------------------------------------------------------
@property
def low(self):
"""获取最低价序列"""
return self.lowArray
#----------------------------------------------------------------------
@property
def close(self):
"""获取收盘价序列"""
return self.closeArray
#----------------------------------------------------------------------
@property
def volume(self):
"""获取成交量序列"""
return self.volumeArray
#----------------------------------------------------------------------
def sma(self, n, array=False):
"""简单均线"""
result = talib.SMA(self.close, n)
if array:
return result
return result[-1]
#----------------------------------------------------------------------
def std(self, n, array=False):
"""标准差"""
result = talib.STDDEV(self.close, n)
if array:
return result
return result[-1]
#----------------------------------------------------------------------
def cci(self, n, array=False):
"""CCI指标"""
result = talib.CCI(self.high, self.low, self.close, n)
if array:
return result
return result[-1]
#----------------------------------------------------------------------
def atr(self, n, array=False):
"""ATR指标"""
result = talib.ATR(self.high, self.low, self.close, n)
if array:
return result
return result[-1]
#----------------------------------------------------------------------
def rsi(self, n, array=False):
"""RSI指标"""
result = talib.RSI(self.close, n)
if array:
return result
return result[-1]
#----------------------------------------------------------------------
def macd(self, fastPeriod, slowPeriod, signalPeriod, array=False):
"""MACD指标"""
macd, signal, hist = talib.MACD(self.close, fastPeriod,
slowPeriod, signalPeriod)
if array:
return macd, signal, hist
return macd[-1], signal[-1], hist[-1]
#----------------------------------------------------------------------
def adx(self, n, array=False):
"""ADX指标"""
result = talib.ADX(self.high, self.low, self.close, n)
if array:
return result
return result[-1]
#----------------------------------------------------------------------
def boll(self, n, dev, array=False):
"""布林通道"""
mid = self.sma(n, array)
std = self.std(n, array)
up = mid + std * dev
down = mid - std * dev
return up, down
#----------------------------------------------------------------------
def keltner(self, n, dev, array=False):
"""肯特纳通道"""
mid = self.sma(n, array)
atr = self.atr(n, array)
up = mid + atr * dev
down = mid - atr * dev
return up, down
#----------------------------------------------------------------------
def donchian(self, n, array=False):
"""唐奇安通道"""
up = talib.MAX(self.high, n)
down = talib.MIN(self.low, n)
if array:
return up, down
return up[-1], down[-1]