Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ctc beam search decoder #59

Merged
merged 31 commits into from
Jul 5, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a51ed51
update code & add test
Jun 2, 2017
b930d69
mv ctc_beam_search_decoder into deep_speech_2/
Jun 4, 2017
f382528
update annotations
Jun 4, 2017
1e9ae32
enable ctc beam search decoder
Jun 5, 2017
142a79f
code clean & add external scorer
Jun 6, 2017
1d7ade1
add annotations
Jun 7, 2017
6698fe6
modify language model scoring
Jun 7, 2017
3a7a52e
rename variables in decoder
Jun 7, 2017
2922a45
tiny modify to pass CI
Jun 7, 2017
ed23e21
improve external scorer
Jun 7, 2017
bcd01f7
optimize the efficiency of beam search
Jun 8, 2017
0deb2e6
add beam search decoder using multiprocesses
Jun 12, 2017
ef1350f
correct typos in annotations
Jun 12, 2017
4b8fe7e
enable lm in multiprocessing decoder & add script for params tuning
Jun 13, 2017
0fa063e
change two arguments
Jun 13, 2017
08203ee
final refining on old data provider: enable pruning & add evaluation …
Jun 18, 2017
c02c650
add scoring last word in beam search
Jun 18, 2017
c8b0ae8
adapt to the new data provider
Jun 19, 2017
c26dab9
Merge branch 'develop' of https://github.com/PaddlePaddle/models into…
Jun 20, 2017
04881fa
tiny adjust
Jun 20, 2017
40b75e3
Merge branch 'develop' of https://github.com/PaddlePaddle/models into…
Jun 20, 2017
46df7c4
add unit test for decoders
Jun 21, 2017
6224d36
Merge branch 'develop' of https://github.com/PaddlePaddle/models into…
Jun 26, 2017
accaf92
Merge branch 'develop' of https://github.com/PaddlePaddle/models into…
Jun 26, 2017
63a72c1
refine ctc_beam_search_decoder
Jun 27, 2017
359dc2a
fix decoders' unittest
Jun 27, 2017
20b50ca
append README.md
Jun 27, 2017
2b594b4
resolve conflicts in requirements.txt
Jun 27, 2017
66bb901
Merge branch 'develop' into ctc_decoder_dev
Jun 29, 2017
eec7cb4
enable resetting params in scorer
Jul 4, 2017
d7f5ee6
follow comments in code format
Jul 5, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions deep_speech_2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,29 @@ More help for arguments:
```
python infer.py --help
```

### Evaluating

```
CUDA_VISIBLE_DEVICES=0 python evaluate.py
```

More help for arguments:

```
python evaluate.py --help
```

### Parameters tuning

Parameters tuning for the CTC beam search decoder

```
CUDA_VISIBLE_DEVICES=0 python tune.py
```

More help for arguments:

```
python tune.py --help
```
216 changes: 197 additions & 19 deletions deep_speech_2/decoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Contains various CTC decoder."""
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from itertools import groupby
import numpy as np
from math import log
import multiprocessing


def ctc_best_path_decode(probs_seq, vocabulary):
"""Best path decoding, also called argmax decoding or greedy decoding.
def ctc_best_path_decoder(probs_seq, vocabulary):
"""Best path decoder, also called argmax decoder or greedy decoder.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.

Expand Down Expand Up @@ -36,24 +38,200 @@ def ctc_best_path_decode(probs_seq, vocabulary):
return ''.join([vocabulary[index] for index in index_list])


def ctc_decode(probs_seq, vocabulary, method):
"""CTC-like sequence decoding from a sequence of likelihood probablilites.
def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
blank_id,
cutoff_prob=1.0,
ext_scoring_func=None,
nproc=False):
"""Beam search decoder for CTC-trained network. It utilizes beam search
to approximately select top best decoding labels and returning results
in the descending order. The implementation is based on Prefix
Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned. Two important modifications: 1) in the iterative computation
of probabilities, the assignment operation is changed to accumulation for
one prefix may comes from different paths; 2) the if condition "if l^+ not
in A_prev then" after probabilities' computation is deprecated for it is
hard to understand and seems unnecessary.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make sure that these modifications are correct?


:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param method: Decoding method name, with options: "best_path".
:type method: basestring
:return: Decoding result string.
:rtype: baseline
:param blank_id: ID of blank.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
# dimension check
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary")
if method == "best_path":
return ctc_best_path_decode(probs_seq, vocabulary)
else:
raise ValueError("Decoding method [%s] is not supported.")
raise ValueError("The shape of prob_seq does not match with the "
"shape of the vocabulary.")

# blank_id check
if not blank_id < len(probs_seq[0]):
raise ValueError("blank_id shouldn't be greater than probs dimension")

# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_batch().
if nproc is True:
global ext_nproc_scorer
ext_scoring_func = ext_nproc_scorer

## initialize
# prefix_set_prev: the set containing selected prefixes
# probs_b_prev: prefixes' probability ending with blank in previous step
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
prefix_set_prev = {'\t': 1.0}
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}

## extend prefix in loop
for time_step in xrange(len(probs_seq)):
# prefix_set_next: the set containing candidate prefixes
# probs_b_cur: prefixes' probability ending with blank in current step
# probs_nb_cur: prefixes' probability ending with non-blank in current step
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}

prob_idx = list(enumerate(probs_seq[time_step]))
cutoff_len = len(prob_idx)
#If pruning is enabled
if cutoff_prob < 1.0:
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len, cum_prob = 0, 0.0
for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
prob_idx = prob_idx[0:cutoff_len]

for l in prefix_set_prev:
if not prefix_set_next.has_key(l):
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

# extend prefix by travering prob_idx
for index in xrange(cutoff_len):
c, prob_c = prob_idx[index][0], prob_idx[index][1]

if c == blank_id:
probs_b_cur[l] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
last_char = l[-1]
new_char = vocabulary[c]
l_plus = l + new_char
if not prefix_set_next.has_key(l_plus):
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

if new_char == last_char:
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
elif new_char == ' ':
if (ext_scoring_func is None) or (len(l) == 1):
score = 1.0
else:
prefix = l[1:]
score = ext_scoring_func(prefix)
probs_nb_cur[l_plus] += score * prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
# add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[
l_plus] + probs_b_cur[l_plus]
# add l into prefix_set_next
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
# update probs
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur

## store top beam_size prefixes
prefix_set_prev = sorted(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we don't need a complete sorting but a partial sorting, is it very time-consuming ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems it is not very time-consuming, and only occupies a small portion of decoding time. We can use priority queue or something to sort partially, but I am not optimistic about the improvement of efficiency in Python.

prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
if beam_size < len(prefix_set_prev):
prefix_set_prev = prefix_set_prev[:beam_size]
prefix_set_prev = dict(prefix_set_prev)

beam_result = []
for seq, prob in prefix_set_prev.items():
if prob > 0.0 and len(seq) > 1:
result = seq[1:]
# score last word by external scorer
if (ext_scoring_func is not None) and (result[-1] != ' '):
prob = prob * ext_scoring_func(result)
log_prob = log(prob)
beam_result.append((log_prob, result))

## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result


def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary,
blank_id,
num_processes,
cutoff_prob=1.0,
ext_scoring_func=None):
"""CTC beam search decoder using multiple processes.

:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")

# use global variable to pass the externnal scorer to beam search decoder
global ext_nproc_scorer
ext_nproc_scorer = ext_scoring_func
nproc = True

pool = multiprocessing.Pool(processes=num_processes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pool could be reusable. It might be very slow to create a new process pool for each batch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't matter for batch size can be taken a relative large value. And seems that to synchronize, pool.close() must be called and after that the pool cannot be used again.

results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args))

pool.close()
pool.join()
beam_search_results = [result.get() for result in results]
return beam_search_results
Loading