tensorlayer3/predict.py

105 lines
4.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#----------------------------------------------------#
# 对视频中的predict.py进行了修改
# 将单张图片预测、摄像头检测和FPS测试功能
# 整合到了一个py文件中通过指定mode进行模式的修改。
#----------------------------------------------------#
import time
import cv2
import numpy as np
import tensorflow as tf
from PIL import Image
from frcnn import FRCNN
# gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
# for gpu in gpus:
# tf.config.experimental.set_memory_growth(gpu, True)
if __name__ == "__main__":
frcnn = FRCNN()
#-------------------------------------------------------------------------#
# mode用于指定测试的模式
# 'predict'表示单张图片预测
# 'video'表示视频检测
# 'fps'表示测试fps
#-------------------------------------------------------------------------#
mode = "predict"
#-------------------------------------------------------------------------#
# video_path用于指定视频的路径当video_path=0时表示检测摄像头
# video_save_path表示视频保存的路径当video_save_path=""时表示不保存
# video_fps用于保存的视频的fps
# video_path、video_save_path和video_fps仅在mode='video'时有效
# 保存视频时需要ctrl+c退出才会完成完整的保存步骤不可直接结束程序。
#-------------------------------------------------------------------------#
video_path = 0
video_save_path = ""
video_fps = 25.0
if mode == "predict":
'''
1、该代码无法直接进行批量预测如果想要批量预测可以利用os.listdir()遍历文件夹利用Image.open打开图片文件进行预测。
具体流程可以参考get_dr_txt.py在get_dr_txt.py即实现了遍历还实现了目标信息的保存。
2、如果想要进行检测完的图片的保存利用r_image.save("img.jpg")即可保存直接在predict.py里进行修改即可。
3、如果想要获得预测框的坐标可以进入frcnn.detect_image函数在绘图部分读取topleftbottomright这四个值。
4、如果想要利用预测框截取下目标可以进入frcnn.detect_image函数在绘图部分利用获取到的topleftbottomright这四个值
在原图上利用矩阵的方式进行截取。
5、如果想要在预测图上写额外的字比如检测到的特定目标的数量可以进入frcnn.detect_image函数在绘图部分对predicted_class进行判断
比如判断if predicted_class == 'car': 即可判断当前目标是否为车然后记录数量即可。利用draw.text即可写字。
'''
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
r_image = frcnn.detect_image(image)
r_image.show()
elif mode == "video":
capture=cv2.VideoCapture(video_path)
if video_save_path!="":
fourcc = cv2.VideoWriter_fourcc(*'XVID')
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
fps = 0.0
while(True):
t1 = time.time()
# 读取某一帧
ref,frame=capture.read()
# 格式转变BGRtoRGB
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
# 转变成Image
frame = Image.fromarray(np.uint8(frame))
# 进行检测
frame = np.array(frcnn.detect_image(frame))
# RGBtoBGR满足opencv显示格式
frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
fps = ( fps + (1./(time.time()-t1)) ) / 2
print("fps= %.2f"%(fps))
frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("video",frame)
c= cv2.waitKey(1) & 0xff
if video_save_path!="":
out.write(frame)
if c==27:
capture.release()
break
capture.release()
out.release()
cv2.destroyAllWindows()
elif mode == "fps":
test_interval = 100
img = Image.open('img/street.jpg')
tact_time = frcnn.get_FPS(img, test_interval)
print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
else:
raise AssertionError("Please specify the correct mode: 'predict', 'video' or 'fps'.")