From cf3ccfe3b844273a3e4b3b27b6739a5159beea4b Mon Sep 17 00:00:00 2001 From: msincenselee Date: Wed, 10 Jun 2020 09:25:43 +0800 Subject: [PATCH] =?UTF-8?q?[improved]=20bug=20fix=20=E5=8F=AF=E7=94=A8?= =?UTF-8?q?=E8=B5=84=E9=87=91,=E5=A2=9E=E5=8A=A0=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E6=B8=AF=E8=82=A1=E3=80=81=E7=BE=8E=E8=82=A1=EF=BC=8C=E5=B8=81?= =?UTF-8?q?=E5=AE=89=E7=8E=B0=E8=B4=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnpy/app/cta_strategy_pro/back_testing.py | 25 +- vnpy/app/cta_strategy_pro/engine.py | 54 ++++- .../app/cta_strategy_pro/portfolio_testing.py | 3 + vnpy/data/binance/binance_spot_data.py | 229 ++++++++++++++++++ vnpy/data/tdx/tdx_future_data.py | 2 +- vnpy/trader/engine.py | 9 +- vnpy/trader/utility.py | 2 +- 7 files changed, 309 insertions(+), 15 deletions(-) create mode 100644 vnpy/data/binance/binance_spot_data.py diff --git a/vnpy/app/cta_strategy_pro/back_testing.py b/vnpy/app/cta_strategy_pro/back_testing.py index e87be15c..0c844d05 100644 --- a/vnpy/app/cta_strategy_pro/back_testing.py +++ b/vnpy/app/cta_strategy_pro/back_testing.py @@ -106,6 +106,7 @@ class BackTestingEngine(object): self.fix_commission = {} # 每手固定手续费 self.size = {} # 合约大小,默认为1 self.price_tick = {} # 价格最小变动 + self.volume_tick = {} # 合约委托单最小单位 self.margin_rate = {} # 回测合约的保证金比率 self.price_dict = {} # 登记vt_symbol对应的最新价 self.contract_dict = {} # 登记vt_symbol得对应合约信息 @@ -161,7 +162,7 @@ class BackTestingEngine(object): self.net_capital = self.init_capital # 实时资金净值(每日根据capital和持仓浮盈计算) self.max_capital = self.init_capital # 资金最高净值 self.max_net_capital = self.init_capital - self.avaliable = self.init_capital + self.available = self.init_capital self.max_pnl = 0 # 最高盈利 self.min_pnl = 0 # 最大亏损 @@ -256,7 +257,7 @@ class BackTestingEngine(object): if self.net_capital == 0.0: self.percent = 0.0 - return self.net_capital, self.avaliable, self.percent, self.percent_limit + return self.net_capital, self.available, self.percent, self.percent_limit def set_test_start_date(self, start_date: str = '20100416', init_days: int = 10): """设置回测的启动日期""" @@ -289,7 +290,7 @@ class BackTestingEngine(object): self.net_capital = capital # 实时资金净值(每日根据capital和持仓浮盈计算) self.max_capital = capital # 资金最高净值 self.max_net_capital = capital - self.avaliable = capital + self.available = capital self.init_capital = capital def set_margin_rate(self, vt_symbol: str, margin_rate: float): @@ -345,8 +346,15 @@ class BackTestingEngine(object): def get_price_tick(self, vt_symbol: str): return self.price_tick.get(vt_symbol, 1) + def set_volume_tick(self, vt_symbol: str, volume_tick: float): + """设置委托单最小单位""" + self.volume_tick.update({vt_symbol: volume_tick}) + + def get_volume_tick(self, vt_symbol: str): + return self.volume_tick.get(vt_symbol, 1) + def set_contract(self, symbol: str, exchange: Exchange, product: Product, name: str, size: int, - price_tick: float, margin_rate: float = 0.1): + price_tick: float, volume_tick: float = 1, margin_rate: float = 0.1): """设置合约信息""" vt_symbol = '.'.join([symbol, exchange.value]) if vt_symbol not in self.contract_dict: @@ -364,6 +372,7 @@ class BackTestingEngine(object): self.set_size(vt_symbol, size) self.set_margin_rate(vt_symbol, margin_rate) self.set_price_tick(vt_symbol, price_tick) + self.set_volume_tick(vt_symbol, volume_tick) self.symbol_exchange_dict.update({symbol: exchange}) @lru_cache() @@ -528,7 +537,8 @@ class BackTestingEngine(object): for symbol, symbol_data in data_dict.items(): self.write_log(u'配置{}数据:{}'.format(symbol, symbol_data)) self.set_price_tick(symbol, symbol_data.get('price_tick', 1)) - + volume_tick = symbol_data.get('min_volume', symbol_data.get('volume_tick', 1)) + self.set_volume_tick(symbol, volume_tick) self.set_slippage(symbol, symbol_data.get('slippage', 0)) self.set_size(symbol, symbol_data.get('symbol_size', 10)) @@ -544,6 +554,7 @@ class BackTestingEngine(object): product=Product(symbol_data.get('product', "期货")), size=symbol_data.get('symbol_size', 10), price_tick=symbol_data.get('price_tick', 1), + volume_tick=volume_tick, margin_rate=margin_rate ) @@ -1786,7 +1797,7 @@ class BackTestingEngine(object): occupy_short_money_dict.get(underly_symbol, 0)) # 可用资金 = 当前净值 - 占用保证金 - self.avaliable = self.net_capital - occupy_money + self.available = self.net_capital - occupy_money # 当前保证金占比 self.percent = round(float(occupy_money * 100 / self.net_capital), 2) # 更新最大保证金占比 @@ -1840,7 +1851,7 @@ class BackTestingEngine(object): self.write_log(msg) # 重新计算一次avaliable - self.avaliable = self.net_capital - occupy_money + self.available = self.net_capital - occupy_money self.percent = round(float(occupy_money * 100 / self.net_capital), 2) def saving_daily_data(self, d, c, m, commission, benchmark=0): diff --git a/vnpy/app/cta_strategy_pro/engine.py b/vnpy/app/cta_strategy_pro/engine.py index 6eee0676..367e5ba0 100644 --- a/vnpy/app/cta_strategy_pro/engine.py +++ b/vnpy/app/cta_strategy_pro/engine.py @@ -28,7 +28,10 @@ from vnpy.trader.object import ( SubscribeRequest, LogData, TickData, - ContractData + ContractData, + HistoryRequest, + Interval, + BarData ) from vnpy.trader.event import ( EVENT_TIMER, @@ -351,6 +354,7 @@ class CtaEngine(BaseEngine): # Update GUI self.put_strategy_event(strategy) + # 如果配置文件 cta_stock_config.json中,有trade_2_wx的设置项,则发送微信通知 if self.engine_config.get('trade_2_wx', False): accountid = self.engine_config.get('accountid', 'XXX') d = { @@ -371,7 +375,6 @@ class CtaEngine(BaseEngine): self.offset_converter.update_position(position) - def check_unsubscribed_symbols(self): """检查未订阅合约""" @@ -809,11 +812,22 @@ class CtaEngine(BaseEngine): """查询价格最小跳动""" contract = self.main_engine.get_contract(vt_symbol) if contract is None: - self.write_error(f'查询不到{vt_symbol}合约信息') + self.write_error(f'查询不到{vt_symbol}合约信息,缺省使用0.1作为价格跳动') return 0.1 return contract.pricetick + @lru_cache() + def get_volume_tick(self, vt_symbol: str): + """查询合约的最小成交数量""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息,缺省使用1作为最小成交数量') + return 1 + + return contract.min_volume + + def get_tick(self, vt_symbol: str): """获取合约得最新tick""" return self.main_engine.get_tick(vt_symbol) @@ -878,6 +892,40 @@ class CtaEngine(BaseEngine): log_path = os.path.abspath(os.path.join(TRADER_DIR, 'log')) return log_path + def load_bar( + self, + vt_symbol: str, + days: int, + interval: Interval, + callback: Callable[[BarData], None], + interval_num: int = 1 + ): + """获取历史记录""" + symbol, exchange = extract_vt_symbol(vt_symbol) + end = datetime.now() + start = end - timedelta(days) + bars = [] + + # Query bars from gateway if available + contract = self.main_engine.get_contract(vt_symbol) + + if contract and contract.history_data: + req = HistoryRequest( + symbol=symbol, + exchange=exchange, + interval=interval, + interval_num=interval_num, + start=start, + end=end + ) + bars = self.main_engine.query_history(req, contract.gateway_name) + + for bar in bars: + if bar.trading_day: + bar.trading_day = bar.datetime.strftime('%Y-%m-%d') + + callback(bar) + def call_strategy_func( self, strategy: CtaTemplate, func: Callable, params: Any = None ): diff --git a/vnpy/app/cta_strategy_pro/portfolio_testing.py b/vnpy/app/cta_strategy_pro/portfolio_testing.py index 05de1207..86119f7a 100644 --- a/vnpy/app/cta_strategy_pro/portfolio_testing.py +++ b/vnpy/app/cta_strategy_pro/portfolio_testing.py @@ -131,6 +131,9 @@ class PortfolioTestingEngine(BackTestingEngine): :return: """ self.output('comine_df') + if len(self.bar_df_dict) == 0: + self.output(f'无加载任何数据,请检查bar文件路径配置') + self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index() self.bar_df_dict.clear() diff --git a/vnpy/data/binance/binance_spot_data.py b/vnpy/data/binance/binance_spot_data.py new file mode 100644 index 00000000..7e791f0a --- /dev/null +++ b/vnpy/data/binance/binance_spot_data.py @@ -0,0 +1,229 @@ +# 币安现货数据 + +import os +import json +from typing import Dict, List, Any +from datetime import datetime, timedelta +from vnpy.api.rest.rest_client import RestClient +from vnpy.trader.object import ( + Interval, + Exchange, + Product, + BarData, + HistoryRequest +) +from vnpy.trader.utility import save_json, load_json + +BINANCE_INTERVALS = ["1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w", "1M"] + +INTERVAL_VT2BINANCEF: Dict[Interval, str] = { + Interval.MINUTE: "1m", + Interval.HOUR: "1h", + Interval.DAILY: "1d", +} + +TIMEDELTA_MAP: Dict[Interval, timedelta] = { + Interval.MINUTE: timedelta(minutes=1), + Interval.HOUR: timedelta(hours=1), + Interval.DAILY: timedelta(days=1), +} + +REST_HOST: str = "https://api.binance.com" + + +class BinanceSpotData(RestClient): + """现货数据接口""" + + def __init__(self, parent=None): + super().__init__() + self.parent = parent + self.init(url_base=REST_HOST) + + def write_log(self, msg): + """日志""" + if self.parent and hasattr(self.parent, 'write_log'): + func = getattr(self.parent, 'write_log') + func(msg) + else: + print(msg) + + def get_interval(self, interval, interval_num): + """ =》K线间隔""" + t = interval[-1] + b_interval = f'{interval_num}{t}' + if b_interval not in BINANCE_INTERVALS: + return interval + else: + return b_interval + + def get_bars(self, + req: HistoryRequest, + return_dict=True, + ) -> List[Any]: + """获取历史kline""" + bars = [] + limit = 1000 + start_time = int(datetime.timestamp(req.start)) + b_interval = self.get_interval(INTERVAL_VT2BINANCEF[req.interval], req.interval_num) + while True: + # Create query params + params = { + "symbol": req.symbol, + "interval": b_interval, + "limit": limit, + "startTime": start_time * 1000, # convert to millisecond + } + + # Add end time if specified + if req.end: + end_time = int(datetime.timestamp(req.end)) + params["endTime"] = end_time * 1000 # convert to millisecond + + # Get response from server + resp = self.request( + "GET", + "/api/v3/klines", + data={}, + params=params + ) + + # Break if request failed with other status code + if resp.status_code // 100 != 2: + msg = f"获取历史数据失败,状态码:{resp.status_code},信息:{resp.text}" + self.write_log(msg) + break + else: + datas = resp.json() + if not datas: + msg = f"获取历史数据为空,开始时间:{start_time}" + self.write_log(msg) + break + + buf = [] + begin, end = None, None + for data in datas: + dt = datetime.fromtimestamp(data[0] / 1000) # convert to second + if not begin: + begin = dt + end = dt + if return_dict: + bar = { + "datetime": dt, + "symbol": req.symbol, + "exchange": req.exchange.value, + "vt_symbol": f'{req.symbol}.{req.exchange.value}', + "interval": req.interval.value, + "volume": float(data[5]), + "open": float(data[1]), + "high": float(data[2]), + "low": float(data[3]), + "close": float(data[4]), + "gateway_name": "", + "open_interest": 0, + "trading_day": dt.strftime('%Y-%m-%d') + } + else: + bar = BarData( + symbol=req.symbol, + exchange=req.exchange, + datetime=dt, + trading_day=dt.strftime('%Y-%m-%d'), + interval=req.interval, + volume=float(data[5]), + open_price=float(data[1]), + high_price=float(data[2]), + low_price=float(data[3]), + close_price=float(data[4]), + gateway_name=self.gateway_name + ) + buf.append(bar) + + bars.extend(buf) + + msg = f"获取历史数据成功,{req.symbol} - {b_interval},{begin} - {end}" + self.write_log(msg) + + # Break if total data count less than limit (latest date collected) + if len(datas) < limit: + break + + # Update start time + start_dt = end + TIMEDELTA_MAP[req.interval] * req.interval_num + start_time = int(datetime.timestamp(start_dt)) + + return bars + + def export_to(self, bars, file_name): + """导出bar到文件""" + if len(bars) == 0: + self.write_log('not data in bars') + return + + import pandas as pd + df = pd.DataFrame(bars) + df = df.set_index('datetime') + df.index.name = 'datetime' + df.to_csv(file_name, index=True) + self.write_log('保存成功') + + def get_contracts(self): + + contracts = {} + # Get response from server + resp = self.request( + "GET", + "/api/v3/exchangeInfo", + data={} + ) + if resp.status_code // 100 != 2: + msg = f"获取交易所失败,状态码:{resp.status_code},信息:{resp.text}" + self.write_log(msg) + else: + data = resp.json() + for d in data["symbols"]: + self.write_log(json.dumps(d, indent=2)) + base_currency = d["baseAsset"] + quote_currency = d["quoteAsset"] + name = f"{base_currency.upper()}/{quote_currency.upper()}" + + pricetick = 1 + min_volume = 1 + + for f in d["filters"]: + if f["filterType"] == "PRICE_FILTER": + pricetick = float(f["tickSize"]) + elif f["filterType"] == "LOT_SIZE": + min_volume = float(f["stepSize"]) + + contract = { + "symbol": d["symbol"], + "exchange": Exchange.BINANCE.value, + "vt_symbol": d["symbol"] + '.' + Exchange.BINANCE.value, + "name": name, + "price_tick": pricetick, + "symbol_size": 20, + "margin_rate": 1, + "min_volume": min_volume, + "product": Product.SPOT.value, + "commission_rate": 0.005 + } + + contracts.update({contract.get('vt_symbol'): contract}) + + return contracts + + @classmethod + def load_contracts(self): + """读取本地配置文件获取期货合约配置""" + f = os.path.abspath(os.path.join(os.path.dirname(__file__), 'spot_contracts.json')) + contracts = load_json(f, auto_save=False) + return contracts + + def save_contracts(self): + """保存合约配置""" + contracts = self.get_contracts() + + if len(contracts) > 0: + f = os.path.abspath(os.path.join(os.path.dirname(__file__), 'spot_contracts.json')) + save_json(f, contracts) + self.write_log(f'保存合约配置=>{f}') diff --git a/vnpy/data/tdx/tdx_future_data.py b/vnpy/data/tdx/tdx_future_data.py index 43053caa..693f4971 100644 --- a/vnpy/data/tdx/tdx_future_data.py +++ b/vnpy/data/tdx/tdx_future_data.py @@ -398,7 +398,7 @@ class TdxFutureData(object): add_bar.low_price = float(row['low']) add_bar.close_price = float(row['close']) add_bar.volume = float(row['volume']) - add_bar.openInterest = float(row['open_interest']) + add_bar.open_interest = float(row['open_interest']) except Exception as ex: self.write_error( 'error when convert bar:{},ex:{},t:{}'.format(row, str(ex), traceback.format_exc())) diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index 519c4d00..4db59f6a 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -453,9 +453,12 @@ class OmsEngine(BaseEngine): contract_file_name = 'vn_contract.pkb2' if not os.path.exists(contract_file_name): return - with bz2.BZ2File(contract_file_name, 'rb') as f: - self.contracts = pickle.load(f) - self.write_log(f'加载缓存合约字典:{contract_file_name}') + try: + with bz2.BZ2File(contract_file_name, 'rb') as f: + self.contracts = pickle.load(f) + self.write_log(f'加载缓存合约字典:{contract_file_name}') + except Exception as ex: + self.write_log(f'加载缓存合约异常:{str(ex)}') # 更新自定义合约 custom_contracts = self.get_all_custom_contracts() diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index fef74bca..f37b826a 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -74,7 +74,7 @@ def get_underlying_symbol(symbol: str): p = re.compile(r"([A-Z]+)[0-9]+", re.I) underlying_symbol = p.match(symbol) - if underlying_symbol is None: + if underlying_symbol is None or len(underlying_symbol) == 0: return symbol return underlying_symbol.group(1)