Skip to content

Commit bbdd85b

Browse files
authored
Add tests to increase code coverage rate for Constraint API (#706)
* Add tests to increase code coverage rate for Constraint API * Update test script based on review comments * Update code based on review comments
1 parent 63975f3 commit bbdd85b

File tree

2 files changed

+141
-78
lines changed

2 files changed

+141
-78
lines changed

tests/cpp/integration/image_constraints.cc

+22-9
Original file line numberDiff line numberDiff line change
@@ -328,18 +328,31 @@ void test_invalid()
328328
auto runtime = legate::Runtime::get_runtime();
329329
auto context = runtime->find_library(library_name);
330330

331-
auto func = runtime->create_store(legate::Shape{10, 10}, legate::int32());
332-
auto range = runtime->create_store(legate::Shape{10, 10}, legate::int64());
331+
auto create_task = [&](auto func, auto range) {
332+
auto task = runtime->create_task(context, static_cast<std::int64_t>(IMAGE_TESTER) + func.dim());
333+
auto part_domain = task.declare_partition();
334+
auto part_range = task.declare_partition();
333335

334-
auto task = runtime->create_task(context, static_cast<std::int64_t>(IMAGE_TESTER) + 1);
335-
auto part_domain = task.declare_partition();
336-
auto part_range = task.declare_partition();
336+
task.add_input(func, part_domain);
337+
task.add_input(range, part_range);
338+
task.add_constraint(legate::image(part_domain, part_range));
337339

338-
task.add_input(func, part_domain);
339-
task.add_input(range, part_range);
340-
task.add_constraint(legate::image(part_domain, part_range));
340+
return task;
341+
};
341342

342-
EXPECT_THROW(runtime->submit(std::move(task)), std::invalid_argument);
343+
{
344+
auto func1 = runtime->create_store(legate::Shape{10, 10}, legate::int32());
345+
auto range1 = runtime->create_store(legate::Shape{10, 10}, legate::int64());
346+
auto task = create_task(func1, range1);
347+
EXPECT_THROW(runtime->submit(std::move(task)), std::invalid_argument);
348+
}
349+
350+
{
351+
auto func2 = runtime->create_store(legate::Shape{4, 4}, legate::point_type(2));
352+
auto range2 = runtime->create_store(legate::Shape{10}, legate::int64());
353+
auto task = create_task(func2, range2.promote(1, 1));
354+
EXPECT_THROW(runtime->submit(std::move(task)), std::runtime_error);
355+
}
343356
}
344357

345358
TEST_P(Valid, 1D)

tests/cpp/unit/constraint.cc

+119-69
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,11 @@
1717

1818
#include <gtest/gtest.h>
1919

20-
namespace unit {
21-
22-
using Variable = DefaultFixture;
23-
using Alignment = DefaultFixture;
24-
using Broadcast = DefaultFixture;
25-
using ImageConstraint = DefaultFixture;
26-
2720
namespace {
2821

29-
constexpr const char library_name[] = "test_constraints";
22+
// NOLINTBEGIN(readability-magic-numbers)
3023

31-
} // namespace
24+
constexpr const char library_name[] = "test_constraints";
3225

3326
// Dummy task to make the runtime think the store is initialized
3427
struct Initializer : public legate::LegateTask<Initializer> {
@@ -37,39 +30,36 @@ struct Initializer : public legate::LegateTask<Initializer> {
3730
static void cpu_variant(legate::TaskContext /*context*/) {}
3831
};
3932

40-
void register_tasks()
41-
{
42-
static bool prepared = false;
43-
if (prepared) {
44-
return;
33+
class Constraint : public DefaultFixture {
34+
public:
35+
void SetUp() override
36+
{
37+
DefaultFixture::SetUp();
38+
auto runtime = legate::Runtime::get_runtime();
39+
auto context = runtime->create_library(library_name);
40+
Initializer::register_variants(context);
4541
}
46-
prepared = true;
47-
auto runtime = legate::Runtime::get_runtime();
48-
auto context = runtime->create_library(library_name);
49-
Initializer::register_variants(context);
50-
}
42+
};
5143

52-
TEST_F(Variable, BasicMethods)
44+
TEST_F(Constraint, Variable)
5345
{
54-
register_tasks();
55-
5646
auto runtime = legate::Runtime::get_runtime();
5747
auto context = runtime->find_library(library_name);
5848
auto task = runtime->create_task(context, Initializer::TASK_ID);
5949

6050
// Test basic properties
6151
auto part = task.declare_partition();
6252
auto part_imp = part.impl();
63-
EXPECT_FALSE(part_imp->closed());
64-
EXPECT_EQ(part_imp->kind(), legate::detail::Expr::Kind::VARIABLE);
65-
EXPECT_EQ(part_imp->as_literal(), nullptr);
66-
EXPECT_EQ(part_imp->as_variable(), part_imp);
67-
EXPECT_TRUE(part_imp->operation() != nullptr);
53+
ASSERT_FALSE(part_imp->closed());
54+
ASSERT_EQ(part_imp->kind(), legate::detail::Expr::Kind::VARIABLE);
55+
ASSERT_EQ(part_imp->as_literal(), nullptr);
56+
ASSERT_EQ(part_imp->as_variable(), part_imp);
57+
ASSERT_TRUE(part_imp->operation() != nullptr);
6858

6959
// Test equal
7060
auto part1(part);
7161
auto part1_imp = part1.impl();
72-
EXPECT_EQ(*part_imp, *part1_imp);
62+
ASSERT_EQ(*part_imp, *part1_imp);
7363
auto part2 = task.declare_partition();
7464
auto part2_imp = part2.impl();
7565

@@ -78,16 +68,14 @@ TEST_F(Variable, BasicMethods)
7868
part_imp->find_partition_symbols(symbols);
7969
part1_imp->find_partition_symbols(symbols);
8070
part2_imp->find_partition_symbols(symbols);
81-
EXPECT_EQ(symbols.size(), 3);
82-
EXPECT_TRUE(std::find(symbols.begin(), symbols.end(), part_imp) != symbols.end());
83-
EXPECT_TRUE(std::find(symbols.begin(), symbols.end(), part1_imp) != symbols.end());
84-
EXPECT_TRUE(std::find(symbols.begin(), symbols.end(), part2_imp) != symbols.end());
71+
ASSERT_EQ(symbols.size(), 3);
72+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part_imp) != symbols.end());
73+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part1_imp) != symbols.end());
74+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part2_imp) != symbols.end());
8575
}
8676

87-
TEST_F(Alignment, BasicMethods)
77+
TEST_F(Constraint, Alignment)
8878
{
89-
register_tasks();
90-
9179
auto runtime = legate::Runtime::get_runtime();
9280
auto context = runtime->find_library(library_name);
9381
auto task = runtime->create_task(context, Initializer::TASK_ID);
@@ -96,51 +84,51 @@ TEST_F(Alignment, BasicMethods)
9684
auto part2 = task.declare_partition();
9785

9886
auto aligment = legate::detail::align(part1.impl(), part2.impl());
99-
EXPECT_EQ(aligment->kind(), legate::detail::Constraint::Kind::ALIGNMENT);
100-
EXPECT_EQ(aligment->lhs(), part1.impl());
101-
EXPECT_EQ(aligment->rhs(), part2.impl());
102-
EXPECT_EQ(aligment->as_alignment(), aligment.get());
103-
EXPECT_EQ(aligment->as_broadcast(), nullptr);
104-
EXPECT_EQ(aligment->as_image_constraint(), nullptr);
105-
EXPECT_FALSE(aligment->is_trivial());
87+
ASSERT_EQ(aligment->kind(), legate::detail::Constraint::Kind::ALIGNMENT);
88+
ASSERT_EQ(aligment->lhs(), part1.impl());
89+
ASSERT_EQ(aligment->rhs(), part2.impl());
90+
ASSERT_EQ(aligment->as_alignment(), aligment.get());
91+
ASSERT_EQ(aligment->as_broadcast(), nullptr);
92+
ASSERT_EQ(aligment->as_image_constraint(), nullptr);
93+
ASSERT_EQ(aligment->as_scale_constraint(), nullptr);
94+
ASSERT_EQ(aligment->as_bloat_constraint(), nullptr);
95+
ASSERT_FALSE(aligment->is_trivial());
10696

10797
// Test find_partition_symbols
10898
std::vector<const legate::detail::Variable*> symbols = {};
10999
aligment->find_partition_symbols(symbols);
110-
EXPECT_EQ(symbols.size(), 2);
111-
EXPECT_TRUE(std::find(symbols.begin(), symbols.end(), part1.impl()) != symbols.end());
112-
EXPECT_TRUE(std::find(symbols.begin(), symbols.end(), part2.impl()) != symbols.end());
100+
ASSERT_EQ(symbols.size(), 2);
101+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part1.impl()) != symbols.end());
102+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part2.impl()) != symbols.end());
113103
}
114104

115-
TEST_F(Broadcast, BasicMethods)
105+
TEST_F(Constraint, Broadcast)
116106
{
117-
register_tasks();
118-
119107
auto runtime = legate::Runtime::get_runtime();
120108
auto context = runtime->find_library(library_name);
121109
auto task = runtime->create_task(context, Initializer::TASK_ID);
122110
auto part1 = task.declare_partition();
123111

124112
auto dims = legate::from_range<std::uint32_t>(3);
125113
auto broadcast = legate::detail::broadcast(part1.impl(), dims);
126-
EXPECT_EQ(broadcast->kind(), legate::detail::Constraint::Kind::BROADCAST);
127-
EXPECT_EQ(broadcast->variable(), part1.impl());
128-
EXPECT_EQ(broadcast->axes(), dims);
129-
EXPECT_EQ(broadcast->as_alignment(), nullptr);
130-
EXPECT_EQ(broadcast->as_broadcast(), broadcast.get());
131-
EXPECT_EQ(broadcast->as_image_constraint(), nullptr);
114+
ASSERT_EQ(broadcast->kind(), legate::detail::Constraint::Kind::BROADCAST);
115+
ASSERT_EQ(broadcast->variable(), part1.impl());
116+
ASSERT_EQ(broadcast->axes(), dims);
117+
ASSERT_EQ(broadcast->as_alignment(), nullptr);
118+
ASSERT_EQ(broadcast->as_broadcast(), broadcast.get());
119+
ASSERT_EQ(broadcast->as_image_constraint(), nullptr);
120+
ASSERT_EQ(broadcast->as_scale_constraint(), nullptr);
121+
ASSERT_EQ(broadcast->as_bloat_constraint(), nullptr);
132122

133123
// Test find_partition_symbols
134124
std::vector<const legate::detail::Variable*> symbols = {};
135125
broadcast->find_partition_symbols(symbols);
136-
EXPECT_EQ(symbols.size(), 1);
137-
EXPECT_TRUE(std::find(symbols.begin(), symbols.end(), part1.impl()) != symbols.end());
126+
ASSERT_EQ(symbols.size(), 1);
127+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part1.impl()) != symbols.end());
138128
}
139129

140-
TEST_F(ImageConstraint, BasicMethods)
130+
TEST_F(Constraint, ImageConstraint)
141131
{
142-
register_tasks();
143-
144132
auto runtime = legate::Runtime::get_runtime();
145133
auto context = runtime->find_library(library_name);
146134
auto task = runtime->create_task(context, Initializer::TASK_ID);
@@ -149,18 +137,80 @@ TEST_F(ImageConstraint, BasicMethods)
149137

150138
auto image_constraint = legate::detail::image(
151139
part_func.impl(), part_range.impl(), legate::ImageComputationHint::NO_HINT);
152-
EXPECT_EQ(image_constraint->kind(), legate::detail::Constraint::Kind::IMAGE);
153-
EXPECT_EQ(image_constraint->var_function(), part_func.impl());
154-
EXPECT_EQ(image_constraint->var_range(), part_range.impl());
155-
EXPECT_EQ(image_constraint->as_alignment(), nullptr);
156-
EXPECT_EQ(image_constraint->as_broadcast(), nullptr);
157-
EXPECT_EQ(image_constraint->as_image_constraint(), image_constraint.get());
140+
ASSERT_EQ(image_constraint->kind(), legate::detail::Constraint::Kind::IMAGE);
141+
ASSERT_EQ(image_constraint->var_function(), part_func.impl());
142+
ASSERT_EQ(image_constraint->var_range(), part_range.impl());
143+
ASSERT_EQ(image_constraint->as_alignment(), nullptr);
144+
ASSERT_EQ(image_constraint->as_broadcast(), nullptr);
145+
ASSERT_EQ(image_constraint->as_image_constraint(), image_constraint.get());
146+
ASSERT_EQ(image_constraint->as_scale_constraint(), nullptr);
147+
ASSERT_EQ(image_constraint->as_bloat_constraint(), nullptr);
158148

159149
// Test find_partition_symbols
160150
std::vector<const legate::detail::Variable*> symbols = {};
161151
image_constraint->find_partition_symbols(symbols);
162-
EXPECT_EQ(symbols.size(), 2);
163-
EXPECT_TRUE(std::find(symbols.begin(), symbols.end(), part_func.impl()) != symbols.end());
164-
EXPECT_TRUE(std::find(symbols.begin(), symbols.end(), part_range.impl()) != symbols.end());
152+
ASSERT_EQ(symbols.size(), 2);
153+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part_func.impl()) != symbols.end());
154+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part_range.impl()) != symbols.end());
155+
}
156+
157+
TEST_F(Constraint, ScaleConstraint)
158+
{
159+
auto runtime = legate::Runtime::get_runtime();
160+
auto context = runtime->find_library(library_name);
161+
auto task = runtime->create_task(context, Initializer::TASK_ID);
162+
auto smaller = runtime->create_store({3}, legate::int64());
163+
auto bigger = runtime->create_store({5}, legate::int64());
164+
auto part_smaller = task.add_output(smaller);
165+
auto part_bigger = task.add_output(bigger);
166+
167+
auto scale_constraint = legate::detail::scale({1}, part_smaller.impl(), part_bigger.impl());
168+
ASSERT_EQ(scale_constraint->kind(), legate::detail::Constraint::Kind::SCALE);
169+
ASSERT_EQ(scale_constraint->var_smaller(), part_smaller.impl());
170+
ASSERT_EQ(scale_constraint->var_bigger(), part_bigger.impl());
171+
ASSERT_EQ(scale_constraint->as_alignment(), nullptr);
172+
ASSERT_EQ(scale_constraint->as_broadcast(), nullptr);
173+
ASSERT_EQ(scale_constraint->as_image_constraint(), nullptr);
174+
ASSERT_EQ(scale_constraint->as_scale_constraint(), scale_constraint.get());
175+
ASSERT_EQ(scale_constraint->as_bloat_constraint(), nullptr);
176+
177+
// Test find_partition_symbols
178+
std::vector<const legate::detail::Variable*> symbols = {};
179+
scale_constraint->find_partition_symbols(symbols);
180+
ASSERT_EQ(symbols.size(), 2);
181+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part_smaller.impl()) != symbols.end());
182+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part_bigger.impl()) != symbols.end());
165183
}
166-
} // namespace unit
184+
185+
TEST_F(Constraint, BloatConstraint)
186+
{
187+
auto runtime = legate::Runtime::get_runtime();
188+
auto context = runtime->find_library(library_name);
189+
auto task = runtime->create_task(context, Initializer::TASK_ID);
190+
auto source = runtime->create_store({5}, legate::int64());
191+
auto bloated = runtime->create_store({5}, legate::int64());
192+
runtime->issue_fill(source, legate::Scalar(std::int64_t{0}));
193+
runtime->issue_fill(bloated, legate::Scalar(std::int64_t{0}));
194+
auto part_source = task.add_input(source);
195+
auto part_bloated = task.add_input(bloated);
196+
197+
auto bloat_constraint = legate::detail::bloat(part_source.impl(), part_bloated.impl(), {1}, {3});
198+
ASSERT_EQ(bloat_constraint->kind(), legate::detail::Constraint::Kind::BLOAT);
199+
ASSERT_EQ(bloat_constraint->var_source(), part_source.impl());
200+
ASSERT_EQ(bloat_constraint->var_bloat(), part_bloated.impl());
201+
ASSERT_EQ(bloat_constraint->as_alignment(), nullptr);
202+
ASSERT_EQ(bloat_constraint->as_broadcast(), nullptr);
203+
ASSERT_EQ(bloat_constraint->as_image_constraint(), nullptr);
204+
ASSERT_EQ(bloat_constraint->as_scale_constraint(), nullptr);
205+
ASSERT_EQ(bloat_constraint->as_bloat_constraint(), bloat_constraint.get());
206+
207+
// Test find_partition_symbols
208+
std::vector<const legate::detail::Variable*> symbols = {};
209+
bloat_constraint->find_partition_symbols(symbols);
210+
ASSERT_EQ(symbols.size(), 2);
211+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part_source.impl()) != symbols.end());
212+
ASSERT_TRUE(std::find(symbols.begin(), symbols.end(), part_bloated.impl()) != symbols.end());
213+
}
214+
215+
// NOLINTEND(readability-magic-numbers)
216+
} // namespace

0 commit comments

Comments
 (0)