From d186d94f801eeeda95fd9d09c51cb7f3a0fb7fb8 Mon Sep 17 00:00:00 2001 From: Graeme Nail Date: Wed, 24 Aug 2022 14:23:45 +0100 Subject: [PATCH] Check shapes on transformer cache --- src/models/transformer.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/models/transformer.h b/src/models/transformer.h index d87594e0e..89adfee87 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -28,7 +28,7 @@ class Transformer : public EncoderOrDecoderBase { protected: using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_; - std::unordered_map cache_; // caching transformation of the encoder that should not be created again + std::unordered_map> cache_; // caching transformation of the encoder that should not be created again mutable/*lazy*/ std::vector sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings bool depthScaling_{false}; // As recommended in the GPT-2 paper, down-scale layer weights by a factor of 1 / sqrt(depth); @@ -288,10 +288,10 @@ class Transformer : public EncoderOrDecoderBase { // Caching transformation of the encoder that should not be created again. // @TODO: set this automatically by memoizing encoder context and // memoization propagation (short-term) - if (cache // if caching - && cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen - && cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change - kh = cache_[prefix + "_keys"]; // then return cached tensor + if (cache // if caching + && cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen + && cache_[prefix + "_keys"].first == keys->shape()) { // and the underlying element size did not change + kh = cache_[prefix + "_keys"].second; // then return cached tensor } else { int dimKeys = keys->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation @@ -300,14 +300,14 @@ class Transformer : public EncoderOrDecoderBase { kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim] kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim] - cache_[prefix + "_keys"] = kh; + cache_[prefix + "_keys"] = std::make_pair(keys->shape(), kh); } Expr vh; if (cache && cache_.count(prefix + "_values") > 0 - && cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) { - vh = cache_[prefix + "_values"]; + && cache_[prefix + "_values"].first == values->shape()) { + vh = cache_[prefix + "_values"].second; } else { int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation auto Wv = graph_->param(prefix + "_Wv", {dimValues, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f)); @@ -315,7 +315,7 @@ class Transformer : public EncoderOrDecoderBase { vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim] vh = SplitHeads(vh, dimHeads); - cache_[prefix + "_values"] = vh; + cache_[prefix + "_values"] = std::make_pair(values->shape(), vh); } int dimBeam = q->shape()[-4];