forked from google-research/pegasus
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathpublic_parsing_ops.py
75 lines (62 loc) · 2.57 KB
/
public_parsing_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright 2020 The PEGASUS Authors..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Parsing with public available ops.
This is a wrapper of sentencepiece ops for public release.
"""
from typing import List
import tensorflow as tf
import sentencepiece as sentencepiece_processor
_SHIFT_RESERVED_TOKENS = 103
_NEWLINE_SYMBOL = "<n>"
def create_text_encoder(encoder_type: str, vocab_filename: str):
if encoder_type == "sentencepiece":
return SentencePieceEncoder(vocab_filename)
elif encoder_type == "sentencepiece_newline":
return SentencePieceEncoder(vocab_filename, newline_symbol=_NEWLINE_SYMBOL)
else:
raise ValueError("Unsupported encoder type: %s" % encoder_type)
class SentencePieceEncoder(object):
"""SentencePieceEncoder.
First two ids are pad=0, eos=1, rest ids are being shifted up by
shift_reserved_tokens. If newline_symbol is provided, will replace newline in
the text with that token.
"""
def __init__(self,
sentencepiece_model_file: str,
shift_reserved_tokens: int = _SHIFT_RESERVED_TOKENS,
newline_symbol: str = ""):
self._tokenizer = sentencepiece_processor.SentencePieceProcessor()
self._sp_model = tf.io.gfile.GFile(sentencepiece_model_file, "rb").read()
self._tokenizer.LoadFromSerializedProto(self._sp_model)
self._shift_reserved_tokens = shift_reserved_tokens
self._newline_symbol = newline_symbol
@property
def vocab_size(self) -> int:
return self._tokenizer.GetPieceSize() + self._shift_reserved_tokens
def encode(self, text: str) -> List[int]:
if self._newline_symbol:
text = text.replace("\n", self._newline_symbol)
ids = self._tokenizer.EncodeAsIds(text)
ids = [i + self._shift_reserved_tokens if i > 1 else i for i in ids]
return ids
def decode(self, ids: List[int]) -> str:
ids = [
i - self._shift_reserved_tokens
if i > 1 + self._shift_reserved_tokens else i for i in ids
]
text = self._tokenizer.DecodeIds(ids)
if self._newline_symbol:
text = text.replace(self._newline_symbol, "\n")
return text