From fdf2d4cf136de6d75f8f7dd765288a4ebc2f9f13 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Sat, 26 Jan 2019 19:45:23 +0800 Subject: [PATCH] [Fix] some code mistakes caused by previous merge --- vnpy/app/cta_strategy/backtesting.py | 118 +++++++++++ vnpy/app/cta_strategy/engine.py | 3 + vnpy/gateway/bitmex/bitmex_gateway.py | 8 +- vnpy/gateway/futu/futu_gateway.py | 4 +- vnpy/trader/engine.py | 19 +- vnpy/trader/gateway.py | 7 +- vnpy/trader/ui/mainwindow.py | 7 +- vnpy/trader/ui/widget.py | 2 + vnpy/trader/utility.py | 279 ++++++++++++++++++++++++++ 9 files changed, 436 insertions(+), 11 deletions(-) diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index 49c4c929..049f3d67 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -1,6 +1,8 @@ from collections import defaultdict from datetime import date, datetime from typing import Callable +from itertools import product +import multiprocessing import matplotlib.pyplot as plt import seaborn as sns @@ -45,6 +47,7 @@ class BacktestingEngine: self.capital = 1_000_000 self.mode = BacktestingMode.BAR + self.strategy_class = None self.strategy = None self.tick = None self.bar = None @@ -126,6 +129,7 @@ class BacktestingEngine: def add_strategy(self, strategy_class: type, setting: dict): """""" + self.strategy_class = strategy_class self.strategy = strategy_class( self, strategy_class.__name__, self.vt_symbol, setting ) @@ -372,6 +376,58 @@ class BacktestingEngine: df["net_pnl"].hist(bins=50) plt.show() + + def run_optimization(self, optimization_setting: OptimizationSetting): + """""" + # Get optimization setting and target + settings = optimization_setting.generate_setting() + target_name = optimization_setting.target_name + + if not settings: + self.output("优化参数组合为空,请检查") + return + + if not target_name: + self.output("优化目标为设置,请检查") + return + + # Use multiprocessing pool for running backtesting with different setting + pool = multiprocessing.Pool(multiprocessing.cpu_count()) + + results = [] + for setting in settings: + result = (pool.apply_async(optimize, ( + target_name, + self.strategy_class, + setting, + self.vt_symbol, + self.interval, + self.start, + self.rate, + self.slippage, + self.size, + self.pricetick, + self.capital, + self.end, + self.mode + ))) + results.append(result) + + pool.close() + pool.join() + + # Sort results and output + result_values = [result.get() for result in results] + result_values.sort(reverse=True, key=lambda result:result[1]) + + for value in result_values: + msg = f"参数:{value[0]}, 目标:{value[1]}" + self.output(msg) + + return result_values + + return resultList + def update_daily_close(self, price: float): """""" @@ -788,3 +844,65 @@ class OptimizationSetting: value += step self.params[name] = value_list + + def set_target(self, target: str): + """""" + self.target = target + + def generate_setting(self): + """""" + keys = self.params.keys() + values = self.params.values() + products = list(product(*values)) + + settings = [] + for product in products: + setting = dict(zip(keys, product)) + settings.append(setting) + + return settings + + + +def optimize( + target_name: str, + strategy_class: CtaTemplate, + setting: dict, + vt_symbol: str, + interval: Interval, + start: datetime, + rate: float, + slippage: float, + size: float, + pricetick: float, + capital: int, + end: datetime, + mode: BacktestingMode, +): + """ + Function for running in multiprocessing.pool + """ + engine = BacktestingEngine() + engine.set_parameters( + vt_symbol=vt_symbol, + interval=interval, + start=start + rate=rate, + slippage=slippage, + size=size, + pricetick=pricetick, + capital=capital, + end=end, + mode=mode + ) + + engine.add_strategy(strategy_class, setting) + engine.load_data() + engine.run_backtesting() + engine.calculate_result() + statistics = engine.calculate_statistics() + + target_value = result[target_name] + return (str(setting), target_value, statistics) + + diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index 9dc12a77..2ab5664e 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -23,6 +23,9 @@ from vnpy.trader.constant import Direction, Offset, Exchange, PriceType, Interva from vnpy.trader.utility import get_temp_path from .base import ( CtaOrderType, + EngineType, + StopOrder, + StopOrderStatus, EVENT_CTA_LOG, EVENT_CTA_STOPORDER, EVENT_CTA_STRATEGY, diff --git a/vnpy/gateway/bitmex/bitmex_gateway.py b/vnpy/gateway/bitmex/bitmex_gateway.py index 6d81f284..f237badf 100644 --- a/vnpy/gateway/bitmex/bitmex_gateway.py +++ b/vnpy/gateway/bitmex/bitmex_gateway.py @@ -25,13 +25,15 @@ from vnpy.trader.constant import ( ) from vnpy.trader.gateway import BaseGateway from vnpy.trader.object import ( - AccountData, - CancelRequest, - ContractData, + TickData, OrderData, + TradeData, PositionData, AccountData, ContractData, + OrderRequest, + CancelRequest, + SubscribeRequest, ) REST_HOST = "https://www.bitmex.com/api/v1" diff --git a/vnpy/gateway/futu/futu_gateway.py b/vnpy/gateway/futu/futu_gateway.py index a2749b73..46b67061 100644 --- a/vnpy/gateway/futu/futu_gateway.py +++ b/vnpy/gateway/futu/futu_gateway.py @@ -10,6 +10,8 @@ from time import sleep from futu import ( ModifyOrderOp, + TrdSide, + TrdEnv, OpenHKTradeContext, OpenQuoteContext, OpenUSTradeContext, @@ -21,7 +23,7 @@ from futu import ( StockQuoteHandlerBase, TradeDealHandlerBase, TradeOrderHandlerBase, - TradeDealHandlerBase, + TradeDealHandlerBase ) from vnpy.trader.constant import Direction, Exchange, Product, Status diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index fd4765e6..f0701f85 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -20,6 +20,8 @@ from .event import ( EVENT_POSITION, EVENT_ACCOUNT, EVENT_CONTRACT, + EVENT_TICK, + EVENT_TRADE, ) from .gateway import BaseGateway from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest @@ -179,7 +181,10 @@ class BaseEngine(ABC): """ def __init__( - self, main_engine: MainEngine, event_engine: EventEngine, engine_name: str + self, + main_engine: MainEngine, + event_engine: EventEngine, + engine_name: str, ): """""" self.main_engine = main_engine @@ -207,7 +212,9 @@ class LogEngine(BaseEngine): self.level = SETTINGS["log.level"] self.logger = logging.getLogger("VN Trader") - self.formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s") + self.formatter = logging.Formatter( + "%(asctime)s %(levelname)s: %(message)s" + ) self.add_null_handler() @@ -243,7 +250,9 @@ class LogEngine(BaseEngine): filename = f"vt_{today_date}.log" file_path = get_temp_path(filename) - file_handler = logging.FileHandler(file_path, mode="w", encoding="utf8") + file_handler = logging.FileHandler( + file_path, mode="w", encoding="utf8" + ) file_handler.setLevel(self.level) file_handler.setFormatter(self.formatter) self.logger.addHandler(file_handler) @@ -474,7 +483,9 @@ class EmailEngine(BaseEngine): with smtplib.SMTP_SSL( SETTINGS["email.server"], SETTINGS["email.port"] ) as smtp: - smtp.login(SETTINGS["email.username"], SETTINGS["email.password"]) + smtp.login( + SETTINGS["email.username"], SETTINGS["email.password"] + ) smtp.send_message(msg) except Empty: pass diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index b43aa7a5..4308a528 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -8,13 +8,16 @@ from typing import Any from vnpy.event import Event, EventEngine from .event import EVENT_ACCOUNT, EVENT_CONTRACT, EVENT_LOG, EVENT_CONTRACT from .object import ( + TickData, + OrderData, + TradeData, + PositionData, AccountData, - CancelRequest, ContractData, LogData, - OrderData, OrderRequest, CancelRequest, + SubscribeRequest ) diff --git a/vnpy/trader/ui/mainwindow.py b/vnpy/trader/ui/mainwindow.py index fc6f7847..bafc4411 100644 --- a/vnpy/trader/ui/mainwindow.py +++ b/vnpy/trader/ui/mainwindow.py @@ -11,11 +11,16 @@ from PyQt5 import QtCore, QtGui, QtWidgets from vnpy.event import EventEngine from .widget import ( AboutDialog, + TickMonitor, + OrderMonitor, + TradeMonitor, + PositionMonitor, AccountMonitor, + LogMonitor, ActiveOrderMonitor, ConnectDialog, ContractManager, - AboutDialog, + TradingWidget, ) from ..engine import MainEngine from ..utility import get_icon_path, get_trader_path diff --git a/vnpy/trader/ui/widget.py b/vnpy/trader/ui/widget.py index eb65640f..3c8111bd 100644 --- a/vnpy/trader/ui/widget.py +++ b/vnpy/trader/ui/widget.py @@ -18,6 +18,8 @@ from ..event import ( EVENT_POSITION, EVENT_CONTRACT, EVENT_LOG, + EVENT_TICK, + EVENT_TRADE ) from ..object import OrderRequest, SubscribeRequest from ..utility import load_setting, save_setting diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 28221334..39262068 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -4,6 +4,12 @@ General utility functions. import shelve from pathlib import Path +from typing import Callable + +import numpy as np +import talib + +from .object import BarData, TickData class Singleton(type): @@ -85,3 +91,276 @@ def round_to_pricetick(price: float, pricetick: float): """ rounded = round(price / pricetick, 0) * pricetick return rounded + + +class BarGenerator: + """ + For: + 1. generating 1 minute bar data from tick data + 2. generateing x minute bar data from 1 minute data + """ + + def __init__( + self, on_bar: Callable, xmin: int = 0, on_xmin_bar: Callable = None + ): + """Constructor""" + self.bar = None + self.on_bar = on_bar + + self.xmin = xmin + self.xmin_bar = None + self.on_xmin_bar = on_xmin_bar + + self.last_tick = None + + def update_tick(self, tick: TickData): + """ + Update new tick data into generator. + """ + new_minute = False + + if not self.bar: + self.bar = BarData() + new_minute = True + elif self.bar.datetime.minute != tick.datetime.minute: + self.bar.datetime = self.bar.datetime.replace( + second=0, microsecond=0 + ) + self.on_bar(self.bar) + + self.bar = BarData() + new_minute = True + + if new_minute: + self.bar.vt_symbol = tick.vt_symbol + self.bar.symbol = tick.symbol + self.bar.exchange = tick.exchange + + self.bar.open = tick.last_price + self.bar.high = tick.last_price + self.bar.low = tick.last_price + else: + self.bar.high = max(self.bar.high, tick.last_price) + self.bar.low = min(self.bar.low, tick.last_price) + + self.bar.close = tick.last_price + self.bar.datetime = tick.datetime + + if self.last_tick: + volume_change = tick.volume - self.last_tick.volume + self.bar.volume += max(volume_change, 0) + + self.last_tick = tick + + def update_bar(self, bar: BarData): + """ + Update 1 minute bar into generator + """ + if not self.xmin_bar: + self.xmin_bar = BarData() + + self.xmin_bar.vt_symbol = bar.vt_symbol + self.xmin_bar.symbol = bar.symbol + self.xmin_bar.exchange = bar.exchange + + self.xmin_bar.open = bar.open + self.xmin_bar.high = bar.high + self.xmin_bar.low = bar.low + + self.xmin_bar.datetime = bar.datetime + else: + self.xmin_bar.high = max(self.xmin_bar.high, bar.high) + self.xmin_bar.low = min(self.xmin_bar.low, bar.low) + + self.xmin_bar.close = bar.close + self.xmin_bar.volume += int(bar.volume) + + if not (bar.datetime.minute + 1) % self.xmin: + self.xmin_bar.datetime = self.xmin_bar.datetime.replace( + second=0, microsecond=0 + ) + self.on_xmin_bar(self.xmin_bar) + + self.xmin_bar = None + + def generate(self): + """ + Generate the bar data and call callback immediately. + """ + self.on_bar(self.bar) + self.bar = None + + +class ArrayManager(object): + """ + For: + 1. time series container of bar data + 2. calculating technical indicator value + """ + + def __init__(self, size=100): + """Constructor""" + self.count = 0 + self.size = size + self.inited = False + + self.open_array = np.zeros(size) + self.high_array = np.zeros(size) + self.low_array = np.zeros(size) + self.close_array = np.zeros(size) + self.volume_array = np.zeros(size) + + def update_bar(self, bar): + """ + Update new bar data into array manager. + """ + self.count += 1 + if not self.inited and self.count >= self.size: + self.inited = True + + self.open_array[:-1] = self.open_array[1:] + self.high_array[:-1] = self.high_array[1:] + self.low_array[:-1] = self.low_array[1:] + self.close_array[:-1] = self.close_array[1:] + self.volume_array[:-1] = self.volume_array[1:] + + self.open_array[-1] = bar.open + self.high_array[-1] = bar.high + self.low_array[-1] = bar.low + self.close_array[-1] = bar.close + self.volume_array[-1] = bar.volume + + @property + def open(self): + """ + Get open price time series. + """ + return self.open_array + + @property + def high(self): + """ + Get high price time series. + """ + return self.high_array + + @property + def low(self): + """ + Get low price time series. + """ + return self.low_array + + @property + def close(self): + """ + Get close price time series. + """ + return self.close_array + + @property + def volume(self): + """ + Get trading volume time series. + """ + return self.volume_array + + def sma(self, n, array=False): + """ + Simple moving average. + """ + result = talib.SMA(self.close, n) + if array: + return result + return result[-1] + + def std(self, n, array=False): + """ + Standard deviation + """ + result = talib.STDDEV(self.close, n) + if array: + return result + return result[-1] + + def cci(self, n, array=False): + """ + Commodity Channel Index (CCI). + """ + result = talib.CCI(self.high, self.low, self.close, n) + if array: + return result + return result[-1] + + def atr(self, n, array=False): + """ + Average True Range (ATR). + """ + result = talib.ATR(self.high, self.low, self.close, n) + if array: + return result + return result[-1] + + def rsi(self, n, array=False): + """ + Relative Strenght Index (RSI). + """ + result = talib.RSI(self.close, n) + if array: + return result + return result[-1] + + def macd(self, fast_period, slow_period, signal_period, array=False): + """ + MACD. + """ + macd, signal, hist = talib.MACD( + self.close, fast_period, slow_period, signal_period + ) + if array: + return macd, signal, hist + return macd[-1], signal[-1], hist[-1] + + def adx(self, n, array=False): + """ + ADX. + """ + result = talib.ADX(self.high, self.low, self.close, n) + if array: + return result + return result[-1] + + def boll(self, n, dev, array=False): + """ + Bollinger Channel. + """ + mid = self.sma(n, array) + std = self.std(n, array) + + up = mid + std * dev + down = mid - std * dev + + return up, down + + def keltner(self, n, dev, array=False): + """ + Keltner Channel. + """ + mid = self.sma(n, array) + atr = self.atr(n, array) + + up = mid + atr * dev + down = mid - atr * dev + + return up, down + + def donchian(self, n, array=False): + """ + Donchian Channel. + """ + up = talib.MAX(self.high, n) + down = talib.MIN(self.low, n) + + if array: + return up, down + return up[-1], down[-1]