17
17
18
18
#include < gtest/gtest.h>
19
19
20
- namespace unit {
21
-
22
- using Variable = DefaultFixture;
23
- using Alignment = DefaultFixture;
24
- using Broadcast = DefaultFixture;
25
- using ImageConstraint = DefaultFixture;
26
-
27
20
namespace {
28
21
29
- constexpr const char library_name[] = " test_constraints " ;
22
+ // NOLINTBEGIN(readability-magic-numbers)
30
23
31
- } // namespace
24
+ constexpr const char library_name[] = " test_constraints " ;
32
25
33
26
// Dummy task to make the runtime think the store is initialized
34
27
struct Initializer : public legate ::LegateTask<Initializer> {
@@ -37,39 +30,36 @@ struct Initializer : public legate::LegateTask<Initializer> {
37
30
static void cpu_variant (legate::TaskContext /* context*/ ) {}
38
31
};
39
32
40
- void register_tasks ()
41
- {
42
- static bool prepared = false ;
43
- if (prepared) {
44
- return ;
33
+ class Constraint : public DefaultFixture {
34
+ public:
35
+ void SetUp () override
36
+ {
37
+ DefaultFixture::SetUp ();
38
+ auto runtime = legate::Runtime::get_runtime ();
39
+ auto context = runtime->create_library (library_name);
40
+ Initializer::register_variants (context);
45
41
}
46
- prepared = true ;
47
- auto runtime = legate::Runtime::get_runtime ();
48
- auto context = runtime->create_library (library_name);
49
- Initializer::register_variants (context);
50
- }
42
+ };
51
43
52
- TEST_F (Variable, BasicMethods )
44
+ TEST_F (Constraint, Variable )
53
45
{
54
- register_tasks ();
55
-
56
46
auto runtime = legate::Runtime::get_runtime ();
57
47
auto context = runtime->find_library (library_name);
58
48
auto task = runtime->create_task (context, Initializer::TASK_ID);
59
49
60
50
// Test basic properties
61
51
auto part = task.declare_partition ();
62
52
auto part_imp = part.impl ();
63
- EXPECT_FALSE (part_imp->closed ());
64
- EXPECT_EQ (part_imp->kind (), legate::detail::Expr::Kind::VARIABLE);
65
- EXPECT_EQ (part_imp->as_literal (), nullptr );
66
- EXPECT_EQ (part_imp->as_variable (), part_imp);
67
- EXPECT_TRUE (part_imp->operation () != nullptr );
53
+ ASSERT_FALSE (part_imp->closed ());
54
+ ASSERT_EQ (part_imp->kind (), legate::detail::Expr::Kind::VARIABLE);
55
+ ASSERT_EQ (part_imp->as_literal (), nullptr );
56
+ ASSERT_EQ (part_imp->as_variable (), part_imp);
57
+ ASSERT_TRUE (part_imp->operation () != nullptr );
68
58
69
59
// Test equal
70
60
auto part1 (part);
71
61
auto part1_imp = part1.impl ();
72
- EXPECT_EQ (*part_imp, *part1_imp);
62
+ ASSERT_EQ (*part_imp, *part1_imp);
73
63
auto part2 = task.declare_partition ();
74
64
auto part2_imp = part2.impl ();
75
65
@@ -78,16 +68,14 @@ TEST_F(Variable, BasicMethods)
78
68
part_imp->find_partition_symbols (symbols);
79
69
part1_imp->find_partition_symbols (symbols);
80
70
part2_imp->find_partition_symbols (symbols);
81
- EXPECT_EQ (symbols.size (), 3 );
82
- EXPECT_TRUE (std::find (symbols.begin (), symbols.end (), part_imp) != symbols.end ());
83
- EXPECT_TRUE (std::find (symbols.begin (), symbols.end (), part1_imp) != symbols.end ());
84
- EXPECT_TRUE (std::find (symbols.begin (), symbols.end (), part2_imp) != symbols.end ());
71
+ ASSERT_EQ (symbols.size (), 3 );
72
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part_imp) != symbols.end ());
73
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part1_imp) != symbols.end ());
74
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part2_imp) != symbols.end ());
85
75
}
86
76
87
- TEST_F (Alignment, BasicMethods )
77
+ TEST_F (Constraint, Alignment )
88
78
{
89
- register_tasks ();
90
-
91
79
auto runtime = legate::Runtime::get_runtime ();
92
80
auto context = runtime->find_library (library_name);
93
81
auto task = runtime->create_task (context, Initializer::TASK_ID);
@@ -96,51 +84,51 @@ TEST_F(Alignment, BasicMethods)
96
84
auto part2 = task.declare_partition ();
97
85
98
86
auto aligment = legate::detail::align (part1.impl (), part2.impl ());
99
- EXPECT_EQ (aligment->kind (), legate::detail::Constraint::Kind::ALIGNMENT);
100
- EXPECT_EQ (aligment->lhs (), part1.impl ());
101
- EXPECT_EQ (aligment->rhs (), part2.impl ());
102
- EXPECT_EQ (aligment->as_alignment (), aligment.get ());
103
- EXPECT_EQ (aligment->as_broadcast (), nullptr );
104
- EXPECT_EQ (aligment->as_image_constraint (), nullptr );
105
- EXPECT_FALSE (aligment->is_trivial ());
87
+ ASSERT_EQ (aligment->kind (), legate::detail::Constraint::Kind::ALIGNMENT);
88
+ ASSERT_EQ (aligment->lhs (), part1.impl ());
89
+ ASSERT_EQ (aligment->rhs (), part2.impl ());
90
+ ASSERT_EQ (aligment->as_alignment (), aligment.get ());
91
+ ASSERT_EQ (aligment->as_broadcast (), nullptr );
92
+ ASSERT_EQ (aligment->as_image_constraint (), nullptr );
93
+ ASSERT_EQ (aligment->as_scale_constraint (), nullptr );
94
+ ASSERT_EQ (aligment->as_bloat_constraint (), nullptr );
95
+ ASSERT_FALSE (aligment->is_trivial ());
106
96
107
97
// Test find_partition_symbols
108
98
std::vector<const legate::detail::Variable*> symbols = {};
109
99
aligment->find_partition_symbols (symbols);
110
- EXPECT_EQ (symbols.size (), 2 );
111
- EXPECT_TRUE (std::find (symbols.begin (), symbols.end (), part1.impl ()) != symbols.end ());
112
- EXPECT_TRUE (std::find (symbols.begin (), symbols.end (), part2.impl ()) != symbols.end ());
100
+ ASSERT_EQ (symbols.size (), 2 );
101
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part1.impl ()) != symbols.end ());
102
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part2.impl ()) != symbols.end ());
113
103
}
114
104
115
- TEST_F (Broadcast, BasicMethods )
105
+ TEST_F (Constraint, Broadcast )
116
106
{
117
- register_tasks ();
118
-
119
107
auto runtime = legate::Runtime::get_runtime ();
120
108
auto context = runtime->find_library (library_name);
121
109
auto task = runtime->create_task (context, Initializer::TASK_ID);
122
110
auto part1 = task.declare_partition ();
123
111
124
112
auto dims = legate::from_range<std::uint32_t >(3 );
125
113
auto broadcast = legate::detail::broadcast (part1.impl (), dims);
126
- EXPECT_EQ (broadcast->kind (), legate::detail::Constraint::Kind::BROADCAST);
127
- EXPECT_EQ (broadcast->variable (), part1.impl ());
128
- EXPECT_EQ (broadcast->axes (), dims);
129
- EXPECT_EQ (broadcast->as_alignment (), nullptr );
130
- EXPECT_EQ (broadcast->as_broadcast (), broadcast.get ());
131
- EXPECT_EQ (broadcast->as_image_constraint (), nullptr );
114
+ ASSERT_EQ (broadcast->kind (), legate::detail::Constraint::Kind::BROADCAST);
115
+ ASSERT_EQ (broadcast->variable (), part1.impl ());
116
+ ASSERT_EQ (broadcast->axes (), dims);
117
+ ASSERT_EQ (broadcast->as_alignment (), nullptr );
118
+ ASSERT_EQ (broadcast->as_broadcast (), broadcast.get ());
119
+ ASSERT_EQ (broadcast->as_image_constraint (), nullptr );
120
+ ASSERT_EQ (broadcast->as_scale_constraint (), nullptr );
121
+ ASSERT_EQ (broadcast->as_bloat_constraint (), nullptr );
132
122
133
123
// Test find_partition_symbols
134
124
std::vector<const legate::detail::Variable*> symbols = {};
135
125
broadcast->find_partition_symbols (symbols);
136
- EXPECT_EQ (symbols.size (), 1 );
137
- EXPECT_TRUE (std::find (symbols.begin (), symbols.end (), part1.impl ()) != symbols.end ());
126
+ ASSERT_EQ (symbols.size (), 1 );
127
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part1.impl ()) != symbols.end ());
138
128
}
139
129
140
- TEST_F (ImageConstraint, BasicMethods )
130
+ TEST_F (Constraint, ImageConstraint )
141
131
{
142
- register_tasks ();
143
-
144
132
auto runtime = legate::Runtime::get_runtime ();
145
133
auto context = runtime->find_library (library_name);
146
134
auto task = runtime->create_task (context, Initializer::TASK_ID);
@@ -149,18 +137,80 @@ TEST_F(ImageConstraint, BasicMethods)
149
137
150
138
auto image_constraint = legate::detail::image (
151
139
part_func.impl (), part_range.impl (), legate::ImageComputationHint::NO_HINT);
152
- EXPECT_EQ (image_constraint->kind (), legate::detail::Constraint::Kind::IMAGE);
153
- EXPECT_EQ (image_constraint->var_function (), part_func.impl ());
154
- EXPECT_EQ (image_constraint->var_range (), part_range.impl ());
155
- EXPECT_EQ (image_constraint->as_alignment (), nullptr );
156
- EXPECT_EQ (image_constraint->as_broadcast (), nullptr );
157
- EXPECT_EQ (image_constraint->as_image_constraint (), image_constraint.get ());
140
+ ASSERT_EQ (image_constraint->kind (), legate::detail::Constraint::Kind::IMAGE);
141
+ ASSERT_EQ (image_constraint->var_function (), part_func.impl ());
142
+ ASSERT_EQ (image_constraint->var_range (), part_range.impl ());
143
+ ASSERT_EQ (image_constraint->as_alignment (), nullptr );
144
+ ASSERT_EQ (image_constraint->as_broadcast (), nullptr );
145
+ ASSERT_EQ (image_constraint->as_image_constraint (), image_constraint.get ());
146
+ ASSERT_EQ (image_constraint->as_scale_constraint (), nullptr );
147
+ ASSERT_EQ (image_constraint->as_bloat_constraint (), nullptr );
158
148
159
149
// Test find_partition_symbols
160
150
std::vector<const legate::detail::Variable*> symbols = {};
161
151
image_constraint->find_partition_symbols (symbols);
162
- EXPECT_EQ (symbols.size (), 2 );
163
- EXPECT_TRUE (std::find (symbols.begin (), symbols.end (), part_func.impl ()) != symbols.end ());
164
- EXPECT_TRUE (std::find (symbols.begin (), symbols.end (), part_range.impl ()) != symbols.end ());
152
+ ASSERT_EQ (symbols.size (), 2 );
153
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part_func.impl ()) != symbols.end ());
154
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part_range.impl ()) != symbols.end ());
155
+ }
156
+
157
+ TEST_F (Constraint, ScaleConstraint)
158
+ {
159
+ auto runtime = legate::Runtime::get_runtime ();
160
+ auto context = runtime->find_library (library_name);
161
+ auto task = runtime->create_task (context, Initializer::TASK_ID);
162
+ auto smaller = runtime->create_store ({3 }, legate::int64 ());
163
+ auto bigger = runtime->create_store ({5 }, legate::int64 ());
164
+ auto part_smaller = task.add_output (smaller);
165
+ auto part_bigger = task.add_output (bigger);
166
+
167
+ auto scale_constraint = legate::detail::scale ({1 }, part_smaller.impl (), part_bigger.impl ());
168
+ ASSERT_EQ (scale_constraint->kind (), legate::detail::Constraint::Kind::SCALE);
169
+ ASSERT_EQ (scale_constraint->var_smaller (), part_smaller.impl ());
170
+ ASSERT_EQ (scale_constraint->var_bigger (), part_bigger.impl ());
171
+ ASSERT_EQ (scale_constraint->as_alignment (), nullptr );
172
+ ASSERT_EQ (scale_constraint->as_broadcast (), nullptr );
173
+ ASSERT_EQ (scale_constraint->as_image_constraint (), nullptr );
174
+ ASSERT_EQ (scale_constraint->as_scale_constraint (), scale_constraint.get ());
175
+ ASSERT_EQ (scale_constraint->as_bloat_constraint (), nullptr );
176
+
177
+ // Test find_partition_symbols
178
+ std::vector<const legate::detail::Variable*> symbols = {};
179
+ scale_constraint->find_partition_symbols (symbols);
180
+ ASSERT_EQ (symbols.size (), 2 );
181
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part_smaller.impl ()) != symbols.end ());
182
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part_bigger.impl ()) != symbols.end ());
165
183
}
166
- } // namespace unit
184
+
185
+ TEST_F (Constraint, BloatConstraint)
186
+ {
187
+ auto runtime = legate::Runtime::get_runtime ();
188
+ auto context = runtime->find_library (library_name);
189
+ auto task = runtime->create_task (context, Initializer::TASK_ID);
190
+ auto source = runtime->create_store ({5 }, legate::int64 ());
191
+ auto bloated = runtime->create_store ({5 }, legate::int64 ());
192
+ runtime->issue_fill (source, legate::Scalar (std::int64_t {0 }));
193
+ runtime->issue_fill (bloated, legate::Scalar (std::int64_t {0 }));
194
+ auto part_source = task.add_input (source);
195
+ auto part_bloated = task.add_input (bloated);
196
+
197
+ auto bloat_constraint = legate::detail::bloat (part_source.impl (), part_bloated.impl (), {1 }, {3 });
198
+ ASSERT_EQ (bloat_constraint->kind (), legate::detail::Constraint::Kind::BLOAT);
199
+ ASSERT_EQ (bloat_constraint->var_source (), part_source.impl ());
200
+ ASSERT_EQ (bloat_constraint->var_bloat (), part_bloated.impl ());
201
+ ASSERT_EQ (bloat_constraint->as_alignment (), nullptr );
202
+ ASSERT_EQ (bloat_constraint->as_broadcast (), nullptr );
203
+ ASSERT_EQ (bloat_constraint->as_image_constraint (), nullptr );
204
+ ASSERT_EQ (bloat_constraint->as_scale_constraint (), nullptr );
205
+ ASSERT_EQ (bloat_constraint->as_bloat_constraint (), bloat_constraint.get ());
206
+
207
+ // Test find_partition_symbols
208
+ std::vector<const legate::detail::Variable*> symbols = {};
209
+ bloat_constraint->find_partition_symbols (symbols);
210
+ ASSERT_EQ (symbols.size (), 2 );
211
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part_source.impl ()) != symbols.end ());
212
+ ASSERT_TRUE (std::find (symbols.begin (), symbols.end (), part_bloated.impl ()) != symbols.end ());
213
+ }
214
+
215
+ // NOLINTEND(readability-magic-numbers)
216
+ } // namespace
0 commit comments