[improved] bug fix 可用资金,增加支持港股、美股,币安现货
This commit is contained in:
parent
f218c676b1
commit
cf3ccfe3b8
@ -106,6 +106,7 @@ class BackTestingEngine(object):
|
||||
self.fix_commission = {} # 每手固定手续费
|
||||
self.size = {} # 合约大小,默认为1
|
||||
self.price_tick = {} # 价格最小变动
|
||||
self.volume_tick = {} # 合约委托单最小单位
|
||||
self.margin_rate = {} # 回测合约的保证金比率
|
||||
self.price_dict = {} # 登记vt_symbol对应的最新价
|
||||
self.contract_dict = {} # 登记vt_symbol得对应合约信息
|
||||
@ -161,7 +162,7 @@ class BackTestingEngine(object):
|
||||
self.net_capital = self.init_capital # 实时资金净值(每日根据capital和持仓浮盈计算)
|
||||
self.max_capital = self.init_capital # 资金最高净值
|
||||
self.max_net_capital = self.init_capital
|
||||
self.avaliable = self.init_capital
|
||||
self.available = self.init_capital
|
||||
|
||||
self.max_pnl = 0 # 最高盈利
|
||||
self.min_pnl = 0 # 最大亏损
|
||||
@ -256,7 +257,7 @@ class BackTestingEngine(object):
|
||||
if self.net_capital == 0.0:
|
||||
self.percent = 0.0
|
||||
|
||||
return self.net_capital, self.avaliable, self.percent, self.percent_limit
|
||||
return self.net_capital, self.available, self.percent, self.percent_limit
|
||||
|
||||
def set_test_start_date(self, start_date: str = '20100416', init_days: int = 10):
|
||||
"""设置回测的启动日期"""
|
||||
@ -289,7 +290,7 @@ class BackTestingEngine(object):
|
||||
self.net_capital = capital # 实时资金净值(每日根据capital和持仓浮盈计算)
|
||||
self.max_capital = capital # 资金最高净值
|
||||
self.max_net_capital = capital
|
||||
self.avaliable = capital
|
||||
self.available = capital
|
||||
self.init_capital = capital
|
||||
|
||||
def set_margin_rate(self, vt_symbol: str, margin_rate: float):
|
||||
@ -345,8 +346,15 @@ class BackTestingEngine(object):
|
||||
def get_price_tick(self, vt_symbol: str):
|
||||
return self.price_tick.get(vt_symbol, 1)
|
||||
|
||||
def set_volume_tick(self, vt_symbol: str, volume_tick: float):
|
||||
"""设置委托单最小单位"""
|
||||
self.volume_tick.update({vt_symbol: volume_tick})
|
||||
|
||||
def get_volume_tick(self, vt_symbol: str):
|
||||
return self.volume_tick.get(vt_symbol, 1)
|
||||
|
||||
def set_contract(self, symbol: str, exchange: Exchange, product: Product, name: str, size: int,
|
||||
price_tick: float, margin_rate: float = 0.1):
|
||||
price_tick: float, volume_tick: float = 1, margin_rate: float = 0.1):
|
||||
"""设置合约信息"""
|
||||
vt_symbol = '.'.join([symbol, exchange.value])
|
||||
if vt_symbol not in self.contract_dict:
|
||||
@ -364,6 +372,7 @@ class BackTestingEngine(object):
|
||||
self.set_size(vt_symbol, size)
|
||||
self.set_margin_rate(vt_symbol, margin_rate)
|
||||
self.set_price_tick(vt_symbol, price_tick)
|
||||
self.set_volume_tick(vt_symbol, volume_tick)
|
||||
self.symbol_exchange_dict.update({symbol: exchange})
|
||||
|
||||
@lru_cache()
|
||||
@ -528,7 +537,8 @@ class BackTestingEngine(object):
|
||||
for symbol, symbol_data in data_dict.items():
|
||||
self.write_log(u'配置{}数据:{}'.format(symbol, symbol_data))
|
||||
self.set_price_tick(symbol, symbol_data.get('price_tick', 1))
|
||||
|
||||
volume_tick = symbol_data.get('min_volume', symbol_data.get('volume_tick', 1))
|
||||
self.set_volume_tick(symbol, volume_tick)
|
||||
self.set_slippage(symbol, symbol_data.get('slippage', 0))
|
||||
|
||||
self.set_size(symbol, symbol_data.get('symbol_size', 10))
|
||||
@ -544,6 +554,7 @@ class BackTestingEngine(object):
|
||||
product=Product(symbol_data.get('product', "期货")),
|
||||
size=symbol_data.get('symbol_size', 10),
|
||||
price_tick=symbol_data.get('price_tick', 1),
|
||||
volume_tick=volume_tick,
|
||||
margin_rate=margin_rate
|
||||
)
|
||||
|
||||
@ -1786,7 +1797,7 @@ class BackTestingEngine(object):
|
||||
occupy_short_money_dict.get(underly_symbol, 0))
|
||||
|
||||
# 可用资金 = 当前净值 - 占用保证金
|
||||
self.avaliable = self.net_capital - occupy_money
|
||||
self.available = self.net_capital - occupy_money
|
||||
# 当前保证金占比
|
||||
self.percent = round(float(occupy_money * 100 / self.net_capital), 2)
|
||||
# 更新最大保证金占比
|
||||
@ -1840,7 +1851,7 @@ class BackTestingEngine(object):
|
||||
self.write_log(msg)
|
||||
|
||||
# 重新计算一次avaliable
|
||||
self.avaliable = self.net_capital - occupy_money
|
||||
self.available = self.net_capital - occupy_money
|
||||
self.percent = round(float(occupy_money * 100 / self.net_capital), 2)
|
||||
|
||||
def saving_daily_data(self, d, c, m, commission, benchmark=0):
|
||||
|
@ -28,7 +28,10 @@ from vnpy.trader.object import (
|
||||
SubscribeRequest,
|
||||
LogData,
|
||||
TickData,
|
||||
ContractData
|
||||
ContractData,
|
||||
HistoryRequest,
|
||||
Interval,
|
||||
BarData
|
||||
)
|
||||
from vnpy.trader.event import (
|
||||
EVENT_TIMER,
|
||||
@ -351,6 +354,7 @@ class CtaEngine(BaseEngine):
|
||||
# Update GUI
|
||||
self.put_strategy_event(strategy)
|
||||
|
||||
# 如果配置文件 cta_stock_config.json中,有trade_2_wx的设置项,则发送微信通知
|
||||
if self.engine_config.get('trade_2_wx', False):
|
||||
accountid = self.engine_config.get('accountid', 'XXX')
|
||||
d = {
|
||||
@ -371,7 +375,6 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
self.offset_converter.update_position(position)
|
||||
|
||||
|
||||
def check_unsubscribed_symbols(self):
|
||||
"""检查未订阅合约"""
|
||||
|
||||
@ -809,11 +812,22 @@ class CtaEngine(BaseEngine):
|
||||
"""查询价格最小跳动"""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if contract is None:
|
||||
self.write_error(f'查询不到{vt_symbol}合约信息')
|
||||
self.write_error(f'查询不到{vt_symbol}合约信息,缺省使用0.1作为价格跳动')
|
||||
return 0.1
|
||||
|
||||
return contract.pricetick
|
||||
|
||||
@lru_cache()
|
||||
def get_volume_tick(self, vt_symbol: str):
|
||||
"""查询合约的最小成交数量"""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if contract is None:
|
||||
self.write_error(f'查询不到{vt_symbol}合约信息,缺省使用1作为最小成交数量')
|
||||
return 1
|
||||
|
||||
return contract.min_volume
|
||||
|
||||
|
||||
def get_tick(self, vt_symbol: str):
|
||||
"""获取合约得最新tick"""
|
||||
return self.main_engine.get_tick(vt_symbol)
|
||||
@ -878,6 +892,40 @@ class CtaEngine(BaseEngine):
|
||||
log_path = os.path.abspath(os.path.join(TRADER_DIR, 'log'))
|
||||
return log_path
|
||||
|
||||
def load_bar(
|
||||
self,
|
||||
vt_symbol: str,
|
||||
days: int,
|
||||
interval: Interval,
|
||||
callback: Callable[[BarData], None],
|
||||
interval_num: int = 1
|
||||
):
|
||||
"""获取历史记录"""
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
end = datetime.now()
|
||||
start = end - timedelta(days)
|
||||
bars = []
|
||||
|
||||
# Query bars from gateway if available
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
|
||||
if contract and contract.history_data:
|
||||
req = HistoryRequest(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
interval_num=interval_num,
|
||||
start=start,
|
||||
end=end
|
||||
)
|
||||
bars = self.main_engine.query_history(req, contract.gateway_name)
|
||||
|
||||
for bar in bars:
|
||||
if bar.trading_day:
|
||||
bar.trading_day = bar.datetime.strftime('%Y-%m-%d')
|
||||
|
||||
callback(bar)
|
||||
|
||||
def call_strategy_func(
|
||||
self, strategy: CtaTemplate, func: Callable, params: Any = None
|
||||
):
|
||||
|
@ -131,6 +131,9 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
:return:
|
||||
"""
|
||||
self.output('comine_df')
|
||||
if len(self.bar_df_dict) == 0:
|
||||
self.output(f'无加载任何数据,请检查bar文件路径配置')
|
||||
|
||||
self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index()
|
||||
self.bar_df_dict.clear()
|
||||
|
||||
|
229
vnpy/data/binance/binance_spot_data.py
Normal file
229
vnpy/data/binance/binance_spot_data.py
Normal file
@ -0,0 +1,229 @@
|
||||
# 币安现货数据
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Any
|
||||
from datetime import datetime, timedelta
|
||||
from vnpy.api.rest.rest_client import RestClient
|
||||
from vnpy.trader.object import (
|
||||
Interval,
|
||||
Exchange,
|
||||
Product,
|
||||
BarData,
|
||||
HistoryRequest
|
||||
)
|
||||
from vnpy.trader.utility import save_json, load_json
|
||||
|
||||
BINANCE_INTERVALS = ["1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w", "1M"]
|
||||
|
||||
INTERVAL_VT2BINANCEF: Dict[Interval, str] = {
|
||||
Interval.MINUTE: "1m",
|
||||
Interval.HOUR: "1h",
|
||||
Interval.DAILY: "1d",
|
||||
}
|
||||
|
||||
TIMEDELTA_MAP: Dict[Interval, timedelta] = {
|
||||
Interval.MINUTE: timedelta(minutes=1),
|
||||
Interval.HOUR: timedelta(hours=1),
|
||||
Interval.DAILY: timedelta(days=1),
|
||||
}
|
||||
|
||||
REST_HOST: str = "https://api.binance.com"
|
||||
|
||||
|
||||
class BinanceSpotData(RestClient):
|
||||
"""现货数据接口"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__()
|
||||
self.parent = parent
|
||||
self.init(url_base=REST_HOST)
|
||||
|
||||
def write_log(self, msg):
|
||||
"""日志"""
|
||||
if self.parent and hasattr(self.parent, 'write_log'):
|
||||
func = getattr(self.parent, 'write_log')
|
||||
func(msg)
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
def get_interval(self, interval, interval_num):
|
||||
""" =》K线间隔"""
|
||||
t = interval[-1]
|
||||
b_interval = f'{interval_num}{t}'
|
||||
if b_interval not in BINANCE_INTERVALS:
|
||||
return interval
|
||||
else:
|
||||
return b_interval
|
||||
|
||||
def get_bars(self,
|
||||
req: HistoryRequest,
|
||||
return_dict=True,
|
||||
) -> List[Any]:
|
||||
"""获取历史kline"""
|
||||
bars = []
|
||||
limit = 1000
|
||||
start_time = int(datetime.timestamp(req.start))
|
||||
b_interval = self.get_interval(INTERVAL_VT2BINANCEF[req.interval], req.interval_num)
|
||||
while True:
|
||||
# Create query params
|
||||
params = {
|
||||
"symbol": req.symbol,
|
||||
"interval": b_interval,
|
||||
"limit": limit,
|
||||
"startTime": start_time * 1000, # convert to millisecond
|
||||
}
|
||||
|
||||
# Add end time if specified
|
||||
if req.end:
|
||||
end_time = int(datetime.timestamp(req.end))
|
||||
params["endTime"] = end_time * 1000 # convert to millisecond
|
||||
|
||||
# Get response from server
|
||||
resp = self.request(
|
||||
"GET",
|
||||
"/api/v3/klines",
|
||||
data={},
|
||||
params=params
|
||||
)
|
||||
|
||||
# Break if request failed with other status code
|
||||
if resp.status_code // 100 != 2:
|
||||
msg = f"获取历史数据失败,状态码:{resp.status_code},信息:{resp.text}"
|
||||
self.write_log(msg)
|
||||
break
|
||||
else:
|
||||
datas = resp.json()
|
||||
if not datas:
|
||||
msg = f"获取历史数据为空,开始时间:{start_time}"
|
||||
self.write_log(msg)
|
||||
break
|
||||
|
||||
buf = []
|
||||
begin, end = None, None
|
||||
for data in datas:
|
||||
dt = datetime.fromtimestamp(data[0] / 1000) # convert to second
|
||||
if not begin:
|
||||
begin = dt
|
||||
end = dt
|
||||
if return_dict:
|
||||
bar = {
|
||||
"datetime": dt,
|
||||
"symbol": req.symbol,
|
||||
"exchange": req.exchange.value,
|
||||
"vt_symbol": f'{req.symbol}.{req.exchange.value}',
|
||||
"interval": req.interval.value,
|
||||
"volume": float(data[5]),
|
||||
"open": float(data[1]),
|
||||
"high": float(data[2]),
|
||||
"low": float(data[3]),
|
||||
"close": float(data[4]),
|
||||
"gateway_name": "",
|
||||
"open_interest": 0,
|
||||
"trading_day": dt.strftime('%Y-%m-%d')
|
||||
}
|
||||
else:
|
||||
bar = BarData(
|
||||
symbol=req.symbol,
|
||||
exchange=req.exchange,
|
||||
datetime=dt,
|
||||
trading_day=dt.strftime('%Y-%m-%d'),
|
||||
interval=req.interval,
|
||||
volume=float(data[5]),
|
||||
open_price=float(data[1]),
|
||||
high_price=float(data[2]),
|
||||
low_price=float(data[3]),
|
||||
close_price=float(data[4]),
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
buf.append(bar)
|
||||
|
||||
bars.extend(buf)
|
||||
|
||||
msg = f"获取历史数据成功,{req.symbol} - {b_interval},{begin} - {end}"
|
||||
self.write_log(msg)
|
||||
|
||||
# Break if total data count less than limit (latest date collected)
|
||||
if len(datas) < limit:
|
||||
break
|
||||
|
||||
# Update start time
|
||||
start_dt = end + TIMEDELTA_MAP[req.interval] * req.interval_num
|
||||
start_time = int(datetime.timestamp(start_dt))
|
||||
|
||||
return bars
|
||||
|
||||
def export_to(self, bars, file_name):
|
||||
"""导出bar到文件"""
|
||||
if len(bars) == 0:
|
||||
self.write_log('not data in bars')
|
||||
return
|
||||
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(bars)
|
||||
df = df.set_index('datetime')
|
||||
df.index.name = 'datetime'
|
||||
df.to_csv(file_name, index=True)
|
||||
self.write_log('保存成功')
|
||||
|
||||
def get_contracts(self):
|
||||
|
||||
contracts = {}
|
||||
# Get response from server
|
||||
resp = self.request(
|
||||
"GET",
|
||||
"/api/v3/exchangeInfo",
|
||||
data={}
|
||||
)
|
||||
if resp.status_code // 100 != 2:
|
||||
msg = f"获取交易所失败,状态码:{resp.status_code},信息:{resp.text}"
|
||||
self.write_log(msg)
|
||||
else:
|
||||
data = resp.json()
|
||||
for d in data["symbols"]:
|
||||
self.write_log(json.dumps(d, indent=2))
|
||||
base_currency = d["baseAsset"]
|
||||
quote_currency = d["quoteAsset"]
|
||||
name = f"{base_currency.upper()}/{quote_currency.upper()}"
|
||||
|
||||
pricetick = 1
|
||||
min_volume = 1
|
||||
|
||||
for f in d["filters"]:
|
||||
if f["filterType"] == "PRICE_FILTER":
|
||||
pricetick = float(f["tickSize"])
|
||||
elif f["filterType"] == "LOT_SIZE":
|
||||
min_volume = float(f["stepSize"])
|
||||
|
||||
contract = {
|
||||
"symbol": d["symbol"],
|
||||
"exchange": Exchange.BINANCE.value,
|
||||
"vt_symbol": d["symbol"] + '.' + Exchange.BINANCE.value,
|
||||
"name": name,
|
||||
"price_tick": pricetick,
|
||||
"symbol_size": 20,
|
||||
"margin_rate": 1,
|
||||
"min_volume": min_volume,
|
||||
"product": Product.SPOT.value,
|
||||
"commission_rate": 0.005
|
||||
}
|
||||
|
||||
contracts.update({contract.get('vt_symbol'): contract})
|
||||
|
||||
return contracts
|
||||
|
||||
@classmethod
|
||||
def load_contracts(self):
|
||||
"""读取本地配置文件获取期货合约配置"""
|
||||
f = os.path.abspath(os.path.join(os.path.dirname(__file__), 'spot_contracts.json'))
|
||||
contracts = load_json(f, auto_save=False)
|
||||
return contracts
|
||||
|
||||
def save_contracts(self):
|
||||
"""保存合约配置"""
|
||||
contracts = self.get_contracts()
|
||||
|
||||
if len(contracts) > 0:
|
||||
f = os.path.abspath(os.path.join(os.path.dirname(__file__), 'spot_contracts.json'))
|
||||
save_json(f, contracts)
|
||||
self.write_log(f'保存合约配置=>{f}')
|
@ -398,7 +398,7 @@ class TdxFutureData(object):
|
||||
add_bar.low_price = float(row['low'])
|
||||
add_bar.close_price = float(row['close'])
|
||||
add_bar.volume = float(row['volume'])
|
||||
add_bar.openInterest = float(row['open_interest'])
|
||||
add_bar.open_interest = float(row['open_interest'])
|
||||
except Exception as ex:
|
||||
self.write_error(
|
||||
'error when convert bar:{},ex:{},t:{}'.format(row, str(ex), traceback.format_exc()))
|
||||
|
@ -453,9 +453,12 @@ class OmsEngine(BaseEngine):
|
||||
contract_file_name = 'vn_contract.pkb2'
|
||||
if not os.path.exists(contract_file_name):
|
||||
return
|
||||
try:
|
||||
with bz2.BZ2File(contract_file_name, 'rb') as f:
|
||||
self.contracts = pickle.load(f)
|
||||
self.write_log(f'加载缓存合约字典:{contract_file_name}')
|
||||
except Exception as ex:
|
||||
self.write_log(f'加载缓存合约异常:{str(ex)}')
|
||||
|
||||
# 更新自定义合约
|
||||
custom_contracts = self.get_all_custom_contracts()
|
||||
|
@ -74,7 +74,7 @@ def get_underlying_symbol(symbol: str):
|
||||
p = re.compile(r"([A-Z]+)[0-9]+", re.I)
|
||||
underlying_symbol = p.match(symbol)
|
||||
|
||||
if underlying_symbol is None:
|
||||
if underlying_symbol is None or len(underlying_symbol) == 0:
|
||||
return symbol
|
||||
|
||||
return underlying_symbol.group(1)
|
||||
|
Loading…
Reference in New Issue
Block a user