|
| 1 | +import tensorflow as tf |
| 2 | +import keras.layers as L |
| 3 | + |
| 4 | +# This code implements a single-GRU seq2seq model. You will have to improve it later in the assignment. |
| 5 | +# Note 1: when using several recurrent layers TF can mixed up the weights of different recurrent layers. |
| 6 | +# In that case, make sure you both create AND use each rnn/gru/lstm/custom layer in a unique variable scope |
| 7 | +# e.g. with tf.variable_scope("first_lstm"): new_cell, new_out = self.lstm_1(...) |
| 8 | +# with tf.variable_scope("second_lstm"): new_cell2, new_out2 = self.lstm_2(...) |
| 9 | +# Note 2: everything you need for decoding should be stored in model state (output list of both encode and decode) |
| 10 | +# e.g. for attention, you should store all encoder sequence and input mask there in addition to lstm/gru states. |
| 11 | + |
| 12 | +class BasicTranslationModel: |
| 13 | + def __init__(self, name, inp_voc, out_voc, |
| 14 | + emb_size, hid_size,): |
| 15 | + |
| 16 | + self.name = name |
| 17 | + self.inp_voc = inp_voc |
| 18 | + self.out_voc = out_voc |
| 19 | + |
| 20 | + with tf.variable_scope(name): |
| 21 | + self.emb_inp = L.Embedding(len(inp_voc), emb_size) |
| 22 | + self.emb_out = L.Embedding(len(out_voc), emb_size) |
| 23 | + self.enc0 = tf.nn.rnn_cell.GRUCell(hid_size) |
| 24 | + self.dec_start = L.Dense(hid_size) |
| 25 | + self.dec0 = tf.nn.rnn_cell.GRUCell(hid_size) |
| 26 | + self.logits = L.Dense(len(out_voc)) |
| 27 | + |
| 28 | + |
| 29 | + # run on dummy output to .build all layers (and therefore create weights) |
| 30 | + inp = tf.placeholder('int32', [None, None]) |
| 31 | + out = tf.placeholder('int32', [None, None]) |
| 32 | + h0 = self.encode(inp) |
| 33 | + h1 = self.decode(h0,out[:,0]) |
| 34 | + # h2 = self.decode(h1,out[:,1]) etc. |
| 35 | + |
| 36 | + self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=name) |
| 37 | + |
| 38 | + |
| 39 | + def encode(self, inp, **flags): |
| 40 | + """ |
| 41 | + Takes symbolic input sequence, computes initial state |
| 42 | + :param inp: matrix of input tokens [batch, time] |
| 43 | + :return: a list of initial decoder state tensors |
| 44 | + """ |
| 45 | + inp_lengths = infer_length(inp, self.inp_voc.eos_ix) |
| 46 | + inp_emb = self.emb_inp(inp) |
| 47 | + |
| 48 | + _, enc_last = tf.nn.dynamic_rnn( |
| 49 | + self.enc0, inp_emb, |
| 50 | + sequence_length=inp_lengths, |
| 51 | + dtype = inp_emb.dtype) |
| 52 | + |
| 53 | + dec_start = self.dec_start(enc_last) |
| 54 | + return [dec_start] |
| 55 | + |
| 56 | + def decode(self, prev_state, prev_tokens, **flags): |
| 57 | + """ |
| 58 | + Takes previous decoder state and tokens, returns new state and logits |
| 59 | + :param prev_state: a list of previous decoder state tensors |
| 60 | + :param prev_tokens: previous output tokens, an int vector of [batch_size] |
| 61 | + :return: a list of next decoder state tensors, a tensor of logits [batch,n_tokens] |
| 62 | + """ |
| 63 | + |
| 64 | + [prev_dec] = prev_state |
| 65 | + |
| 66 | + prev_emb = self.emb_out(prev_tokens[:,None])[:,0] |
| 67 | + |
| 68 | + new_dec_out,new_dec_state = self.dec0(prev_emb, prev_dec) |
| 69 | + |
| 70 | + output_logits = self.logits(new_dec_out) |
| 71 | + |
| 72 | + return [new_dec_state], output_logits |
| 73 | + |
| 74 | + def symbolic_score(self, inp, out, eps=1e-30, **flags): |
| 75 | + """ |
| 76 | + Takes symbolic int32 matrices of hebrew words and their english translations. |
| 77 | + Computes the log-probabilities of all possible english characters given english prefices and hebrew word. |
| 78 | + :param inp: input sequence, int32 matrix of shape [batch,time] |
| 79 | + :param out: output sequence, int32 matrix of shape [batch,time] |
| 80 | + :return: log-probabilities of all possible english characters of shape [bath,time,n_tokens] |
| 81 | +
|
| 82 | + NOTE: log-probabilities time axis is synchronized with out |
| 83 | + In other words, logp are probabilities of __current__ output at each tick, not the next one |
| 84 | + therefore you can get likelihood as logprobas * tf.one_hot(out,n_tokens) |
| 85 | + """ |
| 86 | + first_state = self.encode(inp,**flags) |
| 87 | + |
| 88 | + batch_size = tf.shape(inp)[0] |
| 89 | + bos = tf.fill([batch_size],self.out_voc.bos_ix) |
| 90 | + first_logits = tf.log(tf.one_hot(bos, len(self.out_voc)) + eps) |
| 91 | + |
| 92 | + def step(blob, y_prev): |
| 93 | + h_prev = blob[:-1] |
| 94 | + h_new, logits = self.decode(h_prev, y_prev, **flags) |
| 95 | + return list(h_new) + [logits] |
| 96 | + |
| 97 | + results = tf.scan(step,initializer=list(first_state)+[first_logits], |
| 98 | + elems=tf.transpose(out)) |
| 99 | + |
| 100 | + # gather state and logits, each of shape [time,batch,...] |
| 101 | + states_seq, logits_seq = results[:-1], results[-1] |
| 102 | + |
| 103 | + # add initial state and logits |
| 104 | + logits_seq = tf.concat((first_logits[None], logits_seq),axis=0) |
| 105 | + states_seq = [tf.concat((init[None], states), axis=0) |
| 106 | + for init, states in zip(first_state, states_seq)] |
| 107 | + |
| 108 | + #convert from [time,batch,...] to [batch,time,...] |
| 109 | + logits_seq = tf.transpose(logits_seq, [1, 0, 2]) |
| 110 | + states_seq = [tf.transpose(states, [1, 0] + list(range(2, states.shape.ndims))) |
| 111 | + for states in states_seq] |
| 112 | + |
| 113 | + return tf.nn.log_softmax(logits_seq) |
| 114 | + |
| 115 | + def symbolic_translate(self, inp, greedy=False, max_len = None, eps = 1e-30, **flags): |
| 116 | + """ |
| 117 | + takes symbolic int32 matrix of hebrew words, produces output tokens sampled |
| 118 | + from the model and output log-probabilities for all possible tokens at each tick. |
| 119 | + :param inp: input sequence, int32 matrix of shape [batch,time] |
| 120 | + :param greedy: if greedy, takes token with highest probablity at each tick. |
| 121 | + Otherwise samples proportionally to probability. |
| 122 | + :param max_len: max length of output, defaults to 2 * input length |
| 123 | + :return: output tokens int32[batch,time] and |
| 124 | + log-probabilities of all tokens at each tick, [batch,time,n_tokens] |
| 125 | + """ |
| 126 | + first_state = self.encode(inp, **flags) |
| 127 | + |
| 128 | + batch_size = tf.shape(inp)[0] |
| 129 | + bos = tf.fill([batch_size],self.out_voc.bos_ix) |
| 130 | + first_logits = tf.log(tf.one_hot(bos, len(self.out_voc)) + eps) |
| 131 | + max_len = tf.reduce_max(tf.shape(inp)[1])*2 |
| 132 | + |
| 133 | + def step(blob,t): |
| 134 | + h_prev, y_prev = blob[:-2], blob[-1] |
| 135 | + h_new, logits = self.decode(h_prev, y_prev, **flags) |
| 136 | + y_new = tf.argmax(logits,axis=-1) if greedy else tf.multinomial(logits,1)[:,0] |
| 137 | + return list(h_new) + [logits, tf.cast(y_new,y_prev.dtype)] |
| 138 | + |
| 139 | + results = tf.scan(step, initializer=list(first_state) + [first_logits, bos], |
| 140 | + elems=[tf.range(max_len)]) |
| 141 | + |
| 142 | + # gather state, logits and outs, each of shape [time,batch,...] |
| 143 | + states_seq, logits_seq, out_seq = results[:-2], results[-2], results[-1] |
| 144 | + |
| 145 | + # add initial state, logits and out |
| 146 | + logits_seq = tf.concat((first_logits[None],logits_seq),axis=0) |
| 147 | + out_seq = tf.concat((bos[None], out_seq), axis=0) |
| 148 | + states_seq = [tf.concat((init[None], states), axis=0) |
| 149 | + for init, states in zip(first_state, states_seq)] |
| 150 | + |
| 151 | + #convert from [time,batch,...] to [batch,time,...] |
| 152 | + logits_seq = tf.transpose(logits_seq, [1, 0, 2]) |
| 153 | + out_seq = tf.transpose(out_seq) |
| 154 | + states_seq = [tf.transpose(states, [1, 0] + list(range(2, states.shape.ndims))) |
| 155 | + for states in states_seq] |
| 156 | + |
| 157 | + return out_seq, tf.nn.log_softmax(logits_seq) |
| 158 | + |
| 159 | + |
| 160 | + |
| 161 | +### Utility functions ### |
| 162 | + |
| 163 | +def initialize_uninitialized(sess = None): |
| 164 | + """ |
| 165 | + Initialize unitialized variables, doesn't affect those already initialized |
| 166 | + :param sess: in which session to initialize stuff. Defaults to tf.get_default_session() |
| 167 | + """ |
| 168 | + sess = sess or tf.get_default_session() |
| 169 | + global_vars = tf.global_variables() |
| 170 | + is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars]) |
| 171 | + not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f] |
| 172 | + |
| 173 | + if len(not_initialized_vars): |
| 174 | + sess.run(tf.variables_initializer(not_initialized_vars)) |
| 175 | + |
| 176 | +def infer_length(seq, eos_ix, time_major=False, dtype=tf.int32): |
| 177 | + """ |
| 178 | + compute length given output indices and eos code |
| 179 | + :param seq: tf matrix [time,batch] if time_major else [batch,time] |
| 180 | + :param eos_ix: integer index of end-of-sentence token |
| 181 | + :returns: lengths, int32 vector of shape [batch] |
| 182 | + """ |
| 183 | + axis = 0 if time_major else 1 |
| 184 | + is_eos = tf.cast(tf.equal(seq, eos_ix), dtype) |
| 185 | + count_eos = tf.cumsum(is_eos,axis=axis,exclusive=True) |
| 186 | + lengths = tf.reduce_sum(tf.cast(tf.equal(count_eos,0),dtype),axis=axis) |
| 187 | + return lengths |
| 188 | + |
| 189 | +def infer_mask(seq, eos_ix, time_major=False, dtype=tf.float32): |
| 190 | + """ |
| 191 | + compute mask given output indices and eos code |
| 192 | + :param seq: tf matrix [time,batch] if time_major else [batch,time] |
| 193 | + :param eos_ix: integer index of end-of-sentence token |
| 194 | + :returns: mask, float32 matrix with '0's and '1's of same shape as seq |
| 195 | + """ |
| 196 | + axis = 0 if time_major else 1 |
| 197 | + lengths = infer_length(seq, eos_ix, time_major=time_major) |
| 198 | + mask = tf.sequence_mask(lengths, maxlen=tf.shape(seq)[axis], dtype=dtype) |
| 199 | + if time_major: mask = tf.transpose(mask) |
| 200 | + return mask |
| 201 | + |
| 202 | + |
| 203 | +def select_values_over_last_axis(values, indices): |
| 204 | + """ |
| 205 | + Auxiliary function to select logits corresponding to chosen tokens. |
| 206 | + :param values: logits for all actions: float32[batch,tick,action] |
| 207 | + :param indices: action ids int32[batch,tick] |
| 208 | + :returns: values selected for the given actions: float[batch,tick] |
| 209 | + """ |
| 210 | + assert values.shape.ndims == 3 and indices.shape.ndims == 2 |
| 211 | + batch_size, seq_len = tf.shape(indices)[0], tf.shape(indices)[1] |
| 212 | + batch_i = tf.tile(tf.range(0,batch_size)[:, None],[1,seq_len]) |
| 213 | + time_i = tf.tile(tf.range(0,seq_len)[None, :],[batch_size,1]) |
| 214 | + indices_nd = tf.stack([batch_i, time_i, indices], axis=-1) |
| 215 | + |
| 216 | + return tf.gather_nd(values,indices_nd) |
| 217 | + |
| 218 | + |
| 219 | + |
0 commit comments