Skip to content

Commit c3545fe

Browse files
committed
feat(draw): 添加绘制系统收益的年-月热力图功能(matplotlib)
1 parent 23d0560 commit c3545fe

File tree

3 files changed

+71
-3
lines changed

3 files changed

+71
-3
lines changed

hikyuu/draw/drawplot/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# 1. 20171122, Added by fasiondog
3030
# ===============================================================================
3131

32-
from hikyuu.core import KData, Indicator, SignalBase, ConditionBase, EnvironmentBase, System, Portfolio
32+
from hikyuu.core import KData, Indicator, SignalBase, ConditionBase, EnvironmentBase, System, Portfolio, TradeManager
3333

3434
import matplotlib
3535
from matplotlib.pylab import gca as mpl_gca
@@ -49,6 +49,8 @@
4949
from .matplotlib_draw import ax_set_locator_formatter as mpl_ax_set_locator_formatter
5050
from .matplotlib_draw import adjust_axes_show as mpl_adjust_axes_show
5151
from .matplotlib_draw import sys_performance as mpl_sys_performance
52+
from .matplotlib_draw import tm_heatmap as mpl_tm_heatmap
53+
from .matplotlib_draw import sys_heatmap as mpl_sys_heatmap
5254
from .matplotlib_draw import (DRAWNULL, STICKLINE, DRAWBAND, RGB, PLOYLINE,
5355
DRAWLINE, DRAWTEXT, DRAWNUMBER, DRAWTEXT_FIX, DRAWNUMBER_FIX, DRAWSL,
5456
DRAWIMG, DRAWICON, DRAWBMP, SHOWICONS, DRAWRECTREL)
@@ -128,6 +130,9 @@ def use_draw_with_matplotlib():
128130
System.plot = mpl_sysplot
129131
System.performance = mpl_sys_performance
130132
Portfolio.performance = mpl_sys_performance
133+
TradeManager.heatmap = mpl_tm_heatmap
134+
System.heatmap = mpl_sys_heatmap
135+
Portfolio.heatmap = mpl_sys_heatmap
131136

132137

133138
def use_draw_with_echarts():

hikyuu/draw/drawplot/matplotlib_draw.py

+64-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
"""
66
import sys
77
import os
8-
import datetime
98
import logging
109
import numpy as np
1110
import matplotlib
11+
import seaborn as sns
1212
import math
1313
from typing import Union
1414
from matplotlib.pylab import Rectangle, gca, gcf, figure, ylabel, axes, draw
@@ -770,6 +770,13 @@ def sysplot(sys, new=True, axes=None, style=1, only_draw_close=False):
770770

771771

772772
def sys_performance(sys, ref_stk=None):
773+
"""
774+
绘制系统绩效,即账户累积收益率曲线
775+
776+
:param SystemBase | PortfolioBase sys: SYS或PF实例
777+
:param Stock ref_stk: 参考股票, 默认为沪深300: sh000300, 绘制参考标的的收益曲线
778+
:return: None
779+
"""
773780
if ref_stk is None:
774781
ref_stk = get_stock('sh000300')
775782

@@ -847,10 +854,65 @@ def sys_performance(sys, ref_stk=None):
847854
ax3.set_frame_on(False)
848855

849856

857+
def tm_heatmap(tm, start_date, end_date=None, axes=None):
858+
"""
859+
绘制账户收益年-月收益热力图
860+
861+
:param tm: 交易账户
862+
:param start_date: 开始日期
863+
:param end_date: 结束日期,默认为今天
864+
:param axes: 绘制的轴对象,默认为None,表示创建新的轴对象
865+
:return: None
866+
"""
867+
if end_date is None:
868+
end_date = Datetime.today() + Days(1)
869+
870+
dates = get_date_range(start_date, end_date)
871+
if len(dates) == 0:
872+
hku_error("没有数据,请检查日期范围!start_date={}, end_date={}", start_date, end_date)
873+
return
874+
875+
profit = tm.get_funds_curve(dates)
876+
if len(profit) == 0:
877+
hku_error("获取 tm 收益曲线失败,请检查 tm 初始日期!tm.init_datetime={} start_date={}, end_date={}",
878+
tm.init_datetime, start_date, end_date)
879+
return
880+
881+
data = pd.DataFrame({'date': dates, 'value': profit})
882+
883+
# 提取年月信息
884+
data['year'] = data['date'].apply(lambda v: v.year)
885+
data['month'] = data['date'].apply(lambda v: v.month)
886+
887+
# 获取每个月的收益
888+
monthly = data.groupby(['year', 'month']).last()['value'].reset_index()
889+
monthly['return'] = ((monthly['value'] - monthly['value'].shift(1)) / monthly['value'].shift(1)) * 100.
890+
891+
pivot_data = monthly.pivot_table(index='year', columns='month', values='return')
892+
893+
if axes is None:
894+
axes = create_figure()
895+
896+
sns.heatmap(pivot_data, cmap='RdYlGn_r', center=0, annot=True, fmt="<.2f", ax=axes)
897+
# 设置标题和坐标轴标签
898+
axes.set_title('年-月度收益率(%)热力图')
899+
axes.set_xlabel('月度')
900+
axes.set_ylabel('年份')
901+
902+
903+
def sys_heatmap(sys, axes=None):
904+
"""
905+
绘制系统收益年-月收益热力图
906+
"""
907+
hku_check(sys.tm is not None, "系统未初始化交易账户")
908+
query = sys.query
909+
k = get_kdata('sh000001', query)
910+
tm_heatmap(sys.tm, k[0].datetime, k[-1].datetime, axes)
911+
912+
850913
# ============================================================================
851914
# 通达信画图函数
852915
# ============================================================================
853-
854916
DRAWNULL = constant.null_price
855917

856918

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
click
22
numpy
33
matplotlib
4+
seaborn
45
pandas>=0.17.1
56
pytdx
67
PyQt5

0 commit comments

Comments
 (0)