[Add]run ga optimization in CtaBacktester

This commit is contained in:
vn.py 2019-05-03 17:01:47 +08:00
parent 4cd84b45a5
commit a11e6be3ce
5 changed files with 121 additions and 34 deletions

View File

@ -5,3 +5,4 @@ ignore =
W503 line break before binary operator W503 line break before binary operator
W293 blank line contains whitespace W293 blank line contains whitespace
W291 trailing whitespace W291 trailing whitespace
W391 blank line at end of file

View File

@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -16,7 +16,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -54,11 +54,62 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2019-05-03 16:19:04.193703\t开始运行遗传算法每代族群总数11, 优良品种筛选个数8迭代次数30交叉概率0.95突变概率0.050000000000000044\n",
"gen\tnevals\tmean \tstd \tmin \tmax \n",
"0 \t11 \t[0.58423524]\t[0.30377007]\t[0.13231977]\t[1.2382818]\n",
"1 \t11 \t[0.90248989]\t[0.15747112]\t[0.68707859]\t[1.2382818]\n",
"2 \t11 \t[1.09406088]\t[0.18860523]\t[0.86284921]\t[1.46762684]\n",
"3 \t11 \t[1.21413386]\t[0.12138014]\t[1.02072108]\t[1.46762684]\n",
"4 \t11 \t[1.29561806]\t[0.09930932]\t[1.2382818] \t[1.46762684]\n",
"5 \t11 \t[1.41029058]\t[0.09930932]\t[1.2382818] \t[1.46762684]\n",
"6 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"7 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"8 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"9 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"10 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"11 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"12 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"13 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"14 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"15 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"16 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"17 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"18 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"19 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"20 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"21 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"22 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"23 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"24 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"25 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"26 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"27 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"28 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"29 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"30 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"2019-05-03 16:19:58.256354\t遗传算法优化完成耗时54秒\n"
]
},
{
"data": {
"text/plain": [
"[({'atr_length': 38, 'atr_ma_length': 25}, 1.4676268402266743)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"setting = OptimizationSetting()\n", "setting = OptimizationSetting()\n",
"setting.set_target(\"sharpe_ratio\")\n", "setting.set_target(\"sharpe_ratio\")\n",

View File

@ -222,20 +222,25 @@ class BacktesterEngine(BaseEngine):
return strategy_class.get_class_parameters() return strategy_class.get_class_parameters()
def run_optimization( def run_optimization(
self, self,
class_name: str, class_name: str,
vt_symbol: str, vt_symbol: str,
interval: str, interval: str,
start: datetime, start: datetime,
end: datetime, end: datetime,
rate: float, rate: float,
slippage: float, slippage: float,
size: int, size: int,
pricetick: float, pricetick: float,
capital: int, capital: int,
optimization_setting: OptimizationSetting): optimization_setting: OptimizationSetting,
use_ga: bool
):
"""""" """"""
self.write_log("开始多进程参数优化") if use_ga:
self.write_log("开始遗传算法参数优化")
else:
self.write_log("开始多进程参数优化")
self.result_values = None self.result_values = None
@ -260,10 +265,16 @@ class BacktesterEngine(BaseEngine):
{} {}
) )
self.result_values = engine.run_optimization( if use_ga:
optimization_setting, self.result_values = engine.run_ga_optimization(
output=False optimization_setting,
) output=False
)
else:
self.result_values = engine.run_optimization(
optimization_setting,
output=False
)
# Clear thread object handler. # Clear thread object handler.
self.thread = None self.thread = None
@ -285,7 +296,8 @@ class BacktesterEngine(BaseEngine):
size: int, size: int,
pricetick: float, pricetick: float,
capital: int, capital: int,
optimization_setting: OptimizationSetting optimization_setting: OptimizationSetting,
use_ga: bool
): ):
if self.thread: if self.thread:
self.write_log("已有任务在运行中,请等待完成") self.write_log("已有任务在运行中,请等待完成")
@ -305,7 +317,8 @@ class BacktesterEngine(BaseEngine):
size, size,
pricetick, pricetick,
capital, capital,
optimization_setting optimization_setting,
use_ga
) )
) )
self.thread.start() self.thread.start()

View File

@ -240,7 +240,7 @@ class BacktesterManager(QtWidgets.QWidget):
if i != dialog.Accepted: if i != dialog.Accepted:
return return
optimization_setting = dialog.get_setting() optimization_setting, use_ga = dialog.get_setting()
self.target_display = dialog.target_display self.target_display = dialog.target_display
self.backtester_engine.start_optimization( self.backtester_engine.start_optimization(
@ -254,7 +254,8 @@ class BacktesterManager(QtWidgets.QWidget):
size, size,
pricetick, pricetick,
capital, capital,
optimization_setting optimization_setting,
use_ga
) )
self.result_button.setEnabled(False) self.result_button.setEnabled(False)
@ -592,6 +593,7 @@ class OptimizationSettingEditor(QtWidgets.QDialog):
self.edits = {} self.edits = {}
self.optimization_setting = None self.optimization_setting = None
self.use_ga = False
self.init_ui() self.init_ui()
@ -642,12 +644,27 @@ class OptimizationSettingEditor(QtWidgets.QDialog):
row += 1 row += 1
button = QtWidgets.QPushButton("确定") parallel_button = QtWidgets.QPushButton("多进程优化")
button.clicked.connect(self.generate_setting) parallel_button.clicked.connect(self.generate_parallel_setting)
grid.addWidget(button, row, 0, 1, 4) grid.addWidget(parallel_button, row, 0, 1, 4)
row += 1
ga_button = QtWidgets.QPushButton("遗传算法优化")
ga_button.clicked.connect(self.generate_ga_setting)
grid.addWidget(ga_button, row, 0, 1, 4)
self.setLayout(grid) self.setLayout(grid)
def generate_ga_setting(self):
""""""
self.use_ga = True
self.generate_setting()
def generate_parallel_setting(self):
""""""
self.use_ga = False
self.generate_setting()
def generate_setting(self): def generate_setting(self):
"""""" """"""
self.optimization_setting = OptimizationSetting() self.optimization_setting = OptimizationSetting()
@ -676,7 +693,7 @@ class OptimizationSettingEditor(QtWidgets.QDialog):
def get_setting(self): def get_setting(self):
"""""" """"""
return self.optimization_setting return self.optimization_setting, self.use_ga
class OptimizationResultMonitor(QtWidgets.QDialog): class OptimizationResultMonitor(QtWidgets.QDialog):

View File

@ -601,8 +601,12 @@ class BacktestingEngine:
# toolbox.register("map", pool.map) # toolbox.register("map", pool.map)
# Run ga optimization # Run ga optimization
msg = "开始运行遗传算法,每代族群总数:%s, 优良品种筛选个数:%s,迭代次数:%s,交叉概率:%s,突变概率:%s" % (pop_size, mu, ngen, cxpb, mutpb) self.output(f"参数优化空间:{total_size}")
self.output(msg) self.output(f"每代族群总数:{pop_size}")
self.output(f"优良筛选个数:{mu}")
self.output(f"迭代次数:{ngen}")
self.output(f"交叉概率:{cxpb:.0%}")
self.output(f"突变概率:{mutpb:.0%}")
start = time() start = time()
@ -617,7 +621,7 @@ class BacktestingEngine:
stats, stats,
halloffame=hof halloffame=hof
) )
end = time() end = time()
cost = int((end - start)) cost = int((end - start))
@ -630,7 +634,7 @@ class BacktestingEngine:
for parameter_values in hof: for parameter_values in hof:
setting = dict(zip(parameter_keys, parameter_values)) setting = dict(zip(parameter_keys, parameter_values))
target_value = ga_optimize(parameter_values)[0] target_value = ga_optimize(parameter_values)[0]
results.append((setting, target_value)) results.append((setting, target_value, {}))
return results return results
@ -1059,12 +1063,13 @@ def optimize(
pricetick: float, pricetick: float,
capital: int, capital: int,
end: datetime, end: datetime,
mode: BacktestingMode, mode: BacktestingMode
): ):
""" """
Function for running in multiprocessing.pool Function for running in multiprocessing.pool
""" """
engine = BacktestingEngine() engine = BacktestingEngine()
engine.set_parameters( engine.set_parameters(
vt_symbol=vt_symbol, vt_symbol=vt_symbol,
interval=interval, interval=interval,