[improved] bug fix 可用资金,增加支持港股、美股,币安现货

This commit is contained in:
msincenselee 2020-06-10 09:25:43 +08:00
parent f218c676b1
commit cf3ccfe3b8
7 changed files with 309 additions and 15 deletions

View File

@ -106,6 +106,7 @@ class BackTestingEngine(object):
self.fix_commission = {} # 每手固定手续费
self.size = {} # 合约大小默认为1
self.price_tick = {} # 价格最小变动
self.volume_tick = {} # 合约委托单最小单位
self.margin_rate = {} # 回测合约的保证金比率
self.price_dict = {} # 登记vt_symbol对应的最新价
self.contract_dict = {} # 登记vt_symbol得对应合约信息
@ -161,7 +162,7 @@ class BackTestingEngine(object):
self.net_capital = self.init_capital # 实时资金净值每日根据capital和持仓浮盈计算
self.max_capital = self.init_capital # 资金最高净值
self.max_net_capital = self.init_capital
self.avaliable = self.init_capital
self.available = self.init_capital
self.max_pnl = 0 # 最高盈利
self.min_pnl = 0 # 最大亏损
@ -256,7 +257,7 @@ class BackTestingEngine(object):
if self.net_capital == 0.0:
self.percent = 0.0
return self.net_capital, self.avaliable, self.percent, self.percent_limit
return self.net_capital, self.available, self.percent, self.percent_limit
def set_test_start_date(self, start_date: str = '20100416', init_days: int = 10):
"""设置回测的启动日期"""
@ -289,7 +290,7 @@ class BackTestingEngine(object):
self.net_capital = capital # 实时资金净值每日根据capital和持仓浮盈计算
self.max_capital = capital # 资金最高净值
self.max_net_capital = capital
self.avaliable = capital
self.available = capital
self.init_capital = capital
def set_margin_rate(self, vt_symbol: str, margin_rate: float):
@ -345,8 +346,15 @@ class BackTestingEngine(object):
def get_price_tick(self, vt_symbol: str):
return self.price_tick.get(vt_symbol, 1)
def set_volume_tick(self, vt_symbol: str, volume_tick: float):
"""设置委托单最小单位"""
self.volume_tick.update({vt_symbol: volume_tick})
def get_volume_tick(self, vt_symbol: str):
return self.volume_tick.get(vt_symbol, 1)
def set_contract(self, symbol: str, exchange: Exchange, product: Product, name: str, size: int,
price_tick: float, margin_rate: float = 0.1):
price_tick: float, volume_tick: float = 1, margin_rate: float = 0.1):
"""设置合约信息"""
vt_symbol = '.'.join([symbol, exchange.value])
if vt_symbol not in self.contract_dict:
@ -364,6 +372,7 @@ class BackTestingEngine(object):
self.set_size(vt_symbol, size)
self.set_margin_rate(vt_symbol, margin_rate)
self.set_price_tick(vt_symbol, price_tick)
self.set_volume_tick(vt_symbol, volume_tick)
self.symbol_exchange_dict.update({symbol: exchange})
@lru_cache()
@ -528,7 +537,8 @@ class BackTestingEngine(object):
for symbol, symbol_data in data_dict.items():
self.write_log(u'配置{}数据:{}'.format(symbol, symbol_data))
self.set_price_tick(symbol, symbol_data.get('price_tick', 1))
volume_tick = symbol_data.get('min_volume', symbol_data.get('volume_tick', 1))
self.set_volume_tick(symbol, volume_tick)
self.set_slippage(symbol, symbol_data.get('slippage', 0))
self.set_size(symbol, symbol_data.get('symbol_size', 10))
@ -544,6 +554,7 @@ class BackTestingEngine(object):
product=Product(symbol_data.get('product', "期货")),
size=symbol_data.get('symbol_size', 10),
price_tick=symbol_data.get('price_tick', 1),
volume_tick=volume_tick,
margin_rate=margin_rate
)
@ -1786,7 +1797,7 @@ class BackTestingEngine(object):
occupy_short_money_dict.get(underly_symbol, 0))
# 可用资金 = 当前净值 - 占用保证金
self.avaliable = self.net_capital - occupy_money
self.available = self.net_capital - occupy_money
# 当前保证金占比
self.percent = round(float(occupy_money * 100 / self.net_capital), 2)
# 更新最大保证金占比
@ -1840,7 +1851,7 @@ class BackTestingEngine(object):
self.write_log(msg)
# 重新计算一次avaliable
self.avaliable = self.net_capital - occupy_money
self.available = self.net_capital - occupy_money
self.percent = round(float(occupy_money * 100 / self.net_capital), 2)
def saving_daily_data(self, d, c, m, commission, benchmark=0):

View File

@ -28,7 +28,10 @@ from vnpy.trader.object import (
SubscribeRequest,
LogData,
TickData,
ContractData
ContractData,
HistoryRequest,
Interval,
BarData
)
from vnpy.trader.event import (
EVENT_TIMER,
@ -351,6 +354,7 @@ class CtaEngine(BaseEngine):
# Update GUI
self.put_strategy_event(strategy)
# 如果配置文件 cta_stock_config.json中有trade_2_wx的设置项则发送微信通知
if self.engine_config.get('trade_2_wx', False):
accountid = self.engine_config.get('accountid', 'XXX')
d = {
@ -371,7 +375,6 @@ class CtaEngine(BaseEngine):
self.offset_converter.update_position(position)
def check_unsubscribed_symbols(self):
"""检查未订阅合约"""
@ -809,11 +812,22 @@ class CtaEngine(BaseEngine):
"""查询价格最小跳动"""
contract = self.main_engine.get_contract(vt_symbol)
if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息')
self.write_error(f'查询不到{vt_symbol}合约信息缺省使用0.1作为价格跳动')
return 0.1
return contract.pricetick
@lru_cache()
def get_volume_tick(self, vt_symbol: str):
"""查询合约的最小成交数量"""
contract = self.main_engine.get_contract(vt_symbol)
if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息,缺省使用1作为最小成交数量')
return 1
return contract.min_volume
def get_tick(self, vt_symbol: str):
"""获取合约得最新tick"""
return self.main_engine.get_tick(vt_symbol)
@ -878,6 +892,40 @@ class CtaEngine(BaseEngine):
log_path = os.path.abspath(os.path.join(TRADER_DIR, 'log'))
return log_path
def load_bar(
self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable[[BarData], None],
interval_num: int = 1
):
"""获取历史记录"""
symbol, exchange = extract_vt_symbol(vt_symbol)
end = datetime.now()
start = end - timedelta(days)
bars = []
# Query bars from gateway if available
contract = self.main_engine.get_contract(vt_symbol)
if contract and contract.history_data:
req = HistoryRequest(
symbol=symbol,
exchange=exchange,
interval=interval,
interval_num=interval_num,
start=start,
end=end
)
bars = self.main_engine.query_history(req, contract.gateway_name)
for bar in bars:
if bar.trading_day:
bar.trading_day = bar.datetime.strftime('%Y-%m-%d')
callback(bar)
def call_strategy_func(
self, strategy: CtaTemplate, func: Callable, params: Any = None
):

View File

@ -131,6 +131,9 @@ class PortfolioTestingEngine(BackTestingEngine):
:return:
"""
self.output('comine_df')
if len(self.bar_df_dict) == 0:
self.output(f'无加载任何数据,请检查bar文件路径配置')
self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index()
self.bar_df_dict.clear()

View File

@ -0,0 +1,229 @@
# 币安现货数据
import os
import json
from typing import Dict, List, Any
from datetime import datetime, timedelta
from vnpy.api.rest.rest_client import RestClient
from vnpy.trader.object import (
Interval,
Exchange,
Product,
BarData,
HistoryRequest
)
from vnpy.trader.utility import save_json, load_json
BINANCE_INTERVALS = ["1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w", "1M"]
INTERVAL_VT2BINANCEF: Dict[Interval, str] = {
Interval.MINUTE: "1m",
Interval.HOUR: "1h",
Interval.DAILY: "1d",
}
TIMEDELTA_MAP: Dict[Interval, timedelta] = {
Interval.MINUTE: timedelta(minutes=1),
Interval.HOUR: timedelta(hours=1),
Interval.DAILY: timedelta(days=1),
}
REST_HOST: str = "https://api.binance.com"
class BinanceSpotData(RestClient):
"""现货数据接口"""
def __init__(self, parent=None):
super().__init__()
self.parent = parent
self.init(url_base=REST_HOST)
def write_log(self, msg):
"""日志"""
if self.parent and hasattr(self.parent, 'write_log'):
func = getattr(self.parent, 'write_log')
func(msg)
else:
print(msg)
def get_interval(self, interval, interval_num):
""" =》K线间隔"""
t = interval[-1]
b_interval = f'{interval_num}{t}'
if b_interval not in BINANCE_INTERVALS:
return interval
else:
return b_interval
def get_bars(self,
req: HistoryRequest,
return_dict=True,
) -> List[Any]:
"""获取历史kline"""
bars = []
limit = 1000
start_time = int(datetime.timestamp(req.start))
b_interval = self.get_interval(INTERVAL_VT2BINANCEF[req.interval], req.interval_num)
while True:
# Create query params
params = {
"symbol": req.symbol,
"interval": b_interval,
"limit": limit,
"startTime": start_time * 1000, # convert to millisecond
}
# Add end time if specified
if req.end:
end_time = int(datetime.timestamp(req.end))
params["endTime"] = end_time * 1000 # convert to millisecond
# Get response from server
resp = self.request(
"GET",
"/api/v3/klines",
data={},
params=params
)
# Break if request failed with other status code
if resp.status_code // 100 != 2:
msg = f"获取历史数据失败,状态码:{resp.status_code},信息:{resp.text}"
self.write_log(msg)
break
else:
datas = resp.json()
if not datas:
msg = f"获取历史数据为空,开始时间:{start_time}"
self.write_log(msg)
break
buf = []
begin, end = None, None
for data in datas:
dt = datetime.fromtimestamp(data[0] / 1000) # convert to second
if not begin:
begin = dt
end = dt
if return_dict:
bar = {
"datetime": dt,
"symbol": req.symbol,
"exchange": req.exchange.value,
"vt_symbol": f'{req.symbol}.{req.exchange.value}',
"interval": req.interval.value,
"volume": float(data[5]),
"open": float(data[1]),
"high": float(data[2]),
"low": float(data[3]),
"close": float(data[4]),
"gateway_name": "",
"open_interest": 0,
"trading_day": dt.strftime('%Y-%m-%d')
}
else:
bar = BarData(
symbol=req.symbol,
exchange=req.exchange,
datetime=dt,
trading_day=dt.strftime('%Y-%m-%d'),
interval=req.interval,
volume=float(data[5]),
open_price=float(data[1]),
high_price=float(data[2]),
low_price=float(data[3]),
close_price=float(data[4]),
gateway_name=self.gateway_name
)
buf.append(bar)
bars.extend(buf)
msg = f"获取历史数据成功,{req.symbol} - {b_interval}{begin} - {end}"
self.write_log(msg)
# Break if total data count less than limit (latest date collected)
if len(datas) < limit:
break
# Update start time
start_dt = end + TIMEDELTA_MAP[req.interval] * req.interval_num
start_time = int(datetime.timestamp(start_dt))
return bars
def export_to(self, bars, file_name):
"""导出bar到文件"""
if len(bars) == 0:
self.write_log('not data in bars')
return
import pandas as pd
df = pd.DataFrame(bars)
df = df.set_index('datetime')
df.index.name = 'datetime'
df.to_csv(file_name, index=True)
self.write_log('保存成功')
def get_contracts(self):
contracts = {}
# Get response from server
resp = self.request(
"GET",
"/api/v3/exchangeInfo",
data={}
)
if resp.status_code // 100 != 2:
msg = f"获取交易所失败,状态码:{resp.status_code},信息:{resp.text}"
self.write_log(msg)
else:
data = resp.json()
for d in data["symbols"]:
self.write_log(json.dumps(d, indent=2))
base_currency = d["baseAsset"]
quote_currency = d["quoteAsset"]
name = f"{base_currency.upper()}/{quote_currency.upper()}"
pricetick = 1
min_volume = 1
for f in d["filters"]:
if f["filterType"] == "PRICE_FILTER":
pricetick = float(f["tickSize"])
elif f["filterType"] == "LOT_SIZE":
min_volume = float(f["stepSize"])
contract = {
"symbol": d["symbol"],
"exchange": Exchange.BINANCE.value,
"vt_symbol": d["symbol"] + '.' + Exchange.BINANCE.value,
"name": name,
"price_tick": pricetick,
"symbol_size": 20,
"margin_rate": 1,
"min_volume": min_volume,
"product": Product.SPOT.value,
"commission_rate": 0.005
}
contracts.update({contract.get('vt_symbol'): contract})
return contracts
@classmethod
def load_contracts(self):
"""读取本地配置文件获取期货合约配置"""
f = os.path.abspath(os.path.join(os.path.dirname(__file__), 'spot_contracts.json'))
contracts = load_json(f, auto_save=False)
return contracts
def save_contracts(self):
"""保存合约配置"""
contracts = self.get_contracts()
if len(contracts) > 0:
f = os.path.abspath(os.path.join(os.path.dirname(__file__), 'spot_contracts.json'))
save_json(f, contracts)
self.write_log(f'保存合约配置=>{f}')

View File

@ -398,7 +398,7 @@ class TdxFutureData(object):
add_bar.low_price = float(row['low'])
add_bar.close_price = float(row['close'])
add_bar.volume = float(row['volume'])
add_bar.openInterest = float(row['open_interest'])
add_bar.open_interest = float(row['open_interest'])
except Exception as ex:
self.write_error(
'error when convert bar:{},ex:{},t:{}'.format(row, str(ex), traceback.format_exc()))

View File

@ -453,9 +453,12 @@ class OmsEngine(BaseEngine):
contract_file_name = 'vn_contract.pkb2'
if not os.path.exists(contract_file_name):
return
with bz2.BZ2File(contract_file_name, 'rb') as f:
self.contracts = pickle.load(f)
self.write_log(f'加载缓存合约字典:{contract_file_name}')
try:
with bz2.BZ2File(contract_file_name, 'rb') as f:
self.contracts = pickle.load(f)
self.write_log(f'加载缓存合约字典:{contract_file_name}')
except Exception as ex:
self.write_log(f'加载缓存合约异常:{str(ex)}')
# 更新自定义合约
custom_contracts = self.get_all_custom_contracts()

View File

@ -74,7 +74,7 @@ def get_underlying_symbol(symbol: str):
p = re.compile(r"([A-Z]+)[0-9]+", re.I)
underlying_symbol = p.match(symbol)
if underlying_symbol is None:
if underlying_symbol is None or len(underlying_symbol) == 0:
return symbol
return underlying_symbol.group(1)