import sysimport cv2import torchimport numpy as npfrom PIL import Image, ImageDrawfrom torchvision.transforms import functional as Ffrom PyQt5.QtCore import pyqtSignal, pyqtSlot, QThreadfrom PyQt5.QtGui import QPixmap, QImagefrom PyQt5.QtWidgets import QWidget, QApplication, QLabel, QVBoxLayoutimport torchvisionfrom torchvision.models.detection.faster_rcnn import FastRCNNPredictorfrom torchvision.models.detection.mask_rcnn import MaskRCNNPredictordef get_model_instance_segmentation(num_classes):# load an instance segmentation model pre-trained on COCOmodel = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")# get number of input features for the classifierin_features = model.roi_heads.box_predictor.cls_score.in_features# replace the pre-trained head with a new onemodel.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)# now get the number of input features for the mask classifierin_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channelshidden_layer = 256# and replace the mask predictor with a new onemodel.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,hidden_layer,num_classes)return modelclass VideoThread(QThread):pixmap_signal = pyqtSignal(np.ndarray)def __init__(self, model_path: str):super(VideoThread, self).__init__()self._is_running = True# модель была сохранена как torch.save(model.state_dict(), 'segmentation_model.pt')self.model = get_model_instance_segmentation(2)self.model.load_state_dict(torch.load(model_path))self.model.eval()def run(self):capture = cv2.VideoCapture(0)while self._is_running:ret, img = capture.read()if ret:img = self.convert_to_pil(img)prediction = self.predict(img)self.pixmap_signal.emit(self.draw_boxes(img, prediction))capture.release()def convert_to_pil(self, img: np.ndarray) -> Image:return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))def predict(self, img: Image):img = F.pil_to_tensor(img)img = F.convert_image_dtype(img, dtype=torch.float)with torch.no_grad():prediction = self.model([img])return predictiondef draw_boxes(self, img: Image, prediction):state_dict = prediction[0]draw = ImageDraw.Draw(img)for i in range(len(state_dict['labels'])):draw.rectangle(state_dict['boxes'][i].cpu().numpy(), outline="#FF0000")return np.array(img)def stop(self):self._is_running = Falseself.wait()class VideoWidget(QWidget):def __init__(self):super(VideoWidget, self).__init__()self.display_width = 640self.display_height = 480self.image_label = QLabel(self)layout = QVBoxLayout()layout.addWidget(self.image_label)self.setLayout(layout)self.video_thread = VideoThread('./segmentation_model.pt')self.video_thread.pixmap_signal.connect(self.update_image)self.video_thread.start()def closeEvent(self, event):self.video_thread.stop()event.accept()def np_to_pixmap(self, img: np.ndarray) -> QPixmap:rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)h, w, ch = rgb_img.shapebytes_per_line = ch * wqimage = QImage(rgb_img.data, w, h, bytes_per_line, QImage.Format_RGB888)return QPixmap.fromImage(qimage.scaled(self.display_width, self.display_height))@pyqtSlot(np.ndarray)def update_image(self, img: np.ndarray):pixmap = self.np_to_pixmap(img)self.image_label.setPixmap(pixmap)if __name__ == '__main__':app = QApplication(sys.argv)window = VideoWidget()window.show()sys.exit(app.exec_())
Комментарий недоступен