Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve checks on transformer cache #881

Merged
merged 5 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(MARIAN_SOURCES
common/filesystem.cpp
common/file_stream.cpp
common/file_utils.cpp
common/hash.cpp
common/signal_handling.cpp
common/types.cpp

Expand Down
12 changes: 12 additions & 0 deletions src/common/hash.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include <string>

#include "hash.h"
#include "common/shape.h"

namespace std {
size_t hash<pair<string, marian::Shape>>::operator()(pair<string, marian::Shape> const& k) const {
size_t seed = hash<string>{}(k.first);
marian::util::hash_combine(seed, k.second.hash());
return seed;
}
} // namespace std
24 changes: 19 additions & 5 deletions src/common/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@ namespace util {

template <class T> using hash = std::hash<T>;

// This combinator is based on boost::hash_combine, but uses
// std::hash as the hash implementation. Used as a drop-in
// replacement for boost::hash_combine.
/**
* Combine hash values.
* This combinator is based on boost::hash_combine, but uses std::hash as the hash implementation.
* Used as a drop-in replacement for boost::hash_combine.
*/
template <class T, class HashType = std::size_t>
inline void hash_combine(HashType& seed, T const& v) {
hash<T> hasher;
seed ^= static_cast<HashType>(hasher(v)) + 0x9e3779b9 + (seed<<6) + (seed>>2);
}

// Hash a whole chunk of memory, mostly used for diagnostics
/** Hash a whole chunk of memory. */
template <class T, class HashType = std::size_t>
inline HashType hashMem(const T* beg, size_t len) {
HashType seed = 0;
Expand All @@ -25,5 +27,17 @@ inline HashType hashMem(const T* beg, size_t len) {
return seed;
}

}
} // namespace util

struct Shape; // Forward declaration
} // namespace marian

namespace std {
/**
* std::hash specialization for the string-shape pair used as a cache key in transformer.h.
*/
template <>
struct hash<pair<string, marian::Shape>> {
size_t operator()(pair<string, marian::Shape> const& k) const;
};
}
69 changes: 35 additions & 34 deletions src/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "marian.h"

#include "common/hash.h"
#include "layers/constructors.h"
#include "models/decoder.h"
#include "models/encoder.h"
Expand All @@ -28,7 +29,7 @@ class Transformer : public EncoderOrDecoderBase {

protected:
using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_;
std::unordered_map<std::string, Expr> cache_; // caching transformation of the encoder that should not be created again
std::unordered_map<std::pair<std::string, Shape>, Expr> cache_; // caching transformation of the encoder that should not be created again
mutable/*lazy*/ std::vector<float> sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings

bool depthScaling_{false}; // As recommended in the GPT-2 paper, down-scale layer weights by a factor of 1 / sqrt(depth);
Expand All @@ -40,16 +41,16 @@ class Transformer : public EncoderOrDecoderBase {
std::vector<Expr> alignments_; // [max tgt len or 1][beam depth, max src length, batch size, 1]

// @TODO: make this go away
template <typename T>
T opt(const char* const key) const { Ptr<Options> options = options_; return options->get<T>(key); }
template <typename T>
T opt(const char* const key) const { Ptr<Options> options = options_; return options->get<T>(key); }

template <typename T>
T opt(const std::string& key) const { return opt<T>(key.c_str()); }
template <typename T>
T opt(const std::string& key) const { return opt<T>(key.c_str()); }

template <typename T>
template <typename T>
T opt(const char* const key, const T& def) const { Ptr<Options> options = options_; return options->get<T>(key, def); }

template <typename T>
template <typename T>
T opt(const std::string& key, const T& def) const { opt<T>(key.c_str(), def); }

public:
Expand Down Expand Up @@ -256,7 +257,7 @@ class Transformer : public EncoderOrDecoderBase {

// take softmax along src sequence axis (-1)
auto weights = softmax(z); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: max src length]

if(saveAttentionWeights)
collectOneHead(weights, dimBeam);

Expand Down Expand Up @@ -289,34 +290,34 @@ class Transformer : public EncoderOrDecoderBase {
// Caching transformation of the encoder that should not be created again.
// @TODO: set this automatically by memoizing encoder context and
// memoization propagation (short-term)
if (cache // if caching
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
&& cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
kh = cache_[prefix + "_keys"]; // then return cached tensor
}
else {
std::pair<std::unordered_map<std::pair<std::string, Shape>, Expr>::iterator, bool> cache_result;
if (cache
&& !((cache_result = cache_.insert(std::pair<std::pair<std::string, Shape>, Expr>({prefix + "_keys", keys->shape()}, kh))).second)
) {
kh = cache_result.first->second;
} else {
int dimKeys = keys->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
auto Wk = graph_->param(prefix + "_Wk", {dimKeys, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
auto bk = graph_->param(prefix + "_bk", {1, dimModel}, inits::zeros());

kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
cache_[prefix + "_keys"] = kh;
if (cache) cache_result.first->second = kh;
}

Expr vh;
if (cache
&& cache_.count(prefix + "_values") > 0
&& cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
vh = cache_[prefix + "_values"];
if (cache
&& !((cache_result = cache_.insert(std::pair<std::pair<std::string, Shape>, Expr>({prefix + "_values", values->shape()}, vh))).second)
) {
vh = cache_result.first->second;
} else {
int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
auto Wv = graph_->param(prefix + "_Wv", {dimValues, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
auto bv = graph_->param(prefix + "_bv", {1, dimModel}, inits::zeros());

vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
vh = SplitHeads(vh, dimHeads);
cache_[prefix + "_values"] = vh;
if (cache) cache_result.first->second = vh;
}

int dimBeam = q->shape()[-4];
Expand Down Expand Up @@ -377,7 +378,7 @@ class Transformer : public EncoderOrDecoderBase {

// multi-head self-attention over previous input
output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights);

auto opsPost = opt<std::string>("transformer-postprocess");
output = postProcess(prefix + "_Wo", opsPost, output, input, dropProb);

Expand Down Expand Up @@ -558,7 +559,7 @@ class EncoderTransformer : public Transformer<EncoderBase> {
auto embeddingLayer = getEmbeddingLayer(opt<bool>("ulr", false));
std::tie(batchEmbeddings, batchMask) = embeddingLayer->apply((*batch)[batchIndex_]);
batchEmbeddings = addSpecialEmbeddings(batchEmbeddings, /*start=*/0, batch);

// reorganize batch and timestep
batchEmbeddings = atleast_nd(batchEmbeddings, 4); // [beam depth=1, max length, batch size, vector dim]
batchMask = atleast_nd(batchMask, 4); // [beam depth=1, max length, batch size, vector dim=1]
Expand Down Expand Up @@ -593,7 +594,7 @@ class EncoderTransformer : public Transformer<EncoderBase> {
}

// this allows to run a final layernorm operation after going through the transformer layer stack.
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
// it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested.
auto opsTop = opt<std::string>("transformer-postprocess-top", "");
layer = postProcess(prefix_ + "_top", opsTop, layer, prevLayer, dropProb);
Expand Down Expand Up @@ -622,14 +623,14 @@ class TransformerState : public DecoderState {
int beamSize) const override {

// @TODO: code duplication with DecoderState only because of isBatchMajor=true, should rather be a contructor argument of DecoderState?

std::vector<Ptr<EncoderState>> newEncStates;
for(auto& es : encStates_)
// If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
for(auto& es : encStates_)
// If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));

// Create hypothesis-selected state based on current state and hyp indices
auto selectedState = New<TransformerState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/true), logProbs_, newEncStates, batch_);
auto selectedState = New<TransformerState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/true), logProbs_, newEncStates, batch_);

// Set the same target token position as the current state
// @TODO: This is the same as in base function.
Expand Down Expand Up @@ -763,8 +764,8 @@ class DecoderTransformer : public Transformer<DecoderBase> {

// This would happen if something goes wrong during batch pruning.
ABORT_IF(encoderContext->shape()[-3] != dimBatch,
"Context and query batch dimension do not match {} != {}",
encoderContext->shape()[-3],
"Context and query batch dimension do not match {} != {}",
encoderContext->shape()[-3],
dimBatch);

// LayerAttention expects mask in a different layout
Expand Down Expand Up @@ -871,7 +872,7 @@ class DecoderTransformer : public Transformer<DecoderBase> {
}

// This allows to run a final layernorm operation after going through the transformer layer stack.
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
// it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested.
auto opsTop = opt<std::string>("transformer-postprocess-top", "");
query = postProcess(prefix_ + "_top", opsTop, query, prevQuery, dropProb);
Expand All @@ -884,7 +885,7 @@ class DecoderTransformer : public Transformer<DecoderBase> {
if(shortlist_)
output_->setShortlist(shortlist_);
auto logits = output_->applyAsLogits(decoderContext); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab or shortlist dim]

// return unormalized(!) probabilities
Ptr<DecoderState> nextState;
if (opt<std::string>("transformer-decoder-autoreg", "self-attention") == "rnn") {
Expand All @@ -909,9 +910,9 @@ class DecoderTransformer : public Transformer<DecoderBase> {
output_->clear();
cache_.clear();
alignments_.clear();
perLayerRnn_.clear(); // this needs to be cleared between batches.
// @TODO: figure out how to detect stale nodes i.e. nodes that are referenced,
// but where underlying memory has been deallocated by dropping all tensors
perLayerRnn_.clear(); // this needs to be cleared between batches.
// @TODO: figure out how to detect stale nodes i.e. nodes that are referenced,
// but where underlying memory has been deallocated by dropping all tensors
// from a TensorAllocator object. This can happen during ExpressionGraph::clear()
}
};
Expand Down