-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add ctc beam search decoder #59
Changes from all commits
a51ed51
b930d69
f382528
1e9ae32
142a79f
1d7ade1
6698fe6
3a7a52e
2922a45
ed23e21
bcd01f7
0deb2e6
ef1350f
4b8fe7e
0fa063e
08203ee
c02c650
c8b0ae8
c26dab9
04881fa
40b75e3
46df7c4
6224d36
accaf92
63a72c1
359dc2a
20b50ca
2b594b4
66bb901
eec7cb4
d7f5ee6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,16 @@ | ||
"""Contains various CTC decoder.""" | ||
"""Contains various CTC decoders.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
from itertools import groupby | ||
import numpy as np | ||
from math import log | ||
import multiprocessing | ||
|
||
|
||
def ctc_best_path_decode(probs_seq, vocabulary): | ||
"""Best path decoding, also called argmax decoding or greedy decoding. | ||
def ctc_best_path_decoder(probs_seq, vocabulary): | ||
"""Best path decoder, also called argmax decoder or greedy decoder. | ||
Path consisting of the most probable tokens are further post-processed to | ||
remove consecutive repetitions and all blanks. | ||
|
||
|
@@ -36,24 +38,200 @@ def ctc_best_path_decode(probs_seq, vocabulary): | |
return ''.join([vocabulary[index] for index in index_list]) | ||
|
||
|
||
def ctc_decode(probs_seq, vocabulary, method): | ||
"""CTC-like sequence decoding from a sequence of likelihood probablilites. | ||
def ctc_beam_search_decoder(probs_seq, | ||
beam_size, | ||
vocabulary, | ||
blank_id, | ||
cutoff_prob=1.0, | ||
ext_scoring_func=None, | ||
nproc=False): | ||
"""Beam search decoder for CTC-trained network. It utilizes beam search | ||
to approximately select top best decoding labels and returning results | ||
in the descending order. The implementation is based on Prefix | ||
Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is | ||
redesigned. Two important modifications: 1) in the iterative computation | ||
of probabilities, the assignment operation is changed to accumulation for | ||
one prefix may comes from different paths; 2) the if condition "if l^+ not | ||
in A_prev then" after probabilities' computation is deprecated for it is | ||
hard to understand and seems unnecessary. | ||
|
||
:param probs_seq: 2-D list of probabilities over the vocabulary for each | ||
character. Each element is a list of float probabilities | ||
for one character. | ||
:type probs_seq: list | ||
:param probs_seq: 2-D list of probability distributions over each time | ||
step, with each element being a list of normalized | ||
probabilities over vocabulary and blank. | ||
:type probs_seq: 2-D list | ||
:param beam_size: Width for beam search. | ||
:type beam_size: int | ||
:param vocabulary: Vocabulary list. | ||
:type vocabulary: list | ||
:param method: Decoding method name, with options: "best_path". | ||
:type method: basestring | ||
:return: Decoding result string. | ||
:rtype: baseline | ||
:param blank_id: ID of blank. | ||
:type blank_id: int | ||
:param cutoff_prob: Cutoff probability in pruning, | ||
default 1.0, no pruning. | ||
:type cutoff_prob: float | ||
:param ext_scoring_func: External scoring function for | ||
partially decoded sentence, e.g. word count | ||
or language model. | ||
:type external_scoring_func: callable | ||
:param nproc: Whether the decoder used in multiprocesses. | ||
:type nproc: bool | ||
:return: List of tuples of log probability and sentence as decoding | ||
results, in descending order of the probability. | ||
:rtype: list | ||
""" | ||
# dimension check | ||
for prob_list in probs_seq: | ||
if not len(prob_list) == len(vocabulary) + 1: | ||
raise ValueError("probs dimension mismatchedd with vocabulary") | ||
if method == "best_path": | ||
return ctc_best_path_decode(probs_seq, vocabulary) | ||
else: | ||
raise ValueError("Decoding method [%s] is not supported.") | ||
raise ValueError("The shape of prob_seq does not match with the " | ||
"shape of the vocabulary.") | ||
|
||
# blank_id check | ||
if not blank_id < len(probs_seq[0]): | ||
raise ValueError("blank_id shouldn't be greater than probs dimension") | ||
|
||
# If the decoder called in the multiprocesses, then use the global scorer | ||
# instantiated in ctc_beam_search_decoder_batch(). | ||
if nproc is True: | ||
global ext_nproc_scorer | ||
ext_scoring_func = ext_nproc_scorer | ||
|
||
## initialize | ||
# prefix_set_prev: the set containing selected prefixes | ||
# probs_b_prev: prefixes' probability ending with blank in previous step | ||
# probs_nb_prev: prefixes' probability ending with non-blank in previous step | ||
prefix_set_prev = {'\t': 1.0} | ||
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} | ||
|
||
## extend prefix in loop | ||
for time_step in xrange(len(probs_seq)): | ||
# prefix_set_next: the set containing candidate prefixes | ||
# probs_b_cur: prefixes' probability ending with blank in current step | ||
# probs_nb_cur: prefixes' probability ending with non-blank in current step | ||
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {} | ||
|
||
prob_idx = list(enumerate(probs_seq[time_step])) | ||
cutoff_len = len(prob_idx) | ||
#If pruning is enabled | ||
if cutoff_prob < 1.0: | ||
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) | ||
cutoff_len, cum_prob = 0, 0.0 | ||
for i in xrange(len(prob_idx)): | ||
cum_prob += prob_idx[i][1] | ||
cutoff_len += 1 | ||
if cum_prob >= cutoff_prob: | ||
break | ||
prob_idx = prob_idx[0:cutoff_len] | ||
|
||
for l in prefix_set_prev: | ||
if not prefix_set_next.has_key(l): | ||
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 | ||
|
||
# extend prefix by travering prob_idx | ||
for index in xrange(cutoff_len): | ||
c, prob_c = prob_idx[index][0], prob_idx[index][1] | ||
|
||
if c == blank_id: | ||
probs_b_cur[l] += prob_c * ( | ||
probs_b_prev[l] + probs_nb_prev[l]) | ||
else: | ||
last_char = l[-1] | ||
new_char = vocabulary[c] | ||
l_plus = l + new_char | ||
if not prefix_set_next.has_key(l_plus): | ||
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 | ||
|
||
if new_char == last_char: | ||
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l] | ||
probs_nb_cur[l] += prob_c * probs_nb_prev[l] | ||
elif new_char == ' ': | ||
if (ext_scoring_func is None) or (len(l) == 1): | ||
score = 1.0 | ||
else: | ||
prefix = l[1:] | ||
score = ext_scoring_func(prefix) | ||
probs_nb_cur[l_plus] += score * prob_c * ( | ||
probs_b_prev[l] + probs_nb_prev[l]) | ||
else: | ||
probs_nb_cur[l_plus] += prob_c * ( | ||
probs_b_prev[l] + probs_nb_prev[l]) | ||
# add l_plus into prefix_set_next | ||
prefix_set_next[l_plus] = probs_nb_cur[ | ||
l_plus] + probs_b_cur[l_plus] | ||
# add l into prefix_set_next | ||
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] | ||
# update probs | ||
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur | ||
|
||
## store top beam_size prefixes | ||
prefix_set_prev = sorted( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we don't need a complete sorting but a partial sorting, is it very time-consuming ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems it is not very time-consuming, and only occupies a small portion of decoding time. We can use |
||
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True) | ||
if beam_size < len(prefix_set_prev): | ||
prefix_set_prev = prefix_set_prev[:beam_size] | ||
prefix_set_prev = dict(prefix_set_prev) | ||
|
||
beam_result = [] | ||
for seq, prob in prefix_set_prev.items(): | ||
if prob > 0.0 and len(seq) > 1: | ||
result = seq[1:] | ||
# score last word by external scorer | ||
if (ext_scoring_func is not None) and (result[-1] != ' '): | ||
prob = prob * ext_scoring_func(result) | ||
log_prob = log(prob) | ||
beam_result.append((log_prob, result)) | ||
|
||
## output top beam_size decoding results | ||
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) | ||
return beam_result | ||
|
||
|
||
def ctc_beam_search_decoder_batch(probs_split, | ||
beam_size, | ||
vocabulary, | ||
blank_id, | ||
num_processes, | ||
cutoff_prob=1.0, | ||
ext_scoring_func=None): | ||
"""CTC beam search decoder using multiple processes. | ||
|
||
:param probs_seq: 3-D list with each element as an instance of 2-D list | ||
of probabilities used by ctc_beam_search_decoder(). | ||
:type probs_seq: 3-D list | ||
:param beam_size: Width for beam search. | ||
:type beam_size: int | ||
:param vocabulary: Vocabulary list. | ||
:type vocabulary: list | ||
:param blank_id: ID of blank. | ||
:type blank_id: int | ||
:param num_processes: Number of parallel processes. | ||
:type num_processes: int | ||
:param cutoff_prob: Cutoff probability in pruning, | ||
default 1.0, no pruning. | ||
:param num_processes: Number of parallel processes. | ||
:type num_processes: int | ||
:type cutoff_prob: float | ||
:param ext_scoring_func: External scoring function for | ||
partially decoded sentence, e.g. word count | ||
or language model. | ||
:type external_scoring_function: callable | ||
:return: List of tuples of log probability and sentence as decoding | ||
results, in descending order of the probability. | ||
:rtype: list | ||
""" | ||
if not num_processes > 0: | ||
raise ValueError("Number of processes must be positive!") | ||
|
||
# use global variable to pass the externnal scorer to beam search decoder | ||
global ext_nproc_scorer | ||
ext_nproc_scorer = ext_scoring_func | ||
nproc = True | ||
|
||
pool = multiprocessing.Pool(processes=num_processes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pool could be reusable. It might be very slow to create a new process pool for each batch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't matter for batch size can be taken a relative large value. And seems that to synchronize, |
||
results = [] | ||
for i, probs_list in enumerate(probs_split): | ||
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, | ||
nproc) | ||
results.append(pool.apply_async(ctc_beam_search_decoder, args)) | ||
|
||
pool.close() | ||
pool.join() | ||
beam_search_results = [result.get() for result in results] | ||
return beam_search_results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make sure that these modifications are correct?