Created
January 7, 2022 02:47
-
-
Save mattpopovich/0626d8d071014ff1fa45b76841fef4dc to your computer and use it in GitHub Desktop.
Issue #265 on zhiqwang / yolov5-rt-stack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Matt Popovich (mattpopovich.com) | |
# Date: January 6, 2022 | |
# yolort Release: 0.5.2 | |
# Except for a bit at the end, this is all copied from: | |
# https://github.com/zhiqwang/yolov5-rt-stack/blob/main/notebooks/how-to-align-with-ultralytics-yolov5.ipynb | |
import os | |
import cv2 | |
import torch | |
import sys | |
# sys.path.insert(0, "/home/mpopovich/git/yolov5-rt-stack") | |
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICES"]="0" | |
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
device = torch.device('cpu') | |
from yolort.models.yolo import YOLO | |
from yolort.utils import ( | |
cv2_imshow, | |
get_image_from_url, | |
read_image_to_tensor, | |
) | |
from yolort.utils.image_utils import plot_one_box, color_list | |
from yolort.v5 import load_yolov5_model, letterbox, non_max_suppression, scale_coords, attempt_download | |
# Set LABELS and COLORS | |
import requests | |
label_path = "https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/notebooks/assets/coco.names" | |
response = requests.get(label_path) | |
names = response.text | |
LABELS = [] | |
for label in names.strip().split('\n'): | |
LABELS.append(label) | |
COLORS = color_list() | |
# Get image | |
img_name = 'bus.jpg' | |
img_url = 'https://raw.githubusercontent.com/zhiqwang/yolov5-rt-stack/main/test/assets/' + img_name | |
if os.path.isfile(img_name): | |
print(img_name + " already downloaded!") | |
else: | |
attempt_download(img_url) | |
print("Downloaded " + img_name + " successfully!") | |
img_raw = cv2.imread(img_name) | |
# Preprocess | |
img = letterbox(img_raw, new_shape=(640, 640))[0] | |
img = read_image_to_tensor(img) | |
img = img.to(device) | |
version = "6.0" # "4.0" or "6.0" | |
rversion = "r" + version | |
model_url = "https://github.com/ultralytics/yolov5/releases/download/v" + version + "/yolov5s.pt" | |
full_model_name = "yolov5s-v" + version + ".pt" | |
# Download model from Ultralytics GitHub | |
if os.path.isfile(full_model_name): | |
print(full_model_name + " already downloaded!") | |
else: | |
checkpoint_path = attempt_download(model_url) | |
os.rename("yolov5s.pt", full_model_name) | |
print("Downloaded " + full_model_name + " successfully!") | |
# Load Ultralytics model | |
score_thresh = 0.30 | |
iou = 0.45 | |
model = load_yolov5_model(full_model_name, autoshape=False, verbose=True) | |
model = model.to(device) | |
model.conf = score_thresh # confidence threshold (0-1) | |
model.iou = iou # NMS IoU threshold (0-1) | |
model.classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for persons, cats and dogs | |
model = model.eval() | |
# Perform inference | |
with torch.no_grad(): | |
ultralytics_dets = model(img[None])[0] | |
ultralytics_dets = non_max_suppression(ultralytics_dets, score_thresh, iou, agnostic=True)[0] | |
scaled_ultralytics_dets = ultralytics_dets.clone() | |
print("Ultralytics detections:") | |
print(ultralytics_dets) | |
# Save Ultralytics inference image | |
boxes = scale_coords(img.shape[1:], scaled_ultralytics_dets[:,:4], img_raw.shape[:-1]) | |
labels = scaled_ultralytics_dets[:,5:] | |
for box, label in zip(boxes.tolist(), labels.tolist()): | |
img_raw = plot_one_box(box, img_raw, color=COLORS[int(label[0]) % len(COLORS)], label=LABELS[int(label[0])]) | |
cv2.imwrite(os.path.splitext(img_name)[0] + '-ultralytics-inference.jpg', img_raw) | |
# # Loading the trained checkpoint as instructed in: | |
# # https://github.com/zhiqwang/yolov5-rt-stack/issues/141#issuecomment-924221401 | |
# # This model is able to be scriptable: torch.jit.script(model) | |
# # But the inference results are empty | |
# from yolort.models import yolov5s | |
# model = yolov5s(upstream_version=rversion, score_thresh=score_thresh) | |
# model.load_from_yolov5(checkpoint_path=full_model_name, version=rversion) | |
# model.eval() | |
# results = model.predict(img_name) | |
# print("Results from loading model via model.load_from_yolov5:") | |
# print(results) | |
# Update model weights from Ultralytics to yolort | |
# According to [8] in how-to-align-with-ultralytics-yolov5.ipynb | |
model = YOLO.load_from_yolov5( | |
full_model_name, | |
score_thresh=score_thresh, | |
nms_thresh=iou, | |
version=rversion, | |
) | |
model.eval() | |
with torch.no_grad(): | |
yolort_dets = model(img[None]) | |
print(f"Detection boxes with yolort:\n{yolort_dets[0]['boxes']}") | |
print(f"Detection scores with yolort:\n{yolort_dets[0]['scores']}") | |
print(f"Detection labels with yolort:\n{yolort_dets[0]['labels']}") | |
# Verify the detection results between yolort and Ultralytics | |
# Testing boxes | |
torch.testing.assert_allclose( | |
yolort_dets[0]['boxes'], ultralytics_dets[:, :4], rtol=1e-05, atol=1e-07) | |
# Testing scores | |
torch.testing.assert_allclose( | |
yolort_dets[0]['scores'], ultralytics_dets[:, 4], rtol=1e-05, atol=1e-07) | |
# Testing labels | |
torch.testing.assert_allclose( | |
yolort_dets[0]['labels'], ultralytics_dets[:, 5].to(dtype=torch.int64), rtol=1e-05, atol=1e-07) | |
print("Exported model has been tested, and the result looks good!") | |
# Save yolort inference image | |
boxes = scale_coords(img.shape[1:], yolort_dets[0]['boxes'], img_raw.shape[:-1]) | |
labels = yolort_dets[0]['labels'] | |
for box, label in zip(boxes.tolist(), labels.tolist()): | |
img_raw = plot_one_box(box, img_raw, color=COLORS[label % len(COLORS)], label=LABELS[label]) | |
cv2.imwrite(os.path.splitext(img_name)[0] + '-yolort-inference.jpg', img_raw) | |
# Scripting YOLOv5, basically a copy of: | |
# https://github.com/zhiqwang/yolov5-rt-stack/blob/main/notebooks/inference-pytorch-export-libtorch.ipynb | |
# TorchScript export | |
print(f'Starting TorchScript export with torch {torch.__version__}...') | |
export_script_name = os.path.splitext(full_model_name)[0] + '-RT-v0.5.2.torchscript.pt' | |
model_script = torch.jit.script(model) # THIS FAILS | |
model_script.eval() | |
# Save the scripted model file for subsequent use (Optional) | |
model_script.save(export_script_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment