Skip to content

Commit 4024965

Browse files
committed
add test_MF_Weight
1 parent 97ebaef commit 4024965

File tree

3 files changed

+220
-6
lines changed

3 files changed

+220
-6
lines changed

hikyuu_cpp/hikyuu/trade_sys/multifactor/MultiFactorBase.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "hikyuu/KData.h"
1111
#include "ScoreRecord.h"
1212

13-
#define MF_USE_MULTI_THREAD 0
13+
#define MF_USE_MULTI_THREAD 1
1414

1515
namespace hku {
1616

hikyuu_cpp/hikyuu/trade_sys/multifactor/imp/WeightMultiFactor.cpp

+19-5
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,17 @@ vector<Indicator> WeightMultiFactor::_calculate(const vector<vector<Indicator>>&
3333
vector<price_t> sumByDate(days_total);
3434
vector<Indicator> all_factors(stk_count);
3535
for (size_t si = 0; si < stk_count; si++) {
36-
memset(sumByDate.data(), 0, sizeof(price_t) * days_total);
36+
memset(sumByDate.data(), Null<price_t>(), sizeof(price_t) * days_total);
3737

38+
size_t discard = 0;
3839
const auto& curStkInds = all_stk_inds[si];
39-
for (size_t di = 0; di < days_total; di++) {
40+
for (size_t ii = 0; ii < ind_count; ii++) {
41+
if (curStkInds[ii].discard() > discard) {
42+
discard = curStkInds[ii].discard();
43+
}
44+
}
45+
46+
for (size_t di = discard; di < days_total; di++) {
4047
for (size_t ii = 0; ii < ind_count; ii++) {
4148
const auto& value = curStkInds[ii][di];
4249
if (!std::isnan(value)) {
@@ -49,7 +56,7 @@ vector<Indicator> WeightMultiFactor::_calculate(const vector<vector<Indicator>>&
4956
all_factors[si].name("IC");
5057

5158
// 更新 discard
52-
for (size_t di = 0; di < days_total; di++) {
59+
for (size_t di = discard; di < days_total; di++) {
5360
if (!std::isnan(all_factors[si][di])) {
5461
all_factors[si].setDiscard(di);
5562
break;
@@ -66,8 +73,15 @@ vector<Indicator> WeightMultiFactor::_calculate(const vector<vector<Indicator>>&
6673
return parallel_for_index(0, stk_count, [&](size_t si) {
6774
vector<price_t> sumByDate(days_total);
6875

76+
size_t discard = 0;
6977
const auto& curStkInds = all_stk_inds[si];
70-
for (size_t di = 0; di < days_total; di++) {
78+
for (size_t ii = 0; ii < ind_count; ii++) {
79+
if (curStkInds[ii].discard() > discard) {
80+
discard = curStkInds[ii].discard();
81+
}
82+
}
83+
84+
for (size_t di = discard; di < days_total; di++) {
7185
for (size_t ii = 0; ii < ind_count; ii++) {
7286
const auto& value = curStkInds[ii][di];
7387
if (!std::isnan(value)) {
@@ -80,7 +94,7 @@ vector<Indicator> WeightMultiFactor::_calculate(const vector<vector<Indicator>>&
8094
ret.name("IC");
8195

8296
// 更新 discard
83-
for (size_t di = 0; di < days_total; di++) {
97+
for (size_t di = discard; di < days_total; di++) {
8498
if (!std::isnan(ret[di])) {
8599
ret.setDiscard(di);
86100
break;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
/*
2+
* test_ABS.cpp
3+
*
4+
* Created on: 2019年4月2日
5+
* Author: fasiondog
6+
*/
7+
8+
#include "../../test_config.h"
9+
#include <fstream>
10+
#include <hikyuu/StockManager.h>
11+
#include <hikyuu/indicator/crt/MA.h>
12+
#include <hikyuu/indicator/crt/AMA.h>
13+
#include <hikyuu/indicator/crt/EMA.h>
14+
#include <hikyuu/indicator/crt/IC.h>
15+
#include <hikyuu/indicator/crt/ROCR.h>
16+
#include <hikyuu/indicator/crt/KDATA.h>
17+
#include <hikyuu/trade_sys/multifactor/crt/MF_Weight.h>
18+
19+
using namespace hku;
20+
21+
/**
22+
* @defgroup test_MF_Weight test_MF_Weight
23+
* @ingroup test_hikyuu_trade_sys_suite
24+
* @{
25+
*/
26+
27+
/** @par 检测点 */
28+
TEST_CASE("test_MF_Weight") {
29+
StockManager& sm = StockManager::instance();
30+
StockList stks{sm["sh600004"], sm["sh600005"], sm["sz000001"], sm["sz000002"]};
31+
Stock ref_stk = sm["sh000001"];
32+
KQuery query = KQuery(-100);
33+
KData ref_k = ref_stk.getKData(query);
34+
DatetimeList ref_dates = ref_k.getDatetimeList();
35+
IndicatorList src_inds{MA(CLOSE()), AMA(CLOSE(), EMA(CLOSE()))};
36+
PriceList weights{0.2, 0.3, 0.5};
37+
38+
/** @arg 输入的股票列表中含有空股票 */
39+
CHECK_THROWS_AS(MF_Weight(src_inds, weights, {Null<Stock>()}, query, ref_stk), std::exception);
40+
41+
/** @arg 输入的原始因子列表为空 */
42+
CHECK_THROWS_AS(MF_Weight({}, weights, stks, query, ref_stk), std::exception);
43+
44+
/** @arg 输入的参考证券为空 */
45+
CHECK_THROWS_AS(MF_Weight({}, weights, stks, query, Null<Stock>()), std::exception);
46+
47+
/** @arg 数据长度不足 */
48+
CHECK_THROWS_AS(MF_Weight(src_inds, weights, stks, KQuery(-1), ref_stk), std::exception);
49+
50+
/** @arg 证券列表数量不足 */
51+
CHECK_THROWS_AS(MF_Weight(src_inds, weights, {sm["sh600004"]}, KQuery(-2), ref_stk),
52+
std::exception);
53+
54+
/** @arg 输入非法 ic_n */
55+
CHECK_THROWS_AS(MF_Weight(src_inds, weights, stks, KQuery(-2), ref_stk, 0), std::exception);
56+
57+
/** @arg 因子列表和权重列表长度不一致 */
58+
CHECK_THROWS_AS(MF_Weight(src_inds, PriceList{0.1}, stks, KQuery(-2), ref_stk, 0),
59+
std::exception);
60+
61+
/** @arg 临界状态, 原始因子数量为1, 证券数量2, 数据长度为2 */
62+
src_inds = {MA(CLOSE())};
63+
stks = {sm["sh600005"], sm["sh600004"]};
64+
query = KQuery(-2);
65+
ref_k = ref_stk.getKData(query);
66+
ref_dates = ref_k.getDatetimeList();
67+
auto mf = MF_Weight(src_inds, PriceList{1.0}, stks, query, ref_stk);
68+
CHECK_EQ(mf->name(), "MF_Weight");
69+
CHECK_THROWS_AS(mf->getFactor(sm["sz000001"]), std::exception);
70+
CHECK_EQ(mf->getDatetimeList(), ref_dates);
71+
72+
const auto& all_factors = mf->getAllFactors();
73+
CHECK_EQ(all_factors.size(), 2);
74+
auto ind1 = mf->getFactor(sm["sh600004"]);
75+
auto ind2 = MA(CLOSE(sm["sh600004"].getKData(query)));
76+
CHECK_UNARY(ind1.equal(ind2));
77+
CHECK_UNARY(all_factors[1].equal(ind2));
78+
ind1 = mf->getFactor(sm["sh600005"]);
79+
ind2 = MA(CLOSE(sm["sh600005"].getKData(query)));
80+
CHECK_UNARY(ind1.equal(ind2));
81+
CHECK_UNARY(all_factors[0].equal(ind2));
82+
auto ic1 = mf->getIC();
83+
auto ic2 = IC(MA(CLOSE()), stks, query, ref_stk, 1);
84+
CHECK_UNARY(ic1.equal(ic2));
85+
86+
CHECK_UNARY(mf->getScores(Datetime(20111204)).empty());
87+
auto cross = mf->getScores(Datetime(20111205));
88+
CHECK_EQ(cross.size(), 2);
89+
CHECK_EQ(cross[0].stock, sm["sh600004"]);
90+
CHECK_EQ(cross[0].value, doctest::Approx(6.85));
91+
CHECK_EQ(cross[1].stock, sm["sh600005"]);
92+
CHECK_EQ(cross[1].value, doctest::Approx(3.13));
93+
94+
cross = mf->getScores(Datetime(20111206));
95+
CHECK_EQ(cross.size(), 2);
96+
CHECK_EQ(cross[0].stock, sm["sh600004"]);
97+
CHECK_EQ(cross[0].value, doctest::Approx(6.855));
98+
CHECK_EQ(cross[1].stock, sm["sh600005"]);
99+
CHECK_EQ(cross[1].value, doctest::Approx(3.14));
100+
// HKU_INFO("\n{}", mf->getAllCross());
101+
102+
/** @arg 原始因子数量为3, 证券数量4, 数据长度为20, 指定比较收益率 3 日 */
103+
int ndays = 3;
104+
src_inds = {MA(ROCR(CLOSE(), ndays)), AMA(ROCR(CLOSE(), ndays)), EMA(ROCR(CLOSE(), ndays))};
105+
stks = {sm["sh600004"], sm["sh600005"], sm["sz000001"], sm["sz000002"]};
106+
query = KQuery(-20);
107+
ref_k = ref_stk.getKData(query);
108+
ref_dates = ref_k.getDatetimeList();
109+
mf = MF_Weight(src_inds, weights, stks, query, ref_stk, ndays);
110+
CHECK_EQ(mf->name(), "MF_Weight");
111+
CHECK_THROWS_AS(mf->getFactor(sm["sh600000"]), std::exception);
112+
113+
auto stk = sm["sh600004"];
114+
ind1 = MA(ROCR(CLOSE(stk.getKData(query)), ndays));
115+
ind2 = AMA(ROCR(CLOSE(stk.getKData(query)), ndays));
116+
auto ind3 = EMA(ROCR(CLOSE(stk.getKData(query)), ndays));
117+
auto ind4 = mf->getFactor(stk);
118+
CHECK_EQ(ind4.discard(), 3);
119+
for (size_t i = 0; i < ind4.discard(); i++) {
120+
CHECK_UNARY(std::isnan(ind4[i]));
121+
}
122+
for (size_t i = ind4.discard(), len = ref_dates.size(); i < len; i++) {
123+
CHECK_EQ(ind4[i],
124+
doctest::Approx(ind1[i] * weights[0] + ind2[i] * weights[1] + ind3[i] * weights[2])
125+
.epsilon(0.0001));
126+
}
127+
}
128+
129+
//-----------------------------------------------------------------------------
130+
// benchmark
131+
//-----------------------------------------------------------------------------
132+
#if ENABLE_BENCHMARK_TEST
133+
TEST_CASE("test_MF_Weight_benchmark") {
134+
StockManager& sm = StockManager::instance();
135+
int ndays = 3;
136+
IndicatorList src_inds = {MA(ROCR(CLOSE(), ndays)), AMA(ROCR(CLOSE(), ndays)),
137+
EMA(ROCR(CLOSE(), ndays))};
138+
StockList stks = {sm["sh600004"], sm["sh600005"], sm["sz000001"], sm["sz000002"]};
139+
KQuery query = KQuery(0);
140+
Stock ref_stk = sm["sh000001"];
141+
auto ref_k = ref_stk.getKData(query);
142+
auto ref_dates = ref_k.getDatetimeList();
143+
144+
int cycle = 10; // 测试循环次数
145+
146+
{
147+
BENCHMARK_TIME_MSG(test_MF_Weight_benchmark, cycle,
148+
fmt::format("data len: {}", ref_k.size()));
149+
SPEND_TIME_CONTROL(false);
150+
for (int i = 0; i < cycle; i++) {
151+
auto mf = MF_Weight(src_inds, PriceList{0.2, 0.3, 0.5}, stks, query, ref_stk);
152+
auto ic = mf->getIC();
153+
}
154+
}
155+
}
156+
#endif
157+
158+
//-----------------------------------------------------------------------------
159+
// test export
160+
//-----------------------------------------------------------------------------
161+
#if HKU_SUPPORT_SERIALIZATION
162+
/** @par 检测点 */
163+
TEST_CASE("test_MF_Weight_export") {
164+
StockManager& sm = StockManager::instance();
165+
int ndays = 3;
166+
IndicatorList src_inds = {MA(ROCR(CLOSE(), ndays)), AMA(ROCR(CLOSE(), ndays)),
167+
EMA(ROCR(CLOSE(), ndays))};
168+
StockList stks = {sm["sh600004"], sm["sh600005"], sm["sz000001"], sm["sz000002"]};
169+
KQuery query = KQuery(0);
170+
Stock ref_stk = sm["sh000001"];
171+
auto ref_k = ref_stk.getKData(query);
172+
auto ref_dates = ref_k.getDatetimeList();
173+
174+
string filename(sm.tmpdir());
175+
filename += "/MF_Weight.xml";
176+
177+
auto mf1 = MF_Weight(src_inds, PriceList{0.1, 0.3, 0.6}, stks, query, ref_stk);
178+
auto ic1 = mf1->getIC();
179+
{
180+
std::ofstream ofs(filename);
181+
boost::archive::xml_oarchive oa(ofs);
182+
oa << BOOST_SERIALIZATION_NVP(mf1);
183+
}
184+
185+
MFPtr mf2;
186+
{
187+
std::ifstream ifs(filename);
188+
boost::archive::xml_iarchive ia(ifs);
189+
ia >> BOOST_SERIALIZATION_NVP(mf2);
190+
}
191+
192+
CHECK_EQ(mf2->name(), mf1->name());
193+
auto ic2 = mf2->getIC();
194+
CHECK_EQ(ic1.size(), ic2.size());
195+
CHECK_EQ(ic1.discard(), ic2.discard());
196+
CHECK_UNARY(ic1.equal(ic2));
197+
}
198+
#endif /* #if HKU_SUPPORT_SERIALIZATION */
199+
200+
/** @} */

0 commit comments

Comments
 (0)