|
5 | 5 | """
|
6 | 6 | import sys
|
7 | 7 | import os
|
8 |
| -import datetime |
9 | 8 | import logging
|
10 | 9 | import numpy as np
|
11 | 10 | import matplotlib
|
| 11 | +import seaborn as sns |
12 | 12 | import math
|
13 | 13 | from typing import Union
|
14 | 14 | from matplotlib.pylab import Rectangle, gca, gcf, figure, ylabel, axes, draw
|
@@ -770,6 +770,13 @@ def sysplot(sys, new=True, axes=None, style=1, only_draw_close=False):
|
770 | 770 |
|
771 | 771 |
|
772 | 772 | def sys_performance(sys, ref_stk=None):
|
| 773 | + """ |
| 774 | + 绘制系统绩效,即账户累积收益率曲线 |
| 775 | +
|
| 776 | + :param SystemBase | PortfolioBase sys: SYS或PF实例 |
| 777 | + :param Stock ref_stk: 参考股票, 默认为沪深300: sh000300, 绘制参考标的的收益曲线 |
| 778 | + :return: None |
| 779 | + """ |
773 | 780 | if ref_stk is None:
|
774 | 781 | ref_stk = get_stock('sh000300')
|
775 | 782 |
|
@@ -847,10 +854,65 @@ def sys_performance(sys, ref_stk=None):
|
847 | 854 | ax3.set_frame_on(False)
|
848 | 855 |
|
849 | 856 |
|
| 857 | +def tm_heatmap(tm, start_date, end_date=None, axes=None): |
| 858 | + """ |
| 859 | + 绘制账户收益年-月收益热力图 |
| 860 | +
|
| 861 | + :param tm: 交易账户 |
| 862 | + :param start_date: 开始日期 |
| 863 | + :param end_date: 结束日期,默认为今天 |
| 864 | + :param axes: 绘制的轴对象,默认为None,表示创建新的轴对象 |
| 865 | + :return: None |
| 866 | + """ |
| 867 | + if end_date is None: |
| 868 | + end_date = Datetime.today() + Days(1) |
| 869 | + |
| 870 | + dates = get_date_range(start_date, end_date) |
| 871 | + if len(dates) == 0: |
| 872 | + hku_error("没有数据,请检查日期范围!start_date={}, end_date={}", start_date, end_date) |
| 873 | + return |
| 874 | + |
| 875 | + profit = tm.get_funds_curve(dates) |
| 876 | + if len(profit) == 0: |
| 877 | + hku_error("获取 tm 收益曲线失败,请检查 tm 初始日期!tm.init_datetime={} start_date={}, end_date={}", |
| 878 | + tm.init_datetime, start_date, end_date) |
| 879 | + return |
| 880 | + |
| 881 | + data = pd.DataFrame({'date': dates, 'value': profit}) |
| 882 | + |
| 883 | + # 提取年月信息 |
| 884 | + data['year'] = data['date'].apply(lambda v: v.year) |
| 885 | + data['month'] = data['date'].apply(lambda v: v.month) |
| 886 | + |
| 887 | + # 获取每个月的收益 |
| 888 | + monthly = data.groupby(['year', 'month']).last()['value'].reset_index() |
| 889 | + monthly['return'] = ((monthly['value'] - monthly['value'].shift(1)) / monthly['value'].shift(1)) * 100. |
| 890 | + |
| 891 | + pivot_data = monthly.pivot_table(index='year', columns='month', values='return') |
| 892 | + |
| 893 | + if axes is None: |
| 894 | + axes = create_figure() |
| 895 | + |
| 896 | + sns.heatmap(pivot_data, cmap='RdYlGn_r', center=0, annot=True, fmt="<.2f", ax=axes) |
| 897 | + # 设置标题和坐标轴标签 |
| 898 | + axes.set_title('年-月度收益率(%)热力图') |
| 899 | + axes.set_xlabel('月度') |
| 900 | + axes.set_ylabel('年份') |
| 901 | + |
| 902 | + |
| 903 | +def sys_heatmap(sys, axes=None): |
| 904 | + """ |
| 905 | + 绘制系统收益年-月收益热力图 |
| 906 | + """ |
| 907 | + hku_check(sys.tm is not None, "系统未初始化交易账户") |
| 908 | + query = sys.query |
| 909 | + k = get_kdata('sh000001', query) |
| 910 | + tm_heatmap(sys.tm, k[0].datetime, k[-1].datetime, axes) |
| 911 | + |
| 912 | + |
850 | 913 | # ============================================================================
|
851 | 914 | # 通达信画图函数
|
852 | 915 | # ============================================================================
|
853 |
| - |
854 | 916 | DRAWNULL = constant.null_price
|
855 | 917 |
|
856 | 918 |
|
|
0 commit comments