import sys
import cv2
import torch
import numpy as np
from PIL import Image, ImageDraw
from torchvision.transforms import functional as F
from PyQt5.QtCore import pyqtSignal, pyqtSlot, QThread
from PyQt5.QtGui import QPixmap, QImage
from PyQt5.QtWidgets import QWidget, QApplication, QLabel, QVBoxLayout
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
hidden_layer,
num_classes)
return model
class VideoThread(QThread):
pixmap_signal = pyqtSignal(np.ndarray)
def __init__(self, model_path: str):
super(VideoThread, self).__init__()
self._is_running = True
# модель была сохранена как torch.save(model.state_dict(), 'segmentation_model.pt')
self.model = get_model_instance_segmentation(2)
self.model.load_state_dict(torch.load(model_path))
self.model.eval()
def run(self):
capture = cv2.VideoCapture(0)
while self._is_running:
ret, img = capture.read()
if ret:
img = self.convert_to_pil(img)
prediction = self.predict(img)
self.pixmap_signal.emit(self.draw_boxes(img, prediction))
capture.release()
def convert_to_pil(self, img: np.ndarray) -> Image:
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
def predict(self, img: Image):
img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, dtype=torch.float)
with torch.no_grad():
prediction = self.model([img])
return prediction
def draw_boxes(self, img: Image, prediction):
state_dict = prediction[0]
draw = ImageDraw.Draw(img)
for i in range(len(state_dict['labels'])):
draw.rectangle(state_dict['boxes'][i].cpu().numpy(), outline="#FF0000")
return np.array(img)
def stop(self):
self._is_running = False
self.wait()
class VideoWidget(QWidget):
def __init__(self):
super(VideoWidget, self).__init__()
self.display_width = 640
self.display_height = 480
self.image_label = QLabel(self)
layout = QVBoxLayout()
layout.addWidget(self.image_label)
self.setLayout(layout)
self.video_thread = VideoThread('./segmentation_model.pt')
self.video_thread.pixmap_signal.connect(self.update_image)
self.video_thread.start()
def closeEvent(self, event):
self.video_thread.stop()
event.accept()
def np_to_pixmap(self, img: np.ndarray) -> QPixmap:
rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w, ch = rgb_img.shape
bytes_per_line = ch * w
qimage = QImage(rgb_img.data, w, h, bytes_per_line, QImage.Format_RGB888)
return QPixmap.fromImage(
qimage.scaled(self.display_width, self.display_height)
)
@pyqtSlot(np.ndarray)
def update_image(self, img: np.ndarray):
pixmap = self.np_to_pixmap(img)
self.image_label.setPixmap(pixmap)
if __name__ == '__main__':
app = QApplication(sys.argv)
window = VideoWidget()
window.show()
sys.exit(app.exec_())