Skip to content

Commit

Permalink
Add specialized constructors and safety checks to legate::Scalar (#736
Browse files Browse the repository at this point in the history
)

* Add specialized constructors and safety checks to legate::Scalar

* Update src/core/data/scalar.h

Co-authored-by: Manolis Papadakis <manopapad@gmail.com>

---------

Co-authored-by: Manolis Papadakis <manopapad@gmail.com>
  • Loading branch information
magnatelee and manopapad authored May 30, 2023
1 parent 2680264 commit fff8f3c
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 3 deletions.
17 changes: 16 additions & 1 deletion src/core/data/scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,31 @@
*/

#include "core/data/scalar.h"
#include "core/utilities/dispatch.h"

namespace legate {

Scalar::Scalar(const Scalar& other) : own_(other.own_), type_(other.type_->clone()) { copy(other); }

Scalar::Scalar(Scalar&& other) : own_(other.own_), type_(std::move(other.type_)), data_(other.data_)
{
other.own_ = false;
other.type_ = nullptr;
other.data_ = nullptr;
}

Scalar::Scalar(std::unique_ptr<Type> type, const void* data) : type_(std::move(type)), data_(data)
{
}

Scalar::Scalar(const std::string& string) : own_(true), type_(string_type())
{
auto data_size = sizeof(char) * string.size();
auto buffer = malloc(sizeof(uint32_t) + data_size);
*static_cast<uint32_t*>(buffer) = string.size();
memcpy(static_cast<int8_t*>(buffer) + sizeof(uint32_t), string.data(), data_size);
data_ = buffer;
}

Scalar::~Scalar()
{
if (own_)
Expand Down
20 changes: 20 additions & 0 deletions src/core/data/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class Scalar {
public:
Scalar() = default;
Scalar(const Scalar& other);
Scalar(Scalar&& other);

/**
* @brief Creates a shared `Scalar` with an existing allocation. The caller is responsible
* for passing in a sufficiently big allocation.
Expand All @@ -65,6 +67,24 @@ class Scalar {
*/
template <typename T>
Scalar(T value);
/**
* @brief Creates an owned scalar of a specified type from a scalar value
*
* @tparam T The scalar type to wrap
*
* @param type The type of the scalar
* @param value A scalar value to create a `Scalar` with
*/
template <typename T>
Scalar(T value, std::unique_ptr<Type> type);
/**
* @brief Creates an owned scalar from a string. The value from the
* original string will be copied.
*
* @param string A string to create a `Scalar` with
*/
Scalar(const std::string& string);

/**
* @brief Creates an owned scalar from a tuple of scalars. The values in the input vector
* will be copied.
Expand Down
36 changes: 34 additions & 2 deletions src/core/data/scalar.inl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ namespace legate {
template <typename T>
Scalar::Scalar(T value) : own_(true), type_(primitive_type(legate_type_code_of<T>))
{
static_assert(legate_type_code_of<T> != Type::Code::FIXED_ARRAY);
static_assert(legate_type_code_of<T> != Type::Code::STRUCT);
static_assert(legate_type_code_of<T> != Type::Code::STRING);
static_assert(legate_type_code_of<T> != Type::Code::INVALID);
auto buffer = malloc(sizeof(T));
memcpy(buffer, &value, sizeof(T));
data_ = buffer;
}

template <typename T>
Scalar::Scalar(T value, std::unique_ptr<Type> type) : own_(true), type_(std::move(type))
{
if (type_->code == Type::Code::INVALID)
throw std::invalid_argument("Invalid type cannot be used");
if (type_->size() != sizeof(T))
throw std::invalid_argument("Size of the value doesn't match with the type");
auto buffer = malloc(sizeof(T));
memcpy(buffer, &value, sizeof(T));
data_ = buffer;
Expand All @@ -37,12 +53,17 @@ Scalar::Scalar(const std::vector<T>& values)
template <typename VAL>
VAL Scalar::value() const
{
if (sizeof(VAL) != type_->size())
throw std::invalid_argument("Size of the scalar is " + std::to_string(type_->size()) +
", but the requested type has size " + std::to_string(sizeof(VAL)));
return *static_cast<const VAL*>(data_);
}

template <>
inline std::string Scalar::value() const
{
if (type_->code != Type::Code::STRING)
throw std::invalid_argument("Type of the scalar is not string");
// Getting a span of a temporary scalar is illegal in general,
// but we know this is safe as the span's pointer is held by this object.
auto len = *static_cast<const uint32_t*>(data_);
Expand All @@ -55,10 +76,21 @@ template <typename VAL>
Span<const VAL> Scalar::values() const
{
if (type_->code == Type::Code::FIXED_ARRAY) {
auto size = static_cast<const FixedArrayType*>(type_.get())->num_elements();
auto arr_type = static_cast<const FixedArrayType*>(type_.get());
const auto& elem_type = arr_type->element_type();
if (sizeof(VAL) != elem_type.size())
throw std::invalid_argument(
"The scalar's element type has size " + std::to_string(elem_type.size()) +
", but the requested element type has size " + std::to_string(sizeof(VAL)));
auto size = arr_type->num_elements();
return Span<const VAL>(reinterpret_cast<const VAL*>(data_), size);
} else
} else {
if (sizeof(VAL) != type_->size())
throw std::invalid_argument("Size of the scalar is " + std::to_string(type_->size()) +
", but the requested element type has size " +
std::to_string(sizeof(VAL)));
return Span<const VAL>(static_cast<const VAL*>(data_), 1);
}
}

template <>
Expand Down

0 comments on commit fff8f3c

Please sign in to comment.