[Fix] some code mistakes caused by previous merge

This commit is contained in:
vn.py 2019-01-26 19:45:23 +08:00
parent 91678e0de2
commit fdf2d4cf13
9 changed files with 436 additions and 11 deletions

View File

@ -1,6 +1,8 @@
from collections import defaultdict
from datetime import date, datetime
from typing import Callable
from itertools import product
import multiprocessing
import matplotlib.pyplot as plt
import seaborn as sns
@ -45,6 +47,7 @@ class BacktestingEngine:
self.capital = 1_000_000
self.mode = BacktestingMode.BAR
self.strategy_class = None
self.strategy = None
self.tick = None
self.bar = None
@ -126,6 +129,7 @@ class BacktestingEngine:
def add_strategy(self, strategy_class: type, setting: dict):
""""""
self.strategy_class = strategy_class
self.strategy = strategy_class(
self, strategy_class.__name__, self.vt_symbol, setting
)
@ -373,6 +377,58 @@ class BacktestingEngine:
plt.show()
def run_optimization(self, optimization_setting: OptimizationSetting):
""""""
# Get optimization setting and target
settings = optimization_setting.generate_setting()
target_name = optimization_setting.target_name
if not settings:
self.output("优化参数组合为空,请检查")
return
if not target_name:
self.output("优化目标为设置,请检查")
return
# Use multiprocessing pool for running backtesting with different setting
pool = multiprocessing.Pool(multiprocessing.cpu_count())
results = []
for setting in settings:
result = (pool.apply_async(optimize, (
target_name,
self.strategy_class,
setting,
self.vt_symbol,
self.interval,
self.start,
self.rate,
self.slippage,
self.size,
self.pricetick,
self.capital,
self.end,
self.mode
)))
results.append(result)
pool.close()
pool.join()
# Sort results and output
result_values = [result.get() for result in results]
result_values.sort(reverse=True, key=lambda result:result[1])
for value in result_values:
msg = f"参数:{value[0]}, 目标:{value[1]}"
self.output(msg)
return result_values
return resultList
def update_daily_close(self, price: float):
""""""
d = self.datetime.date()
@ -788,3 +844,65 @@ class OptimizationSetting:
value += step
self.params[name] = value_list
def set_target(self, target: str):
""""""
self.target = target
def generate_setting(self):
""""""
keys = self.params.keys()
values = self.params.values()
products = list(product(*values))
settings = []
for product in products:
setting = dict(zip(keys, product))
settings.append(setting)
return settings
def optimize(
target_name: str,
strategy_class: CtaTemplate,
setting: dict,
vt_symbol: str,
interval: Interval,
start: datetime,
rate: float,
slippage: float,
size: float,
pricetick: float,
capital: int,
end: datetime,
mode: BacktestingMode,
):
"""
Function for running in multiprocessing.pool
"""
engine = BacktestingEngine()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital,
end=end,
mode=mode
)
engine.add_strategy(strategy_class, setting)
engine.load_data()
engine.run_backtesting()
engine.calculate_result()
statistics = engine.calculate_statistics()
target_value = result[target_name]
return (str(setting), target_value, statistics)

View File

@ -23,6 +23,9 @@ from vnpy.trader.constant import Direction, Offset, Exchange, PriceType, Interva
from vnpy.trader.utility import get_temp_path
from .base import (
CtaOrderType,
EngineType,
StopOrder,
StopOrderStatus,
EVENT_CTA_LOG,
EVENT_CTA_STOPORDER,
EVENT_CTA_STRATEGY,

View File

@ -25,13 +25,15 @@ from vnpy.trader.constant import (
)
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import (
AccountData,
CancelRequest,
ContractData,
TickData,
OrderData,
TradeData,
PositionData,
AccountData,
ContractData,
OrderRequest,
CancelRequest,
SubscribeRequest,
)
REST_HOST = "https://www.bitmex.com/api/v1"

View File

@ -10,6 +10,8 @@ from time import sleep
from futu import (
ModifyOrderOp,
TrdSide,
TrdEnv,
OpenHKTradeContext,
OpenQuoteContext,
OpenUSTradeContext,
@ -21,7 +23,7 @@ from futu import (
StockQuoteHandlerBase,
TradeDealHandlerBase,
TradeOrderHandlerBase,
TradeDealHandlerBase,
TradeDealHandlerBase
)
from vnpy.trader.constant import Direction, Exchange, Product, Status

View File

@ -20,6 +20,8 @@ from .event import (
EVENT_POSITION,
EVENT_ACCOUNT,
EVENT_CONTRACT,
EVENT_TICK,
EVENT_TRADE,
)
from .gateway import BaseGateway
from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest
@ -179,7 +181,10 @@ class BaseEngine(ABC):
"""
def __init__(
self, main_engine: MainEngine, event_engine: EventEngine, engine_name: str
self,
main_engine: MainEngine,
event_engine: EventEngine,
engine_name: str,
):
""""""
self.main_engine = main_engine
@ -207,7 +212,9 @@ class LogEngine(BaseEngine):
self.level = SETTINGS["log.level"]
self.logger = logging.getLogger("VN Trader")
self.formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
self.formatter = logging.Formatter(
"%(asctime)s %(levelname)s: %(message)s"
)
self.add_null_handler()
@ -243,7 +250,9 @@ class LogEngine(BaseEngine):
filename = f"vt_{today_date}.log"
file_path = get_temp_path(filename)
file_handler = logging.FileHandler(file_path, mode="w", encoding="utf8")
file_handler = logging.FileHandler(
file_path, mode="w", encoding="utf8"
)
file_handler.setLevel(self.level)
file_handler.setFormatter(self.formatter)
self.logger.addHandler(file_handler)
@ -474,7 +483,9 @@ class EmailEngine(BaseEngine):
with smtplib.SMTP_SSL(
SETTINGS["email.server"], SETTINGS["email.port"]
) as smtp:
smtp.login(SETTINGS["email.username"], SETTINGS["email.password"])
smtp.login(
SETTINGS["email.username"], SETTINGS["email.password"]
)
smtp.send_message(msg)
except Empty:
pass

View File

@ -8,13 +8,16 @@ from typing import Any
from vnpy.event import Event, EventEngine
from .event import EVENT_ACCOUNT, EVENT_CONTRACT, EVENT_LOG, EVENT_CONTRACT
from .object import (
TickData,
OrderData,
TradeData,
PositionData,
AccountData,
CancelRequest,
ContractData,
LogData,
OrderData,
OrderRequest,
CancelRequest,
SubscribeRequest
)

View File

@ -11,11 +11,16 @@ from PyQt5 import QtCore, QtGui, QtWidgets
from vnpy.event import EventEngine
from .widget import (
AboutDialog,
TickMonitor,
OrderMonitor,
TradeMonitor,
PositionMonitor,
AccountMonitor,
LogMonitor,
ActiveOrderMonitor,
ConnectDialog,
ContractManager,
AboutDialog,
TradingWidget,
)
from ..engine import MainEngine
from ..utility import get_icon_path, get_trader_path

View File

@ -18,6 +18,8 @@ from ..event import (
EVENT_POSITION,
EVENT_CONTRACT,
EVENT_LOG,
EVENT_TICK,
EVENT_TRADE
)
from ..object import OrderRequest, SubscribeRequest
from ..utility import load_setting, save_setting

View File

@ -4,6 +4,12 @@ General utility functions.
import shelve
from pathlib import Path
from typing import Callable
import numpy as np
import talib
from .object import BarData, TickData
class Singleton(type):
@ -85,3 +91,276 @@ def round_to_pricetick(price: float, pricetick: float):
"""
rounded = round(price / pricetick, 0) * pricetick
return rounded
class BarGenerator:
"""
For:
1. generating 1 minute bar data from tick data
2. generateing x minute bar data from 1 minute data
"""
def __init__(
self, on_bar: Callable, xmin: int = 0, on_xmin_bar: Callable = None
):
"""Constructor"""
self.bar = None
self.on_bar = on_bar
self.xmin = xmin
self.xmin_bar = None
self.on_xmin_bar = on_xmin_bar
self.last_tick = None
def update_tick(self, tick: TickData):
"""
Update new tick data into generator.
"""
new_minute = False
if not self.bar:
self.bar = BarData()
new_minute = True
elif self.bar.datetime.minute != tick.datetime.minute:
self.bar.datetime = self.bar.datetime.replace(
second=0, microsecond=0
)
self.on_bar(self.bar)
self.bar = BarData()
new_minute = True
if new_minute:
self.bar.vt_symbol = tick.vt_symbol
self.bar.symbol = tick.symbol
self.bar.exchange = tick.exchange
self.bar.open = tick.last_price
self.bar.high = tick.last_price
self.bar.low = tick.last_price
else:
self.bar.high = max(self.bar.high, tick.last_price)
self.bar.low = min(self.bar.low, tick.last_price)
self.bar.close = tick.last_price
self.bar.datetime = tick.datetime
if self.last_tick:
volume_change = tick.volume - self.last_tick.volume
self.bar.volume += max(volume_change, 0)
self.last_tick = tick
def update_bar(self, bar: BarData):
"""
Update 1 minute bar into generator
"""
if not self.xmin_bar:
self.xmin_bar = BarData()
self.xmin_bar.vt_symbol = bar.vt_symbol
self.xmin_bar.symbol = bar.symbol
self.xmin_bar.exchange = bar.exchange
self.xmin_bar.open = bar.open
self.xmin_bar.high = bar.high
self.xmin_bar.low = bar.low
self.xmin_bar.datetime = bar.datetime
else:
self.xmin_bar.high = max(self.xmin_bar.high, bar.high)
self.xmin_bar.low = min(self.xmin_bar.low, bar.low)
self.xmin_bar.close = bar.close
self.xmin_bar.volume += int(bar.volume)
if not (bar.datetime.minute + 1) % self.xmin:
self.xmin_bar.datetime = self.xmin_bar.datetime.replace(
second=0, microsecond=0
)
self.on_xmin_bar(self.xmin_bar)
self.xmin_bar = None
def generate(self):
"""
Generate the bar data and call callback immediately.
"""
self.on_bar(self.bar)
self.bar = None
class ArrayManager(object):
"""
For:
1. time series container of bar data
2. calculating technical indicator value
"""
def __init__(self, size=100):
"""Constructor"""
self.count = 0
self.size = size
self.inited = False
self.open_array = np.zeros(size)
self.high_array = np.zeros(size)
self.low_array = np.zeros(size)
self.close_array = np.zeros(size)
self.volume_array = np.zeros(size)
def update_bar(self, bar):
"""
Update new bar data into array manager.
"""
self.count += 1
if not self.inited and self.count >= self.size:
self.inited = True
self.open_array[:-1] = self.open_array[1:]
self.high_array[:-1] = self.high_array[1:]
self.low_array[:-1] = self.low_array[1:]
self.close_array[:-1] = self.close_array[1:]
self.volume_array[:-1] = self.volume_array[1:]
self.open_array[-1] = bar.open
self.high_array[-1] = bar.high
self.low_array[-1] = bar.low
self.close_array[-1] = bar.close
self.volume_array[-1] = bar.volume
@property
def open(self):
"""
Get open price time series.
"""
return self.open_array
@property
def high(self):
"""
Get high price time series.
"""
return self.high_array
@property
def low(self):
"""
Get low price time series.
"""
return self.low_array
@property
def close(self):
"""
Get close price time series.
"""
return self.close_array
@property
def volume(self):
"""
Get trading volume time series.
"""
return self.volume_array
def sma(self, n, array=False):
"""
Simple moving average.
"""
result = talib.SMA(self.close, n)
if array:
return result
return result[-1]
def std(self, n, array=False):
"""
Standard deviation
"""
result = talib.STDDEV(self.close, n)
if array:
return result
return result[-1]
def cci(self, n, array=False):
"""
Commodity Channel Index (CCI).
"""
result = talib.CCI(self.high, self.low, self.close, n)
if array:
return result
return result[-1]
def atr(self, n, array=False):
"""
Average True Range (ATR).
"""
result = talib.ATR(self.high, self.low, self.close, n)
if array:
return result
return result[-1]
def rsi(self, n, array=False):
"""
Relative Strenght Index (RSI).
"""
result = talib.RSI(self.close, n)
if array:
return result
return result[-1]
def macd(self, fast_period, slow_period, signal_period, array=False):
"""
MACD.
"""
macd, signal, hist = talib.MACD(
self.close, fast_period, slow_period, signal_period
)
if array:
return macd, signal, hist
return macd[-1], signal[-1], hist[-1]
def adx(self, n, array=False):
"""
ADX.
"""
result = talib.ADX(self.high, self.low, self.close, n)
if array:
return result
return result[-1]
def boll(self, n, dev, array=False):
"""
Bollinger Channel.
"""
mid = self.sma(n, array)
std = self.std(n, array)
up = mid + std * dev
down = mid - std * dev
return up, down
def keltner(self, n, dev, array=False):
"""
Keltner Channel.
"""
mid = self.sma(n, array)
atr = self.atr(n, array)
up = mid + atr * dev
down = mid - atr * dev
return up, down
def donchian(self, n, array=False):
"""
Donchian Channel.
"""
up = talib.MAX(self.high, n)
down = talib.MIN(self.low, n)
if array:
return up, down
return up[-1], down[-1]