Created
January 21, 2020 11:59
-
-
Save ashokpant/3de221c48cc0370bac1d25902ec34653 to your computer and use it in GitHub Desktop.
Character segmentation for dataset preparation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
-- Ashok Kumar Pant ([email protected]) | |
-- Treeleaf Technologies Pvt. Ltd. | |
-- Date: 1/18/20 | |
""" | |
import argparse | |
import os | |
from random import random | |
from traceback import print_exc | |
import cv2 | |
def _get_boxes(img): | |
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
(thresh, img_bin) = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY) | |
img_bin = 255 - img_bin | |
v_ker = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) | |
dilated = cv2.dilate(img_bin, v_ker) | |
contours, hierarchy = cv2.findContours(dilated, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
table_boxes = [] | |
for c in contours: | |
x, y, w, h = cv2.boundingRect(c) | |
bbox = [x, y, x + w, y + h] | |
if w * h >= 0.5 * (img.shape[0] * img.shape[1]): | |
continue | |
if w > 100 and h > 100: # Minimum area check | |
table_boxes.append(bbox) | |
return table_boxes | |
def _save_images(img, boxes, output_dir, prefix=""): | |
for i, bbox in enumerate(boxes): | |
l, t, r, b = bbox | |
chip = img[t:b, l:r] | |
cv2.imwrite(os.path.join(output_dir, prefix + "_" + str(i) + ".jpg"), chip) | |
def segment_characters(filename, output_dir, prefix=""): | |
try: | |
img = cv2.imread(filename) | |
boxes = _get_boxes(img) | |
os.makedirs(output_dir, exist_ok=True) | |
_save_images(img, boxes, output_dir, prefix=prefix) | |
return True | |
except Exception as _: | |
print_exc() | |
def _list_images(directory, shuffle=False): | |
""" | |
Generic function to list images from directory or file | |
:return: list of images in a given directory or file | |
""" | |
_ext = ['.jpg', '.jpeg', '.bmp', '.png', '.JPG', '.JPEG', '.ppm', '.pgm', '.webp'] | |
_images = [] | |
for subdir, dirs, files in os.walk(directory): | |
for file in files: | |
file_path = os.path.join(subdir, file) | |
if file_path.endswith(tuple(_ext)): | |
_images.append(file_path) | |
if shuffle: | |
random.shuffle(_images) | |
return _images | |
def _is_dir(path): | |
return os.path.isdir(path) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="Character segmentation") | |
parser.add_argument('--input', help="Input image file or directory", default=None) | |
parser.add_argument('--output', help="Output directory", default="tmp/") | |
parser.add_argument('--shuffle', help="Shuffle images", default=False) | |
args = parser.parse_args() | |
if args.input is None: | |
print("Input image file or directory to process!") | |
print("Example: python character_segmentation.py --input image.jpg --output tmp/") | |
https://github.com/ashokpant/ai_sapphire/blob/master/images/image.jpg | |
parser.print_help() | |
exit(0) | |
if _is_dir(args.input): | |
images = _list_images(args.input, args.shuffle) | |
else: | |
images = [args.input] | |
for i, filename in enumerate(images): | |
success = segment_characters(filename, args.output, prefix=str(i + 1)) | |
if not success: | |
print("Unable to process the file: {}".format(filename)) | |
if i > 0 and i % 10 == 0: | |
print("Processed {} files: ".format(i)) | |
print("Processed total {} files".format(len(images))) | |
print("Done.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment