Merge pull request #2025 from nanoric/remove_trailing_white_space_and_spaces_in_blank_lines

[Del] remove trailing white space & spaces in blank lines
This commit is contained in:
vn.py 2019-08-15 16:42:26 +08:00 committed by GitHub
commit 4450774bda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 3117 additions and 3119 deletions

View File

@ -3,5 +3,3 @@ exclude = venv,build,__pycache__,__init__.py,ib,talib,uic
ignore = ignore =
E501 line too long, fixed by black E501 line too long, fixed by black
W503 line break before binary operator W503 line break before binary operator
W293 blank line contains whitespace
W291 trailing whitespace

View File

@ -7,7 +7,7 @@ from vnpy.rpc import RpcClient
class TestClient(RpcClient): class TestClient(RpcClient):
""" """
Test RpcClient Test RpcClient
""" """
def __init__(self): def __init__(self):

File diff suppressed because it is too large Load Diff

View File

@ -147,7 +147,7 @@ class RestClient(object):
""" """
Add a new request. Add a new request.
:param method: GET, POST, PUT, DELETE, QUERY :param method: GET, POST, PUT, DELETE, QUERY
:param path: :param path:
:param callback: callback function if 2xx status, type: (dict, Request) :param callback: callback function if 2xx status, type: (dict, Request)
:param params: dict for query string :param params: dict for query string
:param data: Http body. If it is a dict, it will be converted to form-data. Otherwise, it will be converted to bytes. :param data: Http body. If it is a dict, it will be converted to form-data. Otherwise, it will be converted to bytes.
@ -296,7 +296,7 @@ class RestClient(object):
""" """
Add a new request. Add a new request.
:param method: GET, POST, PUT, DELETE, QUERY :param method: GET, POST, PUT, DELETE, QUERY
:param path: :param path:
:param params: dict for query string :param params: dict for query string
:param data: dict for body :param data: dict for body
:param headers: dict for headers :param headers: dict for headers

View File

@ -14,7 +14,7 @@ class DmaAlgo(AlgoTemplate):
"vt_symbol": "", "vt_symbol": "",
"direction": [Direction.LONG.value, Direction.SHORT.value], "direction": [Direction.LONG.value, Direction.SHORT.value],
"order_type": [ "order_type": [
OrderType.MARKET.value, OrderType.MARKET.value,
OrderType.LIMIT.value, OrderType.LIMIT.value,
OrderType.STOP.value, OrderType.STOP.value,
OrderType.FAK.value, OrderType.FAK.value,
@ -74,7 +74,7 @@ class DmaAlgo(AlgoTemplate):
self.order_type, self.order_type,
self.offset self.offset
) )
else: else:
self.vt_orderid = self.sell( self.vt_orderid = self.sell(
self.vt_symbol, self.vt_symbol,
@ -96,4 +96,4 @@ class DmaAlgo(AlgoTemplate):
def on_trade(self, trade: TradeData): def on_trade(self, trade: TradeData):
"""""" """"""
pass pass

View File

@ -1013,13 +1013,13 @@ class CandleChartDialog(QtWidgets.QDialog):
def update_trades(self, trades: list): def update_trades(self, trades: list):
"""""" """"""
trade_data = [] trade_data = []
for trade in trades: for trade in trades:
ix = self.dt_ix_map[trade.datetime] ix = self.dt_ix_map[trade.datetime]
scatter = { scatter = {
"pos": (ix, trade.price), "pos": (ix, trade.price),
"data": 1, "data": 1,
"size": 14, "size": 14,
"pen": pg.mkPen((255, 255, 255)) "pen": pg.mkPen((255, 255, 255))
} }
@ -1030,11 +1030,11 @@ class CandleChartDialog(QtWidgets.QDialog):
else: else:
scatter["symbol"] = "t" scatter["symbol"] = "t"
scatter["brush"] = pg.mkBrush((0, 0, 255)) scatter["brush"] = pg.mkBrush((0, 0, 255))
trade_data.append(scatter) trade_data.append(scatter)
self.trade_scatter.setData(trade_data) self.trade_scatter.setData(trade_data)
def clear_data(self): def clear_data(self):
"""""" """"""
self.updated = False self.updated = False
@ -1042,7 +1042,7 @@ class CandleChartDialog(QtWidgets.QDialog):
self.dt_ix_map.clear() self.dt_ix_map.clear()
self.trade_scatter.clear() self.trade_scatter.clear()
def is_updated(self): def is_updated(self):
"""""" """"""
return self.updated return self.updated

View File

@ -13,7 +13,7 @@ import seaborn as sns
from pandas import DataFrame from pandas import DataFrame
from deap import creator, base, tools, algorithms from deap import creator, base, tools, algorithms
from vnpy.trader.constant import (Direction, Offset, Exchange, from vnpy.trader.constant import (Direction, Offset, Exchange,
Interval, Status) Interval, Status)
from vnpy.trader.database import database_manager from vnpy.trader.database import database_manager
from vnpy.trader.object import OrderData, TradeData, BarData, TickData from vnpy.trader.object import OrderData, TradeData, BarData, TickData
@ -84,12 +84,12 @@ class OptimizationSetting:
settings.append(setting) settings.append(setting)
return settings return settings
def generate_setting_ga(self): def generate_setting_ga(self):
"""""" """"""
settings_ga = [] settings_ga = []
settings = self.generate_setting() settings = self.generate_setting()
for d in settings: for d in settings:
param = [tuple(i) for i in d.items()] param = [tuple(i) for i in d.items()]
settings_ga.append(param) settings_ga.append(param)
return settings_ga return settings_ga
@ -215,8 +215,8 @@ class BacktestingEngine:
self.end = datetime.now() self.end = datetime.now()
if self.start >= self.end: if self.start >= self.end:
self.output("起始日期必须小于结束日期") self.output("起始日期必须小于结束日期")
return return
self.history_data.clear() # Clear previously loaded history data self.history_data.clear() # Clear previously loaded history data
@ -230,7 +230,7 @@ class BacktestingEngine:
while start < self.end: while start < self.end:
end = min(end, self.end) # Make sure end time stays within set range end = min(end, self.end) # Make sure end time stays within set range
if self.mode == BacktestingMode.BAR: if self.mode == BacktestingMode.BAR:
data = load_bar_data( data = load_bar_data(
self.symbol, self.symbol,
@ -248,15 +248,15 @@ class BacktestingEngine:
) )
self.history_data.extend(data) self.history_data.extend(data)
progress += progress_delta / total_delta progress += progress_delta / total_delta
progress = min(progress, 1) progress = min(progress, 1)
progress_bar = "#" * int(progress * 10) progress_bar = "#" * int(progress * 10)
self.output(f"加载进度:{progress_bar} [{progress:.0%}]") self.output(f"加载进度:{progress_bar} [{progress:.0%}]")
start = end start = end
end += progress_delta end += progress_delta
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}") self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
def run_backtesting(self): def run_backtesting(self):
@ -271,7 +271,7 @@ class BacktestingEngine:
# Use the first [days] of history data for initializing strategy # Use the first [days] of history data for initializing strategy
day_count = 0 day_count = 0
ix = 0 ix = 0
for ix, data in enumerate(self.history_data): for ix, data in enumerate(self.history_data):
if self.datetime and data.datetime.day != self.datetime.day: if self.datetime and data.datetime.day != self.datetime.day:
day_count += 1 day_count += 1
@ -339,8 +339,8 @@ class BacktestingEngine:
# Check DataFrame input exterior # Check DataFrame input exterior
if df is None: if df is None:
df = self.daily_df df = self.daily_df
# Check for init DataFrame # Check for init DataFrame
if df is None: if df is None:
# Set all statistics to 0 if no trade. # Set all statistics to 0 if no trade.
start_date = "" start_date = ""
@ -484,11 +484,11 @@ class BacktestingEngine:
def show_chart(self, df: DataFrame = None): def show_chart(self, df: DataFrame = None):
"""""" """"""
# Check DataFrame input exterior # Check DataFrame input exterior
if df is None: if df is None:
df = self.daily_df df = self.daily_df
# Check for init DataFrame # Check for init DataFrame
if df is None: if df is None:
return return
@ -580,7 +580,7 @@ class BacktestingEngine:
def generate_parameter(): def generate_parameter():
"""""" """"""
return random.choice(settings) return random.choice(settings)
def mutate_individual(individual, indpb): def mutate_individual(individual, indpb):
"""""" """"""
size = len(individual) size = len(individual)
@ -620,24 +620,24 @@ class BacktestingEngine:
ga_mode = self.mode ga_mode = self.mode
# Set up genetic algorithem # Set up genetic algorithem
toolbox = base.Toolbox() toolbox = base.Toolbox()
toolbox.register("individual", tools.initIterate, creator.Individual, generate_parameter) toolbox.register("individual", tools.initIterate, creator.Individual, generate_parameter)
toolbox.register("population", tools.initRepeat, list, toolbox.individual) toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("mate", tools.cxTwoPoint) toolbox.register("mate", tools.cxTwoPoint)
toolbox.register("mutate", mutate_individual, indpb=1) toolbox.register("mutate", mutate_individual, indpb=1)
toolbox.register("evaluate", ga_optimize) toolbox.register("evaluate", ga_optimize)
toolbox.register("select", tools.selNSGA2) toolbox.register("select", tools.selNSGA2)
total_size = len(settings) total_size = len(settings)
pop_size = population_size # number of individuals in each generation pop_size = population_size # number of individuals in each generation
lambda_ = pop_size # number of children to produce at each generation lambda_ = pop_size # number of children to produce at each generation
mu = int(pop_size * 0.8) # number of individuals to select for the next generation mu = int(pop_size * 0.8) # number of individuals to select for the next generation
cxpb = 0.95 # probability that an offspring is produced by crossover cxpb = 0.95 # probability that an offspring is produced by crossover
mutpb = 1 - cxpb # probability that an offspring is produced by mutation mutpb = 1 - cxpb # probability that an offspring is produced by mutation
ngen = ngen_size # number of generation ngen = ngen_size # number of generation
pop = toolbox.population(pop_size) pop = toolbox.population(pop_size)
hof = tools.ParetoFront() # end result of pareto front hof = tools.ParetoFront() # end result of pareto front
stats = tools.Statistics(lambda ind: ind.fitness.values) stats = tools.Statistics(lambda ind: ind.fitness.values)
@ -662,22 +662,22 @@ class BacktestingEngine:
start = time() start = time()
algorithms.eaMuPlusLambda( algorithms.eaMuPlusLambda(
pop, pop,
toolbox, toolbox,
mu, mu,
lambda_, lambda_,
cxpb, cxpb,
mutpb, mutpb,
ngen, ngen,
stats, stats,
halloffame=hof halloffame=hof
) )
end = time() end = time()
cost = int((end - start)) cost = int((end - start))
self.output(f"遗传算法优化完成,耗时{cost}") self.output(f"遗传算法优化完成,耗时{cost}")
# Return result list # Return result list
results = [] results = []
@ -685,7 +685,7 @@ class BacktestingEngine:
setting = dict(parameter_values) setting = dict(parameter_values)
target_value = ga_optimize(parameter_values)[0] target_value = ga_optimize(parameter_values)[0]
results.append((setting, target_value, {})) results.append((setting, target_value, {}))
return results return results
def update_daily_close(self, price: float): def update_daily_close(self, price: float):
@ -743,14 +743,14 @@ class BacktestingEngine:
# Check whether limit orders can be filled. # Check whether limit orders can be filled.
long_cross = ( long_cross = (
order.direction == Direction.LONG order.direction == Direction.LONG
and order.price >= long_cross_price and order.price >= long_cross_price
and long_cross_price > 0 and long_cross_price > 0
) )
short_cross = ( short_cross = (
order.direction == Direction.SHORT order.direction == Direction.SHORT
and order.price <= short_cross_price and order.price <= short_cross_price
and short_cross_price > 0 and short_cross_price > 0
) )
@ -811,12 +811,12 @@ class BacktestingEngine:
for stop_order in list(self.active_stop_orders.values()): for stop_order in list(self.active_stop_orders.values()):
# Check whether stop order can be triggered. # Check whether stop order can be triggered.
long_cross = ( long_cross = (
stop_order.direction == Direction.LONG stop_order.direction == Direction.LONG
and stop_order.price <= long_cross_price and stop_order.price <= long_cross_price
) )
short_cross = ( short_cross = (
stop_order.direction == Direction.SHORT stop_order.direction == Direction.SHORT
and stop_order.price >= short_cross_price and stop_order.price >= short_cross_price
) )
@ -911,10 +911,10 @@ class BacktestingEngine:
return [vt_orderid] return [vt_orderid]
def send_stop_order( def send_stop_order(
self, self,
direction: Direction, direction: Direction,
offset: Offset, offset: Offset,
price: float, price: float,
volume: float volume: float
): ):
"""""" """"""
@ -936,15 +936,15 @@ class BacktestingEngine:
return stop_order.stop_orderid return stop_order.stop_orderid
def send_limit_order( def send_limit_order(
self, self,
direction: Direction, direction: Direction,
offset: Offset, offset: Offset,
price: float, price: float,
volume: float volume: float
): ):
"""""" """"""
self.limit_order_count += 1 self.limit_order_count += 1
order = OrderData( order = OrderData(
symbol=self.symbol, symbol=self.symbol,
exchange=self.exchange, exchange=self.exchange,
@ -1008,13 +1008,13 @@ class BacktestingEngine:
""" """
msg = f"{self.datetime}\t{msg}" msg = f"{self.datetime}\t{msg}"
self.logs.append(msg) self.logs.append(msg)
def send_email(self, msg: str, strategy: CtaTemplate = None): def send_email(self, msg: str, strategy: CtaTemplate = None):
""" """
Send email to default receiver. Send email to default receiver.
""" """
pass pass
def sync_strategy_data(self, strategy: CtaTemplate): def sync_strategy_data(self, strategy: CtaTemplate):
""" """
Sync strategy data into json file. Sync strategy data into json file.
@ -1145,7 +1145,7 @@ def optimize(
Function for running in multiprocessing.pool Function for running in multiprocessing.pool
""" """
engine = BacktestingEngine() engine = BacktestingEngine()
engine.set_parameters( engine.set_parameters(
vt_symbol=vt_symbol, vt_symbol=vt_symbol,
interval=interval, interval=interval,

View File

@ -22,17 +22,17 @@ from vnpy.trader.object import (
ContractData ContractData
) )
from vnpy.trader.event import ( from vnpy.trader.event import (
EVENT_TICK, EVENT_TICK,
EVENT_ORDER, EVENT_ORDER,
EVENT_TRADE, EVENT_TRADE,
EVENT_POSITION EVENT_POSITION
) )
from vnpy.trader.constant import ( from vnpy.trader.constant import (
Direction, Direction,
OrderType, OrderType,
Interval, Interval,
Exchange, Exchange,
Offset, Offset,
Status Status
) )
from vnpy.trader.utility import load_json, save_json, extract_vt_symbol, round_to from vnpy.trader.utility import load_json, save_json, extract_vt_symbol, round_to
@ -162,7 +162,7 @@ class CtaEngine(BaseEngine):
def process_order_event(self, event: Event): def process_order_event(self, event: Event):
"""""" """"""
order = event.data order = event.data
self.offset_converter.update_order(order) self.offset_converter.update_order(order)
strategy = self.orderid_strategy_map.get(order.vt_orderid, None) strategy = self.orderid_strategy_map.get(order.vt_orderid, None)
@ -187,7 +187,7 @@ class CtaEngine(BaseEngine):
status=STOP_STATUS_MAP[order.status], status=STOP_STATUS_MAP[order.status],
vt_orderids=[order.vt_orderid], vt_orderids=[order.vt_orderid],
) )
self.call_strategy_func(strategy, strategy.on_stop_order, so) self.call_strategy_func(strategy, strategy.on_stop_order, so)
# Call strategy on_order function # Call strategy on_order function
self.call_strategy_func(strategy, strategy.on_order, order) self.call_strategy_func(strategy, strategy.on_order, order)
@ -256,15 +256,15 @@ class CtaEngine(BaseEngine):
price = tick.limit_down price = tick.limit_down
else: else:
price = tick.bid_price_5 price = tick.bid_price_5
contract = self.main_engine.get_contract(stop_order.vt_symbol) contract = self.main_engine.get_contract(stop_order.vt_symbol)
vt_orderids = self.send_limit_order( vt_orderids = self.send_limit_order(
strategy, strategy,
contract, contract,
stop_order.direction, stop_order.direction,
stop_order.offset, stop_order.offset,
price, price,
stop_order.volume, stop_order.volume,
stop_order.lock stop_order.lock
) )
@ -329,13 +329,13 @@ class CtaEngine(BaseEngine):
vt_orderids.append(vt_orderid) vt_orderids.append(vt_orderid)
self.offset_converter.update_order_request(req, vt_orderid) self.offset_converter.update_order_request(req, vt_orderid)
# Save relationship between orderid and strategy. # Save relationship between orderid and strategy.
self.orderid_strategy_map[vt_orderid] = strategy self.orderid_strategy_map[vt_orderid] = strategy
self.strategy_orderid_map[strategy.strategy_name].add(vt_orderid) self.strategy_orderid_map[strategy.strategy_name].add(vt_orderid)
return vt_orderids return vt_orderids
def send_limit_order( def send_limit_order(
self, self,
strategy: CtaTemplate, strategy: CtaTemplate,
@ -359,7 +359,7 @@ class CtaEngine(BaseEngine):
OrderType.LIMIT, OrderType.LIMIT,
lock lock
) )
def send_server_stop_order( def send_server_stop_order(
self, self,
strategy: CtaTemplate, strategy: CtaTemplate,
@ -372,8 +372,8 @@ class CtaEngine(BaseEngine):
): ):
""" """
Send a stop order to server. Send a stop order to server.
Should only be used if stop order supported Should only be used if stop order supported
on the trading server. on the trading server.
""" """
return self.send_server_order( return self.send_server_order(
@ -473,11 +473,11 @@ class CtaEngine(BaseEngine):
if not contract: if not contract:
self.write_log(f"委托失败,找不到合约:{strategy.vt_symbol}", strategy) self.write_log(f"委托失败,找不到合约:{strategy.vt_symbol}", strategy)
return "" return ""
# Round order price and volume to nearest incremental value # Round order price and volume to nearest incremental value
price = round_to(price, contract.pricetick) price = round_to(price, contract.pricetick)
volume = round_to(volume, contract.min_volume) volume = round_to(volume, contract.min_volume)
if stop: if stop:
if contract.stop_supported: if contract.stop_supported:
return self.send_server_stop_order(strategy, contract, direction, offset, price, volume, lock) return self.send_server_stop_order(strategy, contract, direction, offset, price, volume, lock)
@ -510,9 +510,9 @@ class CtaEngine(BaseEngine):
return self.engine_type return self.engine_type
def load_bar( def load_bar(
self, self,
vt_symbol: str, vt_symbol: str,
days: int, days: int,
interval: Interval, interval: Interval,
callback: Callable[[BarData], None] callback: Callable[[BarData], None]
): ):
@ -536,7 +536,7 @@ class CtaEngine(BaseEngine):
callback(bar) callback(bar)
def load_tick( def load_tick(
self, self,
vt_symbol: str, vt_symbol: str,
days: int, days: int,
callback: Callable[[TickData], None] callback: Callable[[TickData], None]
@ -836,9 +836,9 @@ class CtaEngine(BaseEngine):
for strategy_name, strategy_config in self.strategy_setting.items(): for strategy_name, strategy_config in self.strategy_setting.items():
self.add_strategy( self.add_strategy(
strategy_config["class_name"], strategy_config["class_name"],
strategy_name, strategy_name,
strategy_config["vt_symbol"], strategy_config["vt_symbol"],
strategy_config["setting"] strategy_config["setting"]
) )

View File

@ -33,7 +33,7 @@ class CtaTemplate(ABC):
self.trading = False self.trading = False
self.pos = 0 self.pos = 0
# Copy a new variables list here to avoid duplicate insert when multiple # Copy a new variables list here to avoid duplicate insert when multiple
# strategy instances are created with the same strategy class. # strategy instances are created with the same strategy class.
self.variables = copy(self.variables) self.variables = copy(self.variables)
self.variables.insert(0, "inited") self.variables.insert(0, "inited")

View File

@ -13,9 +13,9 @@ EVENT_TIMER = "eTimer"
class Event: class Event:
""" """
Event object consists of a type string which is used Event object consists of a type string which is used
by event engine for distributing event, and a data by event engine for distributing event, and a data
object which contains the real data. object which contains the real data.
""" """
def __init__(self, type: str, data: Any = None): def __init__(self, type: str, data: Any = None):
@ -30,7 +30,7 @@ HandlerType = Callable[[Event], None]
class EventEngine: class EventEngine:
""" """
Event engine distributes event object based on its type Event engine distributes event object based on its type
to those handlers registered. to those handlers registered.
It also generates timer event by every interval seconds, It also generates timer event by every interval seconds,
@ -64,7 +64,7 @@ class EventEngine:
def _process(self, event: Event): def _process(self, event: Event):
""" """
First ditribute event to those handlers registered listening First ditribute event to those handlers registered listening
to this type. to this type.
Then distrubute event to those general handlers which listens Then distrubute event to those general handlers which listens
to all types. to all types.
@ -108,7 +108,7 @@ class EventEngine:
def register(self, type: str, handler: HandlerType): def register(self, type: str, handler: HandlerType):
""" """
Register a new handler function for a specific event type. Every Register a new handler function for a specific event type. Every
function can only be registered once for each event type. function can only be registered once for each event type.
""" """
handler_list = self._handlers[type] handler_list = self._handlers[type]
@ -129,7 +129,7 @@ class EventEngine:
def register_general(self, handler: HandlerType): def register_general(self, handler: HandlerType):
""" """
Register a new handler function for all event types. Every Register a new handler function for all event types. Every
function can only be registered once for each event type. function can only be registered once for each event type.
""" """
if handler not in self._general_handlers: if handler not in self._general_handlers:

View File

@ -543,7 +543,7 @@ class BinanceRestApi(RestClient):
"limit": limit, "limit": limit,
"startTime": start_time * 1000, # convert to millisecond "startTime": start_time * 1000, # convert to millisecond
} }
# Add end time if specified # Add end time if specified
if req.end: if req.end:
end_time = int(datetime.timestamp(req.end)) end_time = int(datetime.timestamp(req.end))
@ -570,7 +570,7 @@ class BinanceRestApi(RestClient):
break break
buf = [] buf = []
for l in data: for l in data:
dt = datetime.fromtimestamp(l[0] / 1000) # convert to second dt = datetime.fromtimestamp(l[0] / 1000) # convert to second
@ -641,7 +641,7 @@ class BinanceTradeWebsocketApi(WebsocketClient):
frozen=float(d["l"]), frozen=float(d["l"]),
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
if account.balance: if account.balance:
self.gateway.on_account(account) self.gateway.on_account(account)

View File

@ -13,7 +13,7 @@ from vnpy.api.ctp import (
THOST_FTDC_OST_PartTradedQueueing, THOST_FTDC_OST_PartTradedQueueing,
THOST_FTDC_OST_AllTraded, THOST_FTDC_OST_AllTraded,
THOST_FTDC_OST_Canceled, THOST_FTDC_OST_Canceled,
THOST_FTDC_D_Buy, THOST_FTDC_D_Buy,
THOST_FTDC_D_Sell, THOST_FTDC_D_Sell,
THOST_FTDC_PD_Long, THOST_FTDC_PD_Long,
THOST_FTDC_PD_Short, THOST_FTDC_PD_Short,
@ -73,7 +73,7 @@ STATUS_CTP2VT = {
} }
DIRECTION_VT2CTP = { DIRECTION_VT2CTP = {
Direction.LONG: THOST_FTDC_D_Buy, Direction.LONG: THOST_FTDC_D_Buy,
Direction.SHORT: THOST_FTDC_D_Sell Direction.SHORT: THOST_FTDC_D_Sell
} }
DIRECTION_CTP2VT = {v: k for k, v in DIRECTION_VT2CTP.items()} DIRECTION_CTP2VT = {v: k for k, v in DIRECTION_VT2CTP.items()}
@ -81,13 +81,13 @@ DIRECTION_CTP2VT[THOST_FTDC_PD_Long] = Direction.LONG
DIRECTION_CTP2VT[THOST_FTDC_PD_Short] = Direction.SHORT DIRECTION_CTP2VT[THOST_FTDC_PD_Short] = Direction.SHORT
ORDERTYPE_VT2CTP = { ORDERTYPE_VT2CTP = {
OrderType.LIMIT: THOST_FTDC_OPT_LimitPrice, OrderType.LIMIT: THOST_FTDC_OPT_LimitPrice,
OrderType.MARKET: THOST_FTDC_OPT_AnyPrice OrderType.MARKET: THOST_FTDC_OPT_AnyPrice
} }
ORDERTYPE_CTP2VT = {v: k for k, v in ORDERTYPE_VT2CTP.items()} ORDERTYPE_CTP2VT = {v: k for k, v in ORDERTYPE_VT2CTP.items()}
OFFSET_VT2CTP = { OFFSET_VT2CTP = {
Offset.OPEN: THOST_FTDC_OF_Open, Offset.OPEN: THOST_FTDC_OF_Open,
Offset.CLOSE: THOST_FTDC_OFEN_Close, Offset.CLOSE: THOST_FTDC_OFEN_Close,
Offset.CLOSETODAY: THOST_FTDC_OFEN_CloseToday, Offset.CLOSETODAY: THOST_FTDC_OFEN_CloseToday,
Offset.CLOSEYESTERDAY: THOST_FTDC_OFEN_CloseYesterday, Offset.CLOSEYESTERDAY: THOST_FTDC_OFEN_CloseYesterday,
@ -136,7 +136,7 @@ class CtpGateway(BaseGateway):
} }
exchanges = list(EXCHANGE_CTP2VT.values()) exchanges = list(EXCHANGE_CTP2VT.values())
def __init__(self, event_engine): def __init__(self, event_engine):
"""Constructor""" """Constructor"""
super().__init__(event_engine, "CTP") super().__init__(event_engine, "CTP")
@ -154,15 +154,15 @@ class CtpGateway(BaseGateway):
appid = setting["产品名称"] appid = setting["产品名称"]
auth_code = setting["授权编码"] auth_code = setting["授权编码"]
product_info = setting["产品信息"] product_info = setting["产品信息"]
if not td_address.startswith("tcp://"): if not td_address.startswith("tcp://"):
td_address = "tcp://" + td_address td_address = "tcp://" + td_address
if not md_address.startswith("tcp://"): if not md_address.startswith("tcp://"):
md_address = "tcp://" + md_address md_address = "tcp://" + md_address
self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info) self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info)
self.md_api.connect(md_address, userid, password, brokerid) self.md_api.connect(md_address, userid, password, brokerid)
self.init_query() self.init_query()
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
@ -195,19 +195,19 @@ class CtpGateway(BaseGateway):
error_id = error["ErrorID"] error_id = error["ErrorID"]
error_msg = error["ErrorMsg"] error_msg = error["ErrorMsg"]
msg = f"{msg},代码:{error_id},信息:{error_msg}" msg = f"{msg},代码:{error_id},信息:{error_msg}"
self.write_log(msg) self.write_log(msg)
def process_timer_event(self, event): def process_timer_event(self, event):
"""""" """"""
self.count += 1 self.count += 1
if self.count < 2: if self.count < 2:
return return
self.count = 0 self.count = 0
func = self.query_functions.pop(0) func = self.query_functions.pop(0)
func() func()
self.query_functions.append(func) self.query_functions.append(func)
def init_query(self): def init_query(self):
"""""" """"""
self.count = 0 self.count = 0
@ -221,20 +221,20 @@ class CtpMdApi(MdApi):
def __init__(self, gateway): def __init__(self, gateway):
"""Constructor""" """Constructor"""
super(CtpMdApi, self).__init__() super(CtpMdApi, self).__init__()
self.gateway = gateway self.gateway = gateway
self.gateway_name = gateway.gateway_name self.gateway_name = gateway.gateway_name
self.reqid = 0 self.reqid = 0
self.connect_status = False self.connect_status = False
self.login_status = False self.login_status = False
self.subscribed = set() self.subscribed = set()
self.userid = "" self.userid = ""
self.password = "" self.password = ""
self.brokerid = "" self.brokerid = ""
def onFrontConnected(self): def onFrontConnected(self):
""" """
Callback when front server is connected. Callback when front server is connected.
@ -256,23 +256,23 @@ class CtpMdApi(MdApi):
if not error["ErrorID"]: if not error["ErrorID"]:
self.login_status = True self.login_status = True
self.gateway.write_log("行情服务器登录成功") self.gateway.write_log("行情服务器登录成功")
for symbol in self.subscribed: for symbol in self.subscribed:
self.subscribeMarketData(symbol) self.subscribeMarketData(symbol)
else: else:
self.gateway.write_error("行情服务器登录失败", error) self.gateway.write_error("行情服务器登录失败", error)
def onRspError(self, error: dict, reqid: int, last: bool): def onRspError(self, error: dict, reqid: int, last: bool):
""" """
Callback when error occured. Callback when error occured.
""" """
self.gateway.write_error("行情接口报错", error) self.gateway.write_error("行情接口报错", error)
def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool): def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error or not error["ErrorID"]: if not error or not error["ErrorID"]:
return return
self.gateway.write_error("行情订阅失败", error) self.gateway.write_error("行情订阅失败", error)
def onRtnDepthMarketData(self, data: dict): def onRtnDepthMarketData(self, data: dict):
@ -283,9 +283,9 @@ class CtpMdApi(MdApi):
exchange = symbol_exchange_map.get(symbol, "") exchange = symbol_exchange_map.get(symbol, "")
if not exchange: if not exchange:
return return
timestamp = f"{data['ActionDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}" timestamp = f"{data['ActionDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}"
tick = TickData( tick = TickData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -306,7 +306,7 @@ class CtpMdApi(MdApi):
ask_volume_1=data["AskVolume1"], ask_volume_1=data["AskVolume1"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_tick(tick) self.gateway.on_tick(tick)
def connect(self, address: str, userid: str, password: str, brokerid: int): def connect(self, address: str, userid: str, password: str, brokerid: int):
""" """
@ -315,12 +315,12 @@ class CtpMdApi(MdApi):
self.userid = userid self.userid = userid
self.password = password self.password = password
self.brokerid = brokerid self.brokerid = brokerid
# If not connected, then start connection first. # If not connected, then start connection first.
if not self.connect_status: if not self.connect_status:
path = get_folder_path(self.gateway_name.lower()) path = get_folder_path(self.gateway_name.lower())
self.createFtdcMdApi(str(path) + "\\Md") self.createFtdcMdApi(str(path) + "\\Md")
self.registerFront(address) self.registerFront(address)
self.init() self.init()
@ -328,7 +328,7 @@ class CtpMdApi(MdApi):
# If already connected, then login immediately. # If already connected, then login immediately.
elif not self.login_status: elif not self.login_status:
self.login() self.login()
def login(self): def login(self):
""" """
Login onto server. Login onto server.
@ -338,10 +338,10 @@ class CtpMdApi(MdApi):
"Password": self.password, "Password": self.password,
"BrokerID": self.brokerid "BrokerID": self.brokerid
} }
self.reqid += 1 self.reqid += 1
self.reqUserLogin(req, self.reqid) self.reqUserLogin(req, self.reqid)
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
""" """
Subscribe to tick data update. Subscribe to tick data update.
@ -349,7 +349,7 @@ class CtpMdApi(MdApi):
if self.login_status: if self.login_status:
self.subscribeMarketData(req.symbol) self.subscribeMarketData(req.symbol)
self.subscribed.add(req.symbol) self.subscribed.add(req.symbol)
def close(self): def close(self):
""" """
Close the connection. Close the connection.
@ -364,47 +364,47 @@ class CtpTdApi(TdApi):
def __init__(self, gateway): def __init__(self, gateway):
"""Constructor""" """Constructor"""
super(CtpTdApi, self).__init__() super(CtpTdApi, self).__init__()
self.gateway = gateway self.gateway = gateway
self.gateway_name = gateway.gateway_name self.gateway_name = gateway.gateway_name
self.reqid = 0 self.reqid = 0
self.order_ref = 0 self.order_ref = 0
self.connect_status = False self.connect_status = False
self.login_status = False self.login_status = False
self.auth_staus = False self.auth_staus = False
self.login_failed = False self.login_failed = False
self.userid = "" self.userid = ""
self.password = "" self.password = ""
self.brokerid = "" self.brokerid = ""
self.auth_code = "" self.auth_code = ""
self.appid = "" self.appid = ""
self.product_info = "" self.product_info = ""
self.frontid = 0 self.frontid = 0
self.sessionid = 0 self.sessionid = 0
self.order_data = [] self.order_data = []
self.trade_data = [] self.trade_data = []
self.positions = {} self.positions = {}
self.sysid_orderid_map = {} self.sysid_orderid_map = {}
def onFrontConnected(self): def onFrontConnected(self):
"""""" """"""
self.gateway.write_log("交易服务器连接成功") self.gateway.write_log("交易服务器连接成功")
if self.auth_code: if self.auth_code:
self.authenticate() self.authenticate()
else: else:
self.login() self.login()
def onFrontDisconnected(self, reason: int): def onFrontDisconnected(self, reason: int):
"""""" """"""
self.login_status = False self.login_status = False
self.gateway.write_log(f"交易服务器连接断开,原因{reason}") self.gateway.write_log(f"交易服务器连接断开,原因{reason}")
def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool): def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error['ErrorID']: if not error['ErrorID']:
@ -413,7 +413,7 @@ class CtpTdApi(TdApi):
self.login() self.login()
else: else:
self.gateway.write_error("交易服务器授权验证失败", error) self.gateway.write_error("交易服务器授权验证失败", error)
def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool): def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error["ErrorID"]: if not error["ErrorID"]:
@ -421,7 +421,7 @@ class CtpTdApi(TdApi):
self.sessionid = data["SessionID"] self.sessionid = data["SessionID"]
self.login_status = True self.login_status = True
self.gateway.write_log("交易服务器登录成功") self.gateway.write_log("交易服务器登录成功")
# Confirm settlement # Confirm settlement
req = { req = {
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
@ -431,17 +431,17 @@ class CtpTdApi(TdApi):
self.reqSettlementInfoConfirm(req, self.reqid) self.reqSettlementInfoConfirm(req, self.reqid)
else: else:
self.login_failed = True self.login_failed = True
self.gateway.write_error("交易服务器登录失败", error) self.gateway.write_error("交易服务器登录失败", error)
def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool): def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
order_ref = data["OrderRef"] order_ref = data["OrderRef"]
orderid = f"{self.frontid}_{self.sessionid}_{order_ref}" orderid = f"{self.frontid}_{self.sessionid}_{order_ref}"
symbol = data["InstrumentID"] symbol = data["InstrumentID"]
exchange = symbol_exchange_map[symbol] exchange = symbol_exchange_map[symbol]
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -454,31 +454,31 @@ class CtpTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
self.gateway.write_error("交易委托失败", error) self.gateway.write_error("交易委托失败", error)
def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool): def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
self.gateway.write_error("交易撤单失败", error) self.gateway.write_error("交易撤单失败", error)
def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool): def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
pass pass
def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool): def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool):
""" """
Callback of settlment info confimation. Callback of settlment info confimation.
""" """
self.gateway.write_log("结算信息确认成功") self.gateway.write_log("结算信息确认成功")
self.reqid += 1 self.reqid += 1
self.reqQryInstrument({}, self.reqid) self.reqQryInstrument({}, self.reqid)
def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not data: if not data:
return return
# Get buffered position object # Get buffered position object
key = f"{data['InstrumentID'], data['PosiDirection']}" key = f"{data['InstrumentID'], data['PosiDirection']}"
position = self.positions.get(key, None) position = self.positions.get(key, None)
@ -490,7 +490,7 @@ class CtpTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.positions[key] = position self.positions[key] = position
# For SHFE position data update # For SHFE position data update
if position.exchange == Exchange.SHFE: if position.exchange == Exchange.SHFE:
if data["YdPosition"] and not data["TodayPosition"]: if data["YdPosition"] and not data["TodayPosition"]:
@ -498,34 +498,34 @@ class CtpTdApi(TdApi):
# For other exchange position data update # For other exchange position data update
else: else:
position.yd_volume = data["Position"] - data["TodayPosition"] position.yd_volume = data["Position"] - data["TodayPosition"]
# Get contract size (spread contract has no size value) # Get contract size (spread contract has no size value)
size = symbol_size_map.get(position.symbol, 0) size = symbol_size_map.get(position.symbol, 0)
# Calculate previous position cost # Calculate previous position cost
cost = position.price * position.volume * size cost = position.price * position.volume * size
# Update new position volume # Update new position volume
position.volume += data["Position"] position.volume += data["Position"]
position.pnl += data["PositionProfit"] position.pnl += data["PositionProfit"]
# Calculate average position price # Calculate average position price
if position.volume and size: if position.volume and size:
cost += data["PositionCost"] cost += data["PositionCost"]
position.price = cost / (position.volume * size) position.price = cost / (position.volume * size)
# Get frozen volume # Get frozen volume
if position.direction == Direction.LONG: if position.direction == Direction.LONG:
position.frozen += data["ShortFrozen"] position.frozen += data["ShortFrozen"]
else: else:
position.frozen += data["LongFrozen"] position.frozen += data["LongFrozen"]
if last: if last:
for position in self.positions.values(): for position in self.positions.values():
self.gateway.on_position(position) self.gateway.on_position(position)
self.positions.clear() self.positions.clear()
def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if "AccountID" not in data: if "AccountID" not in data:
@ -538,15 +538,15 @@ class CtpTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
account.available = data["Available"] account.available = data["Available"]
self.gateway.on_account(account) self.gateway.on_account(account)
def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool):
""" """
Callback of instrument query. Callback of instrument query.
""" """
product = PRODUCT_CTP2VT.get(data["ProductClass"], None) product = PRODUCT_CTP2VT.get(data["ProductClass"], None)
if product: if product:
contract = ContractData( contract = ContractData(
symbol=data["InstrumentID"], symbol=data["InstrumentID"],
exchange=EXCHANGE_CTP2VT[data["ExchangeID"]], exchange=EXCHANGE_CTP2VT[data["ExchangeID"]],
@ -556,31 +556,31 @@ class CtpTdApi(TdApi):
pricetick=data["PriceTick"], pricetick=data["PriceTick"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
# For option only # For option only
if contract.product == Product.OPTION: if contract.product == Product.OPTION:
contract.option_underlying = data["UnderlyingInstrID"], contract.option_underlying = data["UnderlyingInstrID"],
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None), contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
contract.option_strike = data["StrikePrice"], contract.option_strike = data["StrikePrice"],
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"), contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"),
self.gateway.on_contract(contract) self.gateway.on_contract(contract)
symbol_exchange_map[contract.symbol] = contract.exchange symbol_exchange_map[contract.symbol] = contract.exchange
symbol_name_map[contract.symbol] = contract.name symbol_name_map[contract.symbol] = contract.name
symbol_size_map[contract.symbol] = contract.size symbol_size_map[contract.symbol] = contract.size
if last: if last:
self.gateway.write_log("合约信息查询成功") self.gateway.write_log("合约信息查询成功")
for data in self.order_data: for data in self.order_data:
self.onRtnOrder(data) self.onRtnOrder(data)
self.order_data.clear() self.order_data.clear()
for data in self.trade_data: for data in self.trade_data:
self.onRtnTrade(data) self.onRtnTrade(data)
self.trade_data.clear() self.trade_data.clear()
def onRtnOrder(self, data: dict): def onRtnOrder(self, data: dict):
""" """
Callback of order status update. Callback of order status update.
@ -590,12 +590,12 @@ class CtpTdApi(TdApi):
if not exchange: if not exchange:
self.order_data.append(data) self.order_data.append(data)
return return
frontid = data["FrontID"] frontid = data["FrontID"]
sessionid = data["SessionID"] sessionid = data["SessionID"]
order_ref = data["OrderRef"] order_ref = data["OrderRef"]
orderid = f"{frontid}_{sessionid}_{order_ref}" orderid = f"{frontid}_{sessionid}_{order_ref}"
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -611,9 +611,9 @@ class CtpTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
self.sysid_orderid_map[data["OrderSysID"]] = orderid self.sysid_orderid_map[data["OrderSysID"]] = orderid
def onRtnTrade(self, data: dict): def onRtnTrade(self, data: dict):
""" """
Callback of trade status update. Callback of trade status update.
@ -625,7 +625,7 @@ class CtpTdApi(TdApi):
return return
orderid = self.sysid_orderid_map[data["OrderSysID"]] orderid = self.sysid_orderid_map[data["OrderSysID"]]
trade = TradeData( trade = TradeData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -638,15 +638,15 @@ class CtpTdApi(TdApi):
time=data["TradeTime"], time=data["TradeTime"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_trade(trade) self.gateway.on_trade(trade)
def connect( def connect(
self, self,
address: str, address: str,
userid: str, userid: str,
password: str, password: str,
brokerid: int, brokerid: int,
auth_code: str, auth_code: str,
appid: str, appid: str,
product_info product_info
): ):
@ -659,21 +659,21 @@ class CtpTdApi(TdApi):
self.auth_code = auth_code self.auth_code = auth_code
self.appid = appid self.appid = appid
self.product_info = product_info self.product_info = product_info
if not self.connect_status: if not self.connect_status:
path = get_folder_path(self.gateway_name.lower()) path = get_folder_path(self.gateway_name.lower())
self.createFtdcTraderApi(str(path) + "\\Td") self.createFtdcTraderApi(str(path) + "\\Td")
self.subscribePrivateTopic(0) self.subscribePrivateTopic(0)
self.subscribePublicTopic(0) self.subscribePublicTopic(0)
self.registerFront(address) self.registerFront(address)
self.init() self.init()
self.connect_status = True self.connect_status = True
else: else:
self.authenticate() self.authenticate()
def authenticate(self): def authenticate(self):
""" """
Authenticate with auth_code and appid. Authenticate with auth_code and appid.
@ -687,10 +687,10 @@ class CtpTdApi(TdApi):
if self.product_info: if self.product_info:
req["UserProductInfo"] = self.product_info req["UserProductInfo"] = self.product_info
self.reqid += 1 self.reqid += 1
self.reqAuthenticate(req, self.reqid) self.reqAuthenticate(req, self.reqid)
def login(self): def login(self):
""" """
Login onto server. Login onto server.
@ -707,16 +707,16 @@ class CtpTdApi(TdApi):
if self.product_info: if self.product_info:
req["UserProductInfo"] = self.product_info req["UserProductInfo"] = self.product_info
self.reqid += 1 self.reqid += 1
self.reqUserLogin(req, self.reqid) self.reqUserLogin(req, self.reqid)
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
""" """
Send new order. Send new order.
""" """
self.order_ref += 1 self.order_ref += 1
ctp_req = { ctp_req = {
"InstrumentID": req.symbol, "InstrumentID": req.symbol,
"ExchangeID": req.exchange.value, "ExchangeID": req.exchange.value,
@ -737,7 +737,7 @@ class CtpTdApi(TdApi):
"VolumeCondition": THOST_FTDC_VC_AV, "VolumeCondition": THOST_FTDC_VC_AV,
"MinVolume": 1 "MinVolume": 1
} }
if req.type == OrderType.FAK: if req.type == OrderType.FAK:
ctp_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice ctp_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
ctp_req["TimeCondition"] = THOST_FTDC_TC_IOC ctp_req["TimeCondition"] = THOST_FTDC_TC_IOC
@ -745,23 +745,23 @@ class CtpTdApi(TdApi):
elif req.type == OrderType.FOK: elif req.type == OrderType.FOK:
ctp_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice ctp_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
ctp_req["TimeCondition"] = THOST_FTDC_TC_IOC ctp_req["TimeCondition"] = THOST_FTDC_TC_IOC
ctp_req["VolumeCondition"] = THOST_FTDC_VC_CV ctp_req["VolumeCondition"] = THOST_FTDC_VC_CV
self.reqid += 1 self.reqid += 1
self.reqOrderInsert(ctp_req, self.reqid) self.reqOrderInsert(ctp_req, self.reqid)
orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}" orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}"
order = req.create_order_data(orderid, self.gateway_name) order = req.create_order_data(orderid, self.gateway_name)
self.gateway.on_order(order) self.gateway.on_order(order)
return order.vt_orderid return order.vt_orderid
def cancel_order(self, req: CancelRequest): def cancel_order(self, req: CancelRequest):
""" """
Cancel existing order. Cancel existing order.
""" """
frontid, sessionid, order_ref = req.orderid.split("_") frontid, sessionid, order_ref = req.orderid.split("_")
ctp_req = { ctp_req = {
"InstrumentID": req.symbol, "InstrumentID": req.symbol,
"ExchangeID": req.exchange.value, "ExchangeID": req.exchange.value,
@ -772,32 +772,32 @@ class CtpTdApi(TdApi):
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
} }
self.reqid += 1 self.reqid += 1
self.reqOrderAction(ctp_req, self.reqid) self.reqOrderAction(ctp_req, self.reqid)
def query_account(self): def query_account(self):
""" """
Query account balance data. Query account balance data.
""" """
self.reqid += 1 self.reqid += 1
self.reqQryTradingAccount({}, self.reqid) self.reqQryTradingAccount({}, self.reqid)
def query_position(self): def query_position(self):
""" """
Query position holding data. Query position holding data.
""" """
if not symbol_exchange_map: if not symbol_exchange_map:
return return
req = { req = {
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
} }
self.reqid += 1 self.reqid += 1
self.reqQryInvestorPosition(req, self.reqid) self.reqQryInvestorPosition(req, self.reqid)
def close(self): def close(self):
"""""" """"""
if self.connect_status: if self.connect_status:

View File

@ -136,7 +136,7 @@ class CtptestGateway(BaseGateway):
} }
exchanges = list(EXCHANGE_CTP2VT.values()) exchanges = list(EXCHANGE_CTP2VT.values())
def __init__(self, event_engine): def __init__(self, event_engine):
"""Constructor""" """Constructor"""
super().__init__(event_engine, "CTPTEST") super().__init__(event_engine, "CTPTEST")
@ -154,15 +154,15 @@ class CtptestGateway(BaseGateway):
appid = setting["产品名称"] appid = setting["产品名称"]
auth_code = setting["授权编码"] auth_code = setting["授权编码"]
product_info = setting["产品信息"] product_info = setting["产品信息"]
if not td_address.startswith("tcp://"): if not td_address.startswith("tcp://"):
td_address = "tcp://" + td_address td_address = "tcp://" + td_address
if not md_address.startswith("tcp://"): if not md_address.startswith("tcp://"):
md_address = "tcp://" + md_address md_address = "tcp://" + md_address
self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info) self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info)
self.md_api.connect(md_address, userid, password, brokerid) self.md_api.connect(md_address, userid, password, brokerid)
self.init_query() self.init_query()
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
@ -195,19 +195,19 @@ class CtptestGateway(BaseGateway):
error_id = error["ErrorID"] error_id = error["ErrorID"]
error_msg = error["ErrorMsg"] error_msg = error["ErrorMsg"]
msg = f"{msg},代码:{error_id},信息:{error_msg}" msg = f"{msg},代码:{error_id},信息:{error_msg}"
self.write_log(msg) self.write_log(msg)
def process_timer_event(self, event): def process_timer_event(self, event):
"""""" """"""
self.count += 1 self.count += 1
if self.count < 2: if self.count < 2:
return return
self.count = 0 self.count = 0
func = self.query_functions.pop(0) func = self.query_functions.pop(0)
func() func()
self.query_functions.append(func) self.query_functions.append(func)
def init_query(self): def init_query(self):
"""""" """"""
self.count = 0 self.count = 0
@ -221,20 +221,20 @@ class CtpMdApi(MdApi):
def __init__(self, gateway): def __init__(self, gateway):
"""Constructor""" """Constructor"""
super(CtpMdApi, self).__init__() super(CtpMdApi, self).__init__()
self.gateway = gateway self.gateway = gateway
self.gateway_name = gateway.gateway_name self.gateway_name = gateway.gateway_name
self.reqid = 0 self.reqid = 0
self.connect_status = False self.connect_status = False
self.login_status = False self.login_status = False
self.subscribed = set() self.subscribed = set()
self.userid = "" self.userid = ""
self.password = "" self.password = ""
self.brokerid = "" self.brokerid = ""
def onFrontConnected(self): def onFrontConnected(self):
""" """
Callback when front server is connected. Callback when front server is connected.
@ -256,23 +256,23 @@ class CtpMdApi(MdApi):
if not error["ErrorID"]: if not error["ErrorID"]:
self.login_status = True self.login_status = True
self.gateway.write_log("行情服务器登录成功") self.gateway.write_log("行情服务器登录成功")
for symbol in self.subscribed: for symbol in self.subscribed:
self.subscribeMarketData(symbol) self.subscribeMarketData(symbol)
else: else:
self.gateway.write_error("行情服务器登录失败", error) self.gateway.write_error("行情服务器登录失败", error)
def onRspError(self, error: dict, reqid: int, last: bool): def onRspError(self, error: dict, reqid: int, last: bool):
""" """
Callback when error occured. Callback when error occured.
""" """
self.gateway.write_error("行情接口报错", error) self.gateway.write_error("行情接口报错", error)
def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool): def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error or not error["ErrorID"]: if not error or not error["ErrorID"]:
return return
self.gateway.write_error("行情订阅失败", error) self.gateway.write_error("行情订阅失败", error)
def onRtnDepthMarketData(self, data: dict): def onRtnDepthMarketData(self, data: dict):
@ -283,9 +283,9 @@ class CtpMdApi(MdApi):
exchange = symbol_exchange_map.get(symbol, "") exchange = symbol_exchange_map.get(symbol, "")
if not exchange: if not exchange:
return return
timestamp = f"{data['ActionDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}" timestamp = f"{data['ActionDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}"
tick = TickData( tick = TickData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -305,7 +305,7 @@ class CtpMdApi(MdApi):
ask_volume_1=data["AskVolume1"], ask_volume_1=data["AskVolume1"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_tick(tick) self.gateway.on_tick(tick)
def connect(self, address: str, userid: str, password: str, brokerid: int): def connect(self, address: str, userid: str, password: str, brokerid: int):
""" """
@ -314,12 +314,12 @@ class CtpMdApi(MdApi):
self.userid = userid self.userid = userid
self.password = password self.password = password
self.brokerid = brokerid self.brokerid = brokerid
# If not connected, then start connection first. # If not connected, then start connection first.
if not self.connect_status: if not self.connect_status:
path = get_folder_path(self.gateway_name.lower()) path = get_folder_path(self.gateway_name.lower())
self.createFtdcMdApi(str(path) + "\\Md") self.createFtdcMdApi(str(path) + "\\Md")
self.registerFront(address) self.registerFront(address)
self.init() self.init()
@ -327,7 +327,7 @@ class CtpMdApi(MdApi):
# If already connected, then login immediately. # If already connected, then login immediately.
elif not self.login_status: elif not self.login_status:
self.login() self.login()
def login(self): def login(self):
""" """
Login onto server. Login onto server.
@ -337,10 +337,10 @@ class CtpMdApi(MdApi):
"Password": self.password, "Password": self.password,
"BrokerID": self.brokerid "BrokerID": self.brokerid
} }
self.reqid += 1 self.reqid += 1
self.reqUserLogin(req, self.reqid) self.reqUserLogin(req, self.reqid)
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
""" """
Subscribe to tick data update. Subscribe to tick data update.
@ -348,7 +348,7 @@ class CtpMdApi(MdApi):
if self.login_status: if self.login_status:
self.subscribeMarketData(req.symbol) self.subscribeMarketData(req.symbol)
self.subscribed.add(req.symbol) self.subscribed.add(req.symbol)
def close(self): def close(self):
""" """
Close the connection. Close the connection.
@ -363,47 +363,47 @@ class CtpTdApi(TdApi):
def __init__(self, gateway): def __init__(self, gateway):
"""Constructor""" """Constructor"""
super(CtpTdApi, self).__init__() super(CtpTdApi, self).__init__()
self.gateway = gateway self.gateway = gateway
self.gateway_name = gateway.gateway_name self.gateway_name = gateway.gateway_name
self.reqid = 0 self.reqid = 0
self.order_ref = 0 self.order_ref = 0
self.connect_status = False self.connect_status = False
self.login_status = False self.login_status = False
self.auth_staus = False self.auth_staus = False
self.login_failed = False self.login_failed = False
self.userid = "" self.userid = ""
self.password = "" self.password = ""
self.brokerid = "" self.brokerid = ""
self.auth_code = "" self.auth_code = ""
self.appid = "" self.appid = ""
self.product_info = "" self.product_info = ""
self.frontid = 0 self.frontid = 0
self.sessionid = 0 self.sessionid = 0
self.order_data = [] self.order_data = []
self.trade_data = [] self.trade_data = []
self.positions = {} self.positions = {}
self.sysid_orderid_map = {} self.sysid_orderid_map = {}
def onFrontConnected(self): def onFrontConnected(self):
"""""" """"""
self.gateway.write_log("交易服务器连接成功") self.gateway.write_log("交易服务器连接成功")
if self.auth_code: if self.auth_code:
self.authenticate() self.authenticate()
else: else:
self.login() self.login()
def onFrontDisconnected(self, reason: int): def onFrontDisconnected(self, reason: int):
"""""" """"""
self.login_status = False self.login_status = False
self.gateway.write_log(f"交易服务器连接断开,原因{reason}") self.gateway.write_log(f"交易服务器连接断开,原因{reason}")
def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool): def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error['ErrorID']: if not error['ErrorID']:
@ -412,7 +412,7 @@ class CtpTdApi(TdApi):
self.login() self.login()
else: else:
self.gateway.write_error("交易服务器授权验证失败", error) self.gateway.write_error("交易服务器授权验证失败", error)
def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool): def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error["ErrorID"]: if not error["ErrorID"]:
@ -420,7 +420,7 @@ class CtpTdApi(TdApi):
self.sessionid = data["SessionID"] self.sessionid = data["SessionID"]
self.login_status = True self.login_status = True
self.gateway.write_log("交易服务器登录成功") self.gateway.write_log("交易服务器登录成功")
# Confirm settlement # Confirm settlement
req = { req = {
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
@ -430,17 +430,17 @@ class CtpTdApi(TdApi):
self.reqSettlementInfoConfirm(req, self.reqid) self.reqSettlementInfoConfirm(req, self.reqid)
else: else:
self.login_failed = True self.login_failed = True
self.gateway.write_error("交易服务器登录失败", error) self.gateway.write_error("交易服务器登录失败", error)
def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool): def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
order_ref = data["OrderRef"] order_ref = data["OrderRef"]
orderid = f"{self.frontid}_{self.sessionid}_{order_ref}" orderid = f"{self.frontid}_{self.sessionid}_{order_ref}"
symbol = data["InstrumentID"] symbol = data["InstrumentID"]
exchange = symbol_exchange_map[symbol] exchange = symbol_exchange_map[symbol]
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -453,31 +453,31 @@ class CtpTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
self.gateway.write_error("交易委托失败", error) self.gateway.write_error("交易委托失败", error)
def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool): def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
self.gateway.write_error("交易撤单失败", error) self.gateway.write_error("交易撤单失败", error)
def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool): def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
pass pass
def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool): def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool):
""" """
Callback of settlment info confimation. Callback of settlment info confimation.
""" """
self.gateway.write_log("结算信息确认成功") self.gateway.write_log("结算信息确认成功")
self.reqid += 1 self.reqid += 1
self.reqQryInstrument({}, self.reqid) self.reqQryInstrument({}, self.reqid)
def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not data: if not data:
return return
# Get buffered position object # Get buffered position object
key = f"{data['InstrumentID'], data['PosiDirection']}" key = f"{data['InstrumentID'], data['PosiDirection']}"
position = self.positions.get(key, None) position = self.positions.get(key, None)
@ -489,7 +489,7 @@ class CtpTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.positions[key] = position self.positions[key] = position
# For SHFE position data update # For SHFE position data update
if position.exchange == Exchange.SHFE: if position.exchange == Exchange.SHFE:
if data["YdPosition"] and not data["TodayPosition"]: if data["YdPosition"] and not data["TodayPosition"]:
@ -497,34 +497,34 @@ class CtpTdApi(TdApi):
# For other exchange position data update # For other exchange position data update
else: else:
position.yd_volume = data["Position"] - data["TodayPosition"] position.yd_volume = data["Position"] - data["TodayPosition"]
# Get contract size (spread contract has no size value) # Get contract size (spread contract has no size value)
size = symbol_size_map.get(position.symbol, 0) size = symbol_size_map.get(position.symbol, 0)
# Calculate previous position cost # Calculate previous position cost
cost = position.price * position.volume * size cost = position.price * position.volume * size
# Update new position volume # Update new position volume
position.volume += data["Position"] position.volume += data["Position"]
position.pnl += data["PositionProfit"] position.pnl += data["PositionProfit"]
# Calculate average position price # Calculate average position price
if position.volume and size: if position.volume and size:
cost += data["PositionCost"] cost += data["PositionCost"]
position.price = cost / (position.volume * size) position.price = cost / (position.volume * size)
# Get frozen volume # Get frozen volume
if position.direction == Direction.LONG: if position.direction == Direction.LONG:
position.frozen += data["ShortFrozen"] position.frozen += data["ShortFrozen"]
else: else:
position.frozen += data["LongFrozen"] position.frozen += data["LongFrozen"]
if last: if last:
for position in self.positions.values(): for position in self.positions.values():
self.gateway.on_position(position) self.gateway.on_position(position)
self.positions.clear() self.positions.clear()
def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if "AccountID" not in data: if "AccountID" not in data:
@ -537,15 +537,15 @@ class CtpTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
account.available = data["Available"] account.available = data["Available"]
self.gateway.on_account(account) self.gateway.on_account(account)
def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool):
""" """
Callback of instrument query. Callback of instrument query.
""" """
product = PRODUCT_CTP2VT.get(data["ProductClass"], None) product = PRODUCT_CTP2VT.get(data["ProductClass"], None)
if product: if product:
contract = ContractData( contract = ContractData(
symbol=data["InstrumentID"], symbol=data["InstrumentID"],
exchange=EXCHANGE_CTP2VT[data["ExchangeID"]], exchange=EXCHANGE_CTP2VT[data["ExchangeID"]],
@ -555,31 +555,31 @@ class CtpTdApi(TdApi):
pricetick=data["PriceTick"], pricetick=data["PriceTick"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
# For option only # For option only
if contract.product == Product.OPTION: if contract.product == Product.OPTION:
contract.option_underlying = data["UnderlyingInstrID"], contract.option_underlying = data["UnderlyingInstrID"],
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None), contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
contract.option_strike = data["StrikePrice"], contract.option_strike = data["StrikePrice"],
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"), contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"),
self.gateway.on_contract(contract) self.gateway.on_contract(contract)
symbol_exchange_map[contract.symbol] = contract.exchange symbol_exchange_map[contract.symbol] = contract.exchange
symbol_name_map[contract.symbol] = contract.name symbol_name_map[contract.symbol] = contract.name
symbol_size_map[contract.symbol] = contract.size symbol_size_map[contract.symbol] = contract.size
if last: if last:
self.gateway.write_log("合约信息查询成功") self.gateway.write_log("合约信息查询成功")
for data in self.order_data: for data in self.order_data:
self.onRtnOrder(data) self.onRtnOrder(data)
self.order_data.clear() self.order_data.clear()
for data in self.trade_data: for data in self.trade_data:
self.onRtnTrade(data) self.onRtnTrade(data)
self.trade_data.clear() self.trade_data.clear()
def onRtnOrder(self, data: dict): def onRtnOrder(self, data: dict):
""" """
Callback of order status update. Callback of order status update.
@ -589,12 +589,12 @@ class CtpTdApi(TdApi):
if not exchange: if not exchange:
self.order_data.append(data) self.order_data.append(data)
return return
frontid = data["FrontID"] frontid = data["FrontID"]
sessionid = data["SessionID"] sessionid = data["SessionID"]
order_ref = data["OrderRef"] order_ref = data["OrderRef"]
orderid = f"{frontid}_{sessionid}_{order_ref}" orderid = f"{frontid}_{sessionid}_{order_ref}"
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -610,9 +610,9 @@ class CtpTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
self.sysid_orderid_map[data["OrderSysID"]] = orderid self.sysid_orderid_map[data["OrderSysID"]] = orderid
def onRtnTrade(self, data: dict): def onRtnTrade(self, data: dict):
""" """
Callback of trade status update. Callback of trade status update.
@ -624,7 +624,7 @@ class CtpTdApi(TdApi):
return return
orderid = self.sysid_orderid_map[data["OrderSysID"]] orderid = self.sysid_orderid_map[data["OrderSysID"]]
trade = TradeData( trade = TradeData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -637,15 +637,15 @@ class CtpTdApi(TdApi):
time=data["TradeTime"], time=data["TradeTime"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_trade(trade) self.gateway.on_trade(trade)
def connect( def connect(
self, self,
address: str, address: str,
userid: str, userid: str,
password: str, password: str,
brokerid: int, brokerid: int,
auth_code: str, auth_code: str,
appid: str, appid: str,
product_info product_info
): ):
@ -658,21 +658,21 @@ class CtpTdApi(TdApi):
self.auth_code = auth_code self.auth_code = auth_code
self.appid = appid self.appid = appid
self.product_info = product_info self.product_info = product_info
if not self.connect_status: if not self.connect_status:
path = get_folder_path(self.gateway_name.lower()) path = get_folder_path(self.gateway_name.lower())
self.createFtdcTraderApi(str(path) + "\\Td") self.createFtdcTraderApi(str(path) + "\\Td")
self.subscribePrivateTopic(0) self.subscribePrivateTopic(0)
self.subscribePublicTopic(0) self.subscribePublicTopic(0)
self.registerFront(address) self.registerFront(address)
self.init() self.init()
self.connect_status = True self.connect_status = True
else: else:
self.authenticate() self.authenticate()
def authenticate(self): def authenticate(self):
""" """
Authenticate with auth_code and appid. Authenticate with auth_code and appid.
@ -686,10 +686,10 @@ class CtpTdApi(TdApi):
if self.product_info: if self.product_info:
req["UserProductInfo"] = self.product_info req["UserProductInfo"] = self.product_info
self.reqid += 1 self.reqid += 1
self.reqAuthenticate(req, self.reqid) self.reqAuthenticate(req, self.reqid)
def login(self): def login(self):
""" """
Login onto server. Login onto server.
@ -706,16 +706,16 @@ class CtpTdApi(TdApi):
if self.product_info: if self.product_info:
req["UserProductInfo"] = self.product_info req["UserProductInfo"] = self.product_info
self.reqid += 1 self.reqid += 1
self.reqUserLogin(req, self.reqid) self.reqUserLogin(req, self.reqid)
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
""" """
Send new order. Send new order.
""" """
self.order_ref += 1 self.order_ref += 1
ctp_req = { ctp_req = {
"InstrumentID": req.symbol, "InstrumentID": req.symbol,
"LimitPrice": req.price, "LimitPrice": req.price,
@ -735,7 +735,7 @@ class CtpTdApi(TdApi):
"VolumeCondition": THOST_FTDC_VC_AV, "VolumeCondition": THOST_FTDC_VC_AV,
"MinVolume": 1 "MinVolume": 1
} }
if req.type == OrderType.FAK: if req.type == OrderType.FAK:
ctp_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice ctp_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
ctp_req["TimeCondition"] = THOST_FTDC_TC_IOC ctp_req["TimeCondition"] = THOST_FTDC_TC_IOC
@ -743,23 +743,23 @@ class CtpTdApi(TdApi):
elif req.type == OrderType.FOK: elif req.type == OrderType.FOK:
ctp_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice ctp_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
ctp_req["TimeCondition"] = THOST_FTDC_TC_IOC ctp_req["TimeCondition"] = THOST_FTDC_TC_IOC
ctp_req["VolumeCondition"] = THOST_FTDC_VC_CV ctp_req["VolumeCondition"] = THOST_FTDC_VC_CV
self.reqid += 1 self.reqid += 1
self.reqOrderInsert(ctp_req, self.reqid) self.reqOrderInsert(ctp_req, self.reqid)
orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}" orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}"
order = req.create_order_data(orderid, self.gateway_name) order = req.create_order_data(orderid, self.gateway_name)
self.gateway.on_order(order) self.gateway.on_order(order)
return order.vt_orderid return order.vt_orderid
def cancel_order(self, req: CancelRequest): def cancel_order(self, req: CancelRequest):
""" """
Cancel existing order. Cancel existing order.
""" """
frontid, sessionid, order_ref = req.orderid.split("_") frontid, sessionid, order_ref = req.orderid.split("_")
ctp_req = { ctp_req = {
"InstrumentID": req.symbol, "InstrumentID": req.symbol,
"Exchange": req.exchange, "Exchange": req.exchange,
@ -770,32 +770,32 @@ class CtpTdApi(TdApi):
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
} }
self.reqid += 1 self.reqid += 1
self.reqOrderAction(ctp_req, self.reqid) self.reqOrderAction(ctp_req, self.reqid)
def query_account(self): def query_account(self):
""" """
Query account balance data. Query account balance data.
""" """
self.reqid += 1 self.reqid += 1
self.reqQryTradingAccount({}, self.reqid) self.reqQryTradingAccount({}, self.reqid)
def query_position(self): def query_position(self):
""" """
Query position holding data. Query position holding data.
""" """
if not symbol_exchange_map: if not symbol_exchange_map:
return return
req = { req = {
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
} }
self.reqid += 1 self.reqid += 1
self.reqQryInvestorPosition(req, self.reqid) self.reqQryInvestorPosition(req, self.reqid)
def close(self): def close(self):
"""""" """"""
if self.connect_status: if self.connect_status:

View File

@ -366,7 +366,7 @@ class HbdmRestApi(RestClient):
else: else:
for d in data["data"]: for d in data["data"]:
dt = datetime.fromtimestamp(d["id"]) dt = datetime.fromtimestamp(d["id"])
bar = BarData( bar = BarData(
symbol=req.symbol, symbol=req.symbol,
exchange=req.exchange, exchange=req.exchange,
@ -617,7 +617,7 @@ class HbdmRestApi(RestClient):
for d in data["data"]["trades"]: for d in data["data"]["trades"]:
dt = datetime.fromtimestamp(d["create_date"] / 1000) dt = datetime.fromtimestamp(d["create_date"] / 1000)
time = dt.strftime("%H:%M:%S") time = dt.strftime("%H:%M:%S")
trade = TradeData( trade = TradeData(
tradeid=d["match_id"], tradeid=d["match_id"],
orderid=d["order_id"], orderid=d["order_id"],
@ -744,7 +744,7 @@ class HbdmRestApi(RestClient):
Callback when sending order caused exception. Callback when sending order caused exception.
""" """
orders = request.extra orders = request.extra
for order in orders: for order in orders:
order.status = Status.REJECTED order.status = Status.REJECTED
self.gateway.on_order(order) self.gateway.on_order(order)
@ -770,7 +770,7 @@ class HbdmRestApi(RestClient):
"""""" """"""
if data["status"] != "error": if data["status"] != "error":
return False return False
error_code = data["err_code"] error_code = data["err_code"]
error_msg = data["err_msg"] error_msg = data["err_msg"]
@ -795,17 +795,17 @@ class HbdmWebsocketApiBase(WebsocketClient):
self.req_id = 0 self.req_id = 0
def connect( def connect(
self, self,
key: str, key: str,
secret: str, secret: str,
url: str, url: str,
proxy_host: str, proxy_host: str,
proxy_port: int proxy_port: int
): ):
"""""" """"""
self.key = key self.key = key
self.secret = secret self.secret = secret
host, path = _split_url(url) host, path = _split_url(url)
self.sign_host = host self.sign_host = host
self.path = path self.path = path
@ -822,7 +822,7 @@ class HbdmWebsocketApiBase(WebsocketClient):
"type": "api", "type": "api",
"cid": str(self.req_id), "cid": str(self.req_id),
} }
params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret)) params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret))
return self.send_packet(params) return self.send_packet(params)
def on_login(self, packet): def on_login(self, packet):
@ -832,7 +832,7 @@ class HbdmWebsocketApiBase(WebsocketClient):
@staticmethod @staticmethod
def unpack_data(data): def unpack_data(data):
"""""" """"""
return json.loads(zlib.decompress(data, 31)) return json.loads(zlib.decompress(data, 31))
def on_packet(self, packet): def on_packet(self, packet):
"""""" """"""
@ -851,17 +851,17 @@ class HbdmWebsocketApiBase(WebsocketClient):
return self.on_login() return self.on_login()
else: else:
self.on_data(packet) self.on_data(packet)
def on_data(self, packet): def on_data(self, packet):
"""""" """"""
print("data : {}".format(packet)) print("data : {}".format(packet))
def on_error_msg(self, packet): def on_error_msg(self, packet):
"""""" """"""
msg = packet["err-msg"] msg = packet["err-msg"]
if msg == "invalid pong": if msg == "invalid pong":
return return
self.gateway.write_log(packet["err-msg"]) self.gateway.write_log(packet["err-msg"])
@ -900,7 +900,7 @@ class HbdmTradeWebsocketApi(HbdmWebsocketApiBase):
op = packet.get("op", None) op = packet.get("op", None)
if op != "notify": if op != "notify":
return return
topic = packet["topic"] topic = packet["topic"]
if "orders" in topic: if "orders" in topic:
self.on_order(packet) self.on_order(packet)
@ -930,7 +930,7 @@ class HbdmTradeWebsocketApi(HbdmWebsocketApiBase):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
# Push trade event # Push trade event
trades = data["trade"] trades = data["trade"]
if not trades: if not trades:
@ -951,7 +951,7 @@ class HbdmTradeWebsocketApi(HbdmWebsocketApiBase):
volume=d["trade_volume"], volume=d["trade_volume"],
time=time, time=time,
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
) )
self.gateway.on_trade(trade) self.gateway.on_trade(trade)
@ -974,7 +974,7 @@ class HbdmDataWebsocketApi(HbdmWebsocketApiBase):
for ws_symbol in self.ticks.keys(): for ws_symbol in self.ticks.keys():
self.subscribe_data(ws_symbol) self.subscribe_data(ws_symbol)
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
"""""" """"""
contract_type = symbol_type_map.get(req.symbol, "") contract_type = symbol_type_map.get(req.symbol, "")
@ -995,25 +995,25 @@ class HbdmDataWebsocketApi(HbdmWebsocketApiBase):
datetime=datetime.now(), datetime=datetime.now(),
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
) )
self.ticks[ws_symbol] = tick self.ticks[ws_symbol] = tick
self.subscribe_data(ws_symbol) self.subscribe_data(ws_symbol)
def subscribe_data(self, ws_symbol: str): def subscribe_data(self, ws_symbol: str):
"""""" """"""
# Subscribe to market depth update # Subscribe to market depth update
self.req_id += 1 self.req_id += 1
req = { req = {
"sub": f"market.{ws_symbol}.depth.step0", "sub": f"market.{ws_symbol}.depth.step0",
"id": str(self.req_id) "id": str(self.req_id)
} }
self.send_packet(req) self.send_packet(req)
# Subscribe to market detail update # Subscribe to market detail update
self.req_id += 1 self.req_id += 1
req = { req = {
"sub": f"market.{ws_symbol}.detail", "sub": f"market.{ws_symbol}.detail",
"id": str(self.req_id) "id": str(self.req_id)
} }
self.send_packet(req) self.send_packet(req)
@ -1040,7 +1040,7 @@ class HbdmDataWebsocketApi(HbdmWebsocketApiBase):
if "bids" not in tick_data or "asks" not in tick_data: if "bids" not in tick_data or "asks" not in tick_data:
print(data) print(data)
return return
bids = tick_data["bids"] bids = tick_data["bids"]
for n in range(5): for n in range(5):
price, volume = bids[n] price, volume = bids[n]
@ -1061,7 +1061,7 @@ class HbdmDataWebsocketApi(HbdmWebsocketApiBase):
ws_symbol = data["ch"].split(".")[1] ws_symbol = data["ch"].split(".")[1]
tick = self.ticks[ws_symbol] tick = self.ticks[ws_symbol]
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
tick_data = data["tick"] tick_data = data["tick"]
tick.open_price = tick_data["open"] tick.open_price = tick_data["open"]
tick.high_price = tick_data["high"] tick.high_price = tick_data["high"]
@ -1100,16 +1100,16 @@ def create_signature(api_key, method, host, path, secret_key, get_params=None):
sorted_params.extend(list(get_params.items())) sorted_params.extend(list(get_params.items()))
sorted_params = list(sorted(sorted_params)) sorted_params = list(sorted(sorted_params))
encode_params = urllib.parse.urlencode(sorted_params) encode_params = urllib.parse.urlencode(sorted_params)
payload = [method, host, path, encode_params] payload = [method, host, path, encode_params]
payload = "\n".join(payload) payload = "\n".join(payload)
payload = payload.encode(encoding="UTF8") payload = payload.encode(encoding="UTF8")
secret_key = secret_key.encode(encoding="UTF8") secret_key = secret_key.encode(encoding="UTF8")
digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest() digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest()
signature = base64.b64encode(digest) signature = base64.b64encode(digest)
params = dict(sorted_params) params = dict(sorted_params)
params["Signature"] = signature.decode("UTF8") params["Signature"] = signature.decode("UTF8")
return params return params

View File

@ -373,7 +373,7 @@ class HuobiRestApi(RestClient):
name = f"{base_currency.upper()}/{quote_currency.upper()}" name = f"{base_currency.upper()}/{quote_currency.upper()}"
pricetick = 1 / pow(10, d["price-precision"]) pricetick = 1 / pow(10, d["price-precision"])
min_volume = 1 / pow(10, d["amount-precision"]) min_volume = 1 / pow(10, d["amount-precision"])
contract = ContractData( contract = ContractData(
symbol=d["symbol"], symbol=d["symbol"],
exchange=Exchange.HUOBI, exchange=Exchange.HUOBI,
@ -433,13 +433,13 @@ class HuobiRestApi(RestClient):
cancel_request = request.extra cancel_request = request.extra
local_orderid = cancel_request.orderid local_orderid = cancel_request.orderid
order = self.order_manager.get_order_with_local_orderid(local_orderid) order = self.order_manager.get_order_with_local_orderid(local_orderid)
if self.check_error(data, "撤单"): if self.check_error(data, "撤单"):
order.status = Status.REJECTED order.status = Status.REJECTED
else: else:
order.status = Status.CANCELLED order.status = Status.CANCELLED
self.gateway.write_log(f"委托撤单成功:{order.orderid}") self.gateway.write_log(f"委托撤单成功:{order.orderid}")
self.order_manager.on_order(order) self.order_manager.on_order(order)
def on_error( def on_error(
@ -459,7 +459,7 @@ class HuobiRestApi(RestClient):
"""""" """"""
if data["status"] != "error": if data["status"] != "error":
return False return False
error_code = data["err-code"] error_code = data["err-code"]
error_msg = data["err-msg"] error_msg = data["err-msg"]
@ -483,17 +483,17 @@ class HuobiWebsocketApiBase(WebsocketClient):
self.path = "" self.path = ""
def connect( def connect(
self, self,
key: str, key: str,
secret: str, secret: str,
url: str, url: str,
proxy_host: str, proxy_host: str,
proxy_port: int proxy_port: int
): ):
"""""" """"""
self.key = key self.key = key
self.secret = secret self.secret = secret
host, path = _split_url(url) host, path = _split_url(url)
self.sign_host = host self.sign_host = host
self.path = path self.path = path
@ -504,7 +504,7 @@ class HuobiWebsocketApiBase(WebsocketClient):
def login(self): def login(self):
"""""" """"""
params = {"op": "auth"} params = {"op": "auth"}
params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret)) params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret))
return self.send_packet(params) return self.send_packet(params)
def on_login(self, packet): def on_login(self, packet):
@ -514,7 +514,7 @@ class HuobiWebsocketApiBase(WebsocketClient):
@staticmethod @staticmethod
def unpack_data(data): def unpack_data(data):
"""""" """"""
return json.loads(zlib.decompress(data, 31)) return json.loads(zlib.decompress(data, 31))
def on_packet(self, packet): def on_packet(self, packet):
"""""" """"""
@ -533,17 +533,17 @@ class HuobiWebsocketApiBase(WebsocketClient):
return self.on_login() return self.on_login()
else: else:
self.on_data(packet) self.on_data(packet)
def on_data(self, packet): def on_data(self, packet):
"""""" """"""
print("data : {}".format(packet)) print("data : {}".format(packet))
def on_error_msg(self, packet): def on_error_msg(self, packet):
"""""" """"""
msg = packet["err-msg"] msg = packet["err-msg"]
if msg == "invalid pong": if msg == "invalid pong":
return return
self.gateway.write_log(packet["err-msg"]) self.gateway.write_log(packet["err-msg"])
@ -586,7 +586,7 @@ class HuobiTradeWebsocketApi(HuobiWebsocketApiBase):
op = packet.get("op", None) op = packet.get("op", None)
if op != "notify": if op != "notify":
return return
topic = packet["topic"] topic = packet["topic"]
if "orders" in topic: if "orders" in topic:
self.on_order(packet["data"]) self.on_order(packet["data"])
@ -594,19 +594,19 @@ class HuobiTradeWebsocketApi(HuobiWebsocketApiBase):
def on_order(self, data: dict): def on_order(self, data: dict):
"""""" """"""
sys_orderid = str(data["order-id"]) sys_orderid = str(data["order-id"])
order = self.order_manager.get_order_with_sys_orderid(sys_orderid) order = self.order_manager.get_order_with_sys_orderid(sys_orderid)
if not order: if not order:
self.order_manager.add_push_data(sys_orderid, data) self.order_manager.add_push_data(sys_orderid, data)
return return
traded_volume = float(data["filled-amount"]) traded_volume = float(data["filled-amount"])
# Push order event # Push order event
order.traded += traded_volume order.traded += traded_volume
order.status = STATUS_HUOBI2VT.get(data["order-state"], None) order.status = STATUS_HUOBI2VT.get(data["order-state"], None)
self.order_manager.on_order(order) self.order_manager.on_order(order)
# Push trade event # Push trade event
if not traded_volume: if not traded_volume:
return return
@ -621,7 +621,7 @@ class HuobiTradeWebsocketApi(HuobiWebsocketApiBase):
volume=float(data["filled-amount"]), volume=float(data["filled-amount"]),
time=datetime.now().strftime("%H:%M:%S"), time=datetime.now().strftime("%H:%M:%S"),
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
) )
self.gateway.on_trade(trade) self.gateway.on_trade(trade)
@ -642,7 +642,7 @@ class HuobiDataWebsocketApi(HuobiWebsocketApiBase):
def on_connected(self): def on_connected(self):
"""""" """"""
self.gateway.write_log("行情Websocket API连接成功") self.gateway.write_log("行情Websocket API连接成功")
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
"""""" """"""
symbol = req.symbol symbol = req.symbol
@ -655,21 +655,21 @@ class HuobiDataWebsocketApi(HuobiWebsocketApiBase):
datetime=datetime.now(), datetime=datetime.now(),
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
) )
self.ticks[symbol] = tick self.ticks[symbol] = tick
# Subscribe to market depth update # Subscribe to market depth update
self.req_id += 1 self.req_id += 1
req = { req = {
"sub": f"market.{symbol}.depth.step0", "sub": f"market.{symbol}.depth.step0",
"id": str(self.req_id) "id": str(self.req_id)
} }
self.send_packet(req) self.send_packet(req)
# Subscribe to market detail update # Subscribe to market detail update
self.req_id += 1 self.req_id += 1
req = { req = {
"sub": f"market.{symbol}.detail", "sub": f"market.{symbol}.detail",
"id": str(self.req_id) "id": str(self.req_id)
} }
self.send_packet(req) self.send_packet(req)
@ -691,7 +691,7 @@ class HuobiDataWebsocketApi(HuobiWebsocketApiBase):
symbol = data["ch"].split(".")[1] symbol = data["ch"].split(".")[1]
tick = self.ticks[symbol] tick = self.ticks[symbol]
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
bids = data["tick"]["bids"] bids = data["tick"]["bids"]
for n in range(5): for n in range(5):
price, volume = bids[n] price, volume = bids[n]
@ -712,7 +712,7 @@ class HuobiDataWebsocketApi(HuobiWebsocketApiBase):
symbol = data["ch"].split(".")[1] symbol = data["ch"].split(".")[1]
tick = self.ticks[symbol] tick = self.ticks[symbol]
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
tick_data = data["tick"] tick_data = data["tick"]
tick.open_price = float(tick_data["open"]) tick.open_price = float(tick_data["open"])
tick.high_price = float(tick_data["high"]) tick.high_price = float(tick_data["high"])
@ -751,16 +751,16 @@ def create_signature(api_key, method, host, path, secret_key, get_params=None):
sorted_params.extend(list(get_params.items())) sorted_params.extend(list(get_params.items()))
sorted_params = list(sorted(sorted_params)) sorted_params = list(sorted(sorted_params))
encode_params = urllib.parse.urlencode(sorted_params) encode_params = urllib.parse.urlencode(sorted_params)
payload = [method, host, path, encode_params] payload = [method, host, path, encode_params]
payload = "\n".join(payload) payload = "\n".join(payload)
payload = payload.encode(encoding="UTF8") payload = payload.encode(encoding="UTF8")
secret_key = secret_key.encode(encoding="UTF8") secret_key = secret_key.encode(encoding="UTF8")
digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest() digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest()
signature = base64.b64encode(digest) signature = base64.b64encode(digest)
params = dict(sorted_params) params = dict(sorted_params)
params["Signature"] = signature.decode("UTF8") params["Signature"] = signature.decode("UTF8")
return params return params

View File

@ -13,7 +13,7 @@ from vnpy.api.mini import (
THOST_FTDC_OST_PartTradedQueueing, THOST_FTDC_OST_PartTradedQueueing,
THOST_FTDC_OST_AllTraded, THOST_FTDC_OST_AllTraded,
THOST_FTDC_OST_Canceled, THOST_FTDC_OST_Canceled,
THOST_FTDC_D_Buy, THOST_FTDC_D_Buy,
THOST_FTDC_D_Sell, THOST_FTDC_D_Sell,
THOST_FTDC_PD_Long, THOST_FTDC_PD_Long,
THOST_FTDC_PD_Short, THOST_FTDC_PD_Short,
@ -73,7 +73,7 @@ STATUS_MINI2VT = {
} }
DIRECTION_VT2MINI = { DIRECTION_VT2MINI = {
Direction.LONG: THOST_FTDC_D_Buy, Direction.LONG: THOST_FTDC_D_Buy,
Direction.SHORT: THOST_FTDC_D_Sell Direction.SHORT: THOST_FTDC_D_Sell
} }
DIRECTION_MINI2VT = {v: k for k, v in DIRECTION_VT2MINI.items()} DIRECTION_MINI2VT = {v: k for k, v in DIRECTION_VT2MINI.items()}
@ -81,13 +81,13 @@ DIRECTION_MINI2VT[THOST_FTDC_PD_Long] = Direction.LONG
DIRECTION_MINI2VT[THOST_FTDC_PD_Short] = Direction.SHORT DIRECTION_MINI2VT[THOST_FTDC_PD_Short] = Direction.SHORT
ORDERTYPE_VT2MINI = { ORDERTYPE_VT2MINI = {
OrderType.LIMIT: THOST_FTDC_OPT_LimitPrice, OrderType.LIMIT: THOST_FTDC_OPT_LimitPrice,
OrderType.MARKET: THOST_FTDC_OPT_AnyPrice OrderType.MARKET: THOST_FTDC_OPT_AnyPrice
} }
ORDERTYPE_MINI2VT = {v: k for k, v in ORDERTYPE_VT2MINI.items()} ORDERTYPE_MINI2VT = {v: k for k, v in ORDERTYPE_VT2MINI.items()}
OFFSET_VT2MINI = { OFFSET_VT2MINI = {
Offset.OPEN: THOST_FTDC_OF_Open, Offset.OPEN: THOST_FTDC_OF_Open,
Offset.CLOSE: THOST_FTDC_OFEN_Close, Offset.CLOSE: THOST_FTDC_OFEN_Close,
Offset.CLOSETODAY: THOST_FTDC_OFEN_CloseToday, Offset.CLOSETODAY: THOST_FTDC_OFEN_CloseToday,
Offset.CLOSEYESTERDAY: THOST_FTDC_OFEN_CloseYesterday, Offset.CLOSEYESTERDAY: THOST_FTDC_OFEN_CloseYesterday,
@ -136,7 +136,7 @@ class MiniGateway(BaseGateway):
} }
exchanges = list(EXCHANGE_MINI2VT.values()) exchanges = list(EXCHANGE_MINI2VT.values())
def __init__(self, event_engine): def __init__(self, event_engine):
"""Constructor""" """Constructor"""
super().__init__(event_engine, "MINI") super().__init__(event_engine, "MINI")
@ -154,15 +154,15 @@ class MiniGateway(BaseGateway):
appid = setting["产品名称"] appid = setting["产品名称"]
auth_code = setting["授权编码"] auth_code = setting["授权编码"]
product_info = setting["产品信息"] product_info = setting["产品信息"]
if not td_address.startswith("tcp://"): if not td_address.startswith("tcp://"):
td_address = "tcp://" + td_address td_address = "tcp://" + td_address
if not md_address.startswith("tcp://"): if not md_address.startswith("tcp://"):
md_address = "tcp://" + md_address md_address = "tcp://" + md_address
self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info) self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info)
self.md_api.connect(md_address, userid, password, brokerid) self.md_api.connect(md_address, userid, password, brokerid)
self.init_query() self.init_query()
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
@ -195,19 +195,19 @@ class MiniGateway(BaseGateway):
error_id = error["ErrorID"] error_id = error["ErrorID"]
error_msg = error["ErrorMsg"] error_msg = error["ErrorMsg"]
msg = f"{msg},代码:{error_id},信息:{error_msg}" msg = f"{msg},代码:{error_id},信息:{error_msg}"
self.write_log(msg) self.write_log(msg)
def process_timer_event(self, event): def process_timer_event(self, event):
"""""" """"""
self.count += 1 self.count += 1
if self.count < 2: if self.count < 2:
return return
self.count = 0 self.count = 0
func = self.query_functions.pop(0) func = self.query_functions.pop(0)
func() func()
self.query_functions.append(func) self.query_functions.append(func)
def init_query(self): def init_query(self):
"""""" """"""
self.count = 0 self.count = 0
@ -221,20 +221,20 @@ class MiniMdApi(MdApi):
def __init__(self, gateway): def __init__(self, gateway):
"""Constructor""" """Constructor"""
super(MiniMdApi, self).__init__() super(MiniMdApi, self).__init__()
self.gateway = gateway self.gateway = gateway
self.gateway_name = gateway.gateway_name self.gateway_name = gateway.gateway_name
self.reqid = 0 self.reqid = 0
self.connect_status = False self.connect_status = False
self.login_status = False self.login_status = False
self.subscribed = set() self.subscribed = set()
self.userid = "" self.userid = ""
self.password = "" self.password = ""
self.brokerid = "" self.brokerid = ""
def onFrontConnected(self): def onFrontConnected(self):
""" """
Callback when front server is connected. Callback when front server is connected.
@ -256,23 +256,23 @@ class MiniMdApi(MdApi):
if not error["ErrorID"]: if not error["ErrorID"]:
self.login_status = True self.login_status = True
self.gateway.write_log("行情服务器登录成功") self.gateway.write_log("行情服务器登录成功")
for symbol in self.subscribed: for symbol in self.subscribed:
self.subscribeMarketData(symbol) self.subscribeMarketData(symbol)
else: else:
self.gateway.write_error("行情服务器登录失败", error) self.gateway.write_error("行情服务器登录失败", error)
def onRspError(self, error: dict, reqid: int, last: bool): def onRspError(self, error: dict, reqid: int, last: bool):
""" """
Callback when error occured. Callback when error occured.
""" """
self.gateway.write_error("行情接口报错", error) self.gateway.write_error("行情接口报错", error)
def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool): def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error or not error["ErrorID"]: if not error or not error["ErrorID"]:
return return
self.gateway.write_error("行情订阅失败", error) self.gateway.write_error("行情订阅失败", error)
def onRtnDepthMarketData(self, data: dict): def onRtnDepthMarketData(self, data: dict):
@ -283,9 +283,9 @@ class MiniMdApi(MdApi):
exchange = symbol_exchange_map.get(symbol, "") exchange = symbol_exchange_map.get(symbol, "")
if not exchange: if not exchange:
return return
timestamp = f"{data['ActionDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}" timestamp = f"{data['ActionDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}"
tick = TickData( tick = TickData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -328,7 +328,7 @@ class MiniMdApi(MdApi):
tick.ask_volume_4 = data["AskVolume4"] tick.ask_volume_4 = data["AskVolume4"]
tick.ask_volume_5 = data["AskVolume5"] tick.ask_volume_5 = data["AskVolume5"]
self.gateway.on_tick(tick) self.gateway.on_tick(tick)
def connect(self, address: str, userid: str, password: str, brokerid: int): def connect(self, address: str, userid: str, password: str, brokerid: int):
""" """
@ -337,12 +337,12 @@ class MiniMdApi(MdApi):
self.userid = userid self.userid = userid
self.password = password self.password = password
self.brokerid = brokerid self.brokerid = brokerid
# If not connected, then start connection first. # If not connected, then start connection first.
if not self.connect_status: if not self.connect_status:
path = get_folder_path(self.gateway_name.lower()) path = get_folder_path(self.gateway_name.lower())
self.createFtdcMdApi(str(path) + "\\Md") self.createFtdcMdApi(str(path) + "\\Md")
self.registerFront(address) self.registerFront(address)
self.init() self.init()
@ -350,7 +350,7 @@ class MiniMdApi(MdApi):
# If already connected, then login immediately. # If already connected, then login immediately.
elif not self.login_status: elif not self.login_status:
self.login() self.login()
def login(self): def login(self):
""" """
Login onto server. Login onto server.
@ -360,10 +360,10 @@ class MiniMdApi(MdApi):
"Password": self.password, "Password": self.password,
"BrokerID": self.brokerid "BrokerID": self.brokerid
} }
self.reqid += 1 self.reqid += 1
self.reqUserLogin(req, self.reqid) self.reqUserLogin(req, self.reqid)
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
""" """
Subscribe to tick data update. Subscribe to tick data update.
@ -371,7 +371,7 @@ class MiniMdApi(MdApi):
if self.login_status: if self.login_status:
self.subscribeMarketData(req.symbol) self.subscribeMarketData(req.symbol)
self.subscribed.add(req.symbol) self.subscribed.add(req.symbol)
def close(self): def close(self):
""" """
Close the connection. Close the connection.
@ -386,47 +386,47 @@ class MiniTdApi(TdApi):
def __init__(self, gateway): def __init__(self, gateway):
"""Constructor""" """Constructor"""
super(MiniTdApi, self).__init__() super(MiniTdApi, self).__init__()
self.gateway = gateway self.gateway = gateway
self.gateway_name = gateway.gateway_name self.gateway_name = gateway.gateway_name
self.reqid = 0 self.reqid = 0
self.order_ref = 0 self.order_ref = 0
self.connect_status = False self.connect_status = False
self.login_status = False self.login_status = False
self.auth_staus = False self.auth_staus = False
self.login_failed = False self.login_failed = False
self.userid = "" self.userid = ""
self.password = "" self.password = ""
self.brokerid = "" self.brokerid = ""
self.auth_code = "" self.auth_code = ""
self.appid = "" self.appid = ""
self.product_info = "" self.product_info = ""
self.frontid = 0 self.frontid = 0
self.sessionid = 0 self.sessionid = 0
self.order_data = [] self.order_data = []
self.trade_data = [] self.trade_data = []
self.positions = {} self.positions = {}
self.sysid_orderid_map = {} self.sysid_orderid_map = {}
def onFrontConnected(self): def onFrontConnected(self):
"""""" """"""
self.gateway.write_log("交易服务器连接成功") self.gateway.write_log("交易服务器连接成功")
if self.auth_code: if self.auth_code:
self.authenticate() self.authenticate()
else: else:
self.login() self.login()
def onFrontDisconnected(self, reason: int): def onFrontDisconnected(self, reason: int):
"""""" """"""
self.login_status = False self.login_status = False
self.gateway.write_log(f"交易服务器连接断开,原因{reason}") self.gateway.write_log(f"交易服务器连接断开,原因{reason}")
def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool): def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error['ErrorID']: if not error['ErrorID']:
@ -435,7 +435,7 @@ class MiniTdApi(TdApi):
self.login() self.login()
else: else:
self.gateway.write_error("交易服务器授权验证失败", error) self.gateway.write_error("交易服务器授权验证失败", error)
def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool): def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error["ErrorID"]: if not error["ErrorID"]:
@ -443,23 +443,23 @@ class MiniTdApi(TdApi):
self.sessionid = data["SessionID"] self.sessionid = data["SessionID"]
self.login_status = True self.login_status = True
self.gateway.write_log("交易服务器登录成功") self.gateway.write_log("交易服务器登录成功")
# Get instrument data directly without confirm settlement # Get instrument data directly without confirm settlement
self.reqid += 1 self.reqid += 1
self.reqQryInstrument({}, self.reqid) self.reqQryInstrument({}, self.reqid)
else: else:
self.login_failed = True self.login_failed = True
self.gateway.write_error("交易服务器登录失败", error) self.gateway.write_error("交易服务器登录失败", error)
def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool): def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
order_ref = data["OrderRef"] order_ref = data["OrderRef"]
orderid = f"{self.frontid}_{self.sessionid}_{order_ref}" orderid = f"{self.frontid}_{self.sessionid}_{order_ref}"
symbol = data["InstrumentID"] symbol = data["InstrumentID"]
exchange = symbol_exchange_map[symbol] exchange = symbol_exchange_map[symbol]
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -472,23 +472,23 @@ class MiniTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
self.gateway.write_error("交易委托失败", error) self.gateway.write_error("交易委托失败", error)
def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool): def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
self.gateway.write_error("交易撤单失败", error) self.gateway.write_error("交易撤单失败", error)
def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool): def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
pass pass
def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool): def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool):
""" """
Callback of settlment info confimation. Callback of settlment info confimation.
""" """
pass pass
def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if data: if data:
@ -503,7 +503,7 @@ class MiniTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.positions[key] = position self.positions[key] = position
# For SHFE position data update # For SHFE position data update
if position.exchange == Exchange.SHFE: if position.exchange == Exchange.SHFE:
if data["YdPosition"] and not data["TodayPosition"]: if data["YdPosition"] and not data["TodayPosition"]:
@ -511,34 +511,34 @@ class MiniTdApi(TdApi):
# For other exchange position data update # For other exchange position data update
else: else:
position.yd_volume = data["Position"] - data["TodayPosition"] position.yd_volume = data["Position"] - data["TodayPosition"]
# Get contract size (spread contract has no size value) # Get contract size (spread contract has no size value)
size = symbol_size_map.get(position.symbol, 0) size = symbol_size_map.get(position.symbol, 0)
# Calculate previous position cost # Calculate previous position cost
cost = position.price * position.volume * size cost = position.price * position.volume * size
# Update new position volume # Update new position volume
position.volume += data["Position"] position.volume += data["Position"]
position.pnl += data["PositionProfit"] position.pnl += data["PositionProfit"]
# Calculate average position price # Calculate average position price
if position.volume and size: if position.volume and size:
cost += data["PositionCost"] cost += data["PositionCost"]
position.price = cost / (position.volume * size) position.price = cost / (position.volume * size)
# Get frozen volume # Get frozen volume
if position.direction == Direction.LONG: if position.direction == Direction.LONG:
position.frozen += data["ShortFrozen"] position.frozen += data["ShortFrozen"]
else: else:
position.frozen += data["LongFrozen"] position.frozen += data["LongFrozen"]
if last: if last:
for position in self.positions.values(): for position in self.positions.values():
self.gateway.on_position(position) self.gateway.on_position(position)
self.positions.clear() self.positions.clear()
def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if "AccountID" not in data: if "AccountID" not in data:
@ -551,15 +551,15 @@ class MiniTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
account.available = data["Available"] account.available = data["Available"]
self.gateway.on_account(account) self.gateway.on_account(account)
def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool):
""" """
Callback of instrument query. Callback of instrument query.
""" """
product = PRODUCT_MINI2VT.get(data.get("ProductClass", None), None) product = PRODUCT_MINI2VT.get(data.get("ProductClass", None), None)
if product: if product:
contract = ContractData( contract = ContractData(
symbol=data["InstrumentID"], symbol=data["InstrumentID"],
exchange=EXCHANGE_MINI2VT[data["ExchangeID"]], exchange=EXCHANGE_MINI2VT[data["ExchangeID"]],
@ -569,31 +569,31 @@ class MiniTdApi(TdApi):
pricetick=data["PriceTick"], pricetick=data["PriceTick"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
# For option only # For option only
if contract.product == Product.OPTION: if contract.product == Product.OPTION:
contract.option_underlying = data["UnderlyingInstrID"], contract.option_underlying = data["UnderlyingInstrID"],
contract.option_type = OPTIONTYPE_MINI2VT.get(data["OptionsType"], None), contract.option_type = OPTIONTYPE_MINI2VT.get(data["OptionsType"], None),
contract.option_strike = data["StrikePrice"], contract.option_strike = data["StrikePrice"],
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"), contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"),
self.gateway.on_contract(contract) self.gateway.on_contract(contract)
symbol_exchange_map[contract.symbol] = contract.exchange symbol_exchange_map[contract.symbol] = contract.exchange
symbol_name_map[contract.symbol] = contract.name symbol_name_map[contract.symbol] = contract.name
symbol_size_map[contract.symbol] = contract.size symbol_size_map[contract.symbol] = contract.size
if last: if last:
self.gateway.write_log("合约信息查询成功") self.gateway.write_log("合约信息查询成功")
for data in self.order_data: for data in self.order_data:
self.onRtnOrder(data) self.onRtnOrder(data)
self.order_data.clear() self.order_data.clear()
for data in self.trade_data: for data in self.trade_data:
self.onRtnTrade(data) self.onRtnTrade(data)
self.trade_data.clear() self.trade_data.clear()
def onRtnOrder(self, data: dict): def onRtnOrder(self, data: dict):
""" """
Callback of order status update. Callback of order status update.
@ -603,12 +603,12 @@ class MiniTdApi(TdApi):
if not exchange: if not exchange:
self.order_data.append(data) self.order_data.append(data)
return return
frontid = data["FrontID"] frontid = data["FrontID"]
sessionid = data["SessionID"] sessionid = data["SessionID"]
order_ref = data["OrderRef"] order_ref = data["OrderRef"]
orderid = f"{frontid}_{sessionid}_{order_ref}" orderid = f"{frontid}_{sessionid}_{order_ref}"
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -624,9 +624,9 @@ class MiniTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
self.sysid_orderid_map[data["OrderSysID"]] = orderid self.sysid_orderid_map[data["OrderSysID"]] = orderid
def onRtnTrade(self, data: dict): def onRtnTrade(self, data: dict):
""" """
Callback of trade status update. Callback of trade status update.
@ -638,7 +638,7 @@ class MiniTdApi(TdApi):
return return
orderid = self.sysid_orderid_map[data["OrderSysID"]] orderid = self.sysid_orderid_map[data["OrderSysID"]]
trade = TradeData( trade = TradeData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -651,15 +651,15 @@ class MiniTdApi(TdApi):
time=data["TradeTime"], time=data["TradeTime"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_trade(trade) self.gateway.on_trade(trade)
def connect( def connect(
self, self,
address: str, address: str,
userid: str, userid: str,
password: str, password: str,
brokerid: int, brokerid: int,
auth_code: str, auth_code: str,
appid: str, appid: str,
product_info product_info
): ):
@ -672,21 +672,21 @@ class MiniTdApi(TdApi):
self.auth_code = auth_code self.auth_code = auth_code
self.appid = appid self.appid = appid
self.product_info = product_info self.product_info = product_info
if not self.connect_status: if not self.connect_status:
path = get_folder_path(self.gateway_name.lower()) path = get_folder_path(self.gateway_name.lower())
self.createFtdcTraderApi(str(path) + "\\Td") self.createFtdcTraderApi(str(path) + "\\Td")
self.subscribePrivateTopic(0) self.subscribePrivateTopic(0)
self.subscribePublicTopic(0) self.subscribePublicTopic(0)
self.registerFront(address) self.registerFront(address)
self.init() self.init()
self.connect_status = True self.connect_status = True
else: else:
self.authenticate() self.authenticate()
def authenticate(self): def authenticate(self):
""" """
Authenticate with auth_code and appid. Authenticate with auth_code and appid.
@ -700,10 +700,10 @@ class MiniTdApi(TdApi):
if self.product_info: if self.product_info:
req["UserProductInfo"] = self.product_info req["UserProductInfo"] = self.product_info
self.reqid += 1 self.reqid += 1
self.reqAuthenticate(req, self.reqid) self.reqAuthenticate(req, self.reqid)
def login(self): def login(self):
""" """
Login onto server. Login onto server.
@ -720,16 +720,16 @@ class MiniTdApi(TdApi):
if self.product_info: if self.product_info:
req["UserProductInfo"] = self.product_info req["UserProductInfo"] = self.product_info
self.reqid += 1 self.reqid += 1
self.reqUserLogin(req, self.reqid) self.reqUserLogin(req, self.reqid)
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
""" """
Send new order. Send new order.
""" """
self.order_ref += 1 self.order_ref += 1
mini_req = { mini_req = {
"InstrumentID": req.symbol, "InstrumentID": req.symbol,
"ExchangeID": req.exchange.value, "ExchangeID": req.exchange.value,
@ -750,7 +750,7 @@ class MiniTdApi(TdApi):
"VolumeCondition": THOST_FTDC_VC_AV, "VolumeCondition": THOST_FTDC_VC_AV,
"MinVolume": 1 "MinVolume": 1
} }
if req.type == OrderType.FAK: if req.type == OrderType.FAK:
mini_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice mini_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
mini_req["TimeCondition"] = THOST_FTDC_TC_IOC mini_req["TimeCondition"] = THOST_FTDC_TC_IOC
@ -758,23 +758,23 @@ class MiniTdApi(TdApi):
elif req.type == OrderType.FOK: elif req.type == OrderType.FOK:
mini_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice mini_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
mini_req["TimeCondition"] = THOST_FTDC_TC_IOC mini_req["TimeCondition"] = THOST_FTDC_TC_IOC
mini_req["VolumeCondition"] = THOST_FTDC_VC_CV mini_req["VolumeCondition"] = THOST_FTDC_VC_CV
self.reqid += 1 self.reqid += 1
self.reqOrderInsert(mini_req, self.reqid) self.reqOrderInsert(mini_req, self.reqid)
orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}" orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}"
order = req.create_order_data(orderid, self.gateway_name) order = req.create_order_data(orderid, self.gateway_name)
self.gateway.on_order(order) self.gateway.on_order(order)
return order.vt_orderid return order.vt_orderid
def cancel_order(self, req: CancelRequest): def cancel_order(self, req: CancelRequest):
""" """
Cancel existing order. Cancel existing order.
""" """
frontid, sessionid, order_ref = req.orderid.split("_") frontid, sessionid, order_ref = req.orderid.split("_")
mini_req = { mini_req = {
"InstrumentID": req.symbol, "InstrumentID": req.symbol,
"ExchangeID": req.exchange.value, "ExchangeID": req.exchange.value,
@ -785,32 +785,32 @@ class MiniTdApi(TdApi):
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
} }
self.reqid += 1 self.reqid += 1
self.reqOrderAction(mini_req, self.reqid) self.reqOrderAction(mini_req, self.reqid)
def query_account(self): def query_account(self):
""" """
Query account balance data. Query account balance data.
""" """
self.reqid += 1 self.reqid += 1
self.reqQryTradingAccount({}, self.reqid) self.reqQryTradingAccount({}, self.reqid)
def query_position(self): def query_position(self):
""" """
Query position holding data. Query position holding data.
""" """
if not symbol_exchange_map: if not symbol_exchange_map:
return return
req = { req = {
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
} }
self.reqid += 1 self.reqid += 1
self.reqQryInvestorPosition(req, self.reqid) self.reqQryInvestorPosition(req, self.reqid)
def close(self): def close(self):
"""""" """"""
if self.connect_status: if self.connect_status:

View File

@ -11,7 +11,7 @@ from vnpy.api.mini import (
THOST_FTDC_OST_PartTradedQueueing, THOST_FTDC_OST_PartTradedQueueing,
THOST_FTDC_OST_AllTraded, THOST_FTDC_OST_AllTraded,
THOST_FTDC_OST_Canceled, THOST_FTDC_OST_Canceled,
THOST_FTDC_D_Buy, THOST_FTDC_D_Buy,
THOST_FTDC_D_Sell, THOST_FTDC_D_Sell,
THOST_FTDC_PD_Long, THOST_FTDC_PD_Long,
THOST_FTDC_PD_Short, THOST_FTDC_PD_Short,
@ -73,7 +73,7 @@ STATUS_MINI2VT = {
} }
DIRECTION_VT2MINI = { DIRECTION_VT2MINI = {
Direction.LONG: THOST_FTDC_D_Buy, Direction.LONG: THOST_FTDC_D_Buy,
Direction.SHORT: THOST_FTDC_D_Sell Direction.SHORT: THOST_FTDC_D_Sell
} }
DIRECTION_MINI2VT = {v: k for k, v in DIRECTION_VT2MINI.items()} DIRECTION_MINI2VT = {v: k for k, v in DIRECTION_VT2MINI.items()}
@ -81,13 +81,13 @@ DIRECTION_MINI2VT[THOST_FTDC_PD_Long] = Direction.LONG
DIRECTION_MINI2VT[THOST_FTDC_PD_Short] = Direction.SHORT DIRECTION_MINI2VT[THOST_FTDC_PD_Short] = Direction.SHORT
ORDERTYPE_VT2MINI = { ORDERTYPE_VT2MINI = {
OrderType.LIMIT: THOST_FTDC_OPT_LimitPrice, OrderType.LIMIT: THOST_FTDC_OPT_LimitPrice,
OrderType.MARKET: THOST_FTDC_OPT_AnyPrice OrderType.MARKET: THOST_FTDC_OPT_AnyPrice
} }
ORDERTYPE_MINI2VT = {v: k for k, v in ORDERTYPE_VT2MINI.items()} ORDERTYPE_MINI2VT = {v: k for k, v in ORDERTYPE_VT2MINI.items()}
OFFSET_VT2MINI = { OFFSET_VT2MINI = {
Offset.OPEN: THOST_FTDC_OF_Open, Offset.OPEN: THOST_FTDC_OF_Open,
Offset.CLOSE: THOST_FTDC_OFEN_Close, Offset.CLOSE: THOST_FTDC_OFEN_Close,
Offset.CLOSETODAY: THOST_FTDC_OFEN_CloseToday, Offset.CLOSETODAY: THOST_FTDC_OFEN_CloseToday,
Offset.CLOSEYESTERDAY: THOST_FTDC_OFEN_CloseYesterday, Offset.CLOSEYESTERDAY: THOST_FTDC_OFEN_CloseYesterday,
@ -136,7 +136,7 @@ class MinitestGateway(BaseGateway):
} }
exchanges = list(EXCHANGE_MINI2VT.values()) exchanges = list(EXCHANGE_MINI2VT.values())
def __init__(self, event_engine): def __init__(self, event_engine):
"""Constructor""" """Constructor"""
super().__init__(event_engine, "MINITEST") super().__init__(event_engine, "MINITEST")
@ -154,15 +154,15 @@ class MinitestGateway(BaseGateway):
appid = setting["产品名称"] appid = setting["产品名称"]
auth_code = setting["授权编码"] auth_code = setting["授权编码"]
product_info = setting["产品信息"] product_info = setting["产品信息"]
if not td_address.startswith("tcp://"): if not td_address.startswith("tcp://"):
td_address = "tcp://" + td_address td_address = "tcp://" + td_address
if not md_address.startswith("tcp://"): if not md_address.startswith("tcp://"):
md_address = "tcp://" + md_address md_address = "tcp://" + md_address
self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info) self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info)
self.md_api.connect(md_address, userid, password, brokerid) self.md_api.connect(md_address, userid, password, brokerid)
self.init_query() self.init_query()
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
@ -195,19 +195,19 @@ class MinitestGateway(BaseGateway):
error_id = error["ErrorID"] error_id = error["ErrorID"]
error_msg = error["ErrorMsg"] error_msg = error["ErrorMsg"]
msg = f"{msg},代码:{error_id},信息:{error_msg}" msg = f"{msg},代码:{error_id},信息:{error_msg}"
self.write_log(msg) self.write_log(msg)
def process_timer_event(self, event): def process_timer_event(self, event):
"""""" """"""
self.count += 1 self.count += 1
if self.count < 2: if self.count < 2:
return return
self.count = 0 self.count = 0
func = self.query_functions.pop(0) func = self.query_functions.pop(0)
func() func()
self.query_functions.append(func) self.query_functions.append(func)
def init_query(self): def init_query(self):
"""""" """"""
self.count = 0 self.count = 0
@ -221,20 +221,20 @@ class MiniMdApi(MdApi):
def __init__(self, gateway): def __init__(self, gateway):
"""Constructor""" """Constructor"""
super(MiniMdApi, self).__init__() super(MiniMdApi, self).__init__()
self.gateway = gateway self.gateway = gateway
self.gateway_name = gateway.gateway_name self.gateway_name = gateway.gateway_name
self.reqid = 0 self.reqid = 0
self.connect_status = False self.connect_status = False
self.login_status = False self.login_status = False
self.subscribed = set() self.subscribed = set()
self.userid = "" self.userid = ""
self.password = "" self.password = ""
self.brokerid = "" self.brokerid = ""
def onFrontConnected(self): def onFrontConnected(self):
""" """
Callback when front server is connected. Callback when front server is connected.
@ -256,23 +256,23 @@ class MiniMdApi(MdApi):
if not error["ErrorID"]: if not error["ErrorID"]:
self.login_status = True self.login_status = True
self.gateway.write_log("行情服务器登录成功") self.gateway.write_log("行情服务器登录成功")
for symbol in self.subscribed: for symbol in self.subscribed:
self.subscribeMarketData(symbol) self.subscribeMarketData(symbol)
else: else:
self.gateway.write_error("行情服务器登录失败", error) self.gateway.write_error("行情服务器登录失败", error)
def onRspError(self, error: dict, reqid: int, last: bool): def onRspError(self, error: dict, reqid: int, last: bool):
""" """
Callback when error occured. Callback when error occured.
""" """
self.gateway.write_error("行情接口报错", error) self.gateway.write_error("行情接口报错", error)
def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool): def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error or not error["ErrorID"]: if not error or not error["ErrorID"]:
return return
self.gateway.write_error("行情订阅失败", error) self.gateway.write_error("行情订阅失败", error)
def onRtnDepthMarketData(self, data: dict): def onRtnDepthMarketData(self, data: dict):
@ -283,9 +283,9 @@ class MiniMdApi(MdApi):
exchange = symbol_exchange_map.get(symbol, "") exchange = symbol_exchange_map.get(symbol, "")
if not exchange: if not exchange:
return return
timestamp = f"{data['ActionDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}" timestamp = f"{data['ActionDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}"
tick = TickData( tick = TickData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -328,7 +328,7 @@ class MiniMdApi(MdApi):
tick.ask_volume_4 = data["AskVolume4"] tick.ask_volume_4 = data["AskVolume4"]
tick.ask_volume_5 = data["AskVolume5"] tick.ask_volume_5 = data["AskVolume5"]
self.gateway.on_tick(tick) self.gateway.on_tick(tick)
def connect(self, address: str, userid: str, password: str, brokerid: int): def connect(self, address: str, userid: str, password: str, brokerid: int):
""" """
@ -337,12 +337,12 @@ class MiniMdApi(MdApi):
self.userid = userid self.userid = userid
self.password = password self.password = password
self.brokerid = brokerid self.brokerid = brokerid
# If not connected, then start connection first. # If not connected, then start connection first.
if not self.connect_status: if not self.connect_status:
path = get_folder_path(self.gateway_name.lower()) path = get_folder_path(self.gateway_name.lower())
self.createFtdcMdApi(str(path) + "\\Md") self.createFtdcMdApi(str(path) + "\\Md")
self.registerFront(address) self.registerFront(address)
self.init() self.init()
@ -350,7 +350,7 @@ class MiniMdApi(MdApi):
# If already connected, then login immediately. # If already connected, then login immediately.
elif not self.login_status: elif not self.login_status:
self.login() self.login()
def login(self): def login(self):
""" """
Login onto server. Login onto server.
@ -360,10 +360,10 @@ class MiniMdApi(MdApi):
"Password": self.password, "Password": self.password,
"BrokerID": self.brokerid "BrokerID": self.brokerid
} }
self.reqid += 1 self.reqid += 1
self.reqUserLogin(req, self.reqid) self.reqUserLogin(req, self.reqid)
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
""" """
Subscribe to tick data update. Subscribe to tick data update.
@ -371,7 +371,7 @@ class MiniMdApi(MdApi):
if self.login_status: if self.login_status:
self.subscribeMarketData(req.symbol) self.subscribeMarketData(req.symbol)
self.subscribed.add(req.symbol) self.subscribed.add(req.symbol)
def close(self): def close(self):
""" """
Close the connection. Close the connection.
@ -386,47 +386,47 @@ class MiniTdApi(TdApi):
def __init__(self, gateway): def __init__(self, gateway):
"""Constructor""" """Constructor"""
super(MiniTdApi, self).__init__() super(MiniTdApi, self).__init__()
self.gateway = gateway self.gateway = gateway
self.gateway_name = gateway.gateway_name self.gateway_name = gateway.gateway_name
self.reqid = 0 self.reqid = 0
self.order_ref = 0 self.order_ref = 0
self.connect_status = False self.connect_status = False
self.login_status = False self.login_status = False
self.auth_staus = False self.auth_staus = False
self.login_failed = False self.login_failed = False
self.userid = "" self.userid = ""
self.password = "" self.password = ""
self.brokerid = "" self.brokerid = ""
self.auth_code = "" self.auth_code = ""
self.appid = "" self.appid = ""
self.product_info = "" self.product_info = ""
self.frontid = 0 self.frontid = 0
self.sessionid = 0 self.sessionid = 0
self.order_data = [] self.order_data = []
self.trade_data = [] self.trade_data = []
self.positions = {} self.positions = {}
self.sysid_orderid_map = {} self.sysid_orderid_map = {}
def onFrontConnected(self): def onFrontConnected(self):
"""""" """"""
self.gateway.write_log("交易服务器连接成功") self.gateway.write_log("交易服务器连接成功")
if self.auth_code: if self.auth_code:
self.authenticate() self.authenticate()
else: else:
self.login() self.login()
def onFrontDisconnected(self, reason: int): def onFrontDisconnected(self, reason: int):
"""""" """"""
self.login_status = False self.login_status = False
self.gateway.write_log(f"交易服务器连接断开,原因{reason}") self.gateway.write_log(f"交易服务器连接断开,原因{reason}")
def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool): def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error['ErrorID']: if not error['ErrorID']:
@ -435,7 +435,7 @@ class MiniTdApi(TdApi):
self.login() self.login()
else: else:
self.gateway.write_error("交易服务器授权验证失败", error) self.gateway.write_error("交易服务器授权验证失败", error)
def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool): def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error["ErrorID"]: if not error["ErrorID"]:
@ -443,23 +443,23 @@ class MiniTdApi(TdApi):
self.sessionid = data["SessionID"] self.sessionid = data["SessionID"]
self.login_status = True self.login_status = True
self.gateway.write_log("交易服务器登录成功") self.gateway.write_log("交易服务器登录成功")
# Get instrument data directly without confirm settlement # Get instrument data directly without confirm settlement
self.reqid += 1 self.reqid += 1
self.reqQryInstrument({}, self.reqid) self.reqQryInstrument({}, self.reqid)
else: else:
self.login_failed = True self.login_failed = True
self.gateway.write_error("交易服务器登录失败", error) self.gateway.write_error("交易服务器登录失败", error)
def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool): def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
order_ref = data["OrderRef"] order_ref = data["OrderRef"]
orderid = f"{self.frontid}_{self.sessionid}_{order_ref}" orderid = f"{self.frontid}_{self.sessionid}_{order_ref}"
symbol = data["InstrumentID"] symbol = data["InstrumentID"]
exchange = symbol_exchange_map[symbol] exchange = symbol_exchange_map[symbol]
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -472,23 +472,23 @@ class MiniTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
self.gateway.write_error("交易委托失败", error) self.gateway.write_error("交易委托失败", error)
def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool): def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
self.gateway.write_error("交易撤单失败", error) self.gateway.write_error("交易撤单失败", error)
def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool): def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
pass pass
def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool): def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool):
""" """
Callback of settlment info confimation. Callback of settlment info confimation.
""" """
pass pass
def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if data: if data:
@ -503,7 +503,7 @@ class MiniTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.positions[key] = position self.positions[key] = position
# For SHFE position data update # For SHFE position data update
if position.exchange == Exchange.SHFE: if position.exchange == Exchange.SHFE:
if data["YdPosition"] and not data["TodayPosition"]: if data["YdPosition"] and not data["TodayPosition"]:
@ -511,34 +511,34 @@ class MiniTdApi(TdApi):
# For other exchange position data update # For other exchange position data update
else: else:
position.yd_volume = data["Position"] - data["TodayPosition"] position.yd_volume = data["Position"] - data["TodayPosition"]
# Get contract size (spread contract has no size value) # Get contract size (spread contract has no size value)
size = symbol_size_map.get(position.symbol, 0) size = symbol_size_map.get(position.symbol, 0)
# Calculate previous position cost # Calculate previous position cost
cost = position.price * position.volume * size cost = position.price * position.volume * size
# Update new position volume # Update new position volume
position.volume += data["Position"] position.volume += data["Position"]
position.pnl += data["PositionProfit"] position.pnl += data["PositionProfit"]
# Calculate average position price # Calculate average position price
if position.volume and size: if position.volume and size:
cost += data["PositionCost"] cost += data["PositionCost"]
position.price = cost / (position.volume * size) position.price = cost / (position.volume * size)
# Get frozen volume # Get frozen volume
if position.direction == Direction.LONG: if position.direction == Direction.LONG:
position.frozen += data["ShortFrozen"] position.frozen += data["ShortFrozen"]
else: else:
position.frozen += data["LongFrozen"] position.frozen += data["LongFrozen"]
if last: if last:
for position in self.positions.values(): for position in self.positions.values():
self.gateway.on_position(position) self.gateway.on_position(position)
self.positions.clear() self.positions.clear()
def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if "AccountID" not in data: if "AccountID" not in data:
@ -551,15 +551,15 @@ class MiniTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
account.available = data["Available"] account.available = data["Available"]
self.gateway.on_account(account) self.gateway.on_account(account)
def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool):
""" """
Callback of instrument query. Callback of instrument query.
""" """
product = PRODUCT_MINI2VT.get(data.get("ProductClass", None), None) product = PRODUCT_MINI2VT.get(data.get("ProductClass", None), None)
if product: if product:
contract = ContractData( contract = ContractData(
symbol=data["InstrumentID"], symbol=data["InstrumentID"],
exchange=EXCHANGE_MINI2VT[data["ExchangeID"]], exchange=EXCHANGE_MINI2VT[data["ExchangeID"]],
@ -569,31 +569,31 @@ class MiniTdApi(TdApi):
pricetick=data["PriceTick"], pricetick=data["PriceTick"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
# For option only # For option only
if contract.product == Product.OPTION: if contract.product == Product.OPTION:
contract.option_underlying = data["UnderlyingInstrID"], contract.option_underlying = data["UnderlyingInstrID"],
contract.option_type = OPTIONTYPE_MINI2VT.get(data["OptionsType"], None), contract.option_type = OPTIONTYPE_MINI2VT.get(data["OptionsType"], None),
contract.option_strike = data["StrikePrice"], contract.option_strike = data["StrikePrice"],
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"), contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"),
self.gateway.on_contract(contract) self.gateway.on_contract(contract)
symbol_exchange_map[contract.symbol] = contract.exchange symbol_exchange_map[contract.symbol] = contract.exchange
symbol_name_map[contract.symbol] = contract.name symbol_name_map[contract.symbol] = contract.name
symbol_size_map[contract.symbol] = contract.size symbol_size_map[contract.symbol] = contract.size
if last: if last:
self.gateway.write_log("合约信息查询成功") self.gateway.write_log("合约信息查询成功")
for data in self.order_data: for data in self.order_data:
self.onRtnOrder(data) self.onRtnOrder(data)
self.order_data.clear() self.order_data.clear()
for data in self.trade_data: for data in self.trade_data:
self.onRtnTrade(data) self.onRtnTrade(data)
self.trade_data.clear() self.trade_data.clear()
def onRtnOrder(self, data: dict): def onRtnOrder(self, data: dict):
""" """
Callback of order status update. Callback of order status update.
@ -603,12 +603,12 @@ class MiniTdApi(TdApi):
if not exchange: if not exchange:
self.order_data.append(data) self.order_data.append(data)
return return
frontid = data["FrontID"] frontid = data["FrontID"]
sessionid = data["SessionID"] sessionid = data["SessionID"]
order_ref = data["OrderRef"] order_ref = data["OrderRef"]
orderid = f"{frontid}_{sessionid}_{order_ref}" orderid = f"{frontid}_{sessionid}_{order_ref}"
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -624,9 +624,9 @@ class MiniTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
self.sysid_orderid_map[data["OrderSysID"]] = orderid self.sysid_orderid_map[data["OrderSysID"]] = orderid
def onRtnTrade(self, data: dict): def onRtnTrade(self, data: dict):
""" """
Callback of trade status update. Callback of trade status update.
@ -638,7 +638,7 @@ class MiniTdApi(TdApi):
return return
orderid = self.sysid_orderid_map[data["OrderSysID"]] orderid = self.sysid_orderid_map[data["OrderSysID"]]
trade = TradeData( trade = TradeData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -651,15 +651,15 @@ class MiniTdApi(TdApi):
time=data["TradeTime"], time=data["TradeTime"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_trade(trade) self.gateway.on_trade(trade)
def connect( def connect(
self, self,
address: str, address: str,
userid: str, userid: str,
password: str, password: str,
brokerid: int, brokerid: int,
auth_code: str, auth_code: str,
appid: str, appid: str,
product_info product_info
): ):
@ -672,21 +672,21 @@ class MiniTdApi(TdApi):
self.auth_code = auth_code self.auth_code = auth_code
self.appid = appid self.appid = appid
self.product_info = product_info self.product_info = product_info
if not self.connect_status: if not self.connect_status:
path = get_folder_path(self.gateway_name.lower()) path = get_folder_path(self.gateway_name.lower())
self.createFtdcTraderApi(str(path) + "\\Td") self.createFtdcTraderApi(str(path) + "\\Td")
self.subscribePrivateTopic(0) self.subscribePrivateTopic(0)
self.subscribePublicTopic(0) self.subscribePublicTopic(0)
self.registerFront(address) self.registerFront(address)
self.init() self.init()
self.connect_status = True self.connect_status = True
else: else:
self.authenticate() self.authenticate()
def authenticate(self): def authenticate(self):
""" """
Authenticate with auth_code and appid. Authenticate with auth_code and appid.
@ -700,10 +700,10 @@ class MiniTdApi(TdApi):
if self.product_info: if self.product_info:
req["UserProductInfo"] = self.product_info req["UserProductInfo"] = self.product_info
self.reqid += 1 self.reqid += 1
self.reqAuthenticate(req, self.reqid) self.reqAuthenticate(req, self.reqid)
def login(self): def login(self):
""" """
Login onto server. Login onto server.
@ -720,16 +720,16 @@ class MiniTdApi(TdApi):
if self.product_info: if self.product_info:
req["UserProductInfo"] = self.product_info req["UserProductInfo"] = self.product_info
self.reqid += 1 self.reqid += 1
self.reqUserLogin(req, self.reqid) self.reqUserLogin(req, self.reqid)
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
""" """
Send new order. Send new order.
""" """
self.order_ref += 1 self.order_ref += 1
mini_req = { mini_req = {
"InstrumentID": req.symbol, "InstrumentID": req.symbol,
"ExchangeID": req.exchange.value, "ExchangeID": req.exchange.value,
@ -750,7 +750,7 @@ class MiniTdApi(TdApi):
"VolumeCondition": THOST_FTDC_VC_AV, "VolumeCondition": THOST_FTDC_VC_AV,
"MinVolume": 1 "MinVolume": 1
} }
if req.type == OrderType.FAK: if req.type == OrderType.FAK:
mini_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice mini_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
mini_req["TimeCondition"] = THOST_FTDC_TC_IOC mini_req["TimeCondition"] = THOST_FTDC_TC_IOC
@ -758,23 +758,23 @@ class MiniTdApi(TdApi):
elif req.type == OrderType.FOK: elif req.type == OrderType.FOK:
mini_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice mini_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
mini_req["TimeCondition"] = THOST_FTDC_TC_IOC mini_req["TimeCondition"] = THOST_FTDC_TC_IOC
mini_req["VolumeCondition"] = THOST_FTDC_VC_CV mini_req["VolumeCondition"] = THOST_FTDC_VC_CV
self.reqid += 1 self.reqid += 1
self.reqOrderInsert(mini_req, self.reqid) self.reqOrderInsert(mini_req, self.reqid)
orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}" orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}"
order = req.create_order_data(orderid, self.gateway_name) order = req.create_order_data(orderid, self.gateway_name)
self.gateway.on_order(order) self.gateway.on_order(order)
return order.vt_orderid return order.vt_orderid
def cancel_order(self, req: CancelRequest): def cancel_order(self, req: CancelRequest):
""" """
Cancel existing order. Cancel existing order.
""" """
frontid, sessionid, order_ref = req.orderid.split("_") frontid, sessionid, order_ref = req.orderid.split("_")
mini_req = { mini_req = {
"InstrumentID": req.symbol, "InstrumentID": req.symbol,
"ExchangeID": req.exchange.value, "ExchangeID": req.exchange.value,
@ -785,32 +785,32 @@ class MiniTdApi(TdApi):
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
} }
self.reqid += 1 self.reqid += 1
self.reqOrderAction(mini_req, self.reqid) self.reqOrderAction(mini_req, self.reqid)
def query_account(self): def query_account(self):
""" """
Query account balance data. Query account balance data.
""" """
self.reqid += 1 self.reqid += 1
self.reqQryTradingAccount({}, self.reqid) self.reqQryTradingAccount({}, self.reqid)
def query_position(self): def query_position(self):
""" """
Query position holding data. Query position holding data.
""" """
if not symbol_exchange_map: if not symbol_exchange_map:
return return
req = { req = {
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
} }
self.reqid += 1 self.reqid += 1
self.reqQryInvestorPosition(req, self.reqid) self.reqQryInvestorPosition(req, self.reqid)
def close(self): def close(self):
"""""" """"""
if self.connect_status: if self.connect_status:

View File

@ -84,7 +84,7 @@ class OkexfGateway(BaseGateway):
"API Key": "", "API Key": "",
"Secret Key": "", "Secret Key": "",
"Passphrase": "", "Passphrase": "",
"Leverage": 10, "Leverage": 10,
"会话数": 3, "会话数": 3,
"代理地址": "", "代理地址": "",
"代理端口": "", "代理端口": "",
@ -248,7 +248,7 @@ class OkexfRestApi(RestClient):
return "" return ""
orderid = f"a{self.connect_time}{self._new_order_id()}" orderid = f"a{self.connect_time}{self._new_order_id()}"
data = { data = {
"client_oid": orderid, "client_oid": orderid,
"type": TYPE_VT2OKEXF[(req.offset, req.direction)], "type": TYPE_VT2OKEXF[(req.offset, req.direction)],
@ -377,8 +377,8 @@ class OkexfRestApi(RestClient):
balance=float(d["equity"]), balance=float(d["equity"]),
frozen=float(d.get("margin_for_unfilled", 0)), frozen=float(d.get("margin_for_unfilled", 0)),
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
) )
self.gateway.on_account(account) self.gateway.on_account(account)
self.gateway.write_log("账户资金查询成功") self.gateway.write_log("账户资金查询成功")
def on_query_position(self, data, request): def on_query_position(self, data, request):
@ -447,7 +447,7 @@ class OkexfRestApi(RestClient):
order = request.extra order = request.extra
order.status = Status.REJECTED order.status = Status.REJECTED
order.time = datetime.now().strftime("%H:%M:%S.%f") order.time = datetime.now().strftime("%H:%M:%S.%f")
self.gateway.on_order(order) self.gateway.on_order(order)
msg = f"委托失败,状态码:{status_code},信息:{request.response.text}" msg = f"委托失败,状态码:{status_code},信息:{request.response.text}"
self.gateway.write_log(msg) self.gateway.write_log(msg)
@ -532,12 +532,12 @@ class OkexfRestApi(RestClient):
for i in range(10): for i in range(10):
path = f"/api/futures/v3/instruments/{req.symbol}/candles" path = f"/api/futures/v3/instruments/{req.symbol}/candles"
# Create query params # Create query params
params = { params = {
"granularity": INTERVAL_VT2OKEXF[req.interval] "granularity": INTERVAL_VT2OKEXF[req.interval]
} }
if end_time: if end_time:
params["end"] = end_time params["end"] = end_time
@ -586,7 +586,7 @@ class OkexfRestApi(RestClient):
index = list(buf.keys()) index = list(buf.keys())
index.sort() index.sort()
history = [buf[i] for i in index] history = [buf[i] for i in index]
return history return history
@ -886,7 +886,7 @@ class OkexfWebsocketApi(WebsocketClient):
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
) )
self.gateway.on_position(pos) self.gateway.on_position(pos)
def generate_signature(msg: str, secret_key: str): def generate_signature(msg: str, secret_key: str):
"""OKEX V3 signature""" """OKEX V3 signature"""
@ -901,6 +901,6 @@ def get_timestamp():
def utc_to_local(timestamp): def utc_to_local(timestamp):
time = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") time = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
utc_time = time + timedelta(hours=8) utc_time = time + timedelta(hours=8)
return utc_time return utc_time

View File

@ -13,7 +13,7 @@ from vnpy.api.sopt import (
THOST_FTDC_OST_PartTradedQueueing, THOST_FTDC_OST_PartTradedQueueing,
THOST_FTDC_OST_AllTraded, THOST_FTDC_OST_AllTraded,
THOST_FTDC_OST_Canceled, THOST_FTDC_OST_Canceled,
THOST_FTDC_D_Buy, THOST_FTDC_D_Buy,
THOST_FTDC_D_Sell, THOST_FTDC_D_Sell,
THOST_FTDC_PD_Long, THOST_FTDC_PD_Long,
THOST_FTDC_PD_Short, THOST_FTDC_PD_Short,
@ -73,7 +73,7 @@ STATUS_SOPT2VT = {
} }
DIRECTION_VT2SOPT = { DIRECTION_VT2SOPT = {
Direction.LONG: THOST_FTDC_D_Buy, Direction.LONG: THOST_FTDC_D_Buy,
Direction.SHORT: THOST_FTDC_D_Sell Direction.SHORT: THOST_FTDC_D_Sell
} }
DIRECTION_SOPT2VT = {v: k for k, v in DIRECTION_VT2SOPT.items()} DIRECTION_SOPT2VT = {v: k for k, v in DIRECTION_VT2SOPT.items()}
@ -81,13 +81,13 @@ DIRECTION_SOPT2VT[THOST_FTDC_PD_Long] = Direction.LONG
DIRECTION_SOPT2VT[THOST_FTDC_PD_Short] = Direction.SHORT DIRECTION_SOPT2VT[THOST_FTDC_PD_Short] = Direction.SHORT
ORDERTYPE_VT2SOPT = { ORDERTYPE_VT2SOPT = {
OrderType.LIMIT: THOST_FTDC_OPT_LimitPrice, OrderType.LIMIT: THOST_FTDC_OPT_LimitPrice,
OrderType.MARKET: THOST_FTDC_OPT_AnyPrice OrderType.MARKET: THOST_FTDC_OPT_AnyPrice
} }
ORDERTYPE_SOPT2VT = {v: k for k, v in ORDERTYPE_VT2SOPT.items()} ORDERTYPE_SOPT2VT = {v: k for k, v in ORDERTYPE_VT2SOPT.items()}
OFFSET_VT2SOPT = { OFFSET_VT2SOPT = {
Offset.OPEN: THOST_FTDC_OF_Open, Offset.OPEN: THOST_FTDC_OF_Open,
Offset.CLOSE: THOST_FTDC_OFEN_Close, Offset.CLOSE: THOST_FTDC_OFEN_Close,
Offset.CLOSETODAY: THOST_FTDC_OFEN_CloseToday, Offset.CLOSETODAY: THOST_FTDC_OFEN_CloseToday,
Offset.CLOSEYESTERDAY: THOST_FTDC_OFEN_CloseYesterday, Offset.CLOSEYESTERDAY: THOST_FTDC_OFEN_CloseYesterday,
@ -133,7 +133,7 @@ class SoptGateway(BaseGateway):
} }
exchanges = list(EXCHANGE_SOPT2VT.values()) exchanges = list(EXCHANGE_SOPT2VT.values())
def __init__(self, event_engine): def __init__(self, event_engine):
"""Constructor""" """Constructor"""
super().__init__(event_engine, "SOPT") super().__init__(event_engine, "SOPT")
@ -151,15 +151,15 @@ class SoptGateway(BaseGateway):
appid = setting["产品名称"] appid = setting["产品名称"]
auth_code = setting["授权编码"] auth_code = setting["授权编码"]
product_info = setting["产品信息"] product_info = setting["产品信息"]
if not td_address.startswith("tcp://"): if not td_address.startswith("tcp://"):
td_address = "tcp://" + td_address td_address = "tcp://" + td_address
if not md_address.startswith("tcp://"): if not md_address.startswith("tcp://"):
md_address = "tcp://" + md_address md_address = "tcp://" + md_address
self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info) self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info)
self.md_api.connect(md_address, userid, password, brokerid) self.md_api.connect(md_address, userid, password, brokerid)
self.init_query() self.init_query()
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
@ -192,19 +192,19 @@ class SoptGateway(BaseGateway):
error_id = error["ErrorID"] error_id = error["ErrorID"]
error_msg = error["ErrorMsg"] error_msg = error["ErrorMsg"]
msg = f"{msg},代码:{error_id},信息:{error_msg}" msg = f"{msg},代码:{error_id},信息:{error_msg}"
self.write_log(msg) self.write_log(msg)
def process_timer_event(self, event): def process_timer_event(self, event):
"""""" """"""
self.count += 1 self.count += 1
if self.count < 2: if self.count < 2:
return return
self.count = 0 self.count = 0
func = self.query_functions.pop(0) func = self.query_functions.pop(0)
func() func()
self.query_functions.append(func) self.query_functions.append(func)
def init_query(self): def init_query(self):
"""""" """"""
self.count = 0 self.count = 0
@ -218,20 +218,20 @@ class SoptMdApi(MdApi):
def __init__(self, gateway): def __init__(self, gateway):
"""Constructor""" """Constructor"""
super(SoptMdApi, self).__init__() super(SoptMdApi, self).__init__()
self.gateway = gateway self.gateway = gateway
self.gateway_name = gateway.gateway_name self.gateway_name = gateway.gateway_name
self.reqid = 0 self.reqid = 0
self.connect_status = False self.connect_status = False
self.login_status = False self.login_status = False
self.subscribed = set() self.subscribed = set()
self.userid = "" self.userid = ""
self.password = "" self.password = ""
self.brokerid = "" self.brokerid = ""
def onFrontConnected(self): def onFrontConnected(self):
""" """
Callback when front server is connected. Callback when front server is connected.
@ -253,23 +253,23 @@ class SoptMdApi(MdApi):
if not error["ErrorID"]: if not error["ErrorID"]:
self.login_status = True self.login_status = True
self.gateway.write_log("行情服务器登录成功") self.gateway.write_log("行情服务器登录成功")
for symbol in self.subscribed: for symbol in self.subscribed:
self.subscribeMarketData(symbol) self.subscribeMarketData(symbol)
else: else:
self.gateway.write_error("行情服务器登录失败", error) self.gateway.write_error("行情服务器登录失败", error)
def onRspError(self, error: dict, reqid: int, last: bool): def onRspError(self, error: dict, reqid: int, last: bool):
""" """
Callback when error occured. Callback when error occured.
""" """
self.gateway.write_error("行情接口报错", error) self.gateway.write_error("行情接口报错", error)
def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool): def onRspSubMarketData(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error or not error["ErrorID"]: if not error or not error["ErrorID"]:
return return
self.gateway.write_error("行情订阅失败", error) self.gateway.write_error("行情订阅失败", error)
def onRtnDepthMarketData(self, data: dict): def onRtnDepthMarketData(self, data: dict):
@ -281,7 +281,7 @@ class SoptMdApi(MdApi):
if not exchange: if not exchange:
return return
timestamp = f"{data['TradingDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}" timestamp = f"{data['TradingDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}"
tick = TickData( tick = TickData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -302,7 +302,7 @@ class SoptMdApi(MdApi):
ask_volume_1=data["AskVolume1"], ask_volume_1=data["AskVolume1"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_tick(tick) self.gateway.on_tick(tick)
def connect(self, address: str, userid: str, password: str, brokerid: int): def connect(self, address: str, userid: str, password: str, brokerid: int):
""" """
@ -311,7 +311,7 @@ class SoptMdApi(MdApi):
self.userid = userid self.userid = userid
self.password = password self.password = password
self.brokerid = brokerid self.brokerid = brokerid
# If not connected, then start connection first. # If not connected, then start connection first.
if not self.connect_status: if not self.connect_status:
path = get_folder_path(self.gateway_name.lower()) path = get_folder_path(self.gateway_name.lower())
@ -322,7 +322,7 @@ class SoptMdApi(MdApi):
# If already connected, then login immediately. # If already connected, then login immediately.
elif not self.login_status: elif not self.login_status:
self.login() self.login()
def login(self): def login(self):
""" """
Login onto server. Login onto server.
@ -332,10 +332,10 @@ class SoptMdApi(MdApi):
"Password": self.password, "Password": self.password,
"BrokerID": self.brokerid "BrokerID": self.brokerid
} }
self.reqid += 1 self.reqid += 1
self.reqUserLogin(req, self.reqid) self.reqUserLogin(req, self.reqid)
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
""" """
Subscribe to tick data update. Subscribe to tick data update.
@ -343,7 +343,7 @@ class SoptMdApi(MdApi):
if self.login_status: if self.login_status:
self.subscribeMarketData(req.symbol) self.subscribeMarketData(req.symbol)
self.subscribed.add(req.symbol) self.subscribed.add(req.symbol)
def close(self): def close(self):
""" """
Close the connection. Close the connection.
@ -358,49 +358,49 @@ class SoptTdApi(TdApi):
def __init__(self, gateway): def __init__(self, gateway):
"""Constructor""" """Constructor"""
super(SoptTdApi, self).__init__() super(SoptTdApi, self).__init__()
self.test = [] self.test = []
self.gateway = gateway self.gateway = gateway
self.gateway_name = gateway.gateway_name self.gateway_name = gateway.gateway_name
self.reqid = 0 self.reqid = 0
self.order_ref = 0 self.order_ref = 0
self.connect_status = False self.connect_status = False
self.login_status = False self.login_status = False
self.auth_staus = False self.auth_staus = False
self.login_failed = False self.login_failed = False
self.userid = "" self.userid = ""
self.password = "" self.password = ""
self.brokerid = "" self.brokerid = ""
self.auth_code = "" self.auth_code = ""
self.appid = "" self.appid = ""
self.product_info = "" self.product_info = ""
self.frontid = 0 self.frontid = 0
self.sessionid = 0 self.sessionid = 0
self.order_data = [] self.order_data = []
self.trade_data = [] self.trade_data = []
self.positions = {} self.positions = {}
self.sysid_orderid_map = {} self.sysid_orderid_map = {}
def onFrontConnected(self): def onFrontConnected(self):
"""""" """"""
self.gateway.write_log("交易服务器连接成功") self.gateway.write_log("交易服务器连接成功")
if self.auth_code: if self.auth_code:
self.authenticate() self.authenticate()
else: else:
self.login() self.login()
def onFrontDisconnected(self, reason: int): def onFrontDisconnected(self, reason: int):
"""""" """"""
self.login_status = False self.login_status = False
self.gateway.write_log(f"交易服务器连接断开,原因{reason}") self.gateway.write_log(f"交易服务器连接断开,原因{reason}")
def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool): def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error['ErrorID']: if not error['ErrorID']:
@ -409,7 +409,7 @@ class SoptTdApi(TdApi):
self.login() self.login()
else: else:
self.gateway.write_error("交易服务器授权验证失败", error) self.gateway.write_error("交易服务器授权验证失败", error)
def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool): def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not error["ErrorID"]: if not error["ErrorID"]:
@ -417,7 +417,7 @@ class SoptTdApi(TdApi):
self.sessionid = data["SessionID"] self.sessionid = data["SessionID"]
self.login_status = True self.login_status = True
self.gateway.write_log("交易服务器登录成功") self.gateway.write_log("交易服务器登录成功")
# Confirm settlement # Confirm settlement
req = { req = {
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
@ -427,17 +427,17 @@ class SoptTdApi(TdApi):
self.reqSettlementInfoConfirm(req, self.reqid) self.reqSettlementInfoConfirm(req, self.reqid)
else: else:
self.login_failed = True self.login_failed = True
self.gateway.write_error("交易服务器登录失败", error) self.gateway.write_error("交易服务器登录失败", error)
def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool): def onRspOrderInsert(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
order_ref = data["OrderRef"] order_ref = data["OrderRef"]
orderid = f"{self.frontid}_{self.sessionid}_{order_ref}" orderid = f"{self.frontid}_{self.sessionid}_{order_ref}"
symbol = data["InstrumentID"] symbol = data["InstrumentID"]
exchange = symbol_exchange_map[symbol] exchange = symbol_exchange_map[symbol]
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -450,31 +450,31 @@ class SoptTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
self.gateway.write_error("交易委托失败", error) self.gateway.write_error("交易委托失败", error)
def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool): def onRspOrderAction(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
self.gateway.write_error("交易撤单失败", error) self.gateway.write_error("交易撤单失败", error)
def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool): def onRspQueryMaxOrderVolume(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
pass pass
def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool): def onRspSettlementInfoConfirm(self, data: dict, error: dict, reqid: int, last: bool):
""" """
Callback of settlment info confimation. Callback of settlment info confimation.
""" """
self.gateway.write_log("结算信息确认成功") self.gateway.write_log("结算信息确认成功")
self.reqid += 1 self.reqid += 1
self.reqQryInstrument({}, self.reqid) self.reqQryInstrument({}, self.reqid)
def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
if not data: if not data:
return return
# Get buffered position object # Get buffered position object
key = f"{data['InstrumentID'], data['PosiDirection']}" key = f"{data['InstrumentID'], data['PosiDirection']}"
position = self.positions.get(key, None) position = self.positions.get(key, None)
@ -486,7 +486,7 @@ class SoptTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.positions[key] = position self.positions[key] = position
# For SHFE position data update # For SHFE position data update
if position.exchange == Exchange.SHFE: if position.exchange == Exchange.SHFE:
if data["YdPosition"] and not data["TodayPosition"]: if data["YdPosition"] and not data["TodayPosition"]:
@ -494,34 +494,34 @@ class SoptTdApi(TdApi):
# For other exchange position data update # For other exchange position data update
else: else:
position.yd_volume = data["Position"] - data["TodayPosition"] position.yd_volume = data["Position"] - data["TodayPosition"]
# Get contract size (spread contract has no size value) # Get contract size (spread contract has no size value)
size = symbol_size_map.get(position.symbol, 0) size = symbol_size_map.get(position.symbol, 0)
# Calculate previous position cost # Calculate previous position cost
cost = position.price * position.volume * size cost = position.price * position.volume * size
# Update new position volume # Update new position volume
position.volume += data["Position"] position.volume += data["Position"]
position.pnl += data["PositionProfit"] position.pnl += data["PositionProfit"]
# Calculate average position price # Calculate average position price
if position.volume and size: if position.volume and size:
cost += data["PositionCost"] cost += data["PositionCost"]
position.price = cost / (position.volume * size) position.price = cost / (position.volume * size)
# Get frozen volume # Get frozen volume
if position.direction == Direction.LONG: if position.direction == Direction.LONG:
position.frozen += data["ShortFrozen"] position.frozen += data["ShortFrozen"]
else: else:
position.frozen += data["LongFrozen"] position.frozen += data["LongFrozen"]
if last: if last:
for position in self.positions.values(): for position in self.positions.values():
self.gateway.on_position(position) self.gateway.on_position(position)
self.positions.clear() self.positions.clear()
def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool):
"""""" """"""
account = AccountData( account = AccountData(
@ -531,16 +531,16 @@ class SoptTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
account.available = data["Available"] account.available = data["Available"]
self.gateway.on_account(account) self.gateway.on_account(account)
def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool): def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool):
""" """
Callback of instrument query. Callback of instrument query.
""" """
product = PRODUCT_SOPT2VT.get(data["ProductClass"], None) product = PRODUCT_SOPT2VT.get(data["ProductClass"], None)
if product: if product:
contract = ContractData( contract = ContractData(
symbol=data["InstrumentID"], symbol=data["InstrumentID"],
exchange=EXCHANGE_SOPT2VT[data["ExchangeID"]], exchange=EXCHANGE_SOPT2VT[data["ExchangeID"]],
@ -550,31 +550,31 @@ class SoptTdApi(TdApi):
pricetick=data["PriceTick"], pricetick=data["PriceTick"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
# For option only # For option only
if contract.product == Product.OPTION: if contract.product == Product.OPTION:
contract.option_underlying = data["UnderlyingInstrID"], contract.option_underlying = data["UnderlyingInstrID"],
contract.option_type = OPTIONTYPE_SOPT2VT.get(data["OptionsType"], None), contract.option_type = OPTIONTYPE_SOPT2VT.get(data["OptionsType"], None),
contract.option_strike = data["StrikePrice"], contract.option_strike = data["StrikePrice"],
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"), contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"),
self.gateway.on_contract(contract) self.gateway.on_contract(contract)
symbol_exchange_map[contract.symbol] = contract.exchange symbol_exchange_map[contract.symbol] = contract.exchange
symbol_name_map[contract.symbol] = contract.name symbol_name_map[contract.symbol] = contract.name
symbol_size_map[contract.symbol] = contract.size symbol_size_map[contract.symbol] = contract.size
if last: if last:
self.gateway.write_log("合约信息查询成功") self.gateway.write_log("合约信息查询成功")
for data in self.order_data: for data in self.order_data:
self.onRtnOrder(data) self.onRtnOrder(data)
self.order_data.clear() self.order_data.clear()
for data in self.trade_data: for data in self.trade_data:
self.onRtnTrade(data) self.onRtnTrade(data)
self.trade_data.clear() self.trade_data.clear()
def onRtnOrder(self, data: dict): def onRtnOrder(self, data: dict):
""" """
Callback of order status update. Callback of order status update.
@ -584,12 +584,12 @@ class SoptTdApi(TdApi):
if not exchange: if not exchange:
self.order_data.append(data) self.order_data.append(data)
return return
frontid = data["FrontID"] frontid = data["FrontID"]
sessionid = data["SessionID"] sessionid = data["SessionID"]
order_ref = data["OrderRef"] order_ref = data["OrderRef"]
orderid = f"{frontid}_{sessionid}_{order_ref}" orderid = f"{frontid}_{sessionid}_{order_ref}"
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -605,9 +605,9 @@ class SoptTdApi(TdApi):
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_order(order) self.gateway.on_order(order)
self.sysid_orderid_map[data["OrderSysID"]] = orderid self.sysid_orderid_map[data["OrderSysID"]] = orderid
def onRtnTrade(self, data: dict): def onRtnTrade(self, data: dict):
""" """
Callback of trade status update. Callback of trade status update.
@ -619,7 +619,7 @@ class SoptTdApi(TdApi):
return return
orderid = self.sysid_orderid_map[data["OrderSysID"]] orderid = self.sysid_orderid_map[data["OrderSysID"]]
trade = TradeData( trade = TradeData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -632,15 +632,15 @@ class SoptTdApi(TdApi):
time=data["TradeTime"], time=data["TradeTime"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
self.gateway.on_trade(trade) self.gateway.on_trade(trade)
def connect( def connect(
self, self,
address: str, address: str,
userid: str, userid: str,
password: str, password: str,
brokerid: int, brokerid: int,
auth_code: str, auth_code: str,
appid: str, appid: str,
product_info product_info
): ):
@ -653,21 +653,21 @@ class SoptTdApi(TdApi):
self.auth_code = auth_code self.auth_code = auth_code
self.appid = appid self.appid = appid
self.product_info = product_info self.product_info = product_info
if not self.connect_status: if not self.connect_status:
path = get_folder_path(self.gateway_name.lower()) path = get_folder_path(self.gateway_name.lower())
self.createFtdcTraderApi(str(path) + "\\Td") self.createFtdcTraderApi(str(path) + "\\Td")
self.subscribePrivateTopic(0) self.subscribePrivateTopic(0)
self.subscribePublicTopic(0) self.subscribePublicTopic(0)
self.registerFront(address) self.registerFront(address)
self.init() self.init()
self.connect_status = True self.connect_status = True
else: else:
self.authenticate() self.authenticate()
def authenticate(self): def authenticate(self):
""" """
Authenticate with auth_code and appid. Authenticate with auth_code and appid.
@ -681,10 +681,10 @@ class SoptTdApi(TdApi):
if self.product_info: if self.product_info:
req["UserProductInfo"] = self.product_info req["UserProductInfo"] = self.product_info
self.reqid += 1 self.reqid += 1
self.reqAuthenticate(req, self.reqid) self.reqAuthenticate(req, self.reqid)
def login(self): def login(self):
""" """
Login onto server. Login onto server.
@ -701,16 +701,16 @@ class SoptTdApi(TdApi):
if self.product_info: if self.product_info:
req["UserProductInfo"] = self.product_info req["UserProductInfo"] = self.product_info
self.reqid += 1 self.reqid += 1
self.reqUserLogin(req, self.reqid) self.reqUserLogin(req, self.reqid)
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
""" """
Send new order. Send new order.
""" """
self.order_ref += 1 self.order_ref += 1
sopt_req = { sopt_req = {
"InstrumentID": req.symbol, "InstrumentID": req.symbol,
"ExchangeID": req.exchange.value, "ExchangeID": req.exchange.value,
@ -731,7 +731,7 @@ class SoptTdApi(TdApi):
"VolumeCondition": THOST_FTDC_VC_AV, "VolumeCondition": THOST_FTDC_VC_AV,
"MinVolume": 1 "MinVolume": 1
} }
if req.type == OrderType.FAK: if req.type == OrderType.FAK:
sopt_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice sopt_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
sopt_req["TimeCondition"] = THOST_FTDC_TC_IOC sopt_req["TimeCondition"] = THOST_FTDC_TC_IOC
@ -739,23 +739,23 @@ class SoptTdApi(TdApi):
elif req.type == OrderType.FOK: elif req.type == OrderType.FOK:
sopt_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice sopt_req["OrderPriceType"] = THOST_FTDC_OPT_LimitPrice
sopt_req["TimeCondition"] = THOST_FTDC_TC_IOC sopt_req["TimeCondition"] = THOST_FTDC_TC_IOC
sopt_req["VolumeCondition"] = THOST_FTDC_VC_CV sopt_req["VolumeCondition"] = THOST_FTDC_VC_CV
self.reqid += 1 self.reqid += 1
self.reqOrderInsert(sopt_req, self.reqid) self.reqOrderInsert(sopt_req, self.reqid)
orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}" orderid = f"{self.frontid}_{self.sessionid}_{self.order_ref}"
order = req.create_order_data(orderid, self.gateway_name) order = req.create_order_data(orderid, self.gateway_name)
self.gateway.on_order(order) self.gateway.on_order(order)
return order.vt_orderid return order.vt_orderid
def cancel_order(self, req: CancelRequest): def cancel_order(self, req: CancelRequest):
""" """
Cancel existing order. Cancel existing order.
""" """
frontid, sessionid, order_ref = req.orderid.split("_") frontid, sessionid, order_ref = req.orderid.split("_")
sopt_req = { sopt_req = {
"InstrumentID": req.symbol, "InstrumentID": req.symbol,
"Exchange": req.exchange, "Exchange": req.exchange,
@ -766,32 +766,32 @@ class SoptTdApi(TdApi):
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
} }
self.reqid += 1 self.reqid += 1
self.reqOrderAction(sopt_req, self.reqid) self.reqOrderAction(sopt_req, self.reqid)
def query_account(self): def query_account(self):
""" """
Query account balance data. Query account balance data.
""" """
self.reqid += 1 self.reqid += 1
self.reqQryTradingAccount({}, self.reqid) self.reqQryTradingAccount({}, self.reqid)
def query_position(self): def query_position(self):
""" """
Query position holding data. Query position holding data.
""" """
if not symbol_exchange_map: if not symbol_exchange_map:
return return
req = { req = {
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
} }
self.reqid += 1 self.reqid += 1
self.reqQryInvestorPosition(req, self.reqid) self.reqQryInvestorPosition(req, self.reqid)
def close(self): def close(self):
"""""" """"""
if self.connect_status: if self.connect_status:

View File

@ -222,7 +222,7 @@ class QuoteApi(ITapQuoteAPINotify):
def OnAPIReady(self): def OnAPIReady(self):
""" """
Callback when API is ready for sending requests or queries. Callback when API is ready for sending requests or queries.
""" """
self.api.QryCommodity() self.api.QryCommodity()
@ -400,7 +400,7 @@ class TradeApi(ITapTradeAPINotify):
def OnAPIReady(self, code: int): def OnAPIReady(self, code: int):
""" """
Callback when API is ready for sending requests or queries. Callback when API is ready for sending requests or queries.
""" """
self.api.QryCommodity() self.api.QryCommodity()

View File

@ -512,7 +512,7 @@ class XtpTraderApi(API.TraderSpi):
self.margin_trading = False self.margin_trading = False
self.option_trading = False self.option_trading = False
# #
self.short_positions = {} self.short_positions = {}
def connect( def connect(
@ -791,13 +791,13 @@ class XtpTraderApi(API.TraderSpi):
"""""" """"""
pass pass
def OnQueryCreditDebtInfo(self, debt_info: XTPCrdDebtInfo, error_info: XTPRspInfoStruct, def OnQueryCreditDebtInfo(self, debt_info: XTPCrdDebtInfo, error_info: XTPRspInfoStruct,
request_id: int, is_last: bool, session_id: int) -> Any: request_id: int, is_last: bool, session_id: int) -> Any:
"""""" """"""
if debt_info.debt_type == 1: if debt_info.debt_type == 1:
symbol = debt_info.ticker symbol = debt_info.ticker
exchange = MARKET_XTP2VT[debt_info.market] exchange = MARKET_XTP2VT[debt_info.market]
position = self.short_positions.get(symbol, None) position = self.short_positions.get(symbol, None)
if not position: if not position:
position = PositionData( position = PositionData(
@ -809,7 +809,7 @@ class XtpTraderApi(API.TraderSpi):
self.short_positions[symbol] = position self.short_positions[symbol] = position
position.volume += debt_info.remain_qty position.volume += debt_info.remain_qty
if is_last: if is_last:
for position in self.short_positions.values(): for position in self.short_positions.values():
self.gateway.on_position(position) self.gateway.on_position(position)

View File

@ -293,7 +293,7 @@ class LogEngine(BaseEngine):
def add_file_handler(self): def add_file_handler(self):
""" """
Add file output of log. Add file output of log.
""" """
today_date = datetime.now().strftime("%Y%m%d") today_date = datetime.now().strftime("%Y%m%d")
filename = f"vt_{today_date}.log" filename = f"vt_{today_date}.log"

View File

@ -213,7 +213,7 @@ class BaseGateway(ABC):
def send_orders(self, reqs: Sequence[OrderRequest]): def send_orders(self, reqs: Sequence[OrderRequest]):
""" """
Send a batch of orders to server. Send a batch of orders to server.
Use a for loop of send_order function by default. Use a for loop of send_order function by default.
Reimplement this function if batch order supported on server. Reimplement this function if batch order supported on server.
""" """
vt_orderids = [] vt_orderids = []
@ -227,7 +227,7 @@ class BaseGateway(ABC):
def cancel_orders(self, reqs: Sequence[CancelRequest]): def cancel_orders(self, reqs: Sequence[CancelRequest]):
""" """
Cancel a batch of orders to server. Cancel a batch of orders to server.
Use a for loop of cancel_order function by default. Use a for loop of cancel_order function by default.
Reimplement this function if batch cancel supported on server. Reimplement this function if batch cancel supported on server.
""" """
for req in reqs: for req in reqs:

View File

@ -14,7 +14,7 @@ ACTIVE_STATUSES = set([Status.SUBMITTING, Status.NOTTRADED, Status.PARTTRADED])
@dataclass @dataclass
class BaseData: class BaseData:
""" """
Any data object needs a gateway_name as source Any data object needs a gateway_name as source
and should inherit base data. and should inherit base data.
""" """
@ -102,7 +102,7 @@ class BarData(BaseData):
@dataclass @dataclass
class OrderData(BaseData): class OrderData(BaseData):
""" """
Order data contains information for tracking lastest status Order data contains information for tracking lastest status
of a specific order. of a specific order.
""" """

View File

@ -132,7 +132,7 @@ class PnlCell(BaseCell):
def set_content(self, content: Any, data: Any): def set_content(self, content: Any, data: Any):
""" """
Cell color is set based on whether pnl is Cell color is set based on whether pnl is
positive or negative. positive or negative.
""" """
super(PnlCell, self).set_content(content, data) super(PnlCell, self).set_content(content, data)
@ -993,7 +993,7 @@ class AboutDialog(QtWidgets.QDialog):
text = """ text = """
Developed by Traders, for Traders. Developed by Traders, for Traders.
LicenseMIT LicenseMIT
Websitewww.vnpy.com Websitewww.vnpy.com
Githubwww.github.com/vnpy/vnpy Githubwww.github.com/vnpy/vnpy

View File

@ -119,7 +119,7 @@ def round_to(value: float, target: float):
class BarGenerator: class BarGenerator:
""" """
For: For:
1. generating 1 minute bar data from tick data 1. generating 1 minute bar data from tick data
2. generateing x minute bar/x hour bar data from 1 minute data 2. generateing x minute bar/x hour bar data from 1 minute data