[Fix] some code mistakes caused by previous merge
This commit is contained in:
parent
91678e0de2
commit
fdf2d4cf13
@ -1,6 +1,8 @@
|
||||
from collections import defaultdict
|
||||
from datetime import date, datetime
|
||||
from typing import Callable
|
||||
from itertools import product
|
||||
import multiprocessing
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
@ -45,6 +47,7 @@ class BacktestingEngine:
|
||||
self.capital = 1_000_000
|
||||
self.mode = BacktestingMode.BAR
|
||||
|
||||
self.strategy_class = None
|
||||
self.strategy = None
|
||||
self.tick = None
|
||||
self.bar = None
|
||||
@ -126,6 +129,7 @@ class BacktestingEngine:
|
||||
|
||||
def add_strategy(self, strategy_class: type, setting: dict):
|
||||
""""""
|
||||
self.strategy_class = strategy_class
|
||||
self.strategy = strategy_class(
|
||||
self, strategy_class.__name__, self.vt_symbol, setting
|
||||
)
|
||||
@ -373,6 +377,58 @@ class BacktestingEngine:
|
||||
|
||||
plt.show()
|
||||
|
||||
def run_optimization(self, optimization_setting: OptimizationSetting):
|
||||
""""""
|
||||
# Get optimization setting and target
|
||||
settings = optimization_setting.generate_setting()
|
||||
target_name = optimization_setting.target_name
|
||||
|
||||
if not settings:
|
||||
self.output("优化参数组合为空,请检查")
|
||||
return
|
||||
|
||||
if not target_name:
|
||||
self.output("优化目标为设置,请检查")
|
||||
return
|
||||
|
||||
# Use multiprocessing pool for running backtesting with different setting
|
||||
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
||||
|
||||
results = []
|
||||
for setting in settings:
|
||||
result = (pool.apply_async(optimize, (
|
||||
target_name,
|
||||
self.strategy_class,
|
||||
setting,
|
||||
self.vt_symbol,
|
||||
self.interval,
|
||||
self.start,
|
||||
self.rate,
|
||||
self.slippage,
|
||||
self.size,
|
||||
self.pricetick,
|
||||
self.capital,
|
||||
self.end,
|
||||
self.mode
|
||||
)))
|
||||
results.append(result)
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
# Sort results and output
|
||||
result_values = [result.get() for result in results]
|
||||
result_values.sort(reverse=True, key=lambda result:result[1])
|
||||
|
||||
for value in result_values:
|
||||
msg = f"参数:{value[0]}, 目标:{value[1]}"
|
||||
self.output(msg)
|
||||
|
||||
return result_values
|
||||
|
||||
return resultList
|
||||
|
||||
|
||||
def update_daily_close(self, price: float):
|
||||
""""""
|
||||
d = self.datetime.date()
|
||||
@ -788,3 +844,65 @@ class OptimizationSetting:
|
||||
value += step
|
||||
|
||||
self.params[name] = value_list
|
||||
|
||||
def set_target(self, target: str):
|
||||
""""""
|
||||
self.target = target
|
||||
|
||||
def generate_setting(self):
|
||||
""""""
|
||||
keys = self.params.keys()
|
||||
values = self.params.values()
|
||||
products = list(product(*values))
|
||||
|
||||
settings = []
|
||||
for product in products:
|
||||
setting = dict(zip(keys, product))
|
||||
settings.append(setting)
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
def optimize(
|
||||
target_name: str,
|
||||
strategy_class: CtaTemplate,
|
||||
setting: dict,
|
||||
vt_symbol: str,
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
rate: float,
|
||||
slippage: float,
|
||||
size: float,
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
end: datetime,
|
||||
mode: BacktestingMode,
|
||||
):
|
||||
"""
|
||||
Function for running in multiprocessing.pool
|
||||
"""
|
||||
engine = BacktestingEngine()
|
||||
engine.set_parameters(
|
||||
vt_symbol=vt_symbol,
|
||||
interval=interval,
|
||||
start=start
|
||||
rate=rate,
|
||||
slippage=slippage,
|
||||
size=size,
|
||||
pricetick=pricetick,
|
||||
capital=capital,
|
||||
end=end,
|
||||
mode=mode
|
||||
)
|
||||
|
||||
engine.add_strategy(strategy_class, setting)
|
||||
engine.load_data()
|
||||
engine.run_backtesting()
|
||||
engine.calculate_result()
|
||||
statistics = engine.calculate_statistics()
|
||||
|
||||
target_value = result[target_name]
|
||||
return (str(setting), target_value, statistics)
|
||||
|
||||
|
||||
|
@ -23,6 +23,9 @@ from vnpy.trader.constant import Direction, Offset, Exchange, PriceType, Interva
|
||||
from vnpy.trader.utility import get_temp_path
|
||||
from .base import (
|
||||
CtaOrderType,
|
||||
EngineType,
|
||||
StopOrder,
|
||||
StopOrderStatus,
|
||||
EVENT_CTA_LOG,
|
||||
EVENT_CTA_STOPORDER,
|
||||
EVENT_CTA_STRATEGY,
|
||||
|
@ -25,13 +25,15 @@ from vnpy.trader.constant import (
|
||||
)
|
||||
from vnpy.trader.gateway import BaseGateway
|
||||
from vnpy.trader.object import (
|
||||
AccountData,
|
||||
CancelRequest,
|
||||
ContractData,
|
||||
TickData,
|
||||
OrderData,
|
||||
TradeData,
|
||||
PositionData,
|
||||
AccountData,
|
||||
ContractData,
|
||||
OrderRequest,
|
||||
CancelRequest,
|
||||
SubscribeRequest,
|
||||
)
|
||||
|
||||
REST_HOST = "https://www.bitmex.com/api/v1"
|
||||
|
@ -10,6 +10,8 @@ from time import sleep
|
||||
|
||||
from futu import (
|
||||
ModifyOrderOp,
|
||||
TrdSide,
|
||||
TrdEnv,
|
||||
OpenHKTradeContext,
|
||||
OpenQuoteContext,
|
||||
OpenUSTradeContext,
|
||||
@ -21,7 +23,7 @@ from futu import (
|
||||
StockQuoteHandlerBase,
|
||||
TradeDealHandlerBase,
|
||||
TradeOrderHandlerBase,
|
||||
TradeDealHandlerBase,
|
||||
TradeDealHandlerBase
|
||||
)
|
||||
|
||||
from vnpy.trader.constant import Direction, Exchange, Product, Status
|
||||
|
@ -20,6 +20,8 @@ from .event import (
|
||||
EVENT_POSITION,
|
||||
EVENT_ACCOUNT,
|
||||
EVENT_CONTRACT,
|
||||
EVENT_TICK,
|
||||
EVENT_TRADE,
|
||||
)
|
||||
from .gateway import BaseGateway
|
||||
from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest
|
||||
@ -179,7 +181,10 @@ class BaseEngine(ABC):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, main_engine: MainEngine, event_engine: EventEngine, engine_name: str
|
||||
self,
|
||||
main_engine: MainEngine,
|
||||
event_engine: EventEngine,
|
||||
engine_name: str,
|
||||
):
|
||||
""""""
|
||||
self.main_engine = main_engine
|
||||
@ -207,7 +212,9 @@ class LogEngine(BaseEngine):
|
||||
|
||||
self.level = SETTINGS["log.level"]
|
||||
self.logger = logging.getLogger("VN Trader")
|
||||
self.formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
|
||||
self.formatter = logging.Formatter(
|
||||
"%(asctime)s %(levelname)s: %(message)s"
|
||||
)
|
||||
|
||||
self.add_null_handler()
|
||||
|
||||
@ -243,7 +250,9 @@ class LogEngine(BaseEngine):
|
||||
filename = f"vt_{today_date}.log"
|
||||
file_path = get_temp_path(filename)
|
||||
|
||||
file_handler = logging.FileHandler(file_path, mode="w", encoding="utf8")
|
||||
file_handler = logging.FileHandler(
|
||||
file_path, mode="w", encoding="utf8"
|
||||
)
|
||||
file_handler.setLevel(self.level)
|
||||
file_handler.setFormatter(self.formatter)
|
||||
self.logger.addHandler(file_handler)
|
||||
@ -474,7 +483,9 @@ class EmailEngine(BaseEngine):
|
||||
with smtplib.SMTP_SSL(
|
||||
SETTINGS["email.server"], SETTINGS["email.port"]
|
||||
) as smtp:
|
||||
smtp.login(SETTINGS["email.username"], SETTINGS["email.password"])
|
||||
smtp.login(
|
||||
SETTINGS["email.username"], SETTINGS["email.password"]
|
||||
)
|
||||
smtp.send_message(msg)
|
||||
except Empty:
|
||||
pass
|
||||
|
@ -8,13 +8,16 @@ from typing import Any
|
||||
from vnpy.event import Event, EventEngine
|
||||
from .event import EVENT_ACCOUNT, EVENT_CONTRACT, EVENT_LOG, EVENT_CONTRACT
|
||||
from .object import (
|
||||
TickData,
|
||||
OrderData,
|
||||
TradeData,
|
||||
PositionData,
|
||||
AccountData,
|
||||
CancelRequest,
|
||||
ContractData,
|
||||
LogData,
|
||||
OrderData,
|
||||
OrderRequest,
|
||||
CancelRequest,
|
||||
SubscribeRequest
|
||||
)
|
||||
|
||||
|
||||
|
@ -11,11 +11,16 @@ from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
from vnpy.event import EventEngine
|
||||
from .widget import (
|
||||
AboutDialog,
|
||||
TickMonitor,
|
||||
OrderMonitor,
|
||||
TradeMonitor,
|
||||
PositionMonitor,
|
||||
AccountMonitor,
|
||||
LogMonitor,
|
||||
ActiveOrderMonitor,
|
||||
ConnectDialog,
|
||||
ContractManager,
|
||||
AboutDialog,
|
||||
TradingWidget,
|
||||
)
|
||||
from ..engine import MainEngine
|
||||
from ..utility import get_icon_path, get_trader_path
|
||||
|
@ -18,6 +18,8 @@ from ..event import (
|
||||
EVENT_POSITION,
|
||||
EVENT_CONTRACT,
|
||||
EVENT_LOG,
|
||||
EVENT_TICK,
|
||||
EVENT_TRADE
|
||||
)
|
||||
from ..object import OrderRequest, SubscribeRequest
|
||||
from ..utility import load_setting, save_setting
|
||||
|
@ -4,6 +4,12 @@ General utility functions.
|
||||
|
||||
import shelve
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import talib
|
||||
|
||||
from .object import BarData, TickData
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
@ -85,3 +91,276 @@ def round_to_pricetick(price: float, pricetick: float):
|
||||
"""
|
||||
rounded = round(price / pricetick, 0) * pricetick
|
||||
return rounded
|
||||
|
||||
|
||||
class BarGenerator:
|
||||
"""
|
||||
For:
|
||||
1. generating 1 minute bar data from tick data
|
||||
2. generateing x minute bar data from 1 minute data
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, on_bar: Callable, xmin: int = 0, on_xmin_bar: Callable = None
|
||||
):
|
||||
"""Constructor"""
|
||||
self.bar = None
|
||||
self.on_bar = on_bar
|
||||
|
||||
self.xmin = xmin
|
||||
self.xmin_bar = None
|
||||
self.on_xmin_bar = on_xmin_bar
|
||||
|
||||
self.last_tick = None
|
||||
|
||||
def update_tick(self, tick: TickData):
|
||||
"""
|
||||
Update new tick data into generator.
|
||||
"""
|
||||
new_minute = False
|
||||
|
||||
if not self.bar:
|
||||
self.bar = BarData()
|
||||
new_minute = True
|
||||
elif self.bar.datetime.minute != tick.datetime.minute:
|
||||
self.bar.datetime = self.bar.datetime.replace(
|
||||
second=0, microsecond=0
|
||||
)
|
||||
self.on_bar(self.bar)
|
||||
|
||||
self.bar = BarData()
|
||||
new_minute = True
|
||||
|
||||
if new_minute:
|
||||
self.bar.vt_symbol = tick.vt_symbol
|
||||
self.bar.symbol = tick.symbol
|
||||
self.bar.exchange = tick.exchange
|
||||
|
||||
self.bar.open = tick.last_price
|
||||
self.bar.high = tick.last_price
|
||||
self.bar.low = tick.last_price
|
||||
else:
|
||||
self.bar.high = max(self.bar.high, tick.last_price)
|
||||
self.bar.low = min(self.bar.low, tick.last_price)
|
||||
|
||||
self.bar.close = tick.last_price
|
||||
self.bar.datetime = tick.datetime
|
||||
|
||||
if self.last_tick:
|
||||
volume_change = tick.volume - self.last_tick.volume
|
||||
self.bar.volume += max(volume_change, 0)
|
||||
|
||||
self.last_tick = tick
|
||||
|
||||
def update_bar(self, bar: BarData):
|
||||
"""
|
||||
Update 1 minute bar into generator
|
||||
"""
|
||||
if not self.xmin_bar:
|
||||
self.xmin_bar = BarData()
|
||||
|
||||
self.xmin_bar.vt_symbol = bar.vt_symbol
|
||||
self.xmin_bar.symbol = bar.symbol
|
||||
self.xmin_bar.exchange = bar.exchange
|
||||
|
||||
self.xmin_bar.open = bar.open
|
||||
self.xmin_bar.high = bar.high
|
||||
self.xmin_bar.low = bar.low
|
||||
|
||||
self.xmin_bar.datetime = bar.datetime
|
||||
else:
|
||||
self.xmin_bar.high = max(self.xmin_bar.high, bar.high)
|
||||
self.xmin_bar.low = min(self.xmin_bar.low, bar.low)
|
||||
|
||||
self.xmin_bar.close = bar.close
|
||||
self.xmin_bar.volume += int(bar.volume)
|
||||
|
||||
if not (bar.datetime.minute + 1) % self.xmin:
|
||||
self.xmin_bar.datetime = self.xmin_bar.datetime.replace(
|
||||
second=0, microsecond=0
|
||||
)
|
||||
self.on_xmin_bar(self.xmin_bar)
|
||||
|
||||
self.xmin_bar = None
|
||||
|
||||
def generate(self):
|
||||
"""
|
||||
Generate the bar data and call callback immediately.
|
||||
"""
|
||||
self.on_bar(self.bar)
|
||||
self.bar = None
|
||||
|
||||
|
||||
class ArrayManager(object):
|
||||
"""
|
||||
For:
|
||||
1. time series container of bar data
|
||||
2. calculating technical indicator value
|
||||
"""
|
||||
|
||||
def __init__(self, size=100):
|
||||
"""Constructor"""
|
||||
self.count = 0
|
||||
self.size = size
|
||||
self.inited = False
|
||||
|
||||
self.open_array = np.zeros(size)
|
||||
self.high_array = np.zeros(size)
|
||||
self.low_array = np.zeros(size)
|
||||
self.close_array = np.zeros(size)
|
||||
self.volume_array = np.zeros(size)
|
||||
|
||||
def update_bar(self, bar):
|
||||
"""
|
||||
Update new bar data into array manager.
|
||||
"""
|
||||
self.count += 1
|
||||
if not self.inited and self.count >= self.size:
|
||||
self.inited = True
|
||||
|
||||
self.open_array[:-1] = self.open_array[1:]
|
||||
self.high_array[:-1] = self.high_array[1:]
|
||||
self.low_array[:-1] = self.low_array[1:]
|
||||
self.close_array[:-1] = self.close_array[1:]
|
||||
self.volume_array[:-1] = self.volume_array[1:]
|
||||
|
||||
self.open_array[-1] = bar.open
|
||||
self.high_array[-1] = bar.high
|
||||
self.low_array[-1] = bar.low
|
||||
self.close_array[-1] = bar.close
|
||||
self.volume_array[-1] = bar.volume
|
||||
|
||||
@property
|
||||
def open(self):
|
||||
"""
|
||||
Get open price time series.
|
||||
"""
|
||||
return self.open_array
|
||||
|
||||
@property
|
||||
def high(self):
|
||||
"""
|
||||
Get high price time series.
|
||||
"""
|
||||
return self.high_array
|
||||
|
||||
@property
|
||||
def low(self):
|
||||
"""
|
||||
Get low price time series.
|
||||
"""
|
||||
return self.low_array
|
||||
|
||||
@property
|
||||
def close(self):
|
||||
"""
|
||||
Get close price time series.
|
||||
"""
|
||||
return self.close_array
|
||||
|
||||
@property
|
||||
def volume(self):
|
||||
"""
|
||||
Get trading volume time series.
|
||||
"""
|
||||
return self.volume_array
|
||||
|
||||
def sma(self, n, array=False):
|
||||
"""
|
||||
Simple moving average.
|
||||
"""
|
||||
result = talib.SMA(self.close, n)
|
||||
if array:
|
||||
return result
|
||||
return result[-1]
|
||||
|
||||
def std(self, n, array=False):
|
||||
"""
|
||||
Standard deviation
|
||||
"""
|
||||
result = talib.STDDEV(self.close, n)
|
||||
if array:
|
||||
return result
|
||||
return result[-1]
|
||||
|
||||
def cci(self, n, array=False):
|
||||
"""
|
||||
Commodity Channel Index (CCI).
|
||||
"""
|
||||
result = talib.CCI(self.high, self.low, self.close, n)
|
||||
if array:
|
||||
return result
|
||||
return result[-1]
|
||||
|
||||
def atr(self, n, array=False):
|
||||
"""
|
||||
Average True Range (ATR).
|
||||
"""
|
||||
result = talib.ATR(self.high, self.low, self.close, n)
|
||||
if array:
|
||||
return result
|
||||
return result[-1]
|
||||
|
||||
def rsi(self, n, array=False):
|
||||
"""
|
||||
Relative Strenght Index (RSI).
|
||||
"""
|
||||
result = talib.RSI(self.close, n)
|
||||
if array:
|
||||
return result
|
||||
return result[-1]
|
||||
|
||||
def macd(self, fast_period, slow_period, signal_period, array=False):
|
||||
"""
|
||||
MACD.
|
||||
"""
|
||||
macd, signal, hist = talib.MACD(
|
||||
self.close, fast_period, slow_period, signal_period
|
||||
)
|
||||
if array:
|
||||
return macd, signal, hist
|
||||
return macd[-1], signal[-1], hist[-1]
|
||||
|
||||
def adx(self, n, array=False):
|
||||
"""
|
||||
ADX.
|
||||
"""
|
||||
result = talib.ADX(self.high, self.low, self.close, n)
|
||||
if array:
|
||||
return result
|
||||
return result[-1]
|
||||
|
||||
def boll(self, n, dev, array=False):
|
||||
"""
|
||||
Bollinger Channel.
|
||||
"""
|
||||
mid = self.sma(n, array)
|
||||
std = self.std(n, array)
|
||||
|
||||
up = mid + std * dev
|
||||
down = mid - std * dev
|
||||
|
||||
return up, down
|
||||
|
||||
def keltner(self, n, dev, array=False):
|
||||
"""
|
||||
Keltner Channel.
|
||||
"""
|
||||
mid = self.sma(n, array)
|
||||
atr = self.atr(n, array)
|
||||
|
||||
up = mid + atr * dev
|
||||
down = mid - atr * dev
|
||||
|
||||
return up, down
|
||||
|
||||
def donchian(self, n, array=False):
|
||||
"""
|
||||
Donchian Channel.
|
||||
"""
|
||||
up = talib.MAX(self.high, n)
|
||||
down = talib.MIN(self.low, n)
|
||||
|
||||
if array:
|
||||
return up, down
|
||||
return up[-1], down[-1]
|
||||
|
Loading…
Reference in New Issue
Block a user