欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

YOLACT pytorch模型转tensorflow savedModel格式

程序员文章站 2022-06-26 15:29:15
...

YOLACT: pytorch模型 -> onnx -> tensorflow savedModel模型。

pytorch源码需要修改部分代码才能转onnx,其中decode和nms需要用tf自己在实现。

下面代码是decode和nms以及转savedModel格式:

import time
import cv2
import tensorflow as tf
import numpy as np

MEANS = np.array([103.94, 116.78, 123.68])[None, :, None, None]
STD = np.array([57.38, 57.12, 58.40])[None, :, None, None]


def crop(pred, boxes):
    pred_shape = tf.shape(pred)
    w = tf.cast(tf.range(pred_shape[1]), tf.float32)
    h = tf.expand_dims(tf.cast(tf.range(pred_shape[2]), tf.float32), axis=-1)

    rows = tf.broadcast_to(w, pred_shape) / tf.cast(pred_shape[1], tf.float32)
    cols = tf.broadcast_to(h, pred_shape) / tf.cast(pred_shape[2], tf.float32)

    ymin = tf.broadcast_to(tf.reshape(boxes[:, 0], [-1, 1, 1]), pred_shape)
    xmin = tf.broadcast_to(tf.reshape(boxes[:, 1], [-1, 1, 1]), pred_shape)
    ymax = tf.broadcast_to(tf.reshape(boxes[:, 2], [-1, 1, 1]), pred_shape)
    xmax = tf.broadcast_to(tf.reshape(boxes[:, 3], [-1, 1, 1]), pred_shape)

    mask_left = (rows >= xmin)
    mask_right = (rows <= xmax)
    mask_bottom = (cols >= ymin)
    mask_top = (cols <= ymax)

    crop_mask = tf.math.logical_and(tf.math.logical_and(mask_left, mask_right),
                                    tf.math.logical_and(mask_bottom, mask_top))
    crop_mask = tf.cast(crop_mask, tf.float32)

    return pred * crop_mask


# conf_preds [1, 2+1, 19248]  mask_data [1, 19248, 32]  decoded_boxes [19248, 4]  proto_data [138, 138, 32]
def detect(batch_idx, conf_preds, mask_data, decoded_boxes, proto_data, conf_thresh=0.15, nms_thresh=0.5, top_k=100):
    cur_scores = conf_preds[batch_idx, 1:, :]
    conf_scores = tf.math.reduce_max(cur_scores, axis=0)
    conf_score_id = tf.argmax(cur_scores, axis=0)
    keep = tf.squeeze(tf.where(conf_scores > conf_thresh))

    if tf.size(keep) == 0:
        return None
    scores = tf.gather(conf_scores, keep)
    boxes = tf.gather(decoded_boxes, keep)  # 获取符合conf阈值的bbox
    masks = tf.gather(mask_data[batch_idx], keep)  # 获取符合阈值的mask  coefficient
    classes = tf.gather(conf_score_id, keep)
    selected_indices = tf.image.non_max_suppression(boxes, scores, top_k, nms_thresh)
    boxes = tf.gather(boxes, selected_indices)
    scores = tf.gather(scores, selected_indices)
    masks = tf.gather(masks, selected_indices)
    classes = tf.gather(classes, selected_indices)

    masks = tf.linalg.matmul(proto_data, masks, transpose_a=False, transpose_b=True)
    masks = tf.nn.sigmoid(masks)
    masks = tf.transpose(masks, perm=(2, 0, 1))
    masks = crop(masks, boxes)

    masks = tf.image.resize(tf.expand_dims(masks, axis=-1), [550, 550], method="bilinear")
    masks = tf.cast(masks + 0.5, tf.int32)
    # masks = tf.squeeze(tf.cast(masks, tf.float32))
    return boxes, masks, scores, classes


def decode(loc, priors):
    variances = [0.1, 0.2]
    cxy = priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:]  # 按照prior进行偏移, 获得中心点坐标
    wh = priors[:, 2:] * tf.exp(loc[:, 2:] * variances[1])  # 获得wh
    y1 = cxy[:, 1] - wh[:, 1] / 2
    x1 = cxy[:, 0] - wh[:, 0] / 2
    y2 = cxy[:, 1] + wh[:, 1] / 2
    x2 = cxy[:, 0] + wh[:, 0] / 2
    boxes = tf.stack((y1, x1, y2, x2), 1)

    return boxes


def create():
    with tf.compat.v1.Session() as sess:
        output_graph_def = tf.compat.v1.GraphDef()
        with open("./output.pb", "rb") as f:
            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")
        input = sess.graph.get_tensor_by_name("input.1:0")
        proto_data = sess.graph.get_tensor_by_name("1090:0")
        priors = sess.graph.get_tensor_by_name("1289:0")
        boxes_ = sess.graph.get_tensor_by_name("1286:0")
        conf_preds = sess.graph.get_tensor_by_name("1290:0")
        mask_data = sess.graph.get_tensor_by_name("1288:0")
        decoded_boxes = decode(boxes_[0], priors)
        conf_preds = tf.transpose(conf_preds, [0, 2, 1])
        boxes, masks, scores, classes = detect(0, conf_preds, mask_data, decoded_boxes, proto_data[0])
        tf.compat.v1.saved_model.simple_save(sess, "../output/savedmodel/", inputs={"input": input},
                                             outputs={"output0": boxes, "output1": masks, "output2": scores,
                                                      "output3": classes})
        print("create savedmodel files success!")


def run_savedmode():
    with tf.compat.v1.Session() as sess:
        meta_graph_def = tf.compat.v1.saved_model.loader.load(sess,
                                                              [tf.compat.v1.saved_model.tag_constants.SERVING],
                                                              "../output/savedmodel")
        signature = meta_graph_def.signature_def

        # get tensor name
        in_tensor_name = signature['serving_default'].inputs['input'].name
        boxes = signature['serving_default'].outputs['output0'].name
        masks = signature['serving_default'].outputs['output1'].name
        scores = signature['serving_default'].outputs['output2'].name
        classes = signature['serving_default'].outputs['output3'].name

        input_ = sess.graph.get_tensor_by_name(in_tensor_name)
        boxes = sess.graph.get_tensor_by_name(boxes)
        masks = sess.graph.get_tensor_by_name(masks)
        scores = sess.graph.get_tensor_by_name(scores)
        classes = sess.graph.get_tensor_by_name(classes)
        img_roi = cv2.imread("../output/samples/bus3.jpg")
        img_roi = cv2.resize(img_roi, (550, 550))
        img = img_roi.astype(np.float32)
        img = np.transpose(img, [2, 0, 1])
        img = np.expand_dims(img, 0)
        img = (img - MEANS) / STD
        img = img[:, (2, 1, 0), :, :]

        for _ in range(2):
            start_ = time.time()
            b, m, s, c = sess.run([boxes, masks, scores, classes], feed_dict={input_: img})
            print("run time:", time.time() - start_)
            for i in range(m.shape[0]):
                cv2.imshow(f"mask{i}", m[i].astype(np.uint8) * 255)
            for one in b:
                x1 = int(one[1] * 550)
                y1 = int(one[0] * 550)
                x2 = int(one[3] * 550)
                y2 = int(one[2] * 550)
                img_roi = cv2.rectangle(img_roi, (x1, y1), (x2, y2), [255, 0, 0], 2)
            cv2.imshow("box", img_roi)
            cv2.waitKey(0)



# 1. onnx format to tf format
# onnx-tf convert -i ./YOLACT.onnx -o ./output.pb
# 2. create tf model code to savedmodel format
# create()
# 3. test savedmodel format model
run_savedmode()