[新功能] 增加砖图数据对象
This commit is contained in:
parent
0426820e78
commit
301326453a
@ -5,6 +5,13 @@ General constant string used in VN Trader.
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Color(Enum):
|
||||
""" Kline color """
|
||||
RED = 'Red'
|
||||
BLUE = 'Blue'
|
||||
EQUAL = 'Equal'
|
||||
|
||||
|
||||
class Direction(Enum):
|
||||
"""
|
||||
Direction of order/trade/position.
|
||||
@ -141,7 +148,9 @@ class Interval(Enum):
|
||||
"""
|
||||
Interval of bar data.
|
||||
"""
|
||||
SECOND = "1s"
|
||||
MINUTE = "1m"
|
||||
HOUR = "1h"
|
||||
DAILY = "d"
|
||||
WEEKLY = "w"
|
||||
RENKO = 'renko'
|
||||
|
@ -24,7 +24,8 @@ class BaseDatabaseManager(ABC):
|
||||
exchange: "Exchange",
|
||||
interval: "Interval",
|
||||
start: datetime,
|
||||
end: datetime
|
||||
end: datetime,
|
||||
**kwargs
|
||||
) -> Sequence["BarData"]:
|
||||
pass
|
||||
|
||||
@ -34,7 +35,8 @@ class BaseDatabaseManager(ABC):
|
||||
symbol: str,
|
||||
exchange: "Exchange",
|
||||
start: datetime,
|
||||
end: datetime
|
||||
end: datetime,
|
||||
**kwargs
|
||||
) -> Sequence["TickData"]:
|
||||
pass
|
||||
|
||||
|
@ -112,6 +112,9 @@ class DbTickData(Document):
|
||||
symbol: str = StringField()
|
||||
exchange: str = StringField()
|
||||
datetime: datetime = DateTimeField()
|
||||
date: str = StringField()
|
||||
time: str = StringField()
|
||||
trading_day: str = StringField()
|
||||
|
||||
name: str = StringField()
|
||||
volume: float = FloatField()
|
||||
@ -170,6 +173,9 @@ class DbTickData(Document):
|
||||
db_tick.symbol = tick.symbol
|
||||
db_tick.exchange = tick.exchange.value
|
||||
db_tick.datetime = tick.datetime
|
||||
db_tick.date = tick.date
|
||||
db_tick.time = tick.time
|
||||
db_tick.trading_day = tick.trading_day
|
||||
db_tick.name = tick.name
|
||||
db_tick.volume = tick.volume
|
||||
db_tick.open_interest = tick.open_interest
|
||||
@ -218,6 +224,9 @@ class DbTickData(Document):
|
||||
symbol=self.symbol,
|
||||
exchange=Exchange(self.exchange),
|
||||
datetime=self.datetime,
|
||||
date=self.date,
|
||||
time=self.time,
|
||||
trading_day=self.trading_day,
|
||||
name=self.name,
|
||||
volume=self.volume,
|
||||
open_interest=self.open_interest,
|
||||
@ -269,6 +278,7 @@ class MongoManager(BaseDatabaseManager):
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
**kwargs
|
||||
) -> Sequence[BarData]:
|
||||
s = DbBarData.objects(
|
||||
symbol=symbol,
|
||||
@ -281,7 +291,7 @@ class MongoManager(BaseDatabaseManager):
|
||||
return data
|
||||
|
||||
def load_tick_data(
|
||||
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
||||
self, symbol: str, exchange: Exchange, start: datetime, end: datetime, **kwargs
|
||||
) -> Sequence[TickData]:
|
||||
s = DbTickData.objects(
|
||||
symbol=symbol,
|
||||
|
@ -337,6 +337,7 @@ class SqlManager(BaseDatabaseManager):
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
**kwargs
|
||||
) -> Sequence[BarData]:
|
||||
s = (
|
||||
self.class_bar.select()
|
||||
@ -353,7 +354,7 @@ class SqlManager(BaseDatabaseManager):
|
||||
return data
|
||||
|
||||
def load_tick_data(
|
||||
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
||||
self, symbol: str, exchange: Exchange, start: datetime, end: datetime, **kwargs
|
||||
) -> Sequence[TickData]:
|
||||
s = (
|
||||
self.class_tick.select()
|
||||
|
@ -33,6 +33,9 @@ from .object import (
|
||||
from .setting import SETTINGS
|
||||
from .utility import get_folder_path, TRADER_DIR
|
||||
|
||||
# 专有的logger文件
|
||||
from .util_logger import setup_logger
|
||||
|
||||
|
||||
class MainEngine:
|
||||
"""
|
||||
@ -239,6 +242,37 @@ class BaseEngine(ABC):
|
||||
self.event_engine = event_engine
|
||||
self.engine_name = engine_name
|
||||
|
||||
self.logger = None
|
||||
|
||||
def create_logger(self, logger_name: str = 'base_engine'):
|
||||
"""
|
||||
创建engine独有的日志
|
||||
:param logger_name: 日志名,缺省为engine的名称
|
||||
:return:
|
||||
"""
|
||||
log_path = get_folder_path("log")
|
||||
log_filename = os.path.abspath(os.path.join(log_path, logger_name))
|
||||
print(u'create logger:{}'.format(log_filename))
|
||||
self.logger = setup_logger(file_name=log_filename, name=logger_name,
|
||||
log_level=SETTINGS.get('log.level', logging.DEBUG))
|
||||
|
||||
def write_log(self, msg: str, source: str = "", level: int = logging.DEBUG):
|
||||
"""
|
||||
写入日志
|
||||
:param msg: 日志内容
|
||||
:param source: 来源
|
||||
:param level: 日志级别
|
||||
:return:
|
||||
"""
|
||||
if self.logger:
|
||||
if len(source) > 0:
|
||||
msg = f'[{source}]{msg}'
|
||||
self.logger.log(level, msg)
|
||||
else:
|
||||
log = LogData(msg=msg, level=level)
|
||||
event = Event(EVENT_LOG, log)
|
||||
self.event_engine.put(event)
|
||||
|
||||
def close(self):
|
||||
""""""
|
||||
pass
|
||||
|
@ -5,6 +5,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Sequence
|
||||
from copy import copy
|
||||
from logging import INFO
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from .event import (
|
||||
@ -140,11 +141,11 @@ class BaseGateway(ABC):
|
||||
"""
|
||||
self.on_event(EVENT_CONTRACT, contract)
|
||||
|
||||
def write_log(self, msg: str):
|
||||
def write_log(self, msg: str, level: int = INFO):
|
||||
"""
|
||||
Write a log event from gateway.
|
||||
"""
|
||||
log = LogData(msg=msg, gateway_name=self.gateway_name)
|
||||
log = LogData(msg=msg, level=level, gateway_name=self.gateway_name)
|
||||
self.on_log(log)
|
||||
|
||||
@abstractmethod
|
||||
|
@ -103,6 +103,20 @@ class BarData(BaseData):
|
||||
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenkoBarData(BarData):
|
||||
"""
|
||||
Renko bar data of a certain trading period.
|
||||
"""
|
||||
seconds: int = 0 # 当前Bar的秒数(针对RenkoBar)
|
||||
high_seconds: int = -1 # 当前Bar的上限秒数
|
||||
low_seconds: int = -1 # 当前bar的下限秒数
|
||||
height: float = 3 # 当前Bar的高度限制(针对RenkoBar和RangeBar类)
|
||||
up_band: float = 0 # 高位区域的基线
|
||||
down_band: float = 0 # 低位区域的基线
|
||||
low_time = None # 最后一次进入低位区域的时间
|
||||
high_time = None # 最后一次进入高位区域的时间
|
||||
|
||||
@dataclass
|
||||
class OrderData(BaseData):
|
||||
"""
|
||||
|
@ -159,11 +159,19 @@ class MultiprocessHandler(logging. FileHandler):
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
def setup_logger(filename, name=None, debug=False, force=False, backtesing=False):
|
||||
def setup_logger(file_name: str,
|
||||
name: str = None,
|
||||
log_level: int = logging.DEBUG,
|
||||
force: bool = False,
|
||||
backtesing: bool = False):
|
||||
"""
|
||||
设置日志文件,包括路径
|
||||
自动在后面添加 "_日期.log"
|
||||
:param logger_file_name:
|
||||
:param file_name: 日志文件名
|
||||
:param name: logger 名
|
||||
:param log_level: 日志级别
|
||||
:param force: 是否强制更新日志名称
|
||||
:param backtesing: 是否为回测(回测输出的格式不同)
|
||||
:return:
|
||||
"""
|
||||
|
||||
@ -171,21 +179,21 @@ def setup_logger(filename, name=None, debug=False, force=False, backtesing=False
|
||||
global _fileHandler
|
||||
global _logger_filename
|
||||
|
||||
if _logger is not None and _logger_filename == filename and not force:
|
||||
if _logger is not None and _logger_filename == file_name and not force:
|
||||
return _logger
|
||||
|
||||
if _logger_filename != filename or force:
|
||||
if _logger_filename != file_name or force:
|
||||
if force:
|
||||
_logger_filename = filename
|
||||
_logger_filename = file_name
|
||||
|
||||
# 定义日志输出格式
|
||||
fmt = logging.Formatter(RECORD_FORMAT if not backtesing else BACKTEST_FORMAT)
|
||||
if name is None:
|
||||
names = filename.replace('.log', '').split('/')
|
||||
names = file_name.replace('.log', '').split('/')
|
||||
name = names[-1]
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
if debug:
|
||||
if log_level == logging.DEBUG:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
stream_handler.setLevel(logging.DEBUG)
|
||||
@ -193,26 +201,26 @@ def setup_logger(filename, name=None, debug=False, force=False, backtesing=False
|
||||
if not logger.hasHandlers():
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
fileHandler = MultiprocessHandler(filename, encoding='utf8', interval='D')
|
||||
if debug:
|
||||
fileHandler.setLevel(logging.DEBUG)
|
||||
else:
|
||||
fileHandler.setLevel(logging.WARNING)
|
||||
|
||||
# 创建文件日志
|
||||
fileHandler = MultiprocessHandler(file_name, encoding='utf8', interval='D')
|
||||
fileHandler.setLevel(log_level)
|
||||
fileHandler.setFormatter(fmt)
|
||||
logger.addHandler(fileHandler)
|
||||
|
||||
if debug:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
else:
|
||||
logger.setLevel(logging.WARNING)
|
||||
# 设置logger的级别
|
||||
logger.setLevel(log_level)
|
||||
|
||||
return logger
|
||||
|
||||
return _logger
|
||||
|
||||
|
||||
def get_logger(name=None):
|
||||
def get_logger(name: str = None):
|
||||
"""
|
||||
根据name获取logger
|
||||
:param name:
|
||||
:return:
|
||||
"""
|
||||
global _logger
|
||||
|
||||
if _logger is None:
|
||||
@ -227,6 +235,7 @@ def get_logger(name=None):
|
||||
|
||||
# -------------------测试代码------------
|
||||
def single_func(para):
|
||||
""" 测试单进程"""
|
||||
setup_logger('logs/MyLog{}'.format(para))
|
||||
logger = get_logger()
|
||||
if para > 5:
|
||||
@ -240,6 +249,7 @@ def single_func(para):
|
||||
|
||||
|
||||
def multi_func():
|
||||
"""测试多进程"""
|
||||
# 启动多进程
|
||||
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
||||
setup_logger('logs/MyLog')
|
||||
|
@ -19,7 +19,6 @@ import talib
|
||||
from .object import BarData, TickData
|
||||
from .constant import Exchange, Interval
|
||||
|
||||
|
||||
log_formatter = logging.Formatter('[%(asctime)s] %(message)s')
|
||||
|
||||
|
||||
@ -74,6 +73,30 @@ def get_underlying_symbol(symbol: str):
|
||||
return underlying_symbol.group(1)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_stock_exchange(code, vn=True):
|
||||
"""根据股票代码,获取交易所"""
|
||||
# vn:取EXCHANGE_SSE 和 EXCHANGE_SZSE
|
||||
code = str(code)
|
||||
if len(code) < 6:
|
||||
return ''
|
||||
|
||||
market_id = 0 # 缺省深圳
|
||||
code = str(code)
|
||||
if code[0] in ['5', '6', '9'] or code[:3] in ["009", "126", "110", "201", "202", "203", "204"]:
|
||||
market_id = 1 # 上海
|
||||
try:
|
||||
from vnpy.trader.constant import Exchange
|
||||
if vn:
|
||||
return Exchange.SSE.value if market_id == 1 else Exchange.SZSE.value
|
||||
else:
|
||||
return 'XSHG' if market_id == 1 else 'XSHE'
|
||||
except Exception as ex:
|
||||
print(u'加载数据异常:{}'.format(str(ex)))
|
||||
|
||||
return ''
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_full_symbol(symbol: str):
|
||||
"""
|
||||
@ -140,6 +163,11 @@ def _get_trader_dir(temp_name: str):
|
||||
"""
|
||||
Get path where trader is running in.
|
||||
"""
|
||||
# by incenselee
|
||||
# 原方法,当前目录必须自建.vntrader子目录,否则在用户得目录下创建
|
||||
# 为兼容多账号管理,取消此方法。
|
||||
return Path.cwd(), Path.cwd()
|
||||
|
||||
cwd = Path.cwd()
|
||||
temp_path = cwd.joinpath(temp_name)
|
||||
|
||||
@ -161,6 +189,7 @@ def _get_trader_dir(temp_name: str):
|
||||
|
||||
TRADER_DIR, TEMP_DIR = _get_trader_dir(".vntrader")
|
||||
sys.path.append(str(TRADER_DIR))
|
||||
print(f'sys.path append: {str(TRADER_DIR)}')
|
||||
|
||||
|
||||
def get_file_path(filename: str):
|
||||
|
Loading…
Reference in New Issue
Block a user