taneemishere's picture
added flagged functionaly, removed printing model summary
5ee6650
raw
history blame
2 kB
from __future__ import print_function
from __future__ import absolute_import
__author__ = 'Taneem Jan, taneemishere.github.io'
from .Vocabulary import *
from .Utils import *
class Sampler:
def __init__(self, voc_path, input_shape, output_size, context_length):
self.voc = Vocabulary()
self.voc.retrieve(voc_path)
self.input_shape = input_shape
self.output_size = output_size
print("Vocabulary size: {}".format(self.voc.size))
print("Input shape: {}".format(self.input_shape))
print("Output size: {}".format(self.output_size))
self.context_length = context_length
def predict_greedy(self, model, input_img, require_sparse_label=True, sequence_length=150, verbose=False):
current_context = [self.voc.vocabulary[PLACEHOLDER]] * (self.context_length - 1)
current_context.append(self.voc.vocabulary[START_TOKEN])
if require_sparse_label:
current_context = Utils.sparsify(current_context, self.output_size)
predictions = START_TOKEN
out_probas = []
for i in range(0, sequence_length):
if verbose:
print("predicting {}/{}...".format(i, sequence_length))
probas = model.predict(input_img, np.array([current_context]))
prediction = np.argmax(probas)
out_probas.append(probas)
new_context = []
for j in range(1, self.context_length):
new_context.append(current_context[j])
if require_sparse_label:
sparse_label = np.zeros(self.output_size)
sparse_label[prediction] = 1
new_context.append(sparse_label)
else:
new_context.append(prediction)
current_context = new_context
predictions += self.voc.token_lookup[prediction]
if self.voc.token_lookup[prediction] == END_TOKEN:
break
return predictions, out_probas