diff --git a/README.md b/README.md index ec521d4b..7e5c601c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# “当你想放弃时,想想你为什么开始。” +# “当你想放弃时,想想你为什么开始。埃隆·马斯克” ###Fork版本主要改进如下 - 1、增加CtaLineBar,CtaPosition,CtaPolicy,UtlSinaClient等基础组件 diff --git a/examples/CtaBacktesting/util_branch_testing.py b/examples/CtaBacktesting/util_branch_testing.py new file mode 100644 index 00000000..fe3dcb14 --- /dev/null +++ b/examples/CtaBacktesting/util_branch_testing.py @@ -0,0 +1,564 @@ +# encoding: UTF-8 + +""" +批量测试相关方法 +# 华富资产 李来佳 +""" + +import sys, os, platform, gc, copy,multiprocessing +from datetime import datetime +from time import sleep +vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) + +sys.path.append(vnpy_root) + +import numpy as np +import pandas as pd +import talib as ta # 科学计算库 +import statsmodels.api as sm # 统计库 +import matplotlib +import matplotlib.pyplot as plt +import math # 数学计算相关 +matplotlib.rcParams['figure.figsize'] = (20.0, 10.0) +import traceback +from vnpy.trader.setup_logger import * +from vnpy.trader.app.ctaStrategy.ctaBacktesting import BacktestingEngine, OptimizationSetting, MINUTE_DB_NAME + + +# 合并回测结果 +def combine_results(results_list): + result_df = None + + # 判断结果集是否有数据 + if len(results_list) < 1: + print('no records') + return None + + effected_results = 0 + + for dict in results_list: + # 测试项目 + test_item = dict['test_item'] + # 测试csv文件 + file_name = dict['result_file'] + + if not os.path.isfile(file_name): + continue + + effected_results += 1 + + # 读取测试文件 + df = pd.read_csv(file_name) + # 修正索引为时间日期索引 + df = df.set_index(pd.DatetimeIndex(df['date'])) + + if result_df is None: + # 首个测试结果,将净值字段设置为周期 + result_df = df['rate'].to_frame(name=test_item) + # 汇总净值 + result_df['rate'] = result_df[test_item] + else: + # 增加新的测试结果数据 + result_df[test_item] = df['rate'] + + # 汇总净值 + result_df['rate'] = result_df['rate'] + result_df[test_item] + + # 释放内存 + l = [df] + del df + del l + + if effected_results > 0: + # 净值平均 + result_df['avg_rate'] = result_df['rate'] / effected_results + # 组合净值累加(仍然按照1个策略的总资金,累加各策略的收益) + result_df['group_rate'] = result_df['rate'] - effected_results + 1 + + # 删除累加的rate + result_df.drop('rate', axis=1, inplace=True) + + return result_df + + +# 计算最大回撤,单日最大回撤 +def calculate_result(result_df, rate_column): + max_rate = 0 + max_loss = 0 + max_rate_date = None + max_loss_date = None + max_loss_info = '-' + + for idx in result_df.index: + # 当前日净值 + cur_rate = result_df[rate_column].loc[idx] + + if cur_rate > max_rate: + max_rate = cur_rate + max_rate_date = idx.strftime('%Y-%m-%d') + + cur_loss = max_rate - cur_rate + + if cur_loss > max_loss: + max_loss = cur_loss + max_loss_date = idx.strftime('%Y-%m-%d') + max_loss_percent = max_loss / max_rate + max_loss_info = u'{} from {} to {},rate {}=>{},max loss rate {}'.format(rate_column, max_rate_date, + max_loss_date, max_rate, cur_rate, + max_loss_percent) + return max_loss_info + +def single_strategy_test(test_settings): + """ + 根据设置参数,执行回测品种, + :param strategyClass: 策略类 + :param test_settings: 策略参数设置 + test_settings['bar_file']: 回测的bar csv文件路径 + test_settings['bar_interval']: 回测的Bar csv周期 + test_settings['report_file']: 净值输出报告的保存路径 + :return: + """ + + from vnpy.trader.vtEvent import EventEngine2 + eventEngine = EventEngine2() + eventEngine.start() + + # 创建回测引擎 + engine = BacktestingEngine(eventEngine=eventEngine) + # 设置回测的策略类 + engine.setStrategyName(test_settings['name']) + # 创建日志 + if 'debug' in test_settings: + engine.createLogger(debug=test_settings['debug']) + else: + engine.createLogger() + + # 设置引擎的回测模式 + if test_settings['mode'] == 'tick': + engine.setBacktestingMode(engine.TICK_MODE) + else: + engine.setBacktestingMode(engine.BAR_MODE) + + strategy_settings = copy.copy(test_settings) + # 设置回测的策略类 + if 'filenamePrefix' in strategy_settings: + engine.setStrategyName(strategy_settings["filenamePrefix"]) + else: + engine.setStrategyName(test_settings['name']) + + if 'is_7x24' in test_settings: + engine.is_7x24 = test_settings['is_7x24'] + + # del strategy_settings['size'] + #del strategy_settings['margin_rate'] + del strategy_settings['initCapital'] + + if 'report_file' in strategy_settings: + del strategy_settings['report_file'] + + # 设置回测用的数据起始日期 + if 'start_date' in strategy_settings: + engine.setStartDate(test_settings['start_date'], initDays = strategy_settings.get('initDays', 10)) + del strategy_settings['start_date'] + else: + engine.setStartDate('20110101',initDays = strategy_settings.get('initDays',10)) + + # 设置回测用的数据结束日期 + if 'end_date' in test_settings: + engine.setEndDate(test_settings['end_date']) + del strategy_settings['end_date'] + else: + engine.setEndDate('20171201') + + # engine.connectMysql() + engine.setDatabase(dbName=MINUTE_DB_NAME, symbol=test_settings['vtSymbol']) + + # 设置产品相关参数 + if 'slippage' in test_settings and test_settings['slippage'] > 0: + engine.setSlippage(test_settings['slippage']) + else: + engine.setSlippage(0) + engine.setRate(test_settings['rate'] if 'rate' in test_settings else float(0.0001)) # 万1 + engine.setSize(test_settings['size']) # 合约大小 + engine.setMinDiff(test_settings['minDiff']) # 合约价格最小跳动 + engine.setMarginRate(test_settings['margin_rate']) # 合约保证金率 + if 'fixCommission' in test_settings: + engine.fixCommission = float(test_settings['fixCommission']) # 固定交易费用(每次开平仓收费) + + # 删除本地json文件 + data_path = os.path.abspath(os.path.join(os.getcwd(),'data')) + if not os.path.exists(data_path): + os.mkdir(data_path) + + logs_path = os.path.abspath(os.path.join(os.getcwd(), 'logs')) + if not os.path.exists(logs_path): + os.mkdir(logs_path) + + up_grid_json_file = os.path.abspath(os.path.join(data_path,'{0}_upGrids.json'.format(test_settings['name']))) + dn_grid_json_file = os.path.abspath(os.path.join(data_path,'{0}_dnGrids.json'.format(test_settings['name']))) + grid_json_file = os.path.abspath(os.path.join(data_path,'{0}_Grids.json'.format(test_settings['name']))) + policy_json_file = os.path.abspath(os.path.join(data_path, '{0}_Policy.json'.format(test_settings['name']))) + + if os.path.isfile(up_grid_json_file): + print(u'{0} exist,remove it'.format(up_grid_json_file)) + try: + os.remove(up_grid_json_file) + except Exception as ex: + print(u'{0}:{1}'.format(Exception, ex)) + return False + + if os.path.isfile(dn_grid_json_file): + print(u'{0}exist,remove it'.format(dn_grid_json_file)) + try: + os.remove(dn_grid_json_file) + except Exception as ex: + print(u'{0}:{1}'.format(Exception, ex)) + return False + + if os.path.isfile(grid_json_file): + print(u'{0}exist,remove it'.format(grid_json_file)) + try: + os.remove(grid_json_file) + except Exception as ex: + print(u'{0}:{1}'.format(Exception, ex)) + return False + + if os.path.isfile(policy_json_file): + print(u'{0}exist,remove it'.format(policy_json_file)) + try: + os.remove(policy_json_file) + except Exception as ex: + print(u'{0}:{1}'.format(Exception, ex)) + return False + + # 在引擎中创建策略对象 + print(u'run {} using:{}'.format(test_settings['strategy'],strategy_settings)) + engine.initStrategy(test_settings['strategy'], setting=strategy_settings) + + # 设置每日净值的报告文档存储路径 + daily_report_file = 'DailyList.csv' if 'report_file' not in test_settings else test_settings['report_file'] + engine.setDailyReportName(daily_report_file) + + # 使用简单复利模式计算 + engine.usageCompounding = False # True时,只针对FINAL_MODE有效 + + # 启用实时计算净值模式REALTIME_MODE / FINAL_MODE 回测结束时统一计算模式 + engine.calculateMode = engine.REALTIME_MODE + engine.capital = test_settings['initCapital'] # 设置期初资金 + engine.initCapital = test_settings['initCapital'] # 设置期初资金 + engine.avaliable = test_settings['initCapital'] # 设置期初资金 + engine.netCapital = test_settings['initCapital'] + engine.maxCapital = test_settings['initCapital'] # 设置期初资金 + engine.maxNetCapital = test_settings['initCapital'] # 设置期初资金 + engine.percentLimit = test_settings['percentLimit'] # 设置资金使用上限比例(%) + engine.barTimeInterval = 60 * test_settings['bar_interval'] # 回测文件中,bar的周期秒数,用于csv文件自动减时间 + + try: + # 前置动作(无参数) + pre_functions = test_settings.get('pre_functions',[]) + for fun_name in pre_functions: + try: + if not isinstance(fun_name,str): + continue + if hasattr(engine.strategy,fun_name): + fun = getattr(engine.strategy,fun_name) + if fun is not None: + fun() + except Exception as ex: + print(u'调用前置动作异常:{},{}'.format(str(ex),traceback.format_exc()),file=sys.stderr) + + # 开始跑回测 + if 'bar_file' in test_settings: + engine.runBackTestingWithBarFile(test_settings['bar_file']) + else: + engine.runBackTestingWithDataSource() + + print('{}finished loop bars'.format(test_settings['name'])) + # 显示回测结果 + engine.showBacktestingResult() + + # 保存策略得内部数据 + engine.saveStrategyData() + + print('{} finished'.format(test_settings['name'])) + + return True + except Exception as ex: + print(u'single_strategy_test exception:{}'.format(str(ex))) + traceback.print_exc() + return False + +def multi_period_test(gid, group_setting): + """ + 多周期回测品种组合 + 1、对group_settings进行分解,分解出各个运行周期的参数设置 + 2、逐一周期运行测试 + 3、添加回测结果 + :param gid: 测试组名 + :param group_setting:dict,包含参数,多周期清单 + 如果多周期,则对每一周期执行回测,并汇总结果。 + :return 回测的每日净值统计文件 + """ + + # 回测的分钟周期 + minutes_interval_list = group_setting['minute_list'] # 回测分钟队列(3,5,10等) + + strategyClass = group_setting['strategy'] + + # 测试批次时间 + test_dt = datetime.now().strftime('%Y%m%d_%H%M') + + # 回测结果队列,对应测试分钟队列 + daily_results = [] + return_results = [] + + # 启动多进程 + pool = multiprocessing.Pool(multiprocessing.cpu_count()) + l = [] + + # 逐一分钟级别进行回测 + for m_i in minutes_interval_list: + settings = copy.copy(group_setting) + del settings['minute_list'] + + settings['name'] = '{}_{}_{}_M{}'.format(gid, strategyClass.className, group_setting['symbol'], m_i) + + # 资金占用比例,根据组合内周期数量,进行平均分配 + settings['percentLimit'] = group_setting['percentLimit']/ len(minutes_interval_list) + settings['vtSymbol'] = group_setting['symbol'] + settings['MinInterval'] = m_i + + settings['mode'] = 'bar' + settings['backtesting'] = True + settings['bar_interval'] = group_setting['bar_interval'] if 'bar_interval' in group_setting else 1 + settings['strategy'] = strategyClass + # 回测报告文件保存路径: 组合,测试实例名称,测试时间 + daily_report_file = os.path.abspath(os.path.join(group_setting['report_folder'], u'{}_daily_{}.csv' + .format(settings['name'], test_dt))) + + settings['report_file'] = daily_report_file + + #if rt: + # 回测报告集登记 + daily_results.append({'test_item': 'M{}'.format(m_i), 'result_file': daily_report_file}) + + l.append(pool.apply_async(single_strategy_test, (settings,))) + #rt = single_strategy_test(test_settings=settings) + + # 执行内存回收 + gc.collect() + sleep(10) + + result_list = [res.get() for res in l] + + for idx, rt in enumerate(result_list): + if rt: + return_results.append(daily_results[idx]) + + pool.close() + pool.join() + return return_results + +def run_multiperiod_test(gid, group_settings): + """ + 运行多周期的组合测试 + :param gid: 组合名称 + :param group_settings: + :return: + """ + m = '_'.join(str(e) for e in group_settings['minute_list']) + if 'report_folder' in group_settings: + final_file = os.path.abspath(os.path.join(group_settings['report_folder'], '{}_{}_Report_{}.csv'.format(gid, group_settings['symbol'], m))) + else: + # 报告所在目录 + report_folder = os.path.abspath(os.path.join(os.getcwd(), 'logs', gid)) + if not os.path.exists(report_folder): + os.mkdir(report_folder) + # 汇总报告文件 + final_file = os.path.abspath( + os.path.join(report_folder, '{}_{}_Report_{}.csv'.format(gid, group_settings['symbol'], m))) + group_settings['report_folder'] = report_folder + + if not os.path.exists(group_settings['report_folder']): + os.makedirs(group_settings['report_folder']) + + # 运行回测方法,统计结果 + daily_results = multi_period_test(gid, group_settings) + + # 统计结果 + backtest_df = combine_results(daily_results) + + # 保存汇总记录到文件 + backtest_df.to_csv(final_file) + + # 显示资金曲线汇总 + fig, ax1 = plt.subplots() + + period_columns = [item['test_item'] for item in daily_results] + + fig.patch.set_facecolor('white') + ax1.plot(backtest_df[period_columns]) + ax1.legend() + + # 释放内存 + l = [backtest_df] + del backtest_df + del l + + # 释放内存 + gc.collect() + print( 'finished run_multiparameter_test') + +def multi_parameter_test(gid, settings_list): + """ + 不同参数的回测组合 + :param gid: 组合ID + :param settings_list: 参数列表 + :return: + """ + # 回测结果队列,对应测试参数队列 + daily_results = [] + return_results = [] + + # 每个测试的参数名称 + para_list = [i['paraName'] for i in settings_list] + + if len(para_list) == 0: + return daily_results + + # 启动多进程 + pool = multiprocessing.Pool(multiprocessing.cpu_count()) + #pool = multiprocessing.Pool(2) + l = [] + + print('multi_parameter_test,total:{}'.format(len(settings_list))) + + for idx, strategy_settings in enumerate(settings_list): + settings = copy.copy(strategy_settings) + # 测试时间 + test_dt = datetime.now().strftime('%Y%m%d_%H%M') + + del settings['log_file'] + if 'minute_list' in settings: + del settings['minute_list'] + settings['vtSymbol'] = settings['symbol'] + settings['mode'] = 'bar' + settings['backtesting'] = True + if 'bar_interval' not in settings: + settings['bar_interval'] = 1 + + # 回测报告文件保存路径: 组合,测试实例名称,测试时间 + daily_report_file = os.path.abspath(os.path.join(settings['report_folder'], u'{}_daily_{}.csv'.format(settings['name'], test_dt))) + + settings['report_file'] = daily_report_file + + l.append(pool.apply_async(single_strategy_test, (settings,))) + + # 回测报告集登记 + daily_results.append({'test_item': settings['paraName'], 'result_file': daily_report_file}) + + # 执行内存回收 + gc.collect() + sleep(10) + + result_list = [res.get() for res in l] + + # 返回结果是正确的,才添加到返回列表中 + for idx, rt in enumerate(result_list): + if rt: + return_results.append(daily_results[idx]) + + pool.close() + pool.join() + + return return_results + +def run_multiparameter_test(gid, settings_list): + """ + 多策略组测试 + :param gid: + :param settings_list: + :return: """ + if len(settings_list) == 0: + raise ReferenceError('Zero settings') + + first_setting = settings_list[0] + + paraNames = '_'.join(i['paraName'] for i in settings_list) + if 'report_folder' in first_setting: + report_folder = first_setting['report_folder'] + final_file = os.path.abspath(os.path.join(first_setting['report_folder'], + '{}_{}_Report_{}.csv'.format(gid, first_setting['symbol'], + paraNames))) + else: + # 报告所在目录 + report_folder = os.path.abspath(os.path.join(os.getcwd(), 'logs')) + final_file = os.path.abspath( + os.path.join(report_folder, '{}_{}_Report_{}.csv'.format(gid, first_setting['symbol'], paraNames))) + + if not os.path.exists(report_folder): + os.makedirs(report_folder) + + for settings in settings_list: + settings['report_folder'] = report_folder + + # 运行回测方法,统计结果 + daily_results = multi_parameter_test(gid, settings_list) + + # 统计结果 + #backtest_df = combine_results(daily_results) + + # 保存汇总记录到文件 + #backtest_df.to_csv(final_file) + + # 显示资金曲线汇总 + ##fig, ax = plt.subplots() + + #period_columns = [item['test_item'] for item in daily_results] + + #fig.patch.set_facecolor('white') + + #for column in period_columns: + # ax.plot(backtest_df[column], label=column) + + #ax.legend() + + #title = u'{} {}'.format(gid, paraNames) + #plt.title(title) + + #fig.savefig(u'{}/rate.png'.format(report_folder)) + + ## 释放内存 + gc.collect() + + print( 'finished run_multiparameter_test') + +def single_func(para): + logger=setup_logger('MyLog', name='my{}'.format(para)) + if para > 5: + print( u'more than 5') + logger.info('More than 5') + return True + else: + print ('less') + logger.info('Less than 5') + return False + +def multi_func(): + + import logging + # 启动多进程 + pool = multiprocessing.Pool(multiprocessing.cpu_count()) + + logger = setup_logger('MyLog') + + logger.info('main process') + l = [] + + for i in range(0,10): + l.append(pool.apply_async(single_func,(i,))) + + results = [res.get() for res in l] + + pool.close() + pool.join() \ No newline at end of file diff --git a/examples/Services/service.py b/examples/Services/service.py index 62d16457..d4800024 100644 --- a/examples/Services/service.py +++ b/examples/Services/service.py @@ -305,7 +305,7 @@ def start(): # 往任务表增加定时计划 operate_crontab("add") # 执行启动 - _start() + #_start() print(u'启动{}服务执行完毕'.format(base_path)) def _stop(): diff --git a/examples/VnTrader/CTA_setting.json b/examples/VnTrader/CTA_setting.json new file mode 100644 index 00000000..4a817f16 --- /dev/null +++ b/examples/VnTrader/CTA_setting.json @@ -0,0 +1,31 @@ +[ + { + "name": "Strategy_TripleMa_01_RB_Min5", + "comment": "螺纹05合约三均线策略", + "auto_init": true, + "auto_start": true, + "className": "Strategy_TripleMa_v01", + "inputSS": 1, + "minDiff": 1, + "mode": "tick", + "shortSymbol": "RB", + "strategy_module": "Strategy_TripleMa_v01", + "symbol": "RB1905", + "vtSymbol": "rb1905" + }, + { + "name": "Strategy_TripleMa_02_HC_Min5", + "comment": "热卷05合约三均线策略", + "auto_init": true, + "auto_start": true, + "className": "Strategy_TripleMa_v02", + "inputSS": 1, + "minDiff": 1, + "mode": "tick", + "shortSymbol": "HC", + "strategy_module": "Strategy_TripleMa_v02", + "symbol": "HC1905", + "vtSymbol": "hc1905" + } + +] \ No newline at end of file diff --git a/examples/VnTrader/run.py b/examples/VnTrader/run.py index 49b2308a..9b3700f1 100644 --- a/examples/VnTrader/run.py +++ b/examples/VnTrader/run.py @@ -23,9 +23,9 @@ from vnpy.trader.uiMainWindow import * # 加载底层接口 from vnpy.trader.gateway import ctpGateway # 初始化的接口模块,以及其指定的名称,CTP是模块,value,是该模块下的多个连接配置文件,如 CTP_JR_connect.json 'CTP_Prod', 'CTP_JR', , 'CTP_JK', 'CTP_02' -init_gateway_names = {'CTP': ['CTP','CTP_001','CTP_002']} +init_gateway_names = {'CTP': ['CTP','CTP01','CTP02']} -from vnpy.trader.app import (ctaStrategy, riskManager, spreadTrading) #,dataRecorder +from vnpy.trader.app import (ctaStrategy, riskManager, spreadTrading,dataRecorder) #, # 文件路径名 path = os.path.abspath(os.path.dirname(__file__)) @@ -60,7 +60,7 @@ def main(): mainEngine.addApp(ctaStrategy) mainEngine.addApp(riskManager) mainEngine.addApp(spreadTrading) - #mainEngine.addApp(dataRecorder) + mainEngine.addApp(dataRecorder) mainWindow = MainWindow(mainEngine, ee) mainWindow.showMaximized() diff --git a/vnpy/trader/app/ctaStrategy/ctaEngine.py b/vnpy/trader/app/ctaStrategy/ctaEngine.py index c5b285fc..658ea55a 100644 --- a/vnpy/trader/app/ctaStrategy/ctaEngine.py +++ b/vnpy/trader/app/ctaStrategy/ctaEngine.py @@ -1432,6 +1432,7 @@ class CtaEngine(object): return for expired_pos in expired_pos_list: + self.writeCtaLog(u'执行仓位处理:{}'.format(expired_pos)) if expired_pos['volume'] == 0: self.writeCtaError(u'clear_dispatch_pos,pos 为空:{},删除'.format(expired_pos)) flt = {'_id': expired_pos['_id']} @@ -1509,7 +1510,7 @@ class CtaEngine(object): else: if curPos.longYd >= expired_pos['volume']: sell_longYd = expired_pos['volume'] - if curPos.longYd == 0: + elif curPos.longYd == 0: sell_longToday = expired_pos['volume'] else: sell_longYd = curPos.longYd diff --git a/vnpy/trader/app/ctaStrategy/ctaPolicy.py b/vnpy/trader/app/ctaStrategy/ctaPolicy.py index d0b31677..b6a44e4f 100644 --- a/vnpy/trader/app/ctaStrategy/ctaPolicy.py +++ b/vnpy/trader/app/ctaStrategy/ctaPolicy.py @@ -253,9 +253,13 @@ class TurtlePolicy(CtaPolicy): def __init__(self, strategy): super(TurtlePolicy, self).__init__(strategy) - self.tns_open_price = 0 # 首次开仓价格 - self.last_open_price = 0 # 最后一次加仓价格 - self.stop_price = 0 # 止损价 + self.allow_add_pos = False # 是否加仓 + self.add_pos_on_pips = EMPTY_INT # 价格超过开仓价多少点时加仓 + + self.tns_open_price = 0 # 首次开仓价格 + self.last_open_price = 0 # 最后一次加仓价格 + self.stop_price = 0 # 止损价 + self.exit_on_last_rtn_pips = 0 # 最高价/最低价回撤多少跳动 self.high_price_in_long = 0 # 多趋势时,最高价 self.low_price_in_short = 0 # 空趋势时,最低价 self.last_under_open_price = 0 # 低于首次开仓价的补仓价格 @@ -279,12 +283,15 @@ class TurtlePolicy(CtaPolicy): j['tns_open_date'] = self.tns_open_date j['tns_open_price'] = self.tns_open_price if self.tns_open_price is not None else 0 + j['allow_add_pos'] = self.allow_add_pos + j['add_pos_on_pips'] = self.add_pos_on_pips + j['exit_on_last_rtn_pips'] = self.exit_on_last_rtn_pips + j['last_open_price'] = self.last_open_price if self.last_open_price is not None else 0 j['stop_price'] = self.stop_price if self.stop_price is not None else 0 j['high_price_in_long'] = self.high_price_in_long if self.high_price_in_long is not None else 0 j['low_price_in_short'] = self.low_price_in_short if self.low_price_in_short is not None else 0 - j[ - 'add_pos_count_under_first_price'] = self.add_pos_count_under_first_price if self.add_pos_count_under_first_price is not None else 0 + j['add_pos_count_under_first_price'] = self.add_pos_count_under_first_price if self.add_pos_count_under_first_price is not None else 0 j['last_under_open_price'] = self.last_under_open_price if self.last_under_open_price is not None else 0 j['max_pos'] = self.max_pos if self.max_pos is not None else 0 @@ -412,6 +419,9 @@ class TurtlePolicy(CtaPolicy): self.writeCtaError(u'解释last_risk_level异常:{}'.format(str(ex))) self.last_risk_level = 0 + self.allow_add_pos=json_data.get('allow_add_pos',False) + self.add_pos_on_pips = json_data.get('add_pos_on_pips',1) + self.exit_on_last_rtn_pips = json_data.get('exit_on_last_rtn_pips',0) def clean(self): """ @@ -432,6 +442,9 @@ class TurtlePolicy(CtaPolicy): self.tns_has_opened = False self.last_risk_level = 0 self.tns_count = 0 + self.allow_add_pos = False + self.add_pos_on_pips = 1 + self.exit_on_last_rtn_pips = 0 class TrendPolicy(CtaPolicy): """ diff --git a/vnpy/trader/app/dataRecorder/drBase.py b/vnpy/trader/app/dataRecorder/drBase.py index 9e16799e..a0a92092 100644 --- a/vnpy/trader/app/dataRecorder/drBase.py +++ b/vnpy/trader/app/dataRecorder/drBase.py @@ -14,6 +14,7 @@ SETTING_DB_NAME = 'VnTrader_Setting_Db' TICK_DB_NAME = 'VnTrader_Tick_Db' DAILY_DB_NAME = 'VnTrader_Daily_Db' MINUTE_DB_NAME = 'VnTrader_1Min_Db' +RENKO_DB_NAME = 'Renko_Db' # CTA引擎中涉及的数据类定义 diff --git a/vnpy/trader/app/dataRecorder/drEngine.py b/vnpy/trader/app/dataRecorder/drEngine.py index cf1aa51b..602ea664 100644 --- a/vnpy/trader/app/dataRecorder/drEngine.py +++ b/vnpy/trader/app/dataRecorder/drEngine.py @@ -14,12 +14,15 @@ from datetime import datetime, timedelta from queue import Queue from threading import Thread + from vnpy.trader.vtEvent import * +from vnpy.trader.vtObject import VtTickData from vnpy.trader.vtGateway import VtSubscribeReq, VtLogData -from vnpy.trader.vtFunction import todayDate,getJsonPath - +from vnpy.trader.vtFunction import todayDate,getJsonPath,getShortSymbol +from vnpy.trader.app.ctaStrategy.ctaRenkoBar import CtaRenkoBar from .drBase import * - +from vnpy.trader.setup_logger import setup_logger +from vnpy.trader.data_source import DataSource ######################################################################## class DrEngine(object): """数据记录引擎""" @@ -44,15 +47,39 @@ class DrEngine(object): # K线对象字典 self.barDict = {} - + + # Renko K线对象字典 + self.renkoDict = {} + # 负责执行数据库插入的单独线程相关 self.active = False # 工作状态 self.queue = Queue() # 队列 self.thread = Thread(target=self.run) # 线程 + self.logger = None + self.createLogger() # 载入设置,订阅行情 self.loadSetting() - + + def createLogger(self): + """ + 创建日志记录 + :return: + """ + currentFolder = os.path.abspath(os.path.join(os.getcwd(), 'logs')) + if os.path.isdir(currentFolder): + # 如果工作目录下,存在data子目录,就使用data子目录 + path = currentFolder + else: + # 否则,使用缺省保存目录 vnpy/trader/app/ctaStrategy/data + path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'logs')) + + if self.logger is None: + filename = os.path.abspath(os.path.join(path, 'drEngine')) + + print(u'create logger:{}'.format(filename)) + self.logger = setup_logger(filename=filename, name='drEngine', debug=True) + #---------------------------------------------------------------------- def loadSetting(self): """载入设置""" @@ -111,7 +138,43 @@ class DrEngine(object): bar = DrBarData() self.barDict[vtSymbol] = bar - + + if 'renko' in drSetting: + l = drSetting.get('renko') + req_set = set() + for setting in l: + # 获取合约,合约短号,renko的高度(多少个跳) + vtSymbol = setting.get('vtSymbol',None) + if vtSymbol is None: + continue + short_symbol = getShortSymbol(vtSymbol).upper() + height = setting.get('height',5) + minDiff = setting.get('minDiff',1) + + + + # 获取vtSymbol的多个renkobar列表,添加新的CtaRenkoBar + renko_list = self.renkoDict.get(vtSymbol,[]) + + bar_setting = {'name':'{}_{}'.format(vtSymbol,height), + 'shortSymbol':short_symbol, + 'vtSymbol':vtSymbol, + 'minDiff':minDiff, + 'height':minDiff*height} + renko_bar = CtaRenkoBar(strategy=None,onBarFunc=self.onRenkoBar,setting=bar_setting) + renko_list.append(renko_bar) + self.renkoDict.update({vtSymbol:renko_list}) + + req = VtSubscribeReq() + req.symbol = vtSymbol + req_set.add((req,setting.get('gateway',None))) + # 更新合约的历史数据 + self.add_gap_ticks() + + # 订阅行情 + for req,gw in req_set: + self.mainEngine.subscribe(req,gw) + if 'active' in drSetting: d = drSetting['active'] @@ -125,7 +188,7 @@ class DrEngine(object): # 注册事件监听 self.registerEvent() - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def procecssTickEvent(self, event): """处理行情推送""" tick = event.dict_['data'] @@ -137,7 +200,7 @@ class DrEngine(object): for key in d.keys(): if key != 'datetime': d[key] = tick.__getattribute__(key) - drTick.datetime = datetime.strptime(' '.join([tick.date, tick.time]), '%Y%m%d %H:%M:%S.%f') + drTick.datetime = datetime.strptime(' '.join([tick.date, tick.time]), '%Y-%m-%d %H:%M:%S.%f') # 更新Tick数据 if vtSymbol in self.tickDict: @@ -190,6 +253,111 @@ class DrEngine(object): bar.close = drTick.lastPrice #---------------------------------------------------------------------- + + # 更新Renko数据 + for renko_bar in self.renkoDict.get(vtSymbol,[]): + renko_bar.onTick(copy.copy(tick)) + + def add_gap_ticks(self): + """ + 补充缺失的分时数据 + :return: + """ + ds = DataSource() + for vtSymbol in self.renkoDict.keys(): + renkobar_list = self.renkoDict.get(vtSymbol,[]) + cache_bars = None + cache_start_date = None + cache_end_date = None + for renko_bar in renkobar_list: + # 通过mongo获取最新一个Renko bar的数据日期close时间 + last_renko_dt = self.get_last_datetime(renko_bar.name) + + # 根据日期+vtSymbol,像datasource获取分钟数据,以close价格,转化为tick,推送到renko_bar中 + if last_renko_dt is not None: + start_date =last_renko_dt.strftime('%Y-%m-%d') + else: + start_date = (datetime.now() - timedelta(days=90)).strftime('%Y-%m-%d') + + end_date = (datetime.now() + timedelta(days=5)).strftime('%Y-%m-%d') + self.writeDrLog(u'从datasource获取{}数据,开始日期:{}'.format(vtSymbol,start_date)) + if cache_bars is None or cache_start_date!=start_date or cache_end_date!=end_date: + fields = ['open', 'close', 'high', 'low', 'volume', 'open_interest', 'limit_up', 'limit_down', + 'trading_date'] + cache_bars = ds.get_price(order_book_id=vtSymbol,start_date=start_date,end_date=end_date, + frequency='1m', fields=fields) + cache_start_date = start_date + cache_end_date = end_date + if cache_bars is not None: + total = len(cache_bars) + self.writeDrLog(u'一共获取{}条{} 1分钟数据'.format(total,vtSymbol)) + + if cache_bars is not None: + self.writeDrLog(u'推送分时数据tick:{}到:{}'.format(vtSymbol,renko_bar.name)) + for idx in cache_bars.index: + row = cache_bars.loc[idx] + tick = VtTickData() + tick.vtSymbol = vtSymbol + tick.symbol = vtSymbol + last_bar_dt = datetime.strptime(str(idx), '%Y-%m-%d %H:%M:00') + tick.datetime = last_bar_dt - timedelta(minutes=1) + tick.date = tick.datetime.strftime('%Y-%m-%d') + tick.time = tick.datetime.strftime('%H:%M:00') + + if tick.datetime.hour >= 21: + if tick.datetime.isoweekday() == 5: + # 星期五=》星期一 + tick.tradingDay = (tick.datetime + timedelta(days=3)).strftime('%Y-%m-%d') + else: + # 第二天 + tick.tradingDay = (tick.datetime + timedelta(days=1)).strftime('%Y-%m-%d') + elif tick.datetime.hour < 8 and tick.datetime.isoweekday() == 6: + # 星期六=>星期一 + tick.tradingDay = (tick.datetime + timedelta(days=2)).strftime('%Y-%m-%d') + else: + tick.tradingDay = tick.date + tick.upperLimit = float(row['limit_up']) + tick.lowerLimit = float(row['limit_down']) + tick.lastPrice = float(row['close']) + tick.askPrice1 = float(row['close']) + tick.bidPrice1 = float(row['close']) + tick.volume = int(row['volume']) + tick.askVolume1 = tick.volume + tick.bidVolume1 = tick.volume + + if last_renko_dt is not None and tick.datetime <= last_renko_dt: + continue + renko_bar.onTick(tick) + + def get_last_datetime(self,renko_name): + """ + 通过mongo获取最新一个bar的数据日期 + :param renko_name: + :return: + """ + qryData = self.mainEngine.dbQueryBySort(dbName=RENKO_DB_NAME, + collectionName=renko_name, + d={}, + sortName='datetime', + sortType=-1, + limitNum=1) + + last_renko_open_dt =None + last_renko_close_dt=None + for d in qryData: + last_renko_open_dt = d.get('datetime',None) + if last_renko_open_dt is not None: + last_renko_close_dt = last_renko_open_dt + timedelta(seconds=d.get('seconds',0)) + + break + return last_renko_close_dt + + def onRenkoBar(self,bar,bar_name): + newBar = copy.copy(bar) + self.insertData(RENKO_DB_NAME, bar_name, newBar) + self.writeDrLog(u'new Renko Bar:{},dt:{},open:{},close:{},high:{},low:{}' + .format(bar_name,bar.datetime,bar.open,bar.close, bar.high, bar.low)) + def registerEvent(self): """注册事件监听""" self.eventEngine.register(EVENT_TICK, self.procecssTickEvent) @@ -206,7 +374,7 @@ class DrEngine(object): try: dbName, collectionName, d = self.queue.get(block=True, timeout=1) self.mainEngine.dbInsert(dbName, collectionName, d) - except Empty: + except Exception as ex: pass #---------------------------------------------------------------------- def start(self): @@ -228,5 +396,7 @@ class DrEngine(object): log.logContent = content event = Event(type_=EVENT_DATARECORDER_LOG) event.dict_['data'] = log - self.eventEngine.put(event) - \ No newline at end of file + self.eventEngine.put(event) + + if self.logger: + self.logger.info(content) diff --git a/vnpy/trader/app/dataRecorder/uiDrWidget.py b/vnpy/trader/app/dataRecorder/uiDrWidget.py index e232521b..0392b4a9 100644 --- a/vnpy/trader/app/dataRecorder/uiDrWidget.py +++ b/vnpy/trader/app/dataRecorder/uiDrWidget.py @@ -29,7 +29,7 @@ class TableCell(QtWidgets.QTableWidgetItem): if text == '0' or text == '0.0': self.setText('') else: - self.setText(text) + self.setText(str(text)) ######################################################################## @@ -60,7 +60,7 @@ class DrEngineManager(QtWidgets.QWidget): self.tickTable.setColumnCount(2) self.tickTable.verticalHeader().setVisible(False) self.tickTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers) - self.tickTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch) + #self.tickTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch) self.tickTable.setAlternatingRowColors(True) self.tickTable.setHorizontalHeaderLabels([u'合约代码', u'接口']) @@ -69,16 +69,25 @@ class DrEngineManager(QtWidgets.QWidget): self.barTable.setColumnCount(2) self.barTable.verticalHeader().setVisible(False) self.barTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers) - self.barTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch) + #self.barTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch) self.barTable.setAlternatingRowColors(True) self.barTable.setHorizontalHeaderLabels([u'合约代码', u'接口']) + renkoLabel = QtWidgets.QLabel(u'RenkoBar') + self.renkoTable = QtWidgets.QTableWidget() + self.renkoTable.setColumnCount(2) + self.renkoTable.verticalHeader().setVisible(False) + self.renkoTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers) + # self.renkoTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch) + self.renkoTable.setAlternatingRowColors(True) + self.renkoTable.setHorizontalHeaderLabels([u'合约代码', u'高度']) + activeLabel = QtWidgets.QLabel(u'主力合约') self.activeTable = QtWidgets.QTableWidget() self.activeTable.setColumnCount(2) self.activeTable.verticalHeader().setVisible(False) self.activeTable.setEditTriggers(QtWidgets.QTableWidget.NoEditTriggers) - self.activeTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch) + #self.activeTable.horizontalHeader().setResizeMode(QtWidgets.QHeaderView.Stretch) self.activeTable.setAlternatingRowColors(True) self.activeTable.setHorizontalHeaderLabels([u'主力代码', u'合约代码']) @@ -92,10 +101,13 @@ class DrEngineManager(QtWidgets.QWidget): grid.addWidget(tickLabel, 0, 0) grid.addWidget(barLabel, 0, 1) - grid.addWidget(activeLabel, 0, 2) + grid.addWidget(renkoLabel, 0, 2) + grid.addWidget(activeLabel, 0, 3) + grid.addWidget(self.tickTable, 1, 0) grid.addWidget(self.barTable, 1, 1) - grid.addWidget(self.activeTable, 1, 2) + grid.addWidget(self.renkoTable, 1, 2) + grid.addWidget(self.activeTable, 1, 3) vbox = QtWidgets.QVBoxLayout() vbox.addLayout(grid) @@ -118,7 +130,7 @@ class DrEngineManager(QtWidgets.QWidget): #---------------------------------------------------------------------- def updateSetting(self): """显示引擎行情记录配置""" - with open(self.drEngine.settingFileName) as f: + with open(self.drEngine.settingFilePath) as f: drSetting = json.load(f) if 'tick' in drSetting: @@ -135,8 +147,15 @@ class DrEngineManager(QtWidgets.QWidget): for setting in l: self.barTable.insertRow(0) self.barTable.setItem(0, 0, TableCell(setting[0])) - self.barTable.setItem(0, 1, TableCell(setting[1])) - + self.barTable.setItem(0, 1, TableCell(setting[1])) + + if 'renko' in drSetting: + l = drSetting['renko'] + + for setting in l: + self.renkoTable.insertRow(0) + self.renkoTable.setItem(0, 0, TableCell(setting.get('vtSymbol',''))) + self.renkoTable.setItem(0, 1, TableCell(setting.get('height',0))) if 'active' in drSetting: d = drSetting['active'] diff --git a/vnpy/trader/setup_logger.py b/vnpy/trader/setup_logger.py index 5fb93e81..608cb7cd 100644 --- a/vnpy/trader/setup_logger.py +++ b/vnpy/trader/setup_logger.py @@ -244,7 +244,8 @@ class MultiprocessHandler(logging.FileHandler): if not self.suffix: raise ValueError(u"指定的日期间隔单位无效: %s" % self.when) #拼接文件路径 格式化字符串 - self.filefmt = "%s_%s.log" % (self.prefix,self.suffix) + #self.filefmt = "%s_%s.log" % (self.prefix,self.suffix) + self.filefmt = u'{}_{}.log'.format(self.prefix, self.suffix) #使用当前时间,格式化文件格式化字符串 self.filePath = datetime.now().strftime(self.filefmt) #获得文件夹路径 diff --git a/vnpy/trader/vtEvent.py b/vnpy/trader/vtEvent.py index d3b3cd56..7e0ecc92 100644 --- a/vnpy/trader/vtEvent.py +++ b/vnpy/trader/vtEvent.py @@ -38,6 +38,10 @@ EVENT_NOTIFICATION = 'eNotification' # 全局通知 EVENT_SIGNAL = 'eSignal' # 信号通知 EVENT_STATUS = 'eStatus' # 服务状态 +# 股票使用 +EVENT_BAR = 'eBar' # 1分钟Bar 行情 +EVENT_BARDICT = 'eBarDict_' # BarDict事件+策略实例名称 + # CTA模块相关 EVENT_CTA_LOG = 'eCtaLog' # CTA相关的日志事件 EVENT_CTA_STRATEGY = 'eCtaStrategy.' # CTA策略状态变化事件 diff --git a/vnpy/trader/vtGateway.py b/vnpy/trader/vtGateway.py index 0a0b5854..c4f6c9ee 100644 --- a/vnpy/trader/vtGateway.py +++ b/vnpy/trader/vtGateway.py @@ -1,6 +1,6 @@ # encoding: UTF-8 -import time,os,sys +import time,os,sys,copy from datetime import datetime from vnpy.trader.vtEvent import * @@ -23,6 +23,9 @@ class VtGateway(object): self.accountID = 'AccountID' self.createLogger() + # 所有订阅onBar的都会添加 + self.klines = {} + # ---------------------------------------------------------------------- def onTick(self, tick): """市场行情推送""" @@ -35,7 +38,14 @@ class VtGateway(object): event2 = Event(type_=EVENT_TICK+tick.vtSymbol) event2.dict_['data'] = tick self.eventEngine.put(event2) - + + def onBar(self,bar,type=EVENT_BAR): + """市场行情推送""" + # bar, 或者 barDict + event = Event(type_=type) + event.dict_['data'] = bar + self.eventEngine.put(event) + # ---------------------------------------------------------------------- def onTrade(self, trade): """成交信息推送""" diff --git a/vnpy/trader/vtUtility.py b/vnpy/trader/vtUtility.py new file mode 100644 index 00000000..94604421 --- /dev/null +++ b/vnpy/trader/vtUtility.py @@ -0,0 +1,283 @@ +# encoding: UTF-8 + + +import numpy as np +import talib + +from vnpy.trader.vtObject import VtBarData + + +######################################################################## +class BarGenerator(object): + """ + K线合成器,支持: + 1. 基于Tick合成1分钟K线 + 2. 基于1分钟K线合成X分钟K线(X可以是2、3、5、10、15、30 ) + """ + + #---------------------------------------------------------------------- + def __init__(self, onBar, xmin=0, onXminBar=None): + """Constructor""" + self.bar = None # 1分钟K线对象 + self.onBar = onBar # 1分钟K线回调函数 + + self.xminBar = None # X分钟K线对象 + self.xmin = xmin # X的值 + self.onXminBar = onXminBar # X分钟K线的回调函数 + + self.lastTick = None # 上一TICK缓存对象 + + #---------------------------------------------------------------------- + def updateTick(self, tick): + """TICK更新""" + newMinute = False # 默认不是新的一分钟 + + # 尚未创建对象 + if not self.bar: + self.bar = VtBarData() + newMinute = True + # 新的一分钟 + elif self.bar.datetime.minute != tick.datetime.minute: + # 生成上一分钟K线的时间戳 + self.bar.datetime = self.bar.datetime.replace(second=0, microsecond=0) # 将秒和微秒设为0 + self.bar.date = self.bar.datetime.strftime('%Y%m%d') + self.bar.time = self.bar.datetime.strftime('%H:%M:%S.%f') + + # 推送已经结束的上一分钟K线 + self.onBar(self.bar) + + # 创建新的K线对象 + self.bar = VtBarData() + newMinute = True + + # 初始化新一分钟的K线数据 + if newMinute: + self.bar.vtSymbol = tick.vtSymbol + self.bar.symbol = tick.symbol + self.bar.exchange = tick.exchange + + self.bar.open = tick.lastPrice + self.bar.high = tick.lastPrice + self.bar.low = tick.lastPrice + # 累加更新老一分钟的K线数据 + else: + self.bar.high = max(self.bar.high, tick.lastPrice) + self.bar.low = min(self.bar.low, tick.lastPrice) + + # 通用更新部分 + self.bar.close = tick.lastPrice + self.bar.datetime = tick.datetime + self.bar.openInterest = tick.openInterest + + if self.lastTick: + volumeChange = tick.volume - self.lastTick.volume # 当前K线内的成交量 + self.bar.volume += max(volumeChange, 0) # 避免夜盘开盘lastTick.volume为昨日收盘数据,导致成交量变化为负的情况 + + # 缓存Tick + self.lastTick = tick + + #---------------------------------------------------------------------- + def updateBar(self, bar): + """1分钟K线更新""" + # 尚未创建对象 + if not self.xminBar: + self.xminBar = VtBarData() + + self.xminBar.vtSymbol = bar.vtSymbol + self.xminBar.symbol = bar.symbol + self.xminBar.exchange = bar.exchange + + self.xminBar.open = bar.open + self.xminBar.high = bar.high + self.xminBar.low = bar.low + + self.xminBar.datetime = bar.datetime # 以第一根分钟K线的开始时间戳作为X分钟线的时间戳 + # 累加老K线 + else: + self.xminBar.high = max(self.xminBar.high, bar.high) + self.xminBar.low = min(self.xminBar.low, bar.low) + + # 通用部分 + self.xminBar.close = bar.close + self.xminBar.openInterest = bar.openInterest + self.xminBar.volume += int(bar.volume) + + # X分钟已经走完 + if not (bar.datetime.minute + 1) % self.xmin: # 可以用X整除 + # 生成上一X分钟K线的时间戳 + self.xminBar.datetime = self.xminBar.datetime.replace(second=0, microsecond=0) # 将秒和微秒设为0 + self.xminBar.date = self.xminBar.datetime.strftime('%Y%m%d') + self.xminBar.time = self.xminBar.datetime.strftime('%H:%M:%S.%f') + + # 推送 + self.onXminBar(self.xminBar) + + # 清空老K线缓存对象 + self.xminBar = None + + #---------------------------------------------------------------------- + def generate(self): + """手动强制立即完成K线合成""" + self.onBar(self.bar) + self.bar = None + + + +######################################################################## +class ArrayManager(object): + """ + K线序列管理工具,负责: + 1. K线时间序列的维护 + 2. 常用技术指标的计算 + """ + + #---------------------------------------------------------------------- + def __init__(self, size=100): + """Constructor""" + self.count = 0 # 缓存计数 + self.size = size # 缓存大小 + self.inited = False # True if count>=size + + self.openArray = np.zeros(size) # OHLC + self.highArray = np.zeros(size) + self.lowArray = np.zeros(size) + self.closeArray = np.zeros(size) + self.volumeArray = np.zeros(size) + + #---------------------------------------------------------------------- + def updateBar(self, bar): + """更新K线""" + self.count += 1 + if not self.inited and self.count >= self.size: + self.inited = True + + self.openArray[:-1] = self.openArray[1:] + self.highArray[:-1] = self.highArray[1:] + self.lowArray[:-1] = self.lowArray[1:] + self.closeArray[:-1] = self.closeArray[1:] + self.volumeArray[:-1] = self.volumeArray[1:] + + self.openArray[-1] = bar.open + self.highArray[-1] = bar.high + self.lowArray[-1] = bar.low + self.closeArray[-1] = bar.close + self.volumeArray[-1] = bar.volume + + #---------------------------------------------------------------------- + @property + def open(self): + """获取开盘价序列""" + return self.openArray + + #---------------------------------------------------------------------- + @property + def high(self): + """获取最高价序列""" + return self.highArray + + #---------------------------------------------------------------------- + @property + def low(self): + """获取最低价序列""" + return self.lowArray + + #---------------------------------------------------------------------- + @property + def close(self): + """获取收盘价序列""" + return self.closeArray + + #---------------------------------------------------------------------- + @property + def volume(self): + """获取成交量序列""" + return self.volumeArray + + #---------------------------------------------------------------------- + def sma(self, n, array=False): + """简单均线""" + result = talib.SMA(self.close, n) + if array: + return result + return result[-1] + + #---------------------------------------------------------------------- + def std(self, n, array=False): + """标准差""" + result = talib.STDDEV(self.close, n) + if array: + return result + return result[-1] + + #---------------------------------------------------------------------- + def cci(self, n, array=False): + """CCI指标""" + result = talib.CCI(self.high, self.low, self.close, n) + if array: + return result + return result[-1] + + #---------------------------------------------------------------------- + def atr(self, n, array=False): + """ATR指标""" + result = talib.ATR(self.high, self.low, self.close, n) + if array: + return result + return result[-1] + + #---------------------------------------------------------------------- + def rsi(self, n, array=False): + """RSI指标""" + result = talib.RSI(self.close, n) + if array: + return result + return result[-1] + + #---------------------------------------------------------------------- + def macd(self, fastPeriod, slowPeriod, signalPeriod, array=False): + """MACD指标""" + macd, signal, hist = talib.MACD(self.close, fastPeriod, + slowPeriod, signalPeriod) + if array: + return macd, signal, hist + return macd[-1], signal[-1], hist[-1] + + #---------------------------------------------------------------------- + def adx(self, n, array=False): + """ADX指标""" + result = talib.ADX(self.high, self.low, self.close, n) + if array: + return result + return result[-1] + + #---------------------------------------------------------------------- + def boll(self, n, dev, array=False): + """布林通道""" + mid = self.sma(n, array) + std = self.std(n, array) + + up = mid + std * dev + down = mid - std * dev + + return up, down + + #---------------------------------------------------------------------- + def keltner(self, n, dev, array=False): + """肯特纳通道""" + mid = self.sma(n, array) + atr = self.atr(n, array) + + up = mid + atr * dev + down = mid - atr * dev + + return up, down + + #---------------------------------------------------------------------- + def donchian(self, n, array=False): + """唐奇安通道""" + up = talib.MAX(self.high, n) + down = talib.MIN(self.low, n) + + if array: + return up, down + return up[-1], down[-1]