# Copyright 2020 The PEGASUS Authors..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Parsing with public available ops.

This is a wrapper of sentencepiece ops for public release.
"""

from typing import List
import tensorflow as tf
import sentencepiece as sentencepiece_processor

_SHIFT_RESERVED_TOKENS = 103
_NEWLINE_SYMBOL = "<n>"


def create_text_encoder(encoder_type: str, vocab_filename: str):
  if encoder_type == "sentencepiece":
    return SentencePieceEncoder(vocab_filename)
  elif encoder_type == "sentencepiece_newline":
    return SentencePieceEncoder(vocab_filename, newline_symbol=_NEWLINE_SYMBOL)
  else:
    raise ValueError("Unsupported encoder type: %s" % encoder_type)


class SentencePieceEncoder(object):
  """SentencePieceEncoder.

  First two ids are pad=0, eos=1, rest ids are being shifted up by
  shift_reserved_tokens. If newline_symbol is provided, will replace newline in
  the text with that token.
  """

  def __init__(self,
               sentencepiece_model_file: str,
               shift_reserved_tokens: int = _SHIFT_RESERVED_TOKENS,
               newline_symbol: str = ""):
    self._tokenizer = sentencepiece_processor.SentencePieceProcessor()
    self._sp_model = tf.io.gfile.GFile(sentencepiece_model_file, "rb").read()
    self._tokenizer.LoadFromSerializedProto(self._sp_model)
    self._shift_reserved_tokens = shift_reserved_tokens
    self._newline_symbol = newline_symbol

  @property
  def vocab_size(self) -> int:
    return self._tokenizer.GetPieceSize() + self._shift_reserved_tokens

  def encode(self, text: str) -> List[int]:
    if self._newline_symbol:
      text = text.replace("\n", self._newline_symbol)
    ids = self._tokenizer.EncodeAsIds(text)
    ids = [i + self._shift_reserved_tokens if i > 1 else i for i in ids]
    return ids

  def decode(self, ids: List[int]) -> str:
    ids = [
        i - self._shift_reserved_tokens
        if i > 1 + self._shift_reserved_tokens else i for i in ids
    ]
    text = self._tokenizer.DecodeIds(ids)
    if self._newline_symbol:
      text = text.replace(self._newline_symbol, "\n")
    return text