更新
This commit is contained in:
parent
ec89edebe5
commit
6a0f8476e1
@ -1,4 +1,4 @@
|
|||||||
# “当你想放弃时,想想你为什么开始。”
|
# “当你想放弃时,想想你为什么开始。埃隆·马斯克”
|
||||||
|
|
||||||
###Fork版本主要改进如下
|
###Fork版本主要改进如下
|
||||||
- 1、增加CtaLineBar,CtaPosition,CtaPolicy,UtlSinaClient等基础组件
|
- 1、增加CtaLineBar,CtaPosition,CtaPolicy,UtlSinaClient等基础组件
|
||||||
|
564
examples/CtaBacktesting/util_branch_testing.py
Normal file
564
examples/CtaBacktesting/util_branch_testing.py
Normal file
@ -0,0 +1,564 @@
|
|||||||
|
# encoding: UTF-8
|
||||||
|
|
||||||
|
"""
|
||||||
|
批量测试相关方法
|
||||||
|
# 华富资产 李来佳
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys, os, platform, gc, copy,multiprocessing
|
||||||
|
from datetime import datetime
|
||||||
|
from time import sleep
|
||||||
|
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||||
|
|
||||||
|
sys.path.append(vnpy_root)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import talib as ta # 科学计算库
|
||||||
|
import statsmodels.api as sm # 统计库
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import math # 数学计算相关
|
||||||
|
matplotlib.rcParams['figure.figsize'] = (20.0, 10.0)
|
||||||
|
import traceback
|
||||||
|
from vnpy.trader.setup_logger import *
|
||||||
|
from vnpy.trader.app.ctaStrategy.ctaBacktesting import BacktestingEngine, OptimizationSetting, MINUTE_DB_NAME
|
||||||
|
|
||||||
|
|
||||||
|
# 合并回测结果
|
||||||
|
def combine_results(results_list):
|
||||||
|
result_df = None
|
||||||
|
|
||||||
|
# 判断结果集是否有数据
|
||||||
|
if len(results_list) < 1:
|
||||||
|
print('no records')
|
||||||
|
return None
|
||||||
|
|
||||||
|
effected_results = 0
|
||||||
|
|
||||||
|
for dict in results_list:
|
||||||
|
# 测试项目
|
||||||
|
test_item = dict['test_item']
|
||||||
|
# 测试csv文件
|
||||||
|
file_name = dict['result_file']
|
||||||
|
|
||||||
|
if not os.path.isfile(file_name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
effected_results += 1
|
||||||
|
|
||||||
|
# 读取测试文件
|
||||||
|
df = pd.read_csv(file_name)
|
||||||
|
# 修正索引为时间日期索引
|
||||||
|
df = df.set_index(pd.DatetimeIndex(df['date']))
|
||||||
|
|
||||||
|
if result_df is None:
|
||||||
|
# 首个测试结果,将净值字段设置为周期
|
||||||
|
result_df = df['rate'].to_frame(name=test_item)
|
||||||
|
# 汇总净值
|
||||||
|
result_df['rate'] = result_df[test_item]
|
||||||
|
else:
|
||||||
|
# 增加新的测试结果数据
|
||||||
|
result_df[test_item] = df['rate']
|
||||||
|
|
||||||
|
# 汇总净值
|
||||||
|
result_df['rate'] = result_df['rate'] + result_df[test_item]
|
||||||
|
|
||||||
|
# 释放内存
|
||||||
|
l = [df]
|
||||||
|
del df
|
||||||
|
del l
|
||||||
|
|
||||||
|
if effected_results > 0:
|
||||||
|
# 净值平均
|
||||||
|
result_df['avg_rate'] = result_df['rate'] / effected_results
|
||||||
|
# 组合净值累加(仍然按照1个策略的总资金,累加各策略的收益)
|
||||||
|
result_df['group_rate'] = result_df['rate'] - effected_results + 1
|
||||||
|
|
||||||
|
# 删除累加的rate
|
||||||
|
result_df.drop('rate', axis=1, inplace=True)
|
||||||
|
|
||||||
|
return result_df
|
||||||
|
|
||||||
|
|
||||||
|
# 计算最大回撤,单日最大回撤
|
||||||
|
def calculate_result(result_df, rate_column):
|
||||||
|
max_rate = 0
|
||||||
|
max_loss = 0
|
||||||
|
max_rate_date = None
|
||||||
|
max_loss_date = None
|
||||||
|
max_loss_info = '-'
|
||||||
|
|
||||||
|
for idx in result_df.index:
|
||||||
|
# 当前日净值
|
||||||
|
cur_rate = result_df[rate_column].loc[idx]
|
||||||
|
|
||||||
|
if cur_rate > max_rate:
|
||||||
|
max_rate = cur_rate
|
||||||
|
max_rate_date = idx.strftime('%Y-%m-%d')
|
||||||
|
|
||||||
|
cur_loss = max_rate - cur_rate
|
||||||
|
|
||||||
|
if cur_loss > max_loss:
|
||||||
|
max_loss = cur_loss
|
||||||
|
max_loss_date = idx.strftime('%Y-%m-%d')
|
||||||
|
max_loss_percent = max_loss / max_rate
|
||||||
|
max_loss_info = u'{} from {} to {},rate {}=>{},max loss rate {}'.format(rate_column, max_rate_date,
|
||||||
|
max_loss_date, max_rate, cur_rate,
|
||||||
|
max_loss_percent)
|
||||||
|
return max_loss_info
|
||||||
|
|
||||||
|
def single_strategy_test(test_settings):
|
||||||
|
"""
|
||||||
|
根据设置参数,执行回测品种,
|
||||||
|
:param strategyClass: 策略类
|
||||||
|
:param test_settings: 策略参数设置
|
||||||
|
test_settings['bar_file']: 回测的bar csv文件路径
|
||||||
|
test_settings['bar_interval']: 回测的Bar csv周期
|
||||||
|
test_settings['report_file']: 净值输出报告的保存路径
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
from vnpy.trader.vtEvent import EventEngine2
|
||||||
|
eventEngine = EventEngine2()
|
||||||
|
eventEngine.start()
|
||||||
|
|
||||||
|
# 创建回测引擎
|
||||||
|
engine = BacktestingEngine(eventEngine=eventEngine)
|
||||||
|
# 设置回测的策略类
|
||||||
|
engine.setStrategyName(test_settings['name'])
|
||||||
|
# 创建日志
|
||||||
|
if 'debug' in test_settings:
|
||||||
|
engine.createLogger(debug=test_settings['debug'])
|
||||||
|
else:
|
||||||
|
engine.createLogger()
|
||||||
|
|
||||||
|
# 设置引擎的回测模式
|
||||||
|
if test_settings['mode'] == 'tick':
|
||||||
|
engine.setBacktestingMode(engine.TICK_MODE)
|
||||||
|
else:
|
||||||
|
engine.setBacktestingMode(engine.BAR_MODE)
|
||||||
|
|
||||||
|
strategy_settings = copy.copy(test_settings)
|
||||||
|
# 设置回测的策略类
|
||||||
|
if 'filenamePrefix' in strategy_settings:
|
||||||
|
engine.setStrategyName(strategy_settings["filenamePrefix"])
|
||||||
|
else:
|
||||||
|
engine.setStrategyName(test_settings['name'])
|
||||||
|
|
||||||
|
if 'is_7x24' in test_settings:
|
||||||
|
engine.is_7x24 = test_settings['is_7x24']
|
||||||
|
|
||||||
|
# del strategy_settings['size']
|
||||||
|
#del strategy_settings['margin_rate']
|
||||||
|
del strategy_settings['initCapital']
|
||||||
|
|
||||||
|
if 'report_file' in strategy_settings:
|
||||||
|
del strategy_settings['report_file']
|
||||||
|
|
||||||
|
# 设置回测用的数据起始日期
|
||||||
|
if 'start_date' in strategy_settings:
|
||||||
|
engine.setStartDate(test_settings['start_date'], initDays = strategy_settings.get('initDays', 10))
|
||||||
|
del strategy_settings['start_date']
|
||||||
|
else:
|
||||||
|
engine.setStartDate('20110101',initDays = strategy_settings.get('initDays',10))
|
||||||
|
|
||||||
|
# 设置回测用的数据结束日期
|
||||||
|
if 'end_date' in test_settings:
|
||||||
|
engine.setEndDate(test_settings['end_date'])
|
||||||
|
del strategy_settings['end_date']
|
||||||
|
else:
|
||||||
|
engine.setEndDate('20171201')
|
||||||
|
|
||||||
|
# engine.connectMysql()
|
||||||
|
engine.setDatabase(dbName=MINUTE_DB_NAME, symbol=test_settings['vtSymbol'])
|
||||||
|
|
||||||
|
# 设置产品相关参数
|
||||||
|
if 'slippage' in test_settings and test_settings['slippage'] > 0:
|
||||||
|
engine.setSlippage(test_settings['slippage'])
|
||||||
|
else:
|
||||||
|
engine.setSlippage(0)
|
||||||
|
engine.setRate(test_settings['rate'] if 'rate' in test_settings else float(0.0001)) # 万1
|
||||||
|
engine.setSize(test_settings['size']) # 合约大小
|
||||||
|
engine.setMinDiff(test_settings['minDiff']) # 合约价格最小跳动
|
||||||
|
engine.setMarginRate(test_settings['margin_rate']) # 合约保证金率
|
||||||
|
if 'fixCommission' in test_settings:
|
||||||
|
engine.fixCommission = float(test_settings['fixCommission']) # 固定交易费用(每次开平仓收费)
|
||||||
|
|
||||||
|
# 删除本地json文件
|
||||||
|
data_path = os.path.abspath(os.path.join(os.getcwd(),'data'))
|
||||||
|
if not os.path.exists(data_path):
|
||||||
|
os.mkdir(data_path)
|
||||||
|
|
||||||
|
logs_path = os.path.abspath(os.path.join(os.getcwd(), 'logs'))
|
||||||
|
if not os.path.exists(logs_path):
|
||||||
|
os.mkdir(logs_path)
|
||||||
|
|
||||||
|
up_grid_json_file = os.path.abspath(os.path.join(data_path,'{0}_upGrids.json'.format(test_settings['name'])))
|
||||||
|
dn_grid_json_file = os.path.abspath(os.path.join(data_path,'{0}_dnGrids.json'.format(test_settings['name'])))
|
||||||
|
grid_json_file = os.path.abspath(os.path.join(data_path,'{0}_Grids.json'.format(test_settings['name'])))
|
||||||
|
policy_json_file = os.path.abspath(os.path.join(data_path, '{0}_Policy.json'.format(test_settings['name'])))
|
||||||
|
|
||||||
|
if os.path.isfile(up_grid_json_file):
|
||||||
|
print(u'{0} exist,remove it'.format(up_grid_json_file))
|
||||||
|
try:
|
||||||
|
os.remove(up_grid_json_file)
|
||||||
|
except Exception as ex:
|
||||||
|
print(u'{0}:{1}'.format(Exception, ex))
|
||||||
|
return False
|
||||||
|
|
||||||
|
if os.path.isfile(dn_grid_json_file):
|
||||||
|
print(u'{0}exist,remove it'.format(dn_grid_json_file))
|
||||||
|
try:
|
||||||
|
os.remove(dn_grid_json_file)
|
||||||
|
except Exception as ex:
|
||||||
|
print(u'{0}:{1}'.format(Exception, ex))
|
||||||
|
return False
|
||||||
|
|
||||||
|
if os.path.isfile(grid_json_file):
|
||||||
|
print(u'{0}exist,remove it'.format(grid_json_file))
|
||||||
|
try:
|
||||||
|
os.remove(grid_json_file)
|
||||||
|
except Exception as ex:
|
||||||
|
print(u'{0}:{1}'.format(Exception, ex))
|
||||||
|
return False
|
||||||
|
|
||||||
|
if os.path.isfile(policy_json_file):
|
||||||
|
print(u'{0}exist,remove it'.format(policy_json_file))
|
||||||
|
try:
|
||||||
|
os.remove(policy_json_file)
|
||||||
|
except Exception as ex:
|
||||||
|
print(u'{0}:{1}'.format(Exception, ex))
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 在引擎中创建策略对象
|
||||||
|
print(u'run {} using:{}'.format(test_settings['strategy'],strategy_settings))
|
||||||
|
engine.initStrategy(test_settings['strategy'], setting=strategy_settings)
|
||||||
|
|
||||||
|
# 设置每日净值的报告文档存储路径
|
||||||
|
daily_report_file = 'DailyList.csv' if 'report_file' not in test_settings else test_settings['report_file']
|
||||||
|
engine.setDailyReportName(daily_report_file)
|
||||||
|
|
||||||
|
# 使用简单复利模式计算
|
||||||
|
engine.usageCompounding = False # True时,只针对FINAL_MODE有效
|
||||||
|
|
||||||
|
# 启用实时计算净值模式REALTIME_MODE / FINAL_MODE 回测结束时统一计算模式
|
||||||
|
engine.calculateMode = engine.REALTIME_MODE
|
||||||
|
engine.capital = test_settings['initCapital'] # 设置期初资金
|
||||||
|
engine.initCapital = test_settings['initCapital'] # 设置期初资金
|
||||||
|
engine.avaliable = test_settings['initCapital'] # 设置期初资金
|
||||||
|
engine.netCapital = test_settings['initCapital']
|
||||||
|
engine.maxCapital = test_settings['initCapital'] # 设置期初资金
|
||||||
|
engine.maxNetCapital = test_settings['initCapital'] # 设置期初资金
|
||||||
|
engine.percentLimit = test_settings['percentLimit'] # 设置资金使用上限比例(%)
|
||||||
|
engine.barTimeInterval = 60 * test_settings['bar_interval'] # 回测文件中,bar的周期秒数,用于csv文件自动减时间
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 前置动作(无参数)
|
||||||
|
pre_functions = test_settings.get('pre_functions',[])
|
||||||
|
for fun_name in pre_functions:
|
||||||
|
try:
|
||||||
|
if not isinstance(fun_name,str):
|
||||||
|
continue
|
||||||
|
if hasattr(engine.strategy,fun_name):
|
||||||
|
fun = getattr(engine.strategy,fun_name)
|
||||||
|
if fun is not None:
|
||||||
|
fun()
|
||||||
|
except Exception as ex:
|
||||||
|
print(u'调用前置动作异常:{},{}'.format(str(ex),traceback.format_exc()),file=sys.stderr)
|
||||||
|
|
||||||
|
# 开始跑回测
|
||||||
|
if 'bar_file' in test_settings:
|
||||||
|
engine.runBackTestingWithBarFile(test_settings['bar_file'])
|
||||||
|
else:
|
||||||
|
engine.runBackTestingWithDataSource()
|
||||||
|
|
||||||
|
print('{}finished loop bars'.format(test_settings['name']))
|
||||||
|
# 显示回测结果
|
||||||
|
engine.showBacktestingResult()
|
||||||
|
|
||||||
|
# 保存策略得内部数据
|
||||||
|
engine.saveStrategyData()
|
||||||
|
|
||||||
|
print('{} finished'.format(test_settings['name']))
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as ex:
|
||||||
|
print(u'single_strategy_test exception:{}'.format(str(ex)))
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def multi_period_test(gid, group_setting):
|
||||||
|
"""
|
||||||
|
多周期回测品种组合
|
||||||
|
1、对group_settings进行分解,分解出各个运行周期的参数设置
|
||||||
|
2、逐一周期运行测试
|
||||||
|
3、添加回测结果
|
||||||
|
:param gid: 测试组名
|
||||||
|
:param group_setting:dict,包含参数,多周期清单
|
||||||
|
如果多周期,则对每一周期执行回测,并汇总结果。
|
||||||
|
:return 回测的每日净值统计文件
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 回测的分钟周期
|
||||||
|
minutes_interval_list = group_setting['minute_list'] # 回测分钟队列(3,5,10等)
|
||||||
|
|
||||||
|
strategyClass = group_setting['strategy']
|
||||||
|
|
||||||
|
# 测试批次时间
|
||||||
|
test_dt = datetime.now().strftime('%Y%m%d_%H%M')
|
||||||
|
|
||||||
|
# 回测结果队列,对应测试分钟队列
|
||||||
|
daily_results = []
|
||||||
|
return_results = []
|
||||||
|
|
||||||
|
# 启动多进程
|
||||||
|
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
||||||
|
l = []
|
||||||
|
|
||||||
|
# 逐一分钟级别进行回测
|
||||||
|
for m_i in minutes_interval_list:
|
||||||
|
settings = copy.copy(group_setting)
|
||||||
|
del settings['minute_list']
|
||||||
|
|
||||||
|
settings['name'] = '{}_{}_{}_M{}'.format(gid, strategyClass.className, group_setting['symbol'], m_i)
|
||||||
|
|
||||||
|
# 资金占用比例,根据组合内周期数量,进行平均分配
|
||||||
|
settings['percentLimit'] = group_setting['percentLimit']/ len(minutes_interval_list)
|
||||||
|
settings['vtSymbol'] = group_setting['symbol']
|
||||||
|
settings['MinInterval'] = m_i
|
||||||
|
|
||||||
|
settings['mode'] = 'bar'
|
||||||
|
settings['backtesting'] = True
|
||||||
|
settings['bar_interval'] = group_setting['bar_interval'] if 'bar_interval' in group_setting else 1
|
||||||
|
settings['strategy'] = strategyClass
|
||||||
|
# 回测报告文件保存路径: 组合,测试实例名称,测试时间
|
||||||
|
daily_report_file = os.path.abspath(os.path.join(group_setting['report_folder'], u'{}_daily_{}.csv'
|
||||||
|
.format(settings['name'], test_dt)))
|
||||||
|
|
||||||
|
settings['report_file'] = daily_report_file
|
||||||
|
|
||||||
|
#if rt:
|
||||||
|
# 回测报告集登记
|
||||||
|
daily_results.append({'test_item': 'M{}'.format(m_i), 'result_file': daily_report_file})
|
||||||
|
|
||||||
|
l.append(pool.apply_async(single_strategy_test, (settings,)))
|
||||||
|
#rt = single_strategy_test(test_settings=settings)
|
||||||
|
|
||||||
|
# 执行内存回收
|
||||||
|
gc.collect()
|
||||||
|
sleep(10)
|
||||||
|
|
||||||
|
result_list = [res.get() for res in l]
|
||||||
|
|
||||||
|
for idx, rt in enumerate(result_list):
|
||||||
|
if rt:
|
||||||
|
return_results.append(daily_results[idx])
|
||||||
|
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
return return_results
|
||||||
|
|
||||||
|
def run_multiperiod_test(gid, group_settings):
|
||||||
|
"""
|
||||||
|
运行多周期的组合测试
|
||||||
|
:param gid: 组合名称
|
||||||
|
:param group_settings:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
m = '_'.join(str(e) for e in group_settings['minute_list'])
|
||||||
|
if 'report_folder' in group_settings:
|
||||||
|
final_file = os.path.abspath(os.path.join(group_settings['report_folder'], '{}_{}_Report_{}.csv'.format(gid, group_settings['symbol'], m)))
|
||||||
|
else:
|
||||||
|
# 报告所在目录
|
||||||
|
report_folder = os.path.abspath(os.path.join(os.getcwd(), 'logs', gid))
|
||||||
|
if not os.path.exists(report_folder):
|
||||||
|
os.mkdir(report_folder)
|
||||||
|
# 汇总报告文件
|
||||||
|
final_file = os.path.abspath(
|
||||||
|
os.path.join(report_folder, '{}_{}_Report_{}.csv'.format(gid, group_settings['symbol'], m)))
|
||||||
|
group_settings['report_folder'] = report_folder
|
||||||
|
|
||||||
|
if not os.path.exists(group_settings['report_folder']):
|
||||||
|
os.makedirs(group_settings['report_folder'])
|
||||||
|
|
||||||
|
# 运行回测方法,统计结果
|
||||||
|
daily_results = multi_period_test(gid, group_settings)
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
backtest_df = combine_results(daily_results)
|
||||||
|
|
||||||
|
# 保存汇总记录到文件
|
||||||
|
backtest_df.to_csv(final_file)
|
||||||
|
|
||||||
|
# 显示资金曲线汇总
|
||||||
|
fig, ax1 = plt.subplots()
|
||||||
|
|
||||||
|
period_columns = [item['test_item'] for item in daily_results]
|
||||||
|
|
||||||
|
fig.patch.set_facecolor('white')
|
||||||
|
ax1.plot(backtest_df[period_columns])
|
||||||
|
ax1.legend()
|
||||||
|
|
||||||
|
# 释放内存
|
||||||
|
l = [backtest_df]
|
||||||
|
del backtest_df
|
||||||
|
del l
|
||||||
|
|
||||||
|
# 释放内存
|
||||||
|
gc.collect()
|
||||||
|
print( 'finished run_multiparameter_test')
|
||||||
|
|
||||||
|
def multi_parameter_test(gid, settings_list):
|
||||||
|
"""
|
||||||
|
不同参数的回测组合
|
||||||
|
:param gid: 组合ID
|
||||||
|
:param settings_list: 参数列表
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 回测结果队列,对应测试参数队列
|
||||||
|
daily_results = []
|
||||||
|
return_results = []
|
||||||
|
|
||||||
|
# 每个测试的参数名称
|
||||||
|
para_list = [i['paraName'] for i in settings_list]
|
||||||
|
|
||||||
|
if len(para_list) == 0:
|
||||||
|
return daily_results
|
||||||
|
|
||||||
|
# 启动多进程
|
||||||
|
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
||||||
|
#pool = multiprocessing.Pool(2)
|
||||||
|
l = []
|
||||||
|
|
||||||
|
print('multi_parameter_test,total:{}'.format(len(settings_list)))
|
||||||
|
|
||||||
|
for idx, strategy_settings in enumerate(settings_list):
|
||||||
|
settings = copy.copy(strategy_settings)
|
||||||
|
# 测试时间
|
||||||
|
test_dt = datetime.now().strftime('%Y%m%d_%H%M')
|
||||||
|
|
||||||
|
del settings['log_file']
|
||||||
|
if 'minute_list' in settings:
|
||||||
|
del settings['minute_list']
|
||||||
|
settings['vtSymbol'] = settings['symbol']
|
||||||
|
settings['mode'] = 'bar'
|
||||||
|
settings['backtesting'] = True
|
||||||
|
if 'bar_interval' not in settings:
|
||||||
|
settings['bar_interval'] = 1
|
||||||
|
|
||||||
|
# 回测报告文件保存路径: 组合,测试实例名称,测试时间
|
||||||
|
daily_report_file = os.path.abspath(os.path.join(settings['report_folder'], u'{}_daily_{}.csv'.format(settings['name'], test_dt)))
|
||||||
|
|
||||||
|
settings['report_file'] = daily_report_file
|
||||||
|
|
||||||
|
l.append(pool.apply_async(single_strategy_test, (settings,)))
|
||||||
|
|
||||||
|
# 回测报告集登记
|
||||||
|
daily_results.append({'test_item': settings['paraName'], 'result_file': daily_report_file})
|
||||||
|
|
||||||
|
# 执行内存回收
|
||||||
|
gc.collect()
|
||||||
|
sleep(10)
|
||||||
|
|
||||||
|
result_list = [res.get() for res in l]
|
||||||
|
|
||||||
|
# 返回结果是正确的,才添加到返回列表中
|
||||||
|
for idx, rt in enumerate(result_list):
|
||||||
|
if rt:
|
||||||
|
return_results.append(daily_results[idx])
|
||||||
|
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
return return_results
|
||||||
|
|
||||||
|
def run_multiparameter_test(gid, settings_list):
|
||||||
|
"""
|
||||||
|
多策略组测试
|
||||||
|
:param gid:
|
||||||
|
:param settings_list:
|
||||||
|
:return: """
|
||||||
|
if len(settings_list) == 0:
|
||||||
|
raise ReferenceError('Zero settings')
|
||||||
|
|
||||||
|
first_setting = settings_list[0]
|
||||||
|
|
||||||
|
paraNames = '_'.join(i['paraName'] for i in settings_list)
|
||||||
|
if 'report_folder' in first_setting:
|
||||||
|
report_folder = first_setting['report_folder']
|
||||||
|
final_file = os.path.abspath(os.path.join(first_setting['report_folder'],
|
||||||
|
'{}_{}_Report_{}.csv'.format(gid, first_setting['symbol'],
|
||||||
|
paraNames)))
|
||||||
|
else:
|
||||||
|
# 报告所在目录
|
||||||
|
report_folder = os.path.abspath(os.path.join(os.getcwd(), 'logs'))
|
||||||
|
final_file = os.path.abspath(
|
||||||
|
os.path.join(report_folder, '{}_{}_Report_{}.csv'.format(gid, first_setting['symbol'], paraNames)))
|
||||||
|
|
||||||
|
if not os.path.exists(report_folder):
|
||||||
|
os.makedirs(report_folder)
|
||||||
|
|
||||||
|
for settings in settings_list:
|
||||||
|
settings['report_folder'] = report_folder
|
||||||
|
|
||||||
|
# 运行回测方法,统计结果
|
||||||
|
daily_results = multi_parameter_test(gid, settings_list)
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
#backtest_df = combine_results(daily_results)
|
||||||
|
|
||||||
|
# 保存汇总记录到文件
|
||||||
|
#backtest_df.to_csv(final_file)
|
||||||
|
|
||||||
|
# 显示资金曲线汇总
|
||||||
|
##fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
#period_columns = [item['test_item'] for item in daily_results]
|
||||||
|
|
||||||
|
#fig.patch.set_facecolor('white')
|
||||||
|
|
||||||
|
#for column in period_columns:
|
||||||
|
# ax.plot(backtest_df[column], label=column)
|
||||||
|
|
||||||
|
#ax.legend()
|
||||||
|
|
||||||
|
#title = u'{} {}'.format(gid, paraNames)
|
||||||
|
#plt.title(title)
|
||||||
|
|
||||||
|
#fig.savefig(u'{}/rate.png'.format(report_folder))
|
||||||
|
|
||||||
|
## 释放内存
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print( 'finished run_multiparameter_test')
|
||||||
|
|
||||||
|
def single_func(para):
|
||||||
|
logger=setup_logger('MyLog', name='my{}'.format(para))
|
||||||
|
if para > 5:
|
||||||
|
print( u'more than 5')
|
||||||
|
logger.info('More than 5')
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print ('less')
|
||||||
|
logger.info('Less than 5')
|
||||||
|
return False
|
||||||
|
|
||||||
|
def multi_func():
|
||||||
|
|
||||||
|
import logging
|
||||||
|
# 启动多进程
|
||||||
|
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
logger = setup_logger('MyLog')
|
||||||
|
|
||||||
|
logger.info('main process')
|
||||||
|
l = []
|
||||||
|
|
||||||
|
for i in range(0,10):
|
||||||
|
l.append(pool.apply_async(single_func,(i,)))
|
||||||
|
|
||||||
|
results = [res.get() for res in l]
|
||||||
|
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
@ -305,7 +305,7 @@ def start():
|
|||||||
# 往任务表增加定时计划
|
# 往任务表增加定时计划
|
||||||
operate_crontab("add")
|
operate_crontab("add")
|
||||||
# 执行启动
|
# 执行启动
|
||||||
_start()
|
#_start()
|
||||||
print(u'启动{}服务执行完毕'.format(base_path))
|
print(u'启动{}服务执行完毕'.format(base_path))
|
||||||
|
|
||||||
def _stop():
|
def _stop():
|
||||||
|
31
examples/VnTrader/CTA_setting.json
Normal file
31
examples/VnTrader/CTA_setting.json
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "Strategy_TripleMa_01_RB_Min5",
|
||||||
|
"comment": "螺纹05合约三均线策略",
|
||||||
|
"auto_init": true,
|
||||||
|
"auto_start": true,
|
||||||
|
"className": "Strategy_TripleMa_v01",
|
||||||
|
"inputSS": 1,
|
||||||
|
"minDiff": 1,
|
||||||
|
"mode": "tick",
|
||||||
|
"shortSymbol": "RB",
|
||||||
|
"strategy_module": "Strategy_TripleMa_v01",
|
||||||
|
"symbol": "RB1905",
|
||||||
|
"vtSymbol": "rb1905"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Strategy_TripleMa_02_HC_Min5",
|
||||||
|
"comment": "热卷05合约三均线策略",
|
||||||
|
"auto_init": true,
|
||||||
|
"auto_start": true,
|
||||||
|
"className": "Strategy_TripleMa_v02",
|
||||||
|
"inputSS": 1,
|
||||||
|
"minDiff": 1,
|
||||||
|
"mode": "tick",
|
||||||
|
"shortSymbol": "HC",
|
||||||
|
"strategy_module": "Strategy_TripleMa_v02",
|
||||||
|
"symbol": "HC1905",
|
||||||
|
"vtSymbol": "hc1905"
|
||||||
|
}
|
||||||
|
|
||||||
|
]
|
@ -23,9 +23,9 @@ from vnpy.trader.uiMainWindow import *
|
|||||||
# 加载底层接口
|
# 加载底层接口
|
||||||
from vnpy.trader.gateway import ctpGateway
|
from vnpy.trader.gateway import ctpGateway
|
||||||
# 初始化的接口模块,以及其指定的名称,CTP是模块,value,是该模块下的多个连接配置文件,如 CTP_JR_connect.json 'CTP_Prod', 'CTP_JR', , 'CTP_JK', 'CTP_02'
|
# 初始化的接口模块,以及其指定的名称,CTP是模块,value,是该模块下的多个连接配置文件,如 CTP_JR_connect.json 'CTP_Prod', 'CTP_JR', , 'CTP_JK', 'CTP_02'
|
||||||
init_gateway_names = {'CTP': ['CTP','CTP_001','CTP_002']}
|
init_gateway_names = {'CTP': ['CTP','CTP01','CTP02']}
|
||||||
|
|
||||||
from vnpy.trader.app import (ctaStrategy, riskManager, spreadTrading) #,dataRecorder
|
from vnpy.trader.app import (ctaStrategy, riskManager, spreadTrading,dataRecorder) #,
|
||||||
|
|
||||||
# 文件路径名
|
# 文件路径名
|
||||||
path = os.path.abspath(os.path.dirname(__file__))
|
path = os.path.abspath(os.path.dirname(__file__))
|
||||||
@ -60,7 +60,7 @@ def main():
|
|||||||
mainEngine.addApp(ctaStrategy)
|
mainEngine.addApp(ctaStrategy)
|
||||||
mainEngine.addApp(riskManager)
|
mainEngine.addApp(riskManager)
|
||||||
mainEngine.addApp(spreadTrading)
|
mainEngine.addApp(spreadTrading)
|
||||||
#mainEngine.addApp(dataRecorder)
|
mainEngine.addApp(dataRecorder)
|
||||||
|
|
||||||
mainWindow = MainWindow(mainEngine, ee)
|
mainWindow = MainWindow(mainEngine, ee)
|
||||||
mainWindow.showMaximized()
|
mainWindow.showMaximized()
|
||||||
|
@ -1432,6 +1432,7 @@ class CtaEngine(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
for expired_pos in expired_pos_list:
|
for expired_pos in expired_pos_list:
|
||||||
|
self.writeCtaLog(u'执行仓位处理:{}'.format(expired_pos))
|
||||||
if expired_pos['volume'] == 0:
|
if expired_pos['volume'] == 0:
|
||||||
self.writeCtaError(u'clear_dispatch_pos,pos 为空:{},删除'.format(expired_pos))
|
self.writeCtaError(u'clear_dispatch_pos,pos 为空:{},删除'.format(expired_pos))
|
||||||
flt = {'_id': expired_pos['_id']}
|
flt = {'_id': expired_pos['_id']}
|
||||||
@ -1509,7 +1510,7 @@ class CtaEngine(object):
|
|||||||
else:
|
else:
|
||||||
if curPos.longYd >= expired_pos['volume']:
|
if curPos.longYd >= expired_pos['volume']:
|
||||||
sell_longYd = expired_pos['volume']
|
sell_longYd = expired_pos['volume']
|
||||||
if curPos.longYd == 0:
|
elif curPos.longYd == 0:
|
||||||
sell_longToday = expired_pos['volume']
|
sell_longToday = expired_pos['volume']
|
||||||
else:
|
else:
|
||||||
sell_longYd = curPos.longYd
|
sell_longYd = curPos.longYd
|
||||||
|
@ -253,9 +253,13 @@ class TurtlePolicy(CtaPolicy):
|
|||||||
def __init__(self, strategy):
|
def __init__(self, strategy):
|
||||||
super(TurtlePolicy, self).__init__(strategy)
|
super(TurtlePolicy, self).__init__(strategy)
|
||||||
|
|
||||||
self.tns_open_price = 0 # 首次开仓价格
|
self.allow_add_pos = False # 是否加仓
|
||||||
self.last_open_price = 0 # 最后一次加仓价格
|
self.add_pos_on_pips = EMPTY_INT # 价格超过开仓价多少点时加仓
|
||||||
self.stop_price = 0 # 止损价
|
|
||||||
|
self.tns_open_price = 0 # 首次开仓价格
|
||||||
|
self.last_open_price = 0 # 最后一次加仓价格
|
||||||
|
self.stop_price = 0 # 止损价
|
||||||
|
self.exit_on_last_rtn_pips = 0 # 最高价/最低价回撤多少跳动
|
||||||
self.high_price_in_long = 0 # 多趋势时,最高价
|
self.high_price_in_long = 0 # 多趋势时,最高价
|
||||||
self.low_price_in_short = 0 # 空趋势时,最低价
|
self.low_price_in_short = 0 # 空趋势时,最低价
|
||||||
self.last_under_open_price = 0 # 低于首次开仓价的补仓价格
|
self.last_under_open_price = 0 # 低于首次开仓价的补仓价格
|
||||||
@ -279,12 +283,15 @@ class TurtlePolicy(CtaPolicy):
|
|||||||
j['tns_open_date'] = self.tns_open_date
|
j['tns_open_date'] = self.tns_open_date
|
||||||
j['tns_open_price'] = self.tns_open_price if self.tns_open_price is not None else 0
|
j['tns_open_price'] = self.tns_open_price if self.tns_open_price is not None else 0
|
||||||
|
|
||||||
|
j['allow_add_pos'] = self.allow_add_pos
|
||||||
|
j['add_pos_on_pips'] = self.add_pos_on_pips
|
||||||
|
j['exit_on_last_rtn_pips'] = self.exit_on_last_rtn_pips
|
||||||
|
|
||||||
j['last_open_price'] = self.last_open_price if self.last_open_price is not None else 0
|
j['last_open_price'] = self.last_open_price if self.last_open_price is not None else 0
|
||||||
j['stop_price'] = self.stop_price if self.stop_price is not None else 0
|
j['stop_price'] = self.stop_price if self.stop_price is not None else 0
|
||||||
j['high_price_in_long'] = self.high_price_in_long if self.high_price_in_long is not None else 0
|
j['high_price_in_long'] = self.high_price_in_long if self.high_price_in_long is not None else 0
|
||||||
j['low_price_in_short'] = self.low_price_in_short if self.low_price_in_short is not None else 0
|
j['low_price_in_short'] = self.low_price_in_short if self.low_price_in_short is not None else 0
|
||||||
j[
|
j['add_pos_count_under_first_price'] = self.add_pos_count_under_first_price if self.add_pos_count_under_first_price is not None else 0
|
||||||
'add_pos_count_under_first_price'] = self.add_pos_count_under_first_price if self.add_pos_count_under_first_price is not None else 0
|
|
||||||
j['last_under_open_price'] = self.last_under_open_price if self.last_under_open_price is not None else 0
|
j['last_under_open_price'] = self.last_under_open_price if self.last_under_open_price is not None else 0
|
||||||
|
|
||||||
j['max_pos'] = self.max_pos if self.max_pos is not None else 0
|
j['max_pos'] = self.max_pos if self.max_pos is not None else 0
|
||||||
@ -412,6 +419,9 @@ class TurtlePolicy(CtaPolicy):
|
|||||||
self.writeCtaError(u'解释last_risk_level异常:{}'.format(str(ex)))
|
self.writeCtaError(u'解释last_risk_level异常:{}'.format(str(ex)))
|
||||||
self.last_risk_level = 0
|
self.last_risk_level = 0
|
||||||
|
|
||||||
|
self.allow_add_pos=json_data.get('allow_add_pos',False)
|
||||||
|
self.add_pos_on_pips = json_data.get('add_pos_on_pips',1)
|
||||||
|
self.exit_on_last_rtn_pips = json_data.get('exit_on_last_rtn_pips',0)
|
||||||
|
|
||||||
def clean(self):
|
def clean(self):
|
||||||
"""
|
"""
|
||||||
@ -432,6 +442,9 @@ class TurtlePolicy(CtaPolicy):
|
|||||||
self.tns_has_opened = False
|
self.tns_has_opened = False
|
||||||
self.last_risk_level = 0
|
self.last_risk_level = 0
|
||||||
self.tns_count = 0
|
self.tns_count = 0
|
||||||
|
self.allow_add_pos = False
|
||||||
|
self.add_pos_on_pips = 1
|
||||||
|
self.exit_on_last_rtn_pips = 0
|
||||||
|
|
||||||
class TrendPolicy(CtaPolicy):
|
class TrendPolicy(CtaPolicy):
|
||||||
"""
|
"""
|
||||||
|
@ -14,6 +14,7 @@ SETTING_DB_NAME = 'VnTrader_Setting_Db'
|
|||||||
TICK_DB_NAME = 'VnTrader_Tick_Db'
|
TICK_DB_NAME = 'VnTrader_Tick_Db'
|
||||||
DAILY_DB_NAME = 'VnTrader_Daily_Db'
|
DAILY_DB_NAME = 'VnTrader_Daily_Db'
|
||||||
MINUTE_DB_NAME = 'VnTrader_1Min_Db'
|
MINUTE_DB_NAME = 'VnTrader_1Min_Db'
|
||||||
|
RENKO_DB_NAME = 'Renko_Db'
|
||||||
|
|
||||||
|
|
||||||
# CTA引擎中涉及的数据类定义
|
# CTA引擎中涉及的数据类定义
|
||||||
|
@ -14,12 +14,15 @@ from datetime import datetime, timedelta
|
|||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
|
||||||
from vnpy.trader.vtEvent import *
|
from vnpy.trader.vtEvent import *
|
||||||
|
from vnpy.trader.vtObject import VtTickData
|
||||||
from vnpy.trader.vtGateway import VtSubscribeReq, VtLogData
|
from vnpy.trader.vtGateway import VtSubscribeReq, VtLogData
|
||||||
from vnpy.trader.vtFunction import todayDate,getJsonPath
|
from vnpy.trader.vtFunction import todayDate,getJsonPath,getShortSymbol
|
||||||
|
from vnpy.trader.app.ctaStrategy.ctaRenkoBar import CtaRenkoBar
|
||||||
from .drBase import *
|
from .drBase import *
|
||||||
|
from vnpy.trader.setup_logger import setup_logger
|
||||||
|
from vnpy.trader.data_source import DataSource
|
||||||
########################################################################
|
########################################################################
|
||||||
class DrEngine(object):
|
class DrEngine(object):
|
||||||
"""数据记录引擎"""
|
"""数据记录引擎"""
|
||||||
@ -44,15 +47,39 @@ class DrEngine(object):
|
|||||||
|
|
||||||
# K线对象字典
|
# K线对象字典
|
||||||
self.barDict = {}
|
self.barDict = {}
|
||||||
|
|
||||||
|
# Renko K线对象字典
|
||||||
|
self.renkoDict = {}
|
||||||
|
|
||||||
# 负责执行数据库插入的单独线程相关
|
# 负责执行数据库插入的单独线程相关
|
||||||
self.active = False # 工作状态
|
self.active = False # 工作状态
|
||||||
self.queue = Queue() # 队列
|
self.queue = Queue() # 队列
|
||||||
self.thread = Thread(target=self.run) # 线程
|
self.thread = Thread(target=self.run) # 线程
|
||||||
|
|
||||||
|
self.logger = None
|
||||||
|
self.createLogger()
|
||||||
# 载入设置,订阅行情
|
# 载入设置,订阅行情
|
||||||
self.loadSetting()
|
self.loadSetting()
|
||||||
|
|
||||||
|
def createLogger(self):
|
||||||
|
"""
|
||||||
|
创建日志记录
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
currentFolder = os.path.abspath(os.path.join(os.getcwd(), 'logs'))
|
||||||
|
if os.path.isdir(currentFolder):
|
||||||
|
# 如果工作目录下,存在data子目录,就使用data子目录
|
||||||
|
path = currentFolder
|
||||||
|
else:
|
||||||
|
# 否则,使用缺省保存目录 vnpy/trader/app/ctaStrategy/data
|
||||||
|
path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'logs'))
|
||||||
|
|
||||||
|
if self.logger is None:
|
||||||
|
filename = os.path.abspath(os.path.join(path, 'drEngine'))
|
||||||
|
|
||||||
|
print(u'create logger:{}'.format(filename))
|
||||||
|
self.logger = setup_logger(filename=filename, name='drEngine', debug=True)
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def loadSetting(self):
|
def loadSetting(self):
|
||||||
"""载入设置"""
|
"""载入设置"""
|
||||||
@ -111,7 +138,43 @@ class DrEngine(object):
|
|||||||
|
|
||||||
bar = DrBarData()
|
bar = DrBarData()
|
||||||
self.barDict[vtSymbol] = bar
|
self.barDict[vtSymbol] = bar
|
||||||
|
|
||||||
|
if 'renko' in drSetting:
|
||||||
|
l = drSetting.get('renko')
|
||||||
|
req_set = set()
|
||||||
|
for setting in l:
|
||||||
|
# 获取合约,合约短号,renko的高度(多少个跳)
|
||||||
|
vtSymbol = setting.get('vtSymbol',None)
|
||||||
|
if vtSymbol is None:
|
||||||
|
continue
|
||||||
|
short_symbol = getShortSymbol(vtSymbol).upper()
|
||||||
|
height = setting.get('height',5)
|
||||||
|
minDiff = setting.get('minDiff',1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 获取vtSymbol的多个renkobar列表,添加新的CtaRenkoBar
|
||||||
|
renko_list = self.renkoDict.get(vtSymbol,[])
|
||||||
|
|
||||||
|
bar_setting = {'name':'{}_{}'.format(vtSymbol,height),
|
||||||
|
'shortSymbol':short_symbol,
|
||||||
|
'vtSymbol':vtSymbol,
|
||||||
|
'minDiff':minDiff,
|
||||||
|
'height':minDiff*height}
|
||||||
|
renko_bar = CtaRenkoBar(strategy=None,onBarFunc=self.onRenkoBar,setting=bar_setting)
|
||||||
|
renko_list.append(renko_bar)
|
||||||
|
self.renkoDict.update({vtSymbol:renko_list})
|
||||||
|
|
||||||
|
req = VtSubscribeReq()
|
||||||
|
req.symbol = vtSymbol
|
||||||
|
req_set.add((req,setting.get('gateway',None)))
|
||||||
|
# 更新合约的历史数据
|
||||||
|
self.add_gap_ticks()
|
||||||
|
|
||||||
|
# 订阅行情
|
||||||
|
for req,gw in req_set:
|
||||||
|
self.mainEngine.subscribe(req,gw)
|
||||||
|
|
||||||
if 'active' in drSetting:
|
if 'active' in drSetting:
|
||||||
d = drSetting['active']
|
d = drSetting['active']
|
||||||
|
|
||||||
@ -125,7 +188,7 @@ class DrEngine(object):
|
|||||||
# 注册事件监听
|
# 注册事件监听
|
||||||
self.registerEvent()
|
self.registerEvent()
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
def procecssTickEvent(self, event):
|
def procecssTickEvent(self, event):
|
||||||
"""处理行情推送"""
|
"""处理行情推送"""
|
||||||
tick = event.dict_['data']
|
tick = event.dict_['data']
|
||||||
@ -137,7 +200,7 @@ class DrEngine(object):
|
|||||||
for key in d.keys():
|
for key in d.keys():
|
||||||
if key != 'datetime':
|
if key != 'datetime':
|
||||||
d[key] = tick.__getattribute__(key)
|
d[key] = tick.__getattribute__(key)
|
||||||
drTick.datetime = datetime.strptime(' '.join([tick.date, tick.time]), '%Y%m%d %H:%M:%S.%f')
|
drTick.datetime = datetime.strptime(' '.join([tick.date, tick.time]), '%Y-%m-%d %H:%M:%S.%f')
|
||||||
|
|
||||||
# 更新Tick数据
|
# 更新Tick数据
|
||||||
if vtSymbol in self.tickDict:
|
if vtSymbol in self.tickDict:
|
||||||
@ -190,6 +253,111 @@ class DrEngine(object):
|
|||||||
bar.close = drTick.lastPrice
|
bar.close = drTick.lastPrice
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
|
|
||||||
|
# 更新Renko数据
|
||||||
|
for renko_bar in self.renkoDict.get(vtSymbol,[]):
|
||||||
|
renko_bar.onTick(copy.copy(tick))
|
||||||
|
|
||||||
|
def add_gap_ticks(self):
|
||||||
|
"""
|
||||||
|
补充缺失的分时数据
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
ds = DataSource()
|
||||||
|
for vtSymbol in self.renkoDict.keys():
|
||||||
|
renkobar_list = self.renkoDict.get(vtSymbol,[])
|
||||||
|
cache_bars = None
|
||||||
|
cache_start_date = None
|
||||||
|
cache_end_date = None
|
||||||
|
for renko_bar in renkobar_list:
|
||||||
|
# 通过mongo获取最新一个Renko bar的数据日期close时间
|
||||||
|
last_renko_dt = self.get_last_datetime(renko_bar.name)
|
||||||
|
|
||||||
|
# 根据日期+vtSymbol,像datasource获取分钟数据,以close价格,转化为tick,推送到renko_bar中
|
||||||
|
if last_renko_dt is not None:
|
||||||
|
start_date =last_renko_dt.strftime('%Y-%m-%d')
|
||||||
|
else:
|
||||||
|
start_date = (datetime.now() - timedelta(days=90)).strftime('%Y-%m-%d')
|
||||||
|
|
||||||
|
end_date = (datetime.now() + timedelta(days=5)).strftime('%Y-%m-%d')
|
||||||
|
self.writeDrLog(u'从datasource获取{}数据,开始日期:{}'.format(vtSymbol,start_date))
|
||||||
|
if cache_bars is None or cache_start_date!=start_date or cache_end_date!=end_date:
|
||||||
|
fields = ['open', 'close', 'high', 'low', 'volume', 'open_interest', 'limit_up', 'limit_down',
|
||||||
|
'trading_date']
|
||||||
|
cache_bars = ds.get_price(order_book_id=vtSymbol,start_date=start_date,end_date=end_date,
|
||||||
|
frequency='1m', fields=fields)
|
||||||
|
cache_start_date = start_date
|
||||||
|
cache_end_date = end_date
|
||||||
|
if cache_bars is not None:
|
||||||
|
total = len(cache_bars)
|
||||||
|
self.writeDrLog(u'一共获取{}条{} 1分钟数据'.format(total,vtSymbol))
|
||||||
|
|
||||||
|
if cache_bars is not None:
|
||||||
|
self.writeDrLog(u'推送分时数据tick:{}到:{}'.format(vtSymbol,renko_bar.name))
|
||||||
|
for idx in cache_bars.index:
|
||||||
|
row = cache_bars.loc[idx]
|
||||||
|
tick = VtTickData()
|
||||||
|
tick.vtSymbol = vtSymbol
|
||||||
|
tick.symbol = vtSymbol
|
||||||
|
last_bar_dt = datetime.strptime(str(idx), '%Y-%m-%d %H:%M:00')
|
||||||
|
tick.datetime = last_bar_dt - timedelta(minutes=1)
|
||||||
|
tick.date = tick.datetime.strftime('%Y-%m-%d')
|
||||||
|
tick.time = tick.datetime.strftime('%H:%M:00')
|
||||||
|
|
||||||
|
if tick.datetime.hour >= 21:
|
||||||
|
if tick.datetime.isoweekday() == 5:
|
||||||
|
# 星期五=》星期一
|
||||||
|
tick.tradingDay = (tick.datetime + timedelta(days=3)).strftime('%Y-%m-%d')
|
||||||
|
else:
|
||||||
|
# 第二天
|
||||||
|
tick.tradingDay = (tick.datetime + timedelta(days=1)).strftime('%Y-%m-%d')
|
||||||
|
elif tick.datetime.hour < 8 and tick.datetime.isoweekday() == 6:
|
||||||
|
# 星期六=>星期一
|
||||||
|
tick.tradingDay = (tick.datetime + timedelta(days=2)).strftime('%Y-%m-%d')
|
||||||
|
else:
|
||||||
|
tick.tradingDay = tick.date
|
||||||
|
tick.upperLimit = float(row['limit_up'])
|
||||||
|
tick.lowerLimit = float(row['limit_down'])
|
||||||
|
tick.lastPrice = float(row['close'])
|
||||||
|
tick.askPrice1 = float(row['close'])
|
||||||
|
tick.bidPrice1 = float(row['close'])
|
||||||
|
tick.volume = int(row['volume'])
|
||||||
|
tick.askVolume1 = tick.volume
|
||||||
|
tick.bidVolume1 = tick.volume
|
||||||
|
|
||||||
|
if last_renko_dt is not None and tick.datetime <= last_renko_dt:
|
||||||
|
continue
|
||||||
|
renko_bar.onTick(tick)
|
||||||
|
|
||||||
|
def get_last_datetime(self,renko_name):
|
||||||
|
"""
|
||||||
|
通过mongo获取最新一个bar的数据日期
|
||||||
|
:param renko_name:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
qryData = self.mainEngine.dbQueryBySort(dbName=RENKO_DB_NAME,
|
||||||
|
collectionName=renko_name,
|
||||||
|
d={},
|
||||||
|
sortName='datetime',
|
||||||
|
sortType=-1,
|
||||||
|
limitNum=1)
|
||||||
|
|
||||||
|
last_renko_open_dt =None
|
||||||
|
last_renko_close_dt=None
|
||||||
|
for d in qryData:
|
||||||
|
last_renko_open_dt = d.get('datetime',None)
|
||||||
|
if last_renko_open_dt is not None:
|
||||||
|
last_renko_close_dt = last_renko_open_dt + timedelta(seconds=d.get('seconds',0))
|
||||||
|
|
||||||
|
break
|
||||||
|
return last_renko_close_dt
|
||||||
|
|
||||||
|
def onRenkoBar(self,bar,bar_name):
|
||||||
|
newBar = copy.copy(bar)
|
||||||
|
self.insertData(RENKO_DB_NAME, bar_name, newBar)
|
||||||
|
self.writeDrLog(u'new Renko Bar:{},dt:{},open:{},close:{},high:{},low:{}'
|
||||||
|
.format(bar_name,bar.datetime,bar.open,bar.close, bar.high, bar.low))
|
||||||
|
|
||||||
def registerEvent(self):
|
def registerEvent(self):
|
||||||
"""注册事件监听"""
|
"""注册事件监听"""
|
||||||
self.eventEngine.register(EVENT_TICK, self.procecssTickEvent)
|
self.eventEngine.register(EVENT_TICK, self.procecssTickEvent)
|
||||||
@ -206,7 +374,7 @@ class DrEngine(object):
|
|||||||
try:
|
try:
|
||||||
dbName, collectionName, d = self.queue.get(block=True, timeout=1)
|
dbName, collectionName, d = self.queue.get(block=True, timeout=1)
|
||||||
self.mainEngine.dbInsert(dbName, collectionName, d)
|
self.mainEngine.dbInsert(dbName, collectionName, d)
|
||||||
except Empty:
|
except Exception as ex:
|
||||||
pass
|
pass
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def start(self):
|
def start(self):
|
||||||
@ -228,5 +396,7 @@ class DrEngine(object):
|
|||||||
log.logContent = content
|
log.logContent = content
|
||||||
event = Event(type_=EVENT_DATARECORDER_LOG)
|
event = Event(type_=EVENT_DATARECORDER_LOG)
|
||||||
event.dict_['data'] = log
|
event.dict_['data'] = log
|
||||||
self.eventEngine.put(event)
|
self.eventEngine.put(event)
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(content)
|
||||||
|
@ -29,7 +29,7 @@ class TableCell(QtWidgets.QTableWidgetItem):
|
|||||||
if text == '0' or text == '0.0':
|
if text == '0' or text == '0.0':
|
||||||
self.setText('')
|
self.setText('')
|
||||||
else:
|
else:
|
||||||
self.setText(text)
|
self.setText(str(text))
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
@ -60,7 +60,7 @@ class DrEngineManager(QtWidgets.QWidget):
|
|||||||
self.tickTable.setColumnCount(2)
|
self.tickTable.setColumnCount(2)
|
||||||
self.tickTable.verticalHeader().setVisible(False)
|
self.tickTable.verticalHeader().setVisible(False)
|
||||||
self.tickTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
|
self.tickTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
|
||||||
self.tickTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
|
#self.tickTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
|
||||||
self.tickTable.setAlternatingRowColors(True)
|
self.tickTable.setAlternatingRowColors(True)
|
||||||
self.tickTable.setHorizontalHeaderLabels([u'合约代码', u'接口'])
|
self.tickTable.setHorizontalHeaderLabels([u'合约代码', u'接口'])
|
||||||
|
|
||||||
@ -69,16 +69,25 @@ class DrEngineManager(QtWidgets.QWidget):
|
|||||||
self.barTable.setColumnCount(2)
|
self.barTable.setColumnCount(2)
|
||||||
self.barTable.verticalHeader().setVisible(False)
|
self.barTable.verticalHeader().setVisible(False)
|
||||||
self.barTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
|
self.barTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
|
||||||
self.barTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
|
#self.barTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
|
||||||
self.barTable.setAlternatingRowColors(True)
|
self.barTable.setAlternatingRowColors(True)
|
||||||
self.barTable.setHorizontalHeaderLabels([u'合约代码', u'接口'])
|
self.barTable.setHorizontalHeaderLabels([u'合约代码', u'接口'])
|
||||||
|
|
||||||
|
renkoLabel = QtWidgets.QLabel(u'RenkoBar')
|
||||||
|
self.renkoTable = QtWidgets.QTableWidget()
|
||||||
|
self.renkoTable.setColumnCount(2)
|
||||||
|
self.renkoTable.verticalHeader().setVisible(False)
|
||||||
|
self.renkoTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
|
||||||
|
# self.renkoTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
|
||||||
|
self.renkoTable.setAlternatingRowColors(True)
|
||||||
|
self.renkoTable.setHorizontalHeaderLabels([u'合约代码', u'高度'])
|
||||||
|
|
||||||
activeLabel = QtWidgets.QLabel(u'主力合约')
|
activeLabel = QtWidgets.QLabel(u'主力合约')
|
||||||
self.activeTable = QtWidgets.QTableWidget()
|
self.activeTable = QtWidgets.QTableWidget()
|
||||||
self.activeTable.setColumnCount(2)
|
self.activeTable.setColumnCount(2)
|
||||||
self.activeTable.verticalHeader().setVisible(False)
|
self.activeTable.verticalHeader().setVisible(False)
|
||||||
self.activeTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
|
self.activeTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers)
|
||||||
self.activeTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
|
#self.activeTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch)
|
||||||
self.activeTable.setAlternatingRowColors(True)
|
self.activeTable.setAlternatingRowColors(True)
|
||||||
self.activeTable.setHorizontalHeaderLabels([u'主力代码', u'合约代码'])
|
self.activeTable.setHorizontalHeaderLabels([u'主力代码', u'合约代码'])
|
||||||
|
|
||||||
@ -92,10 +101,13 @@ class DrEngineManager(QtWidgets.QWidget):
|
|||||||
|
|
||||||
grid.addWidget(tickLabel, 0, 0)
|
grid.addWidget(tickLabel, 0, 0)
|
||||||
grid.addWidget(barLabel, 0, 1)
|
grid.addWidget(barLabel, 0, 1)
|
||||||
grid.addWidget(activeLabel, 0, 2)
|
grid.addWidget(renkoLabel, 0, 2)
|
||||||
|
grid.addWidget(activeLabel, 0, 3)
|
||||||
|
|
||||||
grid.addWidget(self.tickTable, 1, 0)
|
grid.addWidget(self.tickTable, 1, 0)
|
||||||
grid.addWidget(self.barTable, 1, 1)
|
grid.addWidget(self.barTable, 1, 1)
|
||||||
grid.addWidget(self.activeTable, 1, 2)
|
grid.addWidget(self.renkoTable, 1, 2)
|
||||||
|
grid.addWidget(self.activeTable, 1, 3)
|
||||||
|
|
||||||
vbox = QtWidgets.QVBoxLayout()
|
vbox = QtWidgets.QVBoxLayout()
|
||||||
vbox.addLayout(grid)
|
vbox.addLayout(grid)
|
||||||
@ -118,7 +130,7 @@ class DrEngineManager(QtWidgets.QWidget):
|
|||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def updateSetting(self):
|
def updateSetting(self):
|
||||||
"""显示引擎行情记录配置"""
|
"""显示引擎行情记录配置"""
|
||||||
with open(self.drEngine.settingFileName) as f:
|
with open(self.drEngine.settingFilePath) as f:
|
||||||
drSetting = json.load(f)
|
drSetting = json.load(f)
|
||||||
|
|
||||||
if 'tick' in drSetting:
|
if 'tick' in drSetting:
|
||||||
@ -135,8 +147,15 @@ class DrEngineManager(QtWidgets.QWidget):
|
|||||||
for setting in l:
|
for setting in l:
|
||||||
self.barTable.insertRow(0)
|
self.barTable.insertRow(0)
|
||||||
self.barTable.setItem(0, 0, TableCell(setting[0]))
|
self.barTable.setItem(0, 0, TableCell(setting[0]))
|
||||||
self.barTable.setItem(0, 1, TableCell(setting[1]))
|
self.barTable.setItem(0, 1, TableCell(setting[1]))
|
||||||
|
|
||||||
|
if 'renko' in drSetting:
|
||||||
|
l = drSetting['renko']
|
||||||
|
|
||||||
|
for setting in l:
|
||||||
|
self.renkoTable.insertRow(0)
|
||||||
|
self.renkoTable.setItem(0, 0, TableCell(setting.get('vtSymbol','')))
|
||||||
|
self.renkoTable.setItem(0, 1, TableCell(setting.get('height',0)))
|
||||||
if 'active' in drSetting:
|
if 'active' in drSetting:
|
||||||
d = drSetting['active']
|
d = drSetting['active']
|
||||||
|
|
||||||
|
@ -244,7 +244,8 @@ class MultiprocessHandler(logging.FileHandler):
|
|||||||
if not self.suffix:
|
if not self.suffix:
|
||||||
raise ValueError(u"指定的日期间隔单位无效: %s" % self.when)
|
raise ValueError(u"指定的日期间隔单位无效: %s" % self.when)
|
||||||
#拼接文件路径 格式化字符串
|
#拼接文件路径 格式化字符串
|
||||||
self.filefmt = "%s_%s.log" % (self.prefix,self.suffix)
|
#self.filefmt = "%s_%s.log" % (self.prefix,self.suffix)
|
||||||
|
self.filefmt = u'{}_{}.log'.format(self.prefix, self.suffix)
|
||||||
#使用当前时间,格式化文件格式化字符串
|
#使用当前时间,格式化文件格式化字符串
|
||||||
self.filePath = datetime.now().strftime(self.filefmt)
|
self.filePath = datetime.now().strftime(self.filefmt)
|
||||||
#获得文件夹路径
|
#获得文件夹路径
|
||||||
|
@ -38,6 +38,10 @@ EVENT_NOTIFICATION = 'eNotification' # 全局通知
|
|||||||
EVENT_SIGNAL = 'eSignal' # 信号通知
|
EVENT_SIGNAL = 'eSignal' # 信号通知
|
||||||
EVENT_STATUS = 'eStatus' # 服务状态
|
EVENT_STATUS = 'eStatus' # 服务状态
|
||||||
|
|
||||||
|
# 股票使用
|
||||||
|
EVENT_BAR = 'eBar' # 1分钟Bar 行情
|
||||||
|
EVENT_BARDICT = 'eBarDict_' # BarDict事件+策略实例名称
|
||||||
|
|
||||||
# CTA模块相关
|
# CTA模块相关
|
||||||
EVENT_CTA_LOG = 'eCtaLog' # CTA相关的日志事件
|
EVENT_CTA_LOG = 'eCtaLog' # CTA相关的日志事件
|
||||||
EVENT_CTA_STRATEGY = 'eCtaStrategy.' # CTA策略状态变化事件
|
EVENT_CTA_STRATEGY = 'eCtaStrategy.' # CTA策略状态变化事件
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# encoding: UTF-8
|
# encoding: UTF-8
|
||||||
|
|
||||||
import time,os,sys
|
import time,os,sys,copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from vnpy.trader.vtEvent import *
|
from vnpy.trader.vtEvent import *
|
||||||
@ -23,6 +23,9 @@ class VtGateway(object):
|
|||||||
self.accountID = 'AccountID'
|
self.accountID = 'AccountID'
|
||||||
self.createLogger()
|
self.createLogger()
|
||||||
|
|
||||||
|
# 所有订阅onBar的都会添加
|
||||||
|
self.klines = {}
|
||||||
|
|
||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
def onTick(self, tick):
|
def onTick(self, tick):
|
||||||
"""市场行情推送"""
|
"""市场行情推送"""
|
||||||
@ -35,7 +38,14 @@ class VtGateway(object):
|
|||||||
event2 = Event(type_=EVENT_TICK+tick.vtSymbol)
|
event2 = Event(type_=EVENT_TICK+tick.vtSymbol)
|
||||||
event2.dict_['data'] = tick
|
event2.dict_['data'] = tick
|
||||||
self.eventEngine.put(event2)
|
self.eventEngine.put(event2)
|
||||||
|
|
||||||
|
def onBar(self,bar,type=EVENT_BAR):
|
||||||
|
"""市场行情推送"""
|
||||||
|
# bar, 或者 barDict
|
||||||
|
event = Event(type_=type)
|
||||||
|
event.dict_['data'] = bar
|
||||||
|
self.eventEngine.put(event)
|
||||||
|
|
||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
def onTrade(self, trade):
|
def onTrade(self, trade):
|
||||||
"""成交信息推送"""
|
"""成交信息推送"""
|
||||||
|
283
vnpy/trader/vtUtility.py
Normal file
283
vnpy/trader/vtUtility.py
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
# encoding: UTF-8
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import talib
|
||||||
|
|
||||||
|
from vnpy.trader.vtObject import VtBarData
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
class BarGenerator(object):
|
||||||
|
"""
|
||||||
|
K线合成器,支持:
|
||||||
|
1. 基于Tick合成1分钟K线
|
||||||
|
2. 基于1分钟K线合成X分钟K线(X可以是2、3、5、10、15、30 )
|
||||||
|
"""
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def __init__(self, onBar, xmin=0, onXminBar=None):
|
||||||
|
"""Constructor"""
|
||||||
|
self.bar = None # 1分钟K线对象
|
||||||
|
self.onBar = onBar # 1分钟K线回调函数
|
||||||
|
|
||||||
|
self.xminBar = None # X分钟K线对象
|
||||||
|
self.xmin = xmin # X的值
|
||||||
|
self.onXminBar = onXminBar # X分钟K线的回调函数
|
||||||
|
|
||||||
|
self.lastTick = None # 上一TICK缓存对象
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def updateTick(self, tick):
|
||||||
|
"""TICK更新"""
|
||||||
|
newMinute = False # 默认不是新的一分钟
|
||||||
|
|
||||||
|
# 尚未创建对象
|
||||||
|
if not self.bar:
|
||||||
|
self.bar = VtBarData()
|
||||||
|
newMinute = True
|
||||||
|
# 新的一分钟
|
||||||
|
elif self.bar.datetime.minute != tick.datetime.minute:
|
||||||
|
# 生成上一分钟K线的时间戳
|
||||||
|
self.bar.datetime = self.bar.datetime.replace(second=0, microsecond=0) # 将秒和微秒设为0
|
||||||
|
self.bar.date = self.bar.datetime.strftime('%Y%m%d')
|
||||||
|
self.bar.time = self.bar.datetime.strftime('%H:%M:%S.%f')
|
||||||
|
|
||||||
|
# 推送已经结束的上一分钟K线
|
||||||
|
self.onBar(self.bar)
|
||||||
|
|
||||||
|
# 创建新的K线对象
|
||||||
|
self.bar = VtBarData()
|
||||||
|
newMinute = True
|
||||||
|
|
||||||
|
# 初始化新一分钟的K线数据
|
||||||
|
if newMinute:
|
||||||
|
self.bar.vtSymbol = tick.vtSymbol
|
||||||
|
self.bar.symbol = tick.symbol
|
||||||
|
self.bar.exchange = tick.exchange
|
||||||
|
|
||||||
|
self.bar.open = tick.lastPrice
|
||||||
|
self.bar.high = tick.lastPrice
|
||||||
|
self.bar.low = tick.lastPrice
|
||||||
|
# 累加更新老一分钟的K线数据
|
||||||
|
else:
|
||||||
|
self.bar.high = max(self.bar.high, tick.lastPrice)
|
||||||
|
self.bar.low = min(self.bar.low, tick.lastPrice)
|
||||||
|
|
||||||
|
# 通用更新部分
|
||||||
|
self.bar.close = tick.lastPrice
|
||||||
|
self.bar.datetime = tick.datetime
|
||||||
|
self.bar.openInterest = tick.openInterest
|
||||||
|
|
||||||
|
if self.lastTick:
|
||||||
|
volumeChange = tick.volume - self.lastTick.volume # 当前K线内的成交量
|
||||||
|
self.bar.volume += max(volumeChange, 0) # 避免夜盘开盘lastTick.volume为昨日收盘数据,导致成交量变化为负的情况
|
||||||
|
|
||||||
|
# 缓存Tick
|
||||||
|
self.lastTick = tick
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def updateBar(self, bar):
|
||||||
|
"""1分钟K线更新"""
|
||||||
|
# 尚未创建对象
|
||||||
|
if not self.xminBar:
|
||||||
|
self.xminBar = VtBarData()
|
||||||
|
|
||||||
|
self.xminBar.vtSymbol = bar.vtSymbol
|
||||||
|
self.xminBar.symbol = bar.symbol
|
||||||
|
self.xminBar.exchange = bar.exchange
|
||||||
|
|
||||||
|
self.xminBar.open = bar.open
|
||||||
|
self.xminBar.high = bar.high
|
||||||
|
self.xminBar.low = bar.low
|
||||||
|
|
||||||
|
self.xminBar.datetime = bar.datetime # 以第一根分钟K线的开始时间戳作为X分钟线的时间戳
|
||||||
|
# 累加老K线
|
||||||
|
else:
|
||||||
|
self.xminBar.high = max(self.xminBar.high, bar.high)
|
||||||
|
self.xminBar.low = min(self.xminBar.low, bar.low)
|
||||||
|
|
||||||
|
# 通用部分
|
||||||
|
self.xminBar.close = bar.close
|
||||||
|
self.xminBar.openInterest = bar.openInterest
|
||||||
|
self.xminBar.volume += int(bar.volume)
|
||||||
|
|
||||||
|
# X分钟已经走完
|
||||||
|
if not (bar.datetime.minute + 1) % self.xmin: # 可以用X整除
|
||||||
|
# 生成上一X分钟K线的时间戳
|
||||||
|
self.xminBar.datetime = self.xminBar.datetime.replace(second=0, microsecond=0) # 将秒和微秒设为0
|
||||||
|
self.xminBar.date = self.xminBar.datetime.strftime('%Y%m%d')
|
||||||
|
self.xminBar.time = self.xminBar.datetime.strftime('%H:%M:%S.%f')
|
||||||
|
|
||||||
|
# 推送
|
||||||
|
self.onXminBar(self.xminBar)
|
||||||
|
|
||||||
|
# 清空老K线缓存对象
|
||||||
|
self.xminBar = None
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def generate(self):
|
||||||
|
"""手动强制立即完成K线合成"""
|
||||||
|
self.onBar(self.bar)
|
||||||
|
self.bar = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
class ArrayManager(object):
|
||||||
|
"""
|
||||||
|
K线序列管理工具,负责:
|
||||||
|
1. K线时间序列的维护
|
||||||
|
2. 常用技术指标的计算
|
||||||
|
"""
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def __init__(self, size=100):
|
||||||
|
"""Constructor"""
|
||||||
|
self.count = 0 # 缓存计数
|
||||||
|
self.size = size # 缓存大小
|
||||||
|
self.inited = False # True if count>=size
|
||||||
|
|
||||||
|
self.openArray = np.zeros(size) # OHLC
|
||||||
|
self.highArray = np.zeros(size)
|
||||||
|
self.lowArray = np.zeros(size)
|
||||||
|
self.closeArray = np.zeros(size)
|
||||||
|
self.volumeArray = np.zeros(size)
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def updateBar(self, bar):
|
||||||
|
"""更新K线"""
|
||||||
|
self.count += 1
|
||||||
|
if not self.inited and self.count >= self.size:
|
||||||
|
self.inited = True
|
||||||
|
|
||||||
|
self.openArray[:-1] = self.openArray[1:]
|
||||||
|
self.highArray[:-1] = self.highArray[1:]
|
||||||
|
self.lowArray[:-1] = self.lowArray[1:]
|
||||||
|
self.closeArray[:-1] = self.closeArray[1:]
|
||||||
|
self.volumeArray[:-1] = self.volumeArray[1:]
|
||||||
|
|
||||||
|
self.openArray[-1] = bar.open
|
||||||
|
self.highArray[-1] = bar.high
|
||||||
|
self.lowArray[-1] = bar.low
|
||||||
|
self.closeArray[-1] = bar.close
|
||||||
|
self.volumeArray[-1] = bar.volume
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def open(self):
|
||||||
|
"""获取开盘价序列"""
|
||||||
|
return self.openArray
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def high(self):
|
||||||
|
"""获取最高价序列"""
|
||||||
|
return self.highArray
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def low(self):
|
||||||
|
"""获取最低价序列"""
|
||||||
|
return self.lowArray
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def close(self):
|
||||||
|
"""获取收盘价序列"""
|
||||||
|
return self.closeArray
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def volume(self):
|
||||||
|
"""获取成交量序列"""
|
||||||
|
return self.volumeArray
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def sma(self, n, array=False):
|
||||||
|
"""简单均线"""
|
||||||
|
result = talib.SMA(self.close, n)
|
||||||
|
if array:
|
||||||
|
return result
|
||||||
|
return result[-1]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def std(self, n, array=False):
|
||||||
|
"""标准差"""
|
||||||
|
result = talib.STDDEV(self.close, n)
|
||||||
|
if array:
|
||||||
|
return result
|
||||||
|
return result[-1]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def cci(self, n, array=False):
|
||||||
|
"""CCI指标"""
|
||||||
|
result = talib.CCI(self.high, self.low, self.close, n)
|
||||||
|
if array:
|
||||||
|
return result
|
||||||
|
return result[-1]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def atr(self, n, array=False):
|
||||||
|
"""ATR指标"""
|
||||||
|
result = talib.ATR(self.high, self.low, self.close, n)
|
||||||
|
if array:
|
||||||
|
return result
|
||||||
|
return result[-1]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def rsi(self, n, array=False):
|
||||||
|
"""RSI指标"""
|
||||||
|
result = talib.RSI(self.close, n)
|
||||||
|
if array:
|
||||||
|
return result
|
||||||
|
return result[-1]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def macd(self, fastPeriod, slowPeriod, signalPeriod, array=False):
|
||||||
|
"""MACD指标"""
|
||||||
|
macd, signal, hist = talib.MACD(self.close, fastPeriod,
|
||||||
|
slowPeriod, signalPeriod)
|
||||||
|
if array:
|
||||||
|
return macd, signal, hist
|
||||||
|
return macd[-1], signal[-1], hist[-1]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def adx(self, n, array=False):
|
||||||
|
"""ADX指标"""
|
||||||
|
result = talib.ADX(self.high, self.low, self.close, n)
|
||||||
|
if array:
|
||||||
|
return result
|
||||||
|
return result[-1]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def boll(self, n, dev, array=False):
|
||||||
|
"""布林通道"""
|
||||||
|
mid = self.sma(n, array)
|
||||||
|
std = self.std(n, array)
|
||||||
|
|
||||||
|
up = mid + std * dev
|
||||||
|
down = mid - std * dev
|
||||||
|
|
||||||
|
return up, down
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def keltner(self, n, dev, array=False):
|
||||||
|
"""肯特纳通道"""
|
||||||
|
mid = self.sma(n, array)
|
||||||
|
atr = self.atr(n, array)
|
||||||
|
|
||||||
|
up = mid + atr * dev
|
||||||
|
down = mid - atr * dev
|
||||||
|
|
||||||
|
return up, down
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def donchian(self, n, array=False):
|
||||||
|
"""唐奇安通道"""
|
||||||
|
up = talib.MAX(self.high, n)
|
||||||
|
down = talib.MIN(self.low, n)
|
||||||
|
|
||||||
|
if array:
|
||||||
|
return up, down
|
||||||
|
return up[-1], down[-1]
|
Loading…
Reference in New Issue
Block a user