From 711dc2764b76d366dfc12fad8b7cc92578a536a4 Mon Sep 17 00:00:00 2001 From: Wonchan Lee Date: Wed, 24 May 2023 22:14:19 -0700 Subject: [PATCH 1/2] Add specialized constructors and safety checks to legate::Scalar --- src/core/data/scalar.cc | 17 ++++++++++++++++- src/core/data/scalar.h | 20 ++++++++++++++++++++ src/core/data/scalar.inl | 36 ++++++++++++++++++++++++++++++++++-- 3 files changed, 70 insertions(+), 3 deletions(-) diff --git a/src/core/data/scalar.cc b/src/core/data/scalar.cc index a0246d00f..2ab75fff1 100644 --- a/src/core/data/scalar.cc +++ b/src/core/data/scalar.cc @@ -15,16 +15,31 @@ */ #include "core/data/scalar.h" -#include "core/utilities/dispatch.h" namespace legate { Scalar::Scalar(const Scalar& other) : own_(other.own_), type_(other.type_->clone()) { copy(other); } +Scalar::Scalar(Scalar&& other) : own_(other.own_), type_(std::move(other.type_)), data_(other.data_) +{ + other.own_ = false; + other.type_ = nullptr; + other.data_ = nullptr; +} + Scalar::Scalar(std::unique_ptr type, const void* data) : type_(std::move(type)), data_(data) { } +Scalar::Scalar(const std::string& string) : own_(true), type_(string_type()) +{ + auto data_size = sizeof(char) * string.size(); + auto buffer = malloc(sizeof(uint32_t) + data_size); + *static_cast(buffer) = string.size(); + memcpy(static_cast(buffer) + sizeof(uint32_t), string.data(), data_size); + data_ = buffer; +} + Scalar::~Scalar() { if (own_) diff --git a/src/core/data/scalar.h b/src/core/data/scalar.h index f06cc41a6..2e2cd785c 100644 --- a/src/core/data/scalar.h +++ b/src/core/data/scalar.h @@ -45,6 +45,8 @@ class Scalar { public: Scalar() = default; Scalar(const Scalar& other); + Scalar(Scalar&& other); + /** * @brief Creates a shared `Scalar` with an existing allocation. The caller is responsible * for passing in a sufficiently big allocation. @@ -65,6 +67,24 @@ class Scalar { */ template Scalar(T value); + /** + * @brief Creates an owned scalar of a specified type from a scalar value + * + * @tparam T The scalar type to wrap + * + * @param type A type of the scalar + * @param value A scalar value to create a `Scalar` with + */ + template + Scalar(T value, std::unique_ptr type); + /** + * @brief Creates an owned scalar from a string. The value from the + * original string will be copied. + * + * @param string A string to create a `Scalar` with + */ + Scalar(const std::string& string); + /** * @brief Creates an owned scalar from a tuple of scalars. The values in the input vector * will be copied. diff --git a/src/core/data/scalar.inl b/src/core/data/scalar.inl index 7d61d9306..8f201d5ba 100644 --- a/src/core/data/scalar.inl +++ b/src/core/data/scalar.inl @@ -19,6 +19,22 @@ namespace legate { template Scalar::Scalar(T value) : own_(true), type_(primitive_type(legate_type_code_of)) { + static_assert(legate_type_code_of != Type::Code::FIXED_ARRAY); + static_assert(legate_type_code_of != Type::Code::STRUCT); + static_assert(legate_type_code_of != Type::Code::STRING); + static_assert(legate_type_code_of != Type::Code::INVALID); + auto buffer = malloc(sizeof(T)); + memcpy(buffer, &value, sizeof(T)); + data_ = buffer; +} + +template +Scalar::Scalar(T value, std::unique_ptr type) : own_(true), type_(std::move(type)) +{ + if (type_->code == Type::Code::INVALID) + throw std::invalid_argument("Invalid type cannot be used"); + if (type_->size() != sizeof(T)) + throw std::invalid_argument("Size of the value doesn't match with the type"); auto buffer = malloc(sizeof(T)); memcpy(buffer, &value, sizeof(T)); data_ = buffer; @@ -37,12 +53,17 @@ Scalar::Scalar(const std::vector& values) template VAL Scalar::value() const { + if (sizeof(VAL) != type_->size()) + throw std::invalid_argument("Size of the scalar is " + std::to_string(type_->size()) + + ", but the requested type has size " + std::to_string(sizeof(VAL))); return *static_cast(data_); } template <> inline std::string Scalar::value() const { + if (type_->code != Type::Code::STRING) + throw std::invalid_argument("Type of the scalar is not string"); // Getting a span of a temporary scalar is illegal in general, // but we know this is safe as the span's pointer is held by this object. auto len = *static_cast(data_); @@ -55,10 +76,21 @@ template Span Scalar::values() const { if (type_->code == Type::Code::FIXED_ARRAY) { - auto size = static_cast(type_.get())->num_elements(); + auto arr_type = static_cast(type_.get()); + const auto& elem_type = arr_type->element_type(); + if (sizeof(VAL) != elem_type.size()) + throw std::invalid_argument( + "The scalar's element type has size " + std::to_string(elem_type.size()) + + ", but the requested element type has size " + std::to_string(sizeof(VAL))); + auto size = arr_type->num_elements(); return Span(reinterpret_cast(data_), size); - } else + } else { + if (sizeof(VAL) != type_->size()) + throw std::invalid_argument("Size of the scalar is " + std::to_string(type_->size()) + + ", but the requested element type has size " + + std::to_string(sizeof(VAL))); return Span(static_cast(data_), 1); + } } template <> From 7f6f7be257ae1d4acc55bf6c3aa9006ee9b8f094 Mon Sep 17 00:00:00 2001 From: Wonchan Lee Date: Thu, 25 May 2023 17:09:13 -0700 Subject: [PATCH 2/2] Update src/core/data/scalar.h Co-authored-by: Manolis Papadakis --- src/core/data/scalar.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/data/scalar.h b/src/core/data/scalar.h index 2e2cd785c..62e48bdb9 100644 --- a/src/core/data/scalar.h +++ b/src/core/data/scalar.h @@ -72,7 +72,7 @@ class Scalar { * * @tparam T The scalar type to wrap * - * @param type A type of the scalar + * @param type The type of the scalar * @param value A scalar value to create a `Scalar` with */ template