[新功能] 增加砖图数据对象

This commit is contained in:
msincenselee 2019-12-25 19:02:27 +08:00
parent 0426820e78
commit 301326453a
9 changed files with 135 additions and 25 deletions

View File

@ -5,6 +5,13 @@ General constant string used in VN Trader.
from enum import Enum
class Color(Enum):
""" Kline color """
RED = 'Red'
BLUE = 'Blue'
EQUAL = 'Equal'
class Direction(Enum):
"""
Direction of order/trade/position.
@ -141,7 +148,9 @@ class Interval(Enum):
"""
Interval of bar data.
"""
SECOND = "1s"
MINUTE = "1m"
HOUR = "1h"
DAILY = "d"
WEEKLY = "w"
RENKO = 'renko'

View File

@ -24,7 +24,8 @@ class BaseDatabaseManager(ABC):
exchange: "Exchange",
interval: "Interval",
start: datetime,
end: datetime
end: datetime,
**kwargs
) -> Sequence["BarData"]:
pass
@ -34,7 +35,8 @@ class BaseDatabaseManager(ABC):
symbol: str,
exchange: "Exchange",
start: datetime,
end: datetime
end: datetime,
**kwargs
) -> Sequence["TickData"]:
pass

View File

@ -112,6 +112,9 @@ class DbTickData(Document):
symbol: str = StringField()
exchange: str = StringField()
datetime: datetime = DateTimeField()
date: str = StringField()
time: str = StringField()
trading_day: str = StringField()
name: str = StringField()
volume: float = FloatField()
@ -170,6 +173,9 @@ class DbTickData(Document):
db_tick.symbol = tick.symbol
db_tick.exchange = tick.exchange.value
db_tick.datetime = tick.datetime
db_tick.date = tick.date
db_tick.time = tick.time
db_tick.trading_day = tick.trading_day
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.open_interest = tick.open_interest
@ -218,6 +224,9 @@ class DbTickData(Document):
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
date=self.date,
time=self.time,
trading_day=self.trading_day,
name=self.name,
volume=self.volume,
open_interest=self.open_interest,
@ -269,6 +278,7 @@ class MongoManager(BaseDatabaseManager):
interval: Interval,
start: datetime,
end: datetime,
**kwargs
) -> Sequence[BarData]:
s = DbBarData.objects(
symbol=symbol,
@ -281,7 +291,7 @@ class MongoManager(BaseDatabaseManager):
return data
def load_tick_data(
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
self, symbol: str, exchange: Exchange, start: datetime, end: datetime, **kwargs
) -> Sequence[TickData]:
s = DbTickData.objects(
symbol=symbol,

View File

@ -337,6 +337,7 @@ class SqlManager(BaseDatabaseManager):
interval: Interval,
start: datetime,
end: datetime,
**kwargs
) -> Sequence[BarData]:
s = (
self.class_bar.select()
@ -353,7 +354,7 @@ class SqlManager(BaseDatabaseManager):
return data
def load_tick_data(
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
self, symbol: str, exchange: Exchange, start: datetime, end: datetime, **kwargs
) -> Sequence[TickData]:
s = (
self.class_tick.select()

View File

@ -33,6 +33,9 @@ from .object import (
from .setting import SETTINGS
from .utility import get_folder_path, TRADER_DIR
# 专有的logger文件
from .util_logger import setup_logger
class MainEngine:
"""
@ -239,6 +242,37 @@ class BaseEngine(ABC):
self.event_engine = event_engine
self.engine_name = engine_name
self.logger = None
def create_logger(self, logger_name: str = 'base_engine'):
"""
创建engine独有的日志
:param logger_name: 日志名缺省为engine的名称
:return:
"""
log_path = get_folder_path("log")
log_filename = os.path.abspath(os.path.join(log_path, logger_name))
print(u'create logger:{}'.format(log_filename))
self.logger = setup_logger(file_name=log_filename, name=logger_name,
log_level=SETTINGS.get('log.level', logging.DEBUG))
def write_log(self, msg: str, source: str = "", level: int = logging.DEBUG):
"""
写入日志
:param msg: 日志内容
:param source: 来源
:param level: 日志级别
:return:
"""
if self.logger:
if len(source) > 0:
msg = f'[{source}]{msg}'
self.logger.log(level, msg)
else:
log = LogData(msg=msg, level=level)
event = Event(EVENT_LOG, log)
self.event_engine.put(event)
def close(self):
""""""
pass

View File

@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from typing import Any, Sequence
from copy import copy
from logging import INFO
from vnpy.event import Event, EventEngine
from .event import (
@ -140,11 +141,11 @@ class BaseGateway(ABC):
"""
self.on_event(EVENT_CONTRACT, contract)
def write_log(self, msg: str):
def write_log(self, msg: str, level: int = INFO):
"""
Write a log event from gateway.
"""
log = LogData(msg=msg, gateway_name=self.gateway_name)
log = LogData(msg=msg, level=level, gateway_name=self.gateway_name)
self.on_log(log)
@abstractmethod

View File

@ -103,6 +103,20 @@ class BarData(BaseData):
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
@dataclass
class RenkoBarData(BarData):
"""
Renko bar data of a certain trading period.
"""
seconds: int = 0 # 当前Bar的秒数针对RenkoBar)
high_seconds: int = -1 # 当前Bar的上限秒数
low_seconds: int = -1 # 当前bar的下限秒数
height: float = 3 # 当前Bar的高度限制针对RenkoBar和RangeBar类
up_band: float = 0 # 高位区域的基线
down_band: float = 0 # 低位区域的基线
low_time = None # 最后一次进入低位区域的时间
high_time = None # 最后一次进入高位区域的时间
@dataclass
class OrderData(BaseData):
"""

View File

@ -159,11 +159,19 @@ class MultiprocessHandler(logging. FileHandler):
self.handleError(record)
def setup_logger(filename, name=None, debug=False, force=False, backtesing=False):
def setup_logger(file_name: str,
name: str = None,
log_level: int = logging.DEBUG,
force: bool = False,
backtesing: bool = False):
"""
设置日志文件包括路径
自动在后面添加 "_日期.log"
:param logger_file_name:
:param file_name: 日志文件名
:param name: logger
:param log_level: 日志级别
:param force: 是否强制更新日志名称
:param backtesing: 是否为回测(回测输出的格式不同
:return:
"""
@ -171,21 +179,21 @@ def setup_logger(filename, name=None, debug=False, force=False, backtesing=False
global _fileHandler
global _logger_filename
if _logger is not None and _logger_filename == filename and not force:
if _logger is not None and _logger_filename == file_name and not force:
return _logger
if _logger_filename != filename or force:
if _logger_filename != file_name or force:
if force:
_logger_filename = filename
_logger_filename = file_name
# 定义日志输出格式
fmt = logging.Formatter(RECORD_FORMAT if not backtesing else BACKTEST_FORMAT)
if name is None:
names = filename.replace('.log', '').split('/')
names = file_name.replace('.log', '').split('/')
name = names[-1]
logger = logging.getLogger(name)
if debug:
if log_level == logging.DEBUG:
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(logging.DEBUG)
@ -193,26 +201,26 @@ def setup_logger(filename, name=None, debug=False, force=False, backtesing=False
if not logger.hasHandlers():
logger.addHandler(stream_handler)
fileHandler = MultiprocessHandler(filename, encoding='utf8', interval='D')
if debug:
fileHandler.setLevel(logging.DEBUG)
else:
fileHandler.setLevel(logging.WARNING)
# 创建文件日志
fileHandler = MultiprocessHandler(file_name, encoding='utf8', interval='D')
fileHandler.setLevel(log_level)
fileHandler.setFormatter(fmt)
logger.addHandler(fileHandler)
if debug:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.WARNING)
# 设置logger的级别
logger.setLevel(log_level)
return logger
return _logger
def get_logger(name=None):
def get_logger(name: str = None):
"""
根据name获取logger
:param name:
:return:
"""
global _logger
if _logger is None:
@ -227,6 +235,7 @@ def get_logger(name=None):
# -------------------测试代码------------
def single_func(para):
""" 测试单进程"""
setup_logger('logs/MyLog{}'.format(para))
logger = get_logger()
if para > 5:
@ -240,6 +249,7 @@ def single_func(para):
def multi_func():
"""测试多进程"""
# 启动多进程
pool = multiprocessing.Pool(multiprocessing.cpu_count())
setup_logger('logs/MyLog')

View File

@ -19,7 +19,6 @@ import talib
from .object import BarData, TickData
from .constant import Exchange, Interval
log_formatter = logging.Formatter('[%(asctime)s] %(message)s')
@ -74,6 +73,30 @@ def get_underlying_symbol(symbol: str):
return underlying_symbol.group(1)
@lru_cache()
def get_stock_exchange(code, vn=True):
"""根据股票代码,获取交易所"""
# vn取EXCHANGE_SSE 和 EXCHANGE_SZSE
code = str(code)
if len(code) < 6:
return ''
market_id = 0 # 缺省深圳
code = str(code)
if code[0] in ['5', '6', '9'] or code[:3] in ["009", "126", "110", "201", "202", "203", "204"]:
market_id = 1 # 上海
try:
from vnpy.trader.constant import Exchange
if vn:
return Exchange.SSE.value if market_id == 1 else Exchange.SZSE.value
else:
return 'XSHG' if market_id == 1 else 'XSHE'
except Exception as ex:
print(u'加载数据异常:{}'.format(str(ex)))
return ''
@lru_cache()
def get_full_symbol(symbol: str):
"""
@ -140,6 +163,11 @@ def _get_trader_dir(temp_name: str):
"""
Get path where trader is running in.
"""
# by incenselee
# 原方法,当前目录必须自建.vntrader子目录否则在用户得目录下创建
# 为兼容多账号管理,取消此方法。
return Path.cwd(), Path.cwd()
cwd = Path.cwd()
temp_path = cwd.joinpath(temp_name)
@ -161,6 +189,7 @@ def _get_trader_dir(temp_name: str):
TRADER_DIR, TEMP_DIR = _get_trader_dir(".vntrader")
sys.path.append(str(TRADER_DIR))
print(f'sys.path append: {str(TRADER_DIR)}')
def get_file_path(filename: str):