Beam Search算法及其应用

概述

Beam Search算法是一种平衡性能与消耗的搜索算法,目的是在序列中解码出相对较优的路径。

Beam Search算法广泛运用于OCR、语音识别、翻译系统等场景。

CTC示例

以OCR为例,beam search算法可应用于笔划切分点的判断,CTC解码,Seq2Seq模型解码等步骤。

如文档图像经过识别模型CTC产生若干帧的输出,CTC概率矩阵输出如下:
ctc-decode

对于这种CTC解码,如果按照最简单的概率最大解码方式,那么解码的结果是:

1
我丛大前走过...

naive

但这从语义上来说很明显不是正确的解码结果。实际上,上图的GroundTruth是:

1
我从人前走过...

groundtruth

OCR不仅需要识别模型,同时也需要结合自然语言模型,才能得到最好的效果。识别模型一般输出每个字的置信度,而自然语言模型则会输出整个语句的自然性概率即语句是否自然真实的概率。

如果结合自然语言模型,采用穷举的方式来解码,那么计算量将大的惊人,常见一级汉字有3755个,6帧的解码搜索空间是3755^6,随着句子变长,指数增长的搜索范围将快速耗光计算资源。因此需要一种能在有限的计算资源限制下获取相对较优的解码结果。

算法描述

Beam Search算法作为一种折中手段,在相对受限的搜索空间中找出其最优解,得出的解接近于整个搜索空间中的最优解。

Beam Search算法一般分为两部分:

  • 路径搜索
  • 路径打分

路径搜索是指在受限空间中检索出所有路径,路径打分是指对某一条路径进行评估打分。

Beam Search的一般步骤为:

  1. 初始化beam_size个序列,序列均为空,这些序列称之为beam paths;
  2. 取下一个Frame的前N个候选值(N一般为beam size或者更大,Frame内部侯选值已按照概率倒序排列),与已存在的beam paths组合形成N * beam_size条路径,称之为prob_paths;
  3. 对prob_paths进行打分,取前beam_size个prob_path作为新的beam paths;
  4. 若解码结束在完成算法,否则回到2。

简单代码

下面是一个Beam Search在Seq2Seq模型中的应用示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#coding: utf-8
#demo of beam search for seq2seq model
import numpy as np
import random

vocab = {
0: 'a',
1: 'b',
2: 'c',
3: 'd',
4: 'e',
5: 'BOS',
6: 'EOS'
}
reverse_vocab = dict([(v,k) for k,v in vocab.items()])
vocab_size = len(vocab.items())

def softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()

def reduce_mul(l):
out = 1.0
for x in l:
out *= x
return out

def check_all_done(seqs):
for seq in seqs:
if not seq[-1]:
return False
return True

def decode_step(encoder_context, input_seq):
#encoder_context contains infortaion of encoder
#ouput_step contains the words' probability
#these two varibles should be generated by seq2seq model
words_prob = [random.random() for _ in range(vocab_size)]
#downvote BOS
words_prob[reverse_vocab['BOS']] = 0.0
words_prob = softmax(words_prob)
ouput_step = [(idx,prob) for idx,prob in enumerate(words_prob)]
ouput_step = sorted(ouput_step, key=lambda x: x[1], reverse=True)
return ouput_step

#seq: [[word,word],[word,word],[word,word]]
#output: [[word,word,word],[word,word,word],[word,word,word]]
def beam_search_step(encoder_context, top_seqs, k):
all_seqs = []
for seq in top_seqs:
seq_score = reduce_mul([_score for _,_score in seq])
if seq[-1][0] == reverse_vocab['EOS']:
all_seqs.append((seq, seq_score, True))
continue
#get current step using encoder_context & seq
current_step = decode_step(encoder_context, seq)
for i,word in enumerate(current_step):
if i >= k:
break
word_index = word[0]
word_score = word[1]
score = seq_score * word_score
rs_seq = seq + [word]
done = (word_index == reverse_vocab['EOS'])
all_seqs.append((rs_seq, score, done))
all_seqs = sorted(all_seqs, key = lambda seq: seq[1], reverse=True)
topk_seqs = [seq for seq,_,_ in all_seqs[:k]]
all_done = check_all_done(topk_seqs)
return topk_seqs, all_done

def beam_search(encoder_context):
beam_size = 3
max_len = 10
#START
top_seqs = [[(reverse_vocab['BOS'],1.0)]]
#loop
for _ in range(max_len):
top_seqs, all_done = beam_search_step(encoder_context, top_seqs, beam_size)
if all_done:
break
return top_seqs

if __name__ == '__main__':
#encoder_context is not inportant in this demo
encoder_context = None
top_seqs = beam_search(encoder_context)
for i,seq in enumerate(top_seqs):
print 'Path[%d]: ' % i
for word in seq[1:]:
word_index = word[0]
word_prob = word[1]
print '%s(%.4f)' % (vocab[word_index], word_prob),
if word_index == reverse_vocab['EOS']:
break
print '\n'