Skip to content

Commit

Permalink
Merge pull request #59 from kuke/ctc_decoder_dev
Browse files Browse the repository at this point in the history
Add CTC prefix beam search decoder.
  • Loading branch information
xinghai-sun authored Jul 5, 2017
2 parents 853b538 + d7f5ee6 commit de86560
Show file tree
Hide file tree
Showing 10 changed files with 900 additions and 31 deletions.
26 changes: 26 additions & 0 deletions deep_speech_2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,29 @@ More help for arguments:
```
python infer.py --help
```

### Evaluating

```
CUDA_VISIBLE_DEVICES=0 python evaluate.py
```

More help for arguments:

```
python evaluate.py --help
```

### Parameters tuning

Parameters tuning for the CTC beam search decoder

```
CUDA_VISIBLE_DEVICES=0 python tune.py
```

More help for arguments:

```
python tune.py --help
```
216 changes: 197 additions & 19 deletions deep_speech_2/decoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Contains various CTC decoder."""
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from itertools import groupby
import numpy as np
from math import log
import multiprocessing


def ctc_best_path_decode(probs_seq, vocabulary):
"""Best path decoding, also called argmax decoding or greedy decoding.
def ctc_best_path_decoder(probs_seq, vocabulary):
"""Best path decoder, also called argmax decoder or greedy decoder.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
Expand Down Expand Up @@ -36,24 +38,200 @@ def ctc_best_path_decode(probs_seq, vocabulary):
return ''.join([vocabulary[index] for index in index_list])


def ctc_decode(probs_seq, vocabulary, method):
"""CTC-like sequence decoding from a sequence of likelihood probablilites.
def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
blank_id,
cutoff_prob=1.0,
ext_scoring_func=None,
nproc=False):
"""Beam search decoder for CTC-trained network. It utilizes beam search
to approximately select top best decoding labels and returning results
in the descending order. The implementation is based on Prefix
Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned. Two important modifications: 1) in the iterative computation
of probabilities, the assignment operation is changed to accumulation for
one prefix may comes from different paths; 2) the if condition "if l^+ not
in A_prev then" after probabilities' computation is deprecated for it is
hard to understand and seems unnecessary.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param method: Decoding method name, with options: "best_path".
:type method: basestring
:return: Decoding result string.
:rtype: baseline
:param blank_id: ID of blank.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
# dimension check
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary")
if method == "best_path":
return ctc_best_path_decode(probs_seq, vocabulary)
else:
raise ValueError("Decoding method [%s] is not supported.")
raise ValueError("The shape of prob_seq does not match with the "
"shape of the vocabulary.")

# blank_id check
if not blank_id < len(probs_seq[0]):
raise ValueError("blank_id shouldn't be greater than probs dimension")

# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_batch().
if nproc is True:
global ext_nproc_scorer
ext_scoring_func = ext_nproc_scorer

## initialize
# prefix_set_prev: the set containing selected prefixes
# probs_b_prev: prefixes' probability ending with blank in previous step
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
prefix_set_prev = {'\t': 1.0}
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}

## extend prefix in loop
for time_step in xrange(len(probs_seq)):
# prefix_set_next: the set containing candidate prefixes
# probs_b_cur: prefixes' probability ending with blank in current step
# probs_nb_cur: prefixes' probability ending with non-blank in current step
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}

prob_idx = list(enumerate(probs_seq[time_step]))
cutoff_len = len(prob_idx)
#If pruning is enabled
if cutoff_prob < 1.0:
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len, cum_prob = 0, 0.0
for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
prob_idx = prob_idx[0:cutoff_len]

for l in prefix_set_prev:
if not prefix_set_next.has_key(l):
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

# extend prefix by travering prob_idx
for index in xrange(cutoff_len):
c, prob_c = prob_idx[index][0], prob_idx[index][1]

if c == blank_id:
probs_b_cur[l] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
last_char = l[-1]
new_char = vocabulary[c]
l_plus = l + new_char
if not prefix_set_next.has_key(l_plus):
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

if new_char == last_char:
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
elif new_char == ' ':
if (ext_scoring_func is None) or (len(l) == 1):
score = 1.0
else:
prefix = l[1:]
score = ext_scoring_func(prefix)
probs_nb_cur[l_plus] += score * prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
# add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[
l_plus] + probs_b_cur[l_plus]
# add l into prefix_set_next
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
# update probs
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur

## store top beam_size prefixes
prefix_set_prev = sorted(
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
if beam_size < len(prefix_set_prev):
prefix_set_prev = prefix_set_prev[:beam_size]
prefix_set_prev = dict(prefix_set_prev)

beam_result = []
for seq, prob in prefix_set_prev.items():
if prob > 0.0 and len(seq) > 1:
result = seq[1:]
# score last word by external scorer
if (ext_scoring_func is not None) and (result[-1] != ' '):
prob = prob * ext_scoring_func(result)
log_prob = log(prob)
beam_result.append((log_prob, result))

## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result


def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary,
blank_id,
num_processes,
cutoff_prob=1.0,
ext_scoring_func=None):
"""CTC beam search decoder using multiple processes.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")

# use global variable to pass the externnal scorer to beam search decoder
global ext_nproc_scorer
ext_nproc_scorer = ext_scoring_func
nproc = True

pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args))

pool.close()
pool.join()
beam_search_results = [result.get() for result in results]
return beam_search_results
Loading

0 comments on commit de86560

Please sign in to comment.