This code is modified from mxnet-image-caption.
What do you see in the below picture?
Well, some of you might say "A man is surfing on a wave", some might say "Surfer in the ocean riding a large wave", and yet some others might say "A man is riding a wave in the ocean". All of these answers are saying the same thing with different words and different word orders. Such a task that prodices a caption given an images is called image captioning. There are several existing online services providing this API, such as CaptionBot powered by Microsoft Cognitive Service.
This code implements the paper, Show and Tell: A Neural Image Caption Generator and the model is trained on MSCOCO dataset. Here, we only use the dataset of a set of images each with 1 caption and their feature vectors ($1\times2048$).
import os
import cv2
import sys
import time
import urllib
import pickle
import logging
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
import mxnet as mx
from mxnet.io import DataBatch
import types
def imports():
for name, val in globals().items():
if isinstance(val, types.ModuleType):
yield val.__name__, val.__version__ if hasattr(val, '__version__') else 'NaN'
for name, version in list(imports()):
print(name, version)
!wget -nc https://www.dropbox.com/s/vgtvff6zt15k6m4/captiondata10k.pickle?dl=1 --output-document captiondata10k.pickle
!wget -nc https://www.dropbox.com/s/dnsre7x5dxhxlfw/captiondataval10k.pickle?dl=1 --output-document captiondataval10k.pickle
allwords contains all the feature vectors and their corresponding 1 caption.
Image feature vector is the output of the last layer before the fully connectted layer of ResNet-50 in Mxnet Model Zoo. There are total 16 residual blocks, 1 shortcut connection and 3 convolution layers in each block as shown below.
A caption is a list of word indices. < S > is the start of a sentence and < \S > is the end of a sentence.
allindexes is the indices of all images.
vocabwords is the dictionary from word to id.
vocabids is the dictionary from id to word.
with open('captiondata10k.pickle', 'rb') as f:
[allwords, allindexes, vocabwords, vocabids, _] = pickle.load(f, encoding='latin1')
key = list(allwords.keys())[0]
print('feature vector: ', allwords[key][0])
print('caption indexes: ', allwords[key][1])
print('imgid: ', allindexes[key])
key = list(vocabwords.keys())[0]
print('word -> index: ', key, '->', vocabwords[key])
key = list(vocabids.keys())[0]
print('index -> word: ', key, '->', vocabids[key])
words = []
for key in allwords[key][1]:
words.append(vocabids[key])
print('caption word: ', words)
i2h is the fully connected layer from input $x_t$ to the gates.
h2h is the fully connected layer from previous hidden state $h_{t-1}$ to the gates.
There are 4 gates (forget gate forget_gate, input gate in_gate, output gate out_gate, and candidate memory in_transformation).
This is an illustration of a LSTM cell.
'''
module that defines lstm network that is used for image captioning
'''
LSTMState = namedtuple("LSTMState", ["c", "h"])
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
"h2h_weight", "h2h_bias"])
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
"init_states", "last_states",
"forward_state", "backward_state",
"seq_data", "seq_labels", "seq_outputs",
"param_blocks"])
def lstmcell(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.3):
'''
Defines an LSTM cell
Args:
num_hidden: number of hidden units
indata: input data to LSTM cell
prev_state: previous state vector
param: parameter for this LSTM (weights and biases)
seqidx: sequence id
layeridx: layer index (0 - first layer, 1 - second layer). Useful for
bi-directional LSTM
dropout: fraction of the input that gets dropped out during training
time
Returns:
LSTM cell object
'''
indata = mx.sym.Dropout(data=indata, p=dropout)
i2h = mx.sym.FullyConnected(data=indata,
weight=param.i2h_weight,
bias=param.i2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_i2h" % (seqidx, layeridx))
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.h2h_weight,
bias=param.h2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_h2h" % (seqidx, layeridx))
gates = i2h + h2h
slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
name="t%d_l%d_slice" %
(seqidx, layeridx))
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
next_h = out_gate * next_c
return LSTMState(c=next_c, h=next_h)
def build_lstm_network(seq_len, input_size, num_hidden, num_embed, num_label,
prediction=False):
'''
Build the LSTM network
Args:
seq_len: length of the sequence - number of times to unroll
input_size: input vector dimension
num_hidden: number of hiddent units
num_embed: output dimension for the embedding unit
num_label: output dimension for the fully-connected unit
prediction: True if used for prediction, False if for training
Returns:
LSTM network symbol
'''
embed_weight = mx.sym.Variable("embed_weight")
cls_weight = mx.sym.Variable("cls_weight")
cls_bias = mx.sym.Variable("cls_bias")
# input image feature vector
data = mx.sym.Variable('data')
# word indices
label = mx.sym.Variable('softmax_label')
# one-hot encoding of word indices
veclabel = mx.sym.Variable('veclabel')
# veclabel = mx.sym.Reshape(veclabel, shape=(-1, seq_len)) # https://github.com/apache/incubator-mxnet/issues/7178
name = 'l0'
param = LSTMParam(i2h_weight=mx.sym.Variable(name+"_i2h_weight"),
i2h_bias=mx.sym.Variable(name+"_i2h_bias"),
h2h_weight=mx.sym.Variable(name+"_h2h_weight"),
h2h_bias=mx.sym.Variable(name+"_h2h_bias"))
lstm_state = LSTMState(c=mx.sym.Variable(name+"_init_c"),
h=mx.sym.Variable(name+"_init_h"))
allsm = []
# label indices
labelidx = mx.sym.SliceChannel(data=label, num_outputs=seq_len,
squeeze_axis=1)
# label one-hot vector
labelvec = mx.sym.SliceChannel(data=veclabel, num_outputs=seq_len,
squeeze_axis=1)
output = ''
targetlen = seq_len
if prediction:
# increase seq_len to generate till stop words during testing
# it is a hack for now
targetlen = seq_len + 10
for seqidx in range(targetlen):
k = seqidx
# testing may use more than seq_len, hence reuse the last input
# as dummy labels for softmax
if k >= seq_len:
k = seq_len - 1
# first iteration use image feature as input
if k == 0:
hidden = data
else:
# if in prediction mode and not in first iteration use the
# system output generated in previous timestep as input
if prediction & (k > 1):
embed = mx.sym.Embedding(data=output,
input_dim=input_size,
weight=embed_weight,
output_dim=num_embed, name='embed')
else:
embed = mx.sym.Embedding(data=labelvec[k-1],
input_dim=input_size,
weight=embed_weight,
output_dim=num_embed, name='embed')
hidden = embed
next_state = lstmcell(num_hidden, indata=hidden,
prev_state=lstm_state,
param=param, seqidx=k, layeridx=0)
hidden = next_state.h
lstm_state = next_state
if k == 0:
continue
pred = mx.sym.FullyConnected(data=hidden, num_hidden=num_label,
weight=cls_weight,
bias=cls_bias, name='pred')
softmax_output = mx.sym.SoftmaxOutput(data=pred, label=labelidx[k],
name='softmax')
output = mx.sym.argmax(softmax_output, axis=1)
allsm.append(softmax_output)
allsm = mx.sym.Concat(*allsm, dim=1)
softmax_output = mx.sym.reshape(allsm, shape=(-1, num_label))
return (softmax_output,
['veclabel', 'l0_init_h', 'l0_init_c', 'data'],
['softmax_label'])
When training a recurrent neural network (RNN), we unroll the network in time. For a single example of length T, we would unroll the network T steps. In the unrolled view, the weights are shared across times steps. The unrolled view allows us to train the network via backpropagation (backpropagation through time).
However, there are varying lengths of sequences in a dataset. In the unrolled view, each example requires a different number of unrollings. If we want to perform mini-batch training, we will have to pad all the sequences so they have the same length as the longest example. This could be wastful bacause on shorter sequences, most of the computations are done on padded data.
Bucketing offers an effective solution to make minibatches out of varying-length sequences. Instead of unrolling the network to the maximum possible sequence length, we unroll multiple instances of different lengths (e.g., length 5, 10, 20, 30).
The function default_gen_buckets will generate a list of buckets based on the set of captions in the input. This list will contain the sizes of each bucket.
'''
module that defines bucketing data iter
'''
def default_gen_buckets(allwords, batch_size):
'''
Generate buckets based on data. This method generates a list of buckets
and the length of those buckets based on the input
Args:
allwords: all the sentences (set of words) that are part of the data
batch_size: batch size to check if a particular bucket has that many
elements
Returns:
returns the generated buckets
'''
len_dict = {}
max_len = -1
for key in allwords:
words = allwords[key][1]
if len(words) == 0:
continue
if len(words) > max_len:
max_len = len(words)
if len(words) in len_dict:
len_dict[len(words)] += 1
else:
len_dict[len(words)] = 1
buckets = []
for length, num in len_dict.items():
if num >= batch_size:
buckets.append(length)
return buckets
class BucketIter(mx.io.DataIter):
'''
Class that defines the data iter for image captioning module
'''
def __init__(self, captionf, batch_size=1):
'''
Init function for the class
Args:
captionf: pickle filename that has all the captions
batch_size: batch size for training data
'''
super(BucketIter, self).__init__()
self.batch_size = batch_size
# load datafiles, ignore the second output which is just the img ids
# for each data element - used in case we need access to img id during
# testing
[self.allwords, _, self.vocabwords, self.vocabids, \
self.unknown_id] = pickle.load(open(captionf, 'rb'), encoding='latin1')
# generate buckets
buckets = default_gen_buckets(self.allwords, batch_size)
buckets.sort()
self.buckets = buckets
# assing default bucket - ideally should be the largest bucket
self.default_bucket_key = max(buckets)
# Assign data to their corresponding bucket
self.databkt = [[] for _ in buckets]
self.cursor = {}
self.num_data_bkt = {}
for idx in self.allwords:
strs = self.allwords[idx][1]
for i, bkt in enumerate(buckets):
if bkt == len(strs):
self.databkt[i].append(idx)
break
# initialize bucket specific parameters, the current index into
# the bucket and the remaining number of elements in the bucket
for i, bkt in enumerate(buckets):
self.cursor[i] = -1
self.num_data_bkt[i] = len(self.databkt[i])
# iterator variables
self.epoch = 0
self.bidx = np.argmax(buckets)
self.data, self.label = self.read(self.bidx)
self.reset()
@property
def bucket_key(self):
'''
bucket key for bucketiter module
'''
return self.buckets[self.bidx]
@property
def provide_data(self):
"""The name and shape of data provided by this iterator"""
# res = [(k, tuple(list(self.data[k].shape[0:]))) for k in self.data]
res = [('veclabel', tuple(list(self.data['veclabel'].shape[0:]))),
('l0_init_h', tuple(list(self.data['l0_init_h'].shape[0:]))),
('l0_init_c', tuple(list(self.data['l0_init_c'].shape[0:]))),
('data', tuple(list(self.data['data'].shape[0:])))]
return res
@property
def provide_label(self):
"""The name and shape of label provided by this iterator"""
res = [(k, tuple(list(self.label[k].shape[0:]))) for k in self.label]
return res
def reset(self):
'''
data iter reset
'''
for index, _ in enumerate(self.cursor):
self.cursor[index] = -1
self.epoch += 1
def next(self):
"""return one dict which contains "data" and "label" """
if self.iter_next():
# select one random bucket out of all the ones that has
# > batch_size remaining samples
rem = [i for i, _ in enumerate(self.buckets)
if (len(self.databkt[i])-self.cursor[i]) > self.batch_size]
bidx = np.random.randint(0, len(rem))
bidx = rem[bidx]
# read the samples from the bucket
self.data, self.label = self.read(bidx)
# prepare as databatch to return
res = DataBatch(provide_data=self.provide_data,
provide_label=self.provide_label,
bucket_key=self.buckets[self.bidx],
data=[mx.nd.array(self.data['veclabel']),
mx.nd.array(self.data['l0_init_h']),
mx.nd.array(self.data['l0_init_c']),
mx.nd.array(self.data['data'])],
label=[mx.nd.array(self.label['softmax_label'])],
pad=0, index=None)
return res
else:
raise StopIteration
def iter_next(self):
'''
check if next iteration can be done
'''
for i, _ in enumerate(self.buckets):
if self.cursor[i] + self.batch_size < self.num_data_bkt[i]:
return True
return False
def read(self, bidx):
'''
read the next set of data based on bucket index
Args:
bidx: bucket index
'''
self.bidx = bidx
data_array = []
allimgids = []
label = []
labelvec = []
index = 0
while 1:
self.cursor[bidx] += 1
# obtain the feature vector
data = self.get_data(bidx)
# obtain the label (caption word indices)
labels = self.get_label(bidx)
if len(labels) == 0:
continue
data_array.append(data)
# construct a one-hot vector of labels
labela = []
labelveca = []
for labelidx in labels:
labelarray = np.zeros((len(self.vocabwords)+1), dtype='int')
labelarray[labelidx] = 1
labela.append(labelarray)
labelveca.append(labelidx)
label.append(labela)
labelvec.append(labelveca)
index += 1
if index > (self.batch_size-1):
break
darray = np.vstack(data_array)
# this is also defined in training code - need to set the same number
# of hidden units for the LSTM
num_hidden = 512
data = {}
data['l0_init_h'] = np.zeros((darray.shape[0], num_hidden),
dtype='float')
data['l0_init_c'] = np.zeros((darray.shape[0], num_hidden),
dtype='float')
data['data'] = darray
data['veclabel'] = np.array(labelvec)
finallabel = {}
finallabel['softmax_label'] = np.asarray(label)
return (data, finallabel)
def get_data(self, bidx):
'''
Returns the feature vector based on the current cursor
and bucket index
Args:
bidx: bucket index
'''
idx = self.databkt[bidx][self.cursor[bidx]]
return self.allwords[idx][0]
def get_label(self, bidx):
'''
Returns the label vector based on the current cursor
and bucket index
Args:
bidx: bucket index
'''
idx = self.databkt[bidx][self.cursor[bidx]]
return self.allwords[idx][1]
'''
main code for training image captioning model
'''
DEBUG = True
def custommetric(label, pred):
'''
Simple metric that outputs the fraction of correct word predictions
to the total number of words
Args:
label: ground truth label
pred: predicted output
Returns:
accuracy metric
'''
# shift by one word to match prediction
label = label[:, 1:, :]
pred = np.reshape(pred, label.shape)
label = np.argmax(label, axis=2)
pred = np.argmax(pred, axis=2)
return float(np.sum(pred == label)) / np.sum(label >= 0)
BATCH_SIZE = 192
NUM_HIDDEN = 512
NUM_EPOCH = 96
GENERATE_GRAPH = False
DATADIR = '.'
data_train = BucketIter(DATADIR+'/captiondata10k.pickle',
batch_size=BATCH_SIZE)
if DEBUG:
print('training data loaded ....')
print('provide data', data_train.provide_data, 'provide label', data_train.provide_label, 'default bucket key', data_train.default_bucket_key)
for i in range(len(data_train.provide_data)): # TODO: returned list didn't guarantee sequence
if data_train.provide_data[i][0] == 'data':
INPUT_SIZE = data_train.provide_data[i][1][1]
EMBED_SIZE = data_train.provide_data[i][1][1]
break
NUM_LABEL = len(data_train.vocabwords)+1
if DEBUG:
print('input size', INPUT_SIZE, 'number label', NUM_LABEL, 'embed size', EMBED_SIZE)
data_val = BucketIter(DATADIR+'/captiondataval10k.pickle',
batch_size=BATCH_SIZE)
if DEBUG:
print('validation data loaded ....')
contexts = mx.gpu(3) # [mx.context.gpu(i) for i in range(1)]
# this is needed for the bucketing module
def sym_gen(seq_len):
'''
needed for bucketing module, network generated based on seq_len
Args:
seq_len: length of the current sequence
Returns:
Symbolic network
'''
return build_lstm_network(seq_len, INPUT_SIZE, NUM_HIDDEN, EMBED_SIZE,
NUM_LABEL)
model = mx.mod.BucketingModule(sym_gen, data_train.default_bucket_key,
context=contexts)
model.bind(data_shapes=data_train.provide_data,
label_shapes=data_train.provide_label)
model.init_params(initializer=mx.init.Normal(sigma=0.01))
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.INFO, format=head)
start_time = time.time()
if not os.path.exists('models'):
os.mkdir('models')
CHECKPOINT_NAME = 'models/imagecaption'
model.fit(data_train, data_val, num_epoch=NUM_EPOCH, optimizer='adam',
optimizer_params={'learning_rate': 1e-3},
eval_metric=mx.metric.CustomMetric(custommetric),
epoch_end_callback=mx.callback.do_checkpoint(CHECKPOINT_NAME))
print('program run time', time.time() - start_time)
'''
module that generates features for images using resnet
'''
Batch = namedtuple('Batch', ['data'])
def download(url):
'''
download the file given the url
Args:
url: path for the filename
'''
filename = url.split("/")[-1]
if not os.path.exists(filename):
urllib.request.urlretrieve(url, filename)
def get_model(prefix, epoch):
'''
get the model with prefix and epoch
Args:
prefix: model prefix
epoch: trained model - epoch
'''
download(prefix+'-symbol.json')
download(prefix+'-%04d.params' % (epoch,))
def get_image(filename):
'''
return the image based on filename after resizing it to 224x224 to be
fit for reset format
Args:
filename: filename of the image
'''
img = cv2.imread(filename) # read image in b,g,r order
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # change to r,g,b order
img = cv2.resize(img, (224, 224)) # resize to 224*224 to fit model
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2) # change to (channel, height, width)
img = img[np.newaxis, :] # extend to (example, channel, heigth, width)
return img
class Resnet(object):
'''
Resnet class to construct reset model object
'''
def __init__(self):
'''
Download the model from mxnet database and constructs the network
for prediction
'''
url = 'http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-50'
get_model(url, 0)
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50', 0)
all_layers = sym.get_internals()
sym3 = all_layers['flatten0_output']
mod3 = mx.mod.Module(symbol=sym3, label_names=None, context=mx.cpu())
mod3.bind(for_training=False, data_shapes=[('data', (1, 3, 224, 224))])
mod3.set_params(arg_params, aux_params)
self.mod3 = mod3
def gen_features(self, img_path):
'''
generate features given an image
Args:
img_path: full path to the image
'''
img = get_image(img_path)
self.mod3.forward(Batch([mx.nd.array(img)]))
return self.mod3.get_outputs()[0].asnumpy()
def get_feature(imgfname):
'''
Returns the feature for the image
Args:
imagefname: full path to the image
'''
network = Resnet()
return network.gen_features(imgfname)
'''
generate captions for an image
'''
%matplotlib inline
NUM_LSTM_LAYER = 1
BATCH_SIZE = 1
! wget -nc https://www.dropbox.com/s/41kfb3ezigssa9q/testimage.jpg?dl=1 --output-document testimage.jpg
imgfname = "testimage.jpg"
SEQ_LEN = 25
sym, arg_params, aux_params = \
mx.model.load_checkpoint(CHECKPOINT_NAME, NUM_EPOCH)
NUM_HIDDEN = arg_params['l0_h2h_weight'].shape[1]
INPUT_SIZE = arg_params['l0_h2h_weight'].shape[0]
NUM_LABEL = arg_params['cls_weight'].shape[0]
sym, _, _ = build_lstm_network(SEQ_LEN, INPUT_SIZE, NUM_HIDDEN,
INPUT_SIZE, NUM_LABEL, prediction=True)
init_c = [('l%d_init_c' % l, (BATCH_SIZE, NUM_HIDDEN))
for l in range(NUM_LSTM_LAYER)]
init_h = [('l%d_init_h' % l, (BATCH_SIZE, NUM_HIDDEN))
for l in range(NUM_LSTM_LAYER)]
data_shape = [("data", (BATCH_SIZE, 2048))]
label_shape = [("veclabel",
(BATCH_SIZE, SEQ_LEN, ))]
label_shape1 = [("softmax_label",
(BATCH_SIZE, SEQ_LEN, NUM_LABEL))]
f = get_feature(imgfname)
input_data = mx.nd.array(f)
veclabel = mx.nd.zeros((BATCH_SIZE, SEQ_LEN))
veclabel[0][0] = 0
input_shapes = dict(init_c+init_h+data_shape+label_shape+label_shape1)
executor = sym.simple_bind(ctx=mx.gpu(), **input_shapes)
for key in executor.arg_dict.keys():
if key in arg_params:
arg_params[key].copyto(executor.arg_dict[key])
state_name = []
for i in range(NUM_LSTM_LAYER):
state_name.append("l%d_init_c" % i)
states_dict = dict(zip(state_name, executor.outputs[1:]))
input_arr = mx.nd.zeros(data_shape[0][1])
for key in states_dict.keys():
executor.arg_dict[key][:] = 0.
input_data.copyto(executor.arg_dict["data"])
veclabel.copyto(executor.arg_dict["veclabel"])
executor.forward()
for key in states_dict.keys():
states_dict[key].copyto(executor.arg_dict[key])
prob = executor.outputs[0].asnumpy()
img = cv2.imread(imgfname)[:,:,::-1]
plt.imshow(img)
# [_, _, _, vocab, _] = pickle.load(open(VOCABF, 'r'))
for index in range(0, BATCH_SIZE):
p = np.reshape(prob, (-1, SEQ_LEN+9, len(data_train.vocabwords)+1))
p = np.argmax(p, axis=2)[index, :]
str1 = ''
index = 0
for i in p:
if i == 2:
break
str1 = str1 + data_train.vocabids[i] + ' '
index += 1
print(str1)