-
-
Save Coderx7/830ada70049ebbe7d7dc75101645b2f9 to your computer and use it in GitHub Desktop.
Caffe script to compute accuracy and confusion matrix, Added mean subtraction. it now accurately reports the accuracy (just like caffe)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# -*- coding: utf-8 -*- | |
# Author: Axel Angel, copyright 2015, license GPLv3. | |
# added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction | |
# Seyyed Hossein Hasan Pour | |
# [email protected] | |
# 7/3/2016 | |
import sys | |
import caffe | |
import numpy as np | |
import lmdb | |
import argparse | |
from collections import defaultdict | |
def flat_shape(x): | |
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)" | |
return x.reshape(filter(lambda s: s > 1, x.shape)) | |
def lmdb_reader(fpath): | |
import lmdb | |
lmdb_env = lmdb.open(fpath) | |
lmdb_txn = lmdb_env.begin() | |
lmdb_cursor = lmdb_txn.cursor() | |
for key, value in lmdb_cursor: | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def leveldb_reader(fpath): | |
import leveldb | |
db = leveldb.LevelDB(fpath) | |
for key, value in db.RangeIter(): | |
datum = caffe.proto.caffe_pb2.Datum() | |
datum.ParseFromString(value) | |
label = int(datum.label) | |
image = caffe.io.datum_to_array(datum).astype(np.uint8) | |
yield (key, flat_shape(image), label) | |
def npz_reader(fpath): | |
npz = np.load(fpath) | |
xs = npz['arr_0'] | |
ls = npz['arr_1'] | |
for i, (x, l) in enumerate(np.array([ xs, ls ]).T): | |
yield (i, x, l) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--proto', type=str, required=True) | |
parser.add_argument('--model', type=str, required=True) | |
parser.add_argument('--mean', type=str, required=True) | |
group = parser.add_mutually_exclusive_group(required=True) | |
group.add_argument('--lmdb', type=str, default=None) | |
group.add_argument('--leveldb', type=str, default=None) | |
group.add_argument('--npz', type=str, default=None) | |
args = parser.parse_args() | |
# Extract mean from the mean image file | |
mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto() | |
f = open(args.mean, 'rb') | |
mean_blobproto_new.ParseFromString(f.read()) | |
mean_image = caffe.io.blobproto_to_array(mean_blobproto_new) | |
f.close() | |
count = 0 | |
correct = 0 | |
matrix = defaultdict(int) # (real,pred) -> int | |
labels_set = set() | |
# CNN reconstruction and loading the trained weights | |
net = caffe.Net(args.proto, args.model, caffe.TEST) | |
caffe.set_mode_cpu() | |
print "args", vars(args) | |
if args.lmdb != None: | |
reader = lmdb_reader(args.lmdb) | |
if args.leveldb != None: | |
reader = leveldb_reader(args.leveldb) | |
if args.npz != None: | |
reader = npz_reader(args.npz) | |
for i, image, label in reader: | |
image_caffe = image.reshape(1, *image.shape) | |
out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image) | |
plabel = int(out['prob'][0].argmax(axis=0)) | |
count += 1 | |
iscorrect = label == plabel | |
correct += (1 if iscorrect else 0) | |
matrix[(label, plabel)] += 1 | |
labels_set.update([label, plabel]) | |
if not iscorrect: | |
print("\rError: i=%s, expected %i but predicted %i" \ | |
% (i, label, plabel)) | |
sys.stdout.write("\rAccuracy: %.1f%%" % (100.*correct/count)) | |
sys.stdout.flush() | |
print(", %i/%i corrects" % (correct, count)) | |
print "" | |
print "Confusion matrix:" | |
print "(r , p) | count" | |
for l in labels_set: | |
for pl in labels_set: | |
print "(%i , %i) | %i" % (l, pl, matrix[(l,pl)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for your amazing work.
I am using your script but I get stuck at:
I1119 16:59:51.464190 12772 net.cpp:283] Network initialization done.
args{'proto': 'test.prototxt', 'model': 'models/caffenet_age_train_iter_50000.caffemodel', 'mean': 'mean_image/mean.binaryproto', 'lmdb': 'lmdb_full/age_test_lmdb/', 'leveldb': None, 'npz': None}
Traceback (most recent call last):
File "Caffe_Convnet_ConfuxionMatrix.py", line 85, in
for i, image, label in reader:
File "Caffe_Convnet_ConfuxionMatrix.py", line 29, in lmdb_reader
yield (key, flat_shape(image), label)
File "Caffe_Convnet_ConfuxionMatrix.py", line 16, in flat_shape
return x.reshape(filter(lambda s: s > 1, x.shape))
TypeError: expected sequence object with len >= 0 or a single integer
Can you suggest me a solution. Thanks.