This commit is contained in:
msincenselee 2015-10-08 00:27:06 +08:00
parent e5de7cba31
commit aa489ea31b
10 changed files with 317 additions and 29 deletions

7
.gitignore vendored
View File

@ -36,4 +36,9 @@ Release/
# 其他文件
*.dump
*.vssettings
*.vssettings
vnpy.pyproj.user
.idea/workspace.xml
.idea/.name
.idea/vnpy.iml
*.xml

View File

@ -1,6 +1,7 @@
# encoding: UTF-8
import sys
from time import sleep
from PyQt4 import QtGui

96
vn.data/azure_mssql.py Normal file
View File

@ -0,0 +1,96 @@
# encoding: UTF-8
#!/usr/bin/env python
#-------------------------------------------------------------------------------
# Name: azure_mssql.py
# Purpose: 使用 pymssql库链接Azure的MSSQL
#
# Author: IncenseLee@hotmail.com
#
# Created: 09/19/2015
#-------------------------------------------------------------------------------
import pymssql
class azure_mssql:
"""
对pymssql的简单封装
pymssql库该库到这里下载http://www.lfd.uci.edu/~gohlke/pythonlibs/#pymssql
使用该库时需要在Sql Server Configuration Manager里面将TCP/IP协议开启
http://pymssql.sourceforge.net/ref_pymssql.php
sql语句中有中文的时候进行encode
insertSql = "insert into WeiBo([UserId],[WeiBoContent],[PublishDate]) values(1,'测试','2012/2/1')".encode("utf8")
连接的时候加入charset设置信息
pymssql.connect(host=self.host,user=self.user,password=self.pwd,database=self.db,charset="utf8")
用法
"""
def __init__(self,host,user,pwd,db):
self.host = host
self.user = user
self.pwd = pwd
self.db = db
def __get_connect(self):
"""
得到连接信息
返回: conn.cursor()
"""
if not self.db:
raise(NameError,"没有设置数据库信息")
self.conn = pymssql.connect(host=self.host, user=self.user, password=self.pwd, database=self.db, charset="utf8")
cur = self.conn.cursor()
if not cur:
raise(NameError,"连接数据库失败")
else:
return cur
def exec_query(self,sql):
"""
执行查询语句
返回的是一个包含tuple的listlist的元素是记录行tuple的元素是每行记录的字段
调用示例
ms = azure_mssql(host="stock.database.windows.net",user="sqladmin",pwd="7uhb*IJN",db="stockcn")
resList = ms.ExecQuery("SELECT id,NickName FROM WeiBoUser")
for (id,NickName) in resList:
print str(id),NickName
"""
cur = self.__get_connect()
cur.execute(sql)
resList = cur.fetchall()
#查询完毕后必须关闭连接
self.conn.close()
return resList
def exec_non_query(self,sql):
"""
执行非查询语句
调用示例
cur = self.__GetConnect()
cur.execute(sql)
self.conn.commit()
self.conn.close()
"""
cur = self.__get_connect()
cur.execute(sql)
self.conn.commit()
self.conn.close()
def main():
## ms = azure_mssql(host="stock.database.windows.net",user="sqladmin",pwd="7uhb*IJN",db="stockcn")
## #返回的是一个包含tuple的listlist的元素是记录行tuple的元素是每行记录的字段
## ms.ExecNonQuery("insert into WeiBoUser values('2','3')")
ms = azure_mssql(host="stock.database.windows.net:1433", user="sqladmin", pwd="7uhb*IJN", db="stockcn")
resList = ms.exec_query("SELECT count(*) as recordcounts FROM tb_ami")
for (recordcounts) in resList:
print str(recrodcounts).decode("utf8")
if __name__ == '__main__':
main()

47
vn.data/test_mysql.py Normal file
View File

@ -0,0 +1,47 @@
# encoding: UTF-8
import MySQLdb
try:
#连接数据库
conn = MySQLdb.connect(host='vnpy.cloudapp.net', user='stockcn', passwd='7uhb*IJN', db='stockcn', port=3306)
#获取指针
cur = conn.cursor(MySQLdb.cursors.DictCursor)
symbol = 'a'
#执行脚本,返回记录数
count = cur.execute(' select \'{0}\' as InstrumentID, str_to_date(concat(ndate,\' \', ntime),\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB_{0}MI limit 0,100;'.format(symbol))
print 'there has %s rows record' % count
#取回第一条记录
result = cur.fetchone()
print result
#取回5条记录
results = cur.fetchmany(5)
for r in results:
print r
print '=='*10
cur.scroll(0, mode = 'absolute')
results = cur.fetchall()
desc = cur.description
print 'cur.description:', desc
for r in results:
#InstrumentID, UpdateTime, LastPrice, Volume, OpenInterest, BidPrice1, BidVolume1, AskPrice1, AskVolume1 = r
#print InstrumentID, UpdateTime
#print LastPrice*Volume;
print r['InstrumentID'], float(r['LastPrice'])
#关闭指针,关闭连接
cur.close()
conn.close()
except MySQLdb.Error, e:
print "Mysql Error %d: %s" % (e.args[0], e.args[1])

View File

@ -3,7 +3,7 @@
import shelve
from eventEngine import *
from pymongo import Connection
from pymongo import MongoClient as Connection
from pymongo.errors import *
from strategyEngine import *

View File

@ -1,10 +1,12 @@
# encoding: UTF-8
import shelve
import MySQLdb
from eventEngine import *
from pymongo import MongoClient as Connection
from pymongo.errors import *
from datetime import datetime, timedelta, time
from strategyEngine import *
@ -73,11 +75,12 @@ class BacktestingEngine(object):
self.__mongoTickDB = self.__mongoConnection['TickDB']
self.writeLog(u'回测引擎连接MongoDB成功')
except ConnectionFailure:
self.writeLog(u'回测引擎连接MongoDB失败')
self.writeLog(u'回测引擎连接MongoDB失败')
#----------------------------------------------------------------------
def loadDataHistory(self, symbol, startDate, endDate):
"""载入历史TICK数据"""
def loadMongoDataHistory(self, symbol, startDate, endDate):
"""从Mongo载入历史TICK数据"""
if self.__mongoConnected:
collection = self.__mongoTickDB[symbol]
@ -95,7 +98,98 @@ class BacktestingEngine(object):
self.writeLog(u'历史TICK数据载入完成')
else:
self.writeLog(u'MongoDB未连接请检查')
#----------------------------------------------------------------------
def connectMysql(self):
"""连接MysqlDB"""
try:
self.__mysqlConnection = MySQLdb.connect(host='vnpy.cloudapp.net', user='stockcn',
passwd='7uhb*IJN', db='stockcn', port=3306)
self.__mysqlConnected = True
self.writeLog(u'回测引擎连接MysqlDB成功')
except ConnectionFailure:
self.writeLog(u'回测引擎连接MysqlDB失败')
#----------------------------------------------------------------------
def loadMysqlDataHistory(self, symbol, startDate, endDate):
"""从Mysql载入历史TICK数据,"""
try:
if self.__mysqlConnected:
#获取指针
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
if endDate:
sqlstring = ' select \'{0}\' as InstrumentID, str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB_{0}MI ' \
'where ndate between cast(\'{1}\' as date) and cast(\'{2}\' as date)'.format(symbol, startDate, endDate)
elif startDate:
sqlstring = ' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI ' \
'where ndate > cast(\'{1}\' as date)'.format( symbol, startDate)
else:
sqlstring =' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI '.format(symbol)
self.writeLog(sqlstring)
count = cur.execute(sqlstring)
# 将TICK数据读入内存
self.listDataHistory = cur.fetchall()
self.writeLog(u'历史TICK数据载入完成{0}'.format(count))
else:
self.writeLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeLog(u'MysqlDB载入数据失败请检查.Error {0}: {1}'.format(e.arg[0],e.arg[1]))
#----------------------------------------------------------------------
def getMysqlDeltaDate(self,symbol, startDate, decreaseDays):
try:
if self.__mysqlConnected:
#获取指针
cur = self.__mysqlConnection.cursor()
sqlstring='select distinct ndate from TB_{0}MI where ndate < ' \
'cast(\'{1}\' as date) order by ndate desc limit {2},1'.format(symbol, startDate, decreaseDays-1)
self.writeLog(sqlstring)
count = cur.execute(sqlstring)
if count > 0:
result = cur.fetchone()
return result[0]
else:
self.writeLog(u'MysqlDB没有查询结果请检查日期')
else:
self.writeLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeLog(u'MysqlDB载入数据失败请检查.Error {0}: {1}'.format(e.arg[0],e.arg[1]))
td = timedelta(days=3)
return startDate-td;
#----------------------------------------------------------------------
def processLimitOrder(self):
"""处理限价单"""
@ -123,7 +217,9 @@ class BacktestingEngine(object):
tradeData['OffsetFlag'] = order.offset
tradeData['Price'] = price
tradeData['Volume'] = order.volume
print tradeData
tradeEvent = Event()
tradeEvent.dict_['data'] = tradeData
self.strategyEngine.updateTrade(tradeEvent)
@ -200,7 +296,13 @@ class BacktestingEngine(object):
def writeLog(self, log):
"""写日志"""
print log
#----------------------------------------------------------------------
def subscribe(self, symbol, exchange):
"""仿真订阅合约"""
pass
#----------------------------------------------------------------------
def selectInstrument(self, symbol):
"""读取合约数据"""
@ -214,11 +316,7 @@ class BacktestingEngine(object):
f = shelve.open('result.vn')
f['listTrade'] = self.listTrade
f.close()
#----------------------------------------------------------------------
def subscribe(self, symbol, exchange):
"""仿真订阅合约"""
pass

View File

@ -129,8 +129,7 @@ class DemoMdApi(MdApi):
#for instrument in self.__setSubscribed:
#self.subscribe(instrument[0], instrument[1])
onRspUserLogin
print u'DemoApi.py DemoMdApi.onRspUserLogin() end'
#----------------------------------------------------------------------
@ -162,9 +161,7 @@ class DemoMdApi(MdApi):
#----------------------------------------------------------------------
def onRspUnSubMarketData(self, data, error, n, last):
"""退订合约回报"""
onRspUnSubMarketData
print u'DemoApi.py DemoMdApi.onRspUnSubMarketData()'
# 同上
pass

View File

@ -3,12 +3,14 @@
from strategyEngine import *
from backtestingEngine import *
from demoStrategy import SimpleEmaStrategy
import decimal
# 回测脚本
if __name__ == '__main__':
symbol = 'IF1506'
#symbol = 'IF1506'
symbol = 'a'
# 创建回测引擎
be = BacktestingEngine()
@ -18,20 +20,25 @@ if __name__ == '__main__':
be.setStrategyEngine(se)
# 初始化回测引擎
be.connectMongo()
be.loadDataHistory(symbol, datetime(2015,5,1), datetime.today())
#be.connectMongo()
be.connectMysql()
#be.loadMongoDataHistory(symbol, datetime(2015,5,1), datetime.today())
#be.loadMongoDataHistory(symbol, datetime(2012,1,9), datetime(2012,1,14))
be.loadMysqlDataHistory(symbol, datetime(2012,1,9), datetime(2012,1,20))
# 创建策略对象
setting = {}
setting['fastAlpha'] = 0.2
setting['slowAlpha'] = 0.05
setting['startDate'] = datetime(year=2015, month=5, day=20)
setting['slowAlpha'] = 0.05
#setting['startDate'] = datetime(year=2015, month=5, day=20)
setting['startDate'] = datetime(year=2012, month=1, day=9)
se.createStrategy(u'EMA演示策略', symbol, SimpleEmaStrategy, setting)
# 启动所有策略
se.startAll()
# 开始回测
be.startBacktesting()

View File

@ -0,0 +1,30 @@
import numpy as np
import matplotlib.pyplot as plt
#first generate some datapoint for a randomly sampled noisy sinewave
x = np.random.random(1000)*10
noise = np.random.normal(scale=0.3,size=len(x))
y = np.sin(x) + noise
#plot the data
plt.plot(x,y,'ro',alpha=0.3,ms=4,label='data')
plt.xlabel('Time')
plt.ylabel('Intensity')
#define a moving average function
def moving_average(x,y,step_size=.1,bin_size=1):
bin_centers = np.arange(np.min(x),np.max(x)-0.5*step_size,step_size)+0.5*step_size
bin_avg = np.zeros(len(bin_centers))
for index in range(0,len(bin_centers)):
bin_center = bin_centers[index]
items_in_bin = y[(x>(bin_center-bin_size*0.5) ) & (x<(bin_center+bin_size*0.5))]
bin_avg[index] = np.mean(items_in_bin)
return bin_centers,bin_avg
#plot the moving average
bins, average = moving_average(x,y)
plt.plot(bins, average,label='moving average')
plt.show()

7
vn.training/t1.py Normal file
View File

@ -0,0 +1,7 @@
# encoding: UTF-8
import matplotlib.pyplot as plt
plt.plot([10,20,30])
plt.xlabel('times')
plt.ylabel('numbers')
plt.show()