1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
import numpy as np import random
vocab = { 0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'BOS', 6: 'EOS' } reverse_vocab = dict([(v,k) for k,v in vocab.items()]) vocab_size = len(vocab.items())
def softmax(x): """Compute softmax values for each sets of scores in x.""" e_x = np.exp(x - np.max(x)) return e_x / e_x.sum()
def reduce_mul(l): out = 1.0 for x in l: out *= x return out
def check_all_done(seqs): for seq in seqs: if not seq[-1]: return False return True def decode_step(encoder_context, input_seq): words_prob = [random.random() for _ in range(vocab_size)] words_prob[reverse_vocab['BOS']] = 0.0 words_prob = softmax(words_prob) ouput_step = [(idx,prob) for idx,prob in enumerate(words_prob)] ouput_step = sorted(ouput_step, key=lambda x: x[1], reverse=True) return ouput_step
def beam_search_step(encoder_context, top_seqs, k): all_seqs = [] for seq in top_seqs: seq_score = reduce_mul([_score for _,_score in seq]) if seq[-1][0] == reverse_vocab['EOS']: all_seqs.append((seq, seq_score, True)) continue current_step = decode_step(encoder_context, seq) for i,word in enumerate(current_step): if i >= k: break word_index = word[0] word_score = word[1] score = seq_score * word_score rs_seq = seq + [word] done = (word_index == reverse_vocab['EOS']) all_seqs.append((rs_seq, score, done)) all_seqs = sorted(all_seqs, key = lambda seq: seq[1], reverse=True) topk_seqs = [seq for seq,_,_ in all_seqs[:k]] all_done = check_all_done(topk_seqs) return topk_seqs, all_done
def beam_search(encoder_context): beam_size = 3 max_len = 10 top_seqs = [[(reverse_vocab['BOS'],1.0)]] for _ in range(max_len): top_seqs, all_done = beam_search_step(encoder_context, top_seqs, beam_size) if all_done: break return top_seqs
if __name__ == '__main__': encoder_context = None top_seqs = beam_search(encoder_context) for i,seq in enumerate(top_seqs): print 'Path[%d]: ' % i for word in seq[1:]: word_index = word[0] word_prob = word[1] print '%s(%.4f)' % (vocab[word_index], word_prob), if word_index == reverse_vocab['EOS']: break print '\n'
|