Merge pull request #1658 from vnpy/genetic_algorithm

Genetic algorithm
This commit is contained in:
vn.py 2019-05-03 17:02:50 +08:00 committed by GitHub
commit 1b10ee4713
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 339 additions and 235 deletions

View File

@ -5,3 +5,4 @@ ignore =
W503 line break before binary operator
W293 blank line contains whitespace
W291 trailing whitespace
W391 blank line at end of file

View File

@ -17,3 +17,4 @@ tigeropen
rqdatac
ta-lib
ibapi
deap

View File

@ -10,28 +10,19 @@
"import multiprocessing\n",
"import numpy as np\n",
"from deap import creator, base, tools, algorithms\n",
"from backtesting import BacktestingEngine,OptimizationSetting\n",
"from boll_channel_strategy import BollChannelStrategy\n",
"from atr_rsi_strategy import AtrRsiStrategy\n",
"from vnpy.app.cta_strategy.backtesting import BacktestingEngine,OptimizationSetting\n",
"from vnpy.app.cta_strategy.strategies.boll_channel_strategy import BollChannelStrategy\n",
"from vnpy.app.cta_strategy.strategies.atr_rsi_strategy import AtrRsiStrategy\n",
"from datetime import datetime\n",
"import multiprocessing #多进程\n",
"from scoop import futures #多进程\n",
"from functools import lru_cache"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"数据总体: 13824\n"
]
}
],
"outputs": [],
"source": [
"setting = OptimizationSetting()\n",
"#setting.add_parameter('atr_length', 10, 50, 2)\n",
@ -51,20 +42,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['boll_window', 'cci_window', 'atr_window'])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"setting_names = random.choice(local_setting).keys()\n",
"setting_names"
@ -72,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -83,40 +63,18 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[22, 28, 22]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"parameter_generate()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'boll_window': 24, 'cci_window': 14, 'atr_window': 28}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"setting=dict(zip(setting_names,parameter_generate()))\n",
"setting"
@ -124,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -168,27 +126,16 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(-0.51, -0.28)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"object_func(parameter_generate())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -203,7 +150,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -224,8 +171,8 @@
" toolbox.register(\"mutate\", tools.mutUniformInt,low = 4,up = 40,indpb=1) \n",
" toolbox.register(\"evaluate\", object_func) \n",
" toolbox.register(\"select\", tools.selNSGA2) \n",
" #pool = multiprocessing.Pool()\n",
" #toolbox.register(\"map\", pool.map)\n",
" pool = multiprocessing.Pool(multiprocessing.cpu_count())\n",
" toolbox.register(\"map\", pool.map)\n",
" #toolbox.register(\"map\", futures.map)\n",
" \n",
" \n",
@ -266,55 +213,11 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"开始运行遗传算法每代族群总数34, 优良品种筛选个数27迭代次数30交叉概率0.95突变概率0.05\n",
"gen\tnevals\tmean \tstd \tmin \tmax \n",
"0 \t34 \t[0.08852941 0.00352941]\t[0.5373362 0.29107188]\t[-0.7 -0.63]\t[1.51 0.5 ]\n",
"1 \t34 \t[0.60148148 0.27518519]\t[0.31013383 0.08573734]\t[0.32 0.18] \t[1.51 0.5 ]\n",
"2 \t34 \t[0.79333333 0.33851852]\t[0.27758215 0.06742369]\t[0.47 0.25] \t[1.54 0.5 ]\n",
"3 \t34 \t[1.00888889 0.39777778]\t[0.3147525 0.06214281]\t[0.7 0.33] \t[1.54 0.5 ]\n",
"4 \t34 \t[1.41074074 0.47444444]\t[0.22881217 0.04661373]\t[0.96 0.36] \t[1.92 0.57]\n",
"5 \t34 \t[1.59666667 0.51222222]\t[0.14714568 0.0255797 ]\t[1.51 0.49] \t[1.92 0.57]\n",
"6 \t34 \t[1.66259259 0.52185185]\t[0.16585564 0.02981884]\t[1.52 0.49] \t[1.92 0.57]\n",
"7 \t34 \t[1.8737037 0.55666667]\t[0.07713135 0.01763834]\t[1.75 0.53] \t[1.95 0.57]\n",
"8 \t34 \t[1.93666667 0.57 ]\t[0.01490712 0. ]\t[1.92 0.57] \t[1.95 0.57]\n",
"9 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"10 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"11 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"12 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"13 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"14 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"15 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"16 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"17 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"18 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"19 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"20 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"21 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"22 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"23 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"24 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"25 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"26 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"27 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"28 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"29 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"30 \t34 \t[1.95 0.57] \t[0. 0.] \t[1.95 0.57] \t[1.95 0.57]\n",
"遗传算法优化完成耗时309秒\n",
"输出帕累托前沿解集:\n",
"{'boll_window': 48, 'cci_window': 40, 'atr_window': 22, 'return_drawdown_ratio': 1.95, 'sharpe_ratio': 0.57}\n",
"{'boll_window': 48, 'cci_window': 50, 'atr_window': 22, 'return_drawdown_ratio': 1.95, 'sharpe_ratio': 0.57}\n"
]
}
],
"outputs": [],
"source": [
"optimize()"
]

View File

@ -0,0 +1,27 @@
from vnpy.app.cta_strategy.backtesting import BacktestingEngine, OptimizationSetting
from vnpy.app.cta_strategy.strategies.atr_rsi_strategy import (
AtrRsiStrategy,
)
from datetime import datetime
if __name__ == "__main__":
engine = BacktestingEngine()
engine.set_parameters(
vt_symbol="IF88.CFFEX",
interval="1m",
start=datetime(2019, 1, 1),
end=datetime(2019, 4, 30),
rate=0.3 / 10000,
slippage=0.2,
size=300,
pricetick=0.2,
capital=1_000_000,
)
engine.add_strategy(AtrRsiStrategy, {})
setting = OptimizationSetting()
setting.set_target("sharpe_ratio")
setting.add_parameter("atr_length", 3, 39, 1)
setting.add_parameter("atr_ma_length", 10, 30, 1)
engine.run_ga_optimization(setting)

View File

@ -16,7 +16,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -54,115 +54,87 @@
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2019-04-15 22:19:49.696835\t参数{'atr_length': 22}, 目标121.19996051999999\n",
"2019-04-15 22:19:49.709531\t参数{'atr_length': 23}, 目标116.54901966000013\n",
"2019-04-15 22:19:49.710507\t参数{'atr_length': 24}, 目标113.29820520000014\n"
"2019-05-03 16:19:04.193703\t开始运行遗传算法每代族群总数11, 优良品种筛选个数8迭代次数30交叉概率0.95突变概率0.050000000000000044\n",
"gen\tnevals\tmean \tstd \tmin \tmax \n",
"0 \t11 \t[0.58423524]\t[0.30377007]\t[0.13231977]\t[1.2382818]\n",
"1 \t11 \t[0.90248989]\t[0.15747112]\t[0.68707859]\t[1.2382818]\n",
"2 \t11 \t[1.09406088]\t[0.18860523]\t[0.86284921]\t[1.46762684]\n",
"3 \t11 \t[1.21413386]\t[0.12138014]\t[1.02072108]\t[1.46762684]\n",
"4 \t11 \t[1.29561806]\t[0.09930932]\t[1.2382818] \t[1.46762684]\n",
"5 \t11 \t[1.41029058]\t[0.09930932]\t[1.2382818] \t[1.46762684]\n",
"6 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"7 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"8 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"9 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"10 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"11 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"12 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"13 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"14 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"15 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"16 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"17 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"18 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"19 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"20 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"21 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"22 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"23 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"24 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"25 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"26 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"27 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"28 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"29 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"30 \t11 \t[1.46762684]\t[0.] \t[1.46762684]\t[1.46762684]\n",
"2019-05-03 16:19:58.256354\t遗传算法优化完成耗时54秒\n"
]
},
{
"data": {
"text/plain": [
"[(\"{'atr_length': 22}\",\n",
" 121.19996051999999,\n",
" {'start_date': datetime.date(2013, 1, 18),\n",
" 'end_date': datetime.date(2019, 4, 11),\n",
" 'total_days': 1514,\n",
" 'profit_days': 763,\n",
" 'loss_days': 750,\n",
" 'capital': 1000000,\n",
" 'end_balance': 2211999.6052,\n",
" 'max_drawdown': -248787.6971999996,\n",
" 'max_ddpercent': -12.636908338002794,\n",
" 'total_net_pnl': 1211999.6052000003,\n",
" 'daily_net_pnl': 800.5281408190227,\n",
" 'total_commission': 242400.39479999998,\n",
" 'daily_commission': 160.10594108322323,\n",
" 'total_slippage': 481860.0,\n",
" 'daily_slippage': 318.2694848084544,\n",
" 'total_turnover': 8080013160.0,\n",
" 'daily_turnover': 5336864.702774108,\n",
" 'total_trade_count': 8031,\n",
" 'daily_trade_count': 5.30449141347424,\n",
" 'total_return': 121.19996051999999,\n",
" 'annual_return': 19.212675379656538,\n",
" 'daily_return': 0.052348808029058974,\n",
" 'return_std': 0.9487639654919149,\n",
" 'sharpe_ratio': 0.854779772691872,\n",
" 'return_drawdown_ratio': 9.590950355754112}),\n",
" (\"{'atr_length': 23}\",\n",
" 116.54901966000013,\n",
" {'start_date': datetime.date(2013, 1, 18),\n",
" 'end_date': datetime.date(2019, 4, 11),\n",
" 'total_days': 1514,\n",
" 'profit_days': 759,\n",
" 'loss_days': 754,\n",
" 'capital': 1000000,\n",
" 'end_balance': 2165490.1966000013,\n",
" 'max_drawdown': -232904.1239999996,\n",
" 'max_ddpercent': -13.536251422505968,\n",
" 'total_net_pnl': 1165490.1966000004,\n",
" 'daily_net_pnl': 769.8085842800531,\n",
" 'total_commission': 242769.80339999998,\n",
" 'daily_commission': 160.34993619550858,\n",
" 'total_slippage': 482700.0,\n",
" 'daily_slippage': 318.82430647291943,\n",
" 'total_turnover': 8092326780.0,\n",
" 'daily_turnover': 5344997.873183619,\n",
" 'total_trade_count': 8045,\n",
" 'daily_trade_count': 5.313738441215324,\n",
" 'total_return': 116.54901966000013,\n",
" 'annual_return': 18.475406022721288,\n",
" 'daily_return': 0.0509452313711608,\n",
" 'return_std': 0.961380153488665,\n",
" 'sharpe_ratio': 0.8209448965768181,\n",
" 'return_drawdown_ratio': 8.610139987960078}),\n",
" (\"{'atr_length': 24}\",\n",
" 113.29820520000014,\n",
" {'start_date': datetime.date(2013, 1, 18),\n",
" 'end_date': datetime.date(2019, 4, 11),\n",
" 'total_days': 1514,\n",
" 'profit_days': 760,\n",
" 'loss_days': 753,\n",
" 'capital': 1000000,\n",
" 'end_balance': 2132982.0520000015,\n",
" 'max_drawdown': -236503.9475999996,\n",
" 'max_ddpercent': -13.23872340727957,\n",
" 'total_net_pnl': 1132982.0520000013,\n",
" 'daily_net_pnl': 748.3368903566719,\n",
" 'total_commission': 242817.948,\n",
" 'daily_commission': 160.3817357992074,\n",
" 'total_slippage': 482700.0,\n",
" 'daily_slippage': 318.82430647291943,\n",
" 'total_turnover': 8093931600.0,\n",
" 'daily_turnover': 5346057.85997358,\n",
" 'total_trade_count': 8045,\n",
" 'daily_trade_count': 5.313738441215324,\n",
" 'total_return': 113.29820520000014,\n",
" 'annual_return': 17.96008536856013,\n",
" 'daily_return': 0.049946173936258026,\n",
" 'return_std': 0.959328411709829,\n",
" 'sharpe_ratio': 0.8065671672003681,\n",
" 'return_drawdown_ratio': 8.558091419728651})]"
"[({'atr_length': 38, 'atr_ma_length': 25}, 1.4676268402266743)]"
]
},
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"setting = OptimizationSetting()\n",
"setting.set_target(\"total_return\")\n",
"setting.add_parameter(\"atr_length\", 22, 24, 1)\n",
"setting.set_target(\"sharpe_ratio\")\n",
"setting.add_parameter(\"atr_length\", 3, 39, 1)\n",
"setting.add_parameter(\"atr_ma_length\", 10, 30, 1)\n",
"\n",
"engine.run_optimization(setting)"
"engine.run_ga_optimization(setting)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"result = _"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(result)"
]
},
{

View File

@ -222,20 +222,25 @@ class BacktesterEngine(BaseEngine):
return strategy_class.get_class_parameters()
def run_optimization(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
optimization_setting: OptimizationSetting):
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
optimization_setting: OptimizationSetting,
use_ga: bool
):
""""""
self.write_log("开始多进程参数优化")
if use_ga:
self.write_log("开始遗传算法参数优化")
else:
self.write_log("开始多进程参数优化")
self.result_values = None
@ -260,10 +265,16 @@ class BacktesterEngine(BaseEngine):
{}
)
self.result_values = engine.run_optimization(
optimization_setting,
output=False
)
if use_ga:
self.result_values = engine.run_ga_optimization(
optimization_setting,
output=False
)
else:
self.result_values = engine.run_optimization(
optimization_setting,
output=False
)
# Clear thread object handler.
self.thread = None
@ -285,7 +296,8 @@ class BacktesterEngine(BaseEngine):
size: int,
pricetick: float,
capital: int,
optimization_setting: OptimizationSetting
optimization_setting: OptimizationSetting,
use_ga: bool
):
if self.thread:
self.write_log("已有任务在运行中,请等待完成")
@ -305,7 +317,8 @@ class BacktesterEngine(BaseEngine):
size,
pricetick,
capital,
optimization_setting
optimization_setting,
use_ga
)
)
self.thread.start()

View File

@ -240,7 +240,7 @@ class BacktesterManager(QtWidgets.QWidget):
if i != dialog.Accepted:
return
optimization_setting = dialog.get_setting()
optimization_setting, use_ga = dialog.get_setting()
self.target_display = dialog.target_display
self.backtester_engine.start_optimization(
@ -254,7 +254,8 @@ class BacktesterManager(QtWidgets.QWidget):
size,
pricetick,
capital,
optimization_setting
optimization_setting,
use_ga
)
self.result_button.setEnabled(False)
@ -592,6 +593,7 @@ class OptimizationSettingEditor(QtWidgets.QDialog):
self.edits = {}
self.optimization_setting = None
self.use_ga = False
self.init_ui()
@ -642,12 +644,27 @@ class OptimizationSettingEditor(QtWidgets.QDialog):
row += 1
button = QtWidgets.QPushButton("确定")
button.clicked.connect(self.generate_setting)
grid.addWidget(button, row, 0, 1, 4)
parallel_button = QtWidgets.QPushButton("多进程优化")
parallel_button.clicked.connect(self.generate_parallel_setting)
grid.addWidget(parallel_button, row, 0, 1, 4)
row += 1
ga_button = QtWidgets.QPushButton("遗传算法优化")
ga_button.clicked.connect(self.generate_ga_setting)
grid.addWidget(ga_button, row, 0, 1, 4)
self.setLayout(grid)
def generate_ga_setting(self):
""""""
self.use_ga = True
self.generate_setting()
def generate_parallel_setting(self):
""""""
self.use_ga = False
self.generate_setting()
def generate_setting(self):
""""""
self.optimization_setting = OptimizationSetting()
@ -676,7 +693,7 @@ class OptimizationSettingEditor(QtWidgets.QDialog):
def get_setting(self):
""""""
return self.optimization_setting
return self.optimization_setting, self.use_ga
class OptimizationResultMonitor(QtWidgets.QDialog):

View File

@ -3,12 +3,16 @@ from datetime import date, datetime
from typing import Callable
from itertools import product
from functools import lru_cache
from time import time
import multiprocessing
import random
import math
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pandas import DataFrame
from deap import creator, base, tools, algorithms
from vnpy.trader.constant import (Direction, Offset, Exchange,
Interval, Status)
@ -26,6 +30,8 @@ from .base import (
from .template import CtaTemplate
sns.set_style("whitegrid")
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", list, fitness=creator.FitnessMax)
class OptimizationSetting:
@ -514,6 +520,124 @@ class BacktestingEngine:
return result_values
def run_ga_optimization(self, optimization_setting: OptimizationSetting, output=True):
""""""
# Get optimization setting and target
settings = optimization_setting.generate_setting()
target_name = optimization_setting.target_name
if not settings:
self.output("优化参数组合为空,请检查")
return
if not target_name:
self.output("优化目标未设置,请检查")
return
# Define parameter generation function
def generate_parameter():
""""""
return list(random.choice(settings).values())
# Create ga object function
global ga_target_name
global ga_strategy_class
global ga_setting
global ga_vt_symbol
global ga_interval
global ga_start
global ga_rate
global ga_slippage
global ga_size
global ga_pricetick
global ga_capital
global ga_end
global ga_mode
ga_target_name = target_name
ga_strategy_class = self.strategy_class
ga_setting = settings[0]
ga_vt_symbol = self.vt_symbol
ga_interval = self.interval
ga_start = self.start
ga_rate = self.rate
ga_slippage = self.slippage
ga_size = self.size
ga_pricetick = self.pricetick
ga_capital = self.capital
ga_end = self.end
ga_mode = self.mode
# Set up genetic algorithem
toolbox = base.Toolbox()
toolbox.register("individual", tools.initIterate, creator.Individual, generate_parameter)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("mate", tools.cxTwoPoint)
toolbox.register("mutate", tools.mutUniformInt, low=4, up=40, indpb=1)
toolbox.register("evaluate", ga_optimize)
toolbox.register("select", tools.selNSGA2)
total_size = len(settings)
pop_size = int(pow(total_size, 1 / math.e)) # number of individuals in each generation
lambda_ = pop_size # number of children to produce at each generation
mu = int(pop_size * 0.8) # number of individuals to select for the next generation
cxpb = 0.95 # probability that an offspring is produced by crossover
mutpb = 1 - cxpb # probability that an offspring is produced by mutation
ngen = 30 # number of generation
pop = toolbox.population(pop_size)
hof = tools.ParetoFront() # end result of pareto front
stats = tools.Statistics(lambda ind: ind.fitness.values)
np.set_printoptions(suppress=True)
stats.register("mean", np.mean, axis=0)
stats.register("std", np.std, axis=0)
stats.register("min", np.min, axis=0)
stats.register("max", np.max, axis=0)
# Multiprocessing is not supported yet.
# pool = multiprocessing.Pool(multiprocessing.cpu_count())
# toolbox.register("map", pool.map)
# Run ga optimization
self.output(f"参数优化空间:{total_size}")
self.output(f"每代族群总数:{pop_size}")
self.output(f"优良筛选个数:{mu}")
self.output(f"迭代次数:{ngen}")
self.output(f"交叉概率:{cxpb:.0%}")
self.output(f"突变概率:{mutpb:.0%}")
start = time()
algorithms.eaMuPlusLambda(
pop,
toolbox,
mu,
lambda_,
cxpb,
mutpb,
ngen,
stats,
halloffame=hof
)
end = time()
cost = int((end - start))
self.output(f"遗传算法优化完成,耗时{cost}")
# Return result list
results = []
parameter_keys = list(ga_setting.keys())
for parameter_values in hof:
setting = dict(zip(parameter_keys, parameter_values))
target_value = ga_optimize(parameter_values)[0]
results.append((setting, target_value, {}))
return results
def update_daily_close(self, price: float):
""""""
d = self.datetime.date()
@ -939,12 +1063,13 @@ def optimize(
pricetick: float,
capital: int,
end: datetime,
mode: BacktestingMode,
mode: BacktestingMode
):
"""
Function for running in multiprocessing.pool
"""
engine = BacktestingEngine()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
@ -968,6 +1093,35 @@ def optimize(
return (str(setting), target_value, statistics)
@lru_cache(maxsize=1000000)
def _ga_optimizae(parameter_values: tuple):
""""""
parameter_keys = list(ga_setting.keys())
setting = dict(zip(parameter_keys, parameter_values))
result = optimize(
ga_target_name,
ga_strategy_class,
setting,
ga_vt_symbol,
ga_interval,
ga_start,
ga_rate,
ga_slippage,
ga_size,
ga_pricetick,
ga_capital,
ga_end,
ga_mode
)
return (result[1],)
def ga_optimize(parameter_values: list):
""""""
return _ga_optimizae(tuple(parameter_values))
@lru_cache(maxsize=10)
def load_bar_data(
symbol: str,
@ -993,3 +1147,19 @@ def load_tick_data(
return database_manager.load_tick_data(
symbol, exchange, start, end
)
# GA related global value
ga_end = None
ga_mode = None
ga_target_name = None
ga_strategy_class = None
ga_setting = None
ga_vt_symbol = None
ga_interval = None
ga_start = None
ga_rate = None
ga_slippage = None
ga_size = None
ga_pricetick = None
ga_capital = None