欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch 多进程读写同一个文件

程序员文章站 2024-01-25 20:18:28
...

torch 读写同一个图片

# -*- coding:utf-8 -*-
from threading import Thread

import bind_cv as demo
import time

import torch
import torch.multiprocessing as mp

from torch.multiprocessing import Pool, Manager


import cv2


class VideoLoader(Thread):
    def __init__(self, name, cam_url,model):
        super(VideoLoader, self).__init__()
        self.name = name
        self.cam_url = cam_url
        self.index=0
        self.model=model

    def end_callback(self, type):
        print("end", type)
        time.sleep(10)
        print("play restart")
        try:
            binddemo.play_url(self.cam_url, self.m_callback, self.end_callback)
        except Exception as e:
            print("连接失败", e)

    def m_callback(self, a, width, height, t1):
        global start
        global start2
        global t2
        data = a.reshape(height, width, 3)
        data = cv2.cvtColor(data, cv2.COLOR_RGB2BGR)
        t1 = time.time()
        x = torch.from_numpy(data)
        # torch.set_num_threads(3)
        torch.save(x, "c:/"+ str(0) + '.dat')
        self.model.put(1)
        # red.set(self.index, data.tostring())
        print('存入时间', time.time() - t1)


        return 0
        # print('callback ok',a)

    def run(self):
        binddemo.play_url(self.cam_url, self.m_callback, self.end_callback)

def write(model):
    # 构建 data_loader,优化器等
    while True:

        data2 =model.get()
        start = time.time()

        get = False
        err = 0
        # while not get:
        try:
            data = torch.load("c:/"+ str(0) + '.dat')
            if time.time() - start > 0.002:
                print('get time', time.time() - start)
            get = True
            img = data.numpy()
            cv2.imshow("img", img)
            cv2.waitKeyEx(1)

            # if err>0:
            #   print('read time', time.time() - start, err)
        except Exception as e:
            err += 1
            # print("err time",)
            print('err time', time.time() - start, err)
            # time.sleep(0.01)
        del data2




if __name__ == '__main__':
    manager = Manager()
    num_processes = 1
    model = manager.Queue(10)
    filename = 'rtsp://192.168.1.12:554/h264/ch1/main/av_stream'
    vide_loader = VideoLoader("1", filename,model)
    vide_loader.start()


    # 注意:这是 "fork" 方法工作所必需的
    # model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=write, args=(model,))
        p.start()
        p.join()

    # videoRead = VideoRead("1", filename)
    # videoRead.start()

 

 

torch读写tensor

import os
import time
import torch


import torch.multiprocessing as mp

from torch.multiprocessing import Pool, Manager

def train(model):
    # 构建 data_loader,优化器等
    for i in range(300):
        start = time.time()
        get=False
        err=0
        while not get:

            try:
              aaa = torch.load('d:/lib/' + str(0) + '.dat')
              get=True
              # if err>0:
              #   print('read time', time.time() - start, err)
            except Exception as e:
                err += 1
                # print("err time",)
                print('err time', time.time() - start,err)
                time.sleep(0.002)

def write(model):
    # 构建 data_loader,优化器等
    for i in range(300):
        x = torch.rand(1, 3, 1280, 720)
        # torch.set_num_threads(3)
        start = time.time()
        torch.save(x, 'd:/lib/' + str(0) + '.dat')
        if time.time() - start>0.01:
            print('save time', time.time() - start)

if __name__ == '__main__':

    manager=Manager()
    num_processes = 2
    model = manager.Queue(2)
    # 注意:这是 "fork" 方法工作所必需的
    # model.share_memory()
    processes = []
    for rank in range(num_processes):
        if rank==1:
            p = mp.Process(target=train, args=(model,))
            p.start()
        else:
            p = mp.Process(target=write, args=(model,))
            p.start()
        processes.append(p)
    for p in processes:
        p.join()