欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

MPI使用-python

程序员文章站 2022-07-12 21:11:46
...

MPI使用

简介

可能会遇到的问题

preprocessing

import numpy as np
import os

import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib import rc
import matplotlib
matplotlib.rcParams.update({'font.size': 14})
  • 一个简单的测试
  • 注意:在jupyter中是一个进程,如果要实现多个进程的效果,可以将代码保存为py文件,然后在命令行中运行,windows中是使用mpiexec -np 4 python your_file.py,4是设置进程的数目
from mpi4py import MPI
print( "my rank is %d" % MPI.COMM_WORLD.Get_rank() )
my rank is 0

多个进程之间通信

点对点的通信

  • 实现2个进程之间的数据传输
  • 注意:下面的程序必须至少使用2个进程运行,否则会出现异常(需要进程之间通信)
# 需要将这段代码保存成文件才能实现多进程程序的运行
import mpi4py.MPI as MPI
comm = MPI.COMM_WORLD
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()

data_send = [comm_rank] * 4
comm.send( data_send, dest=(comm_rank+1)%comm_size )
# 如果comm_rank-1<0,会自动加comm_size变为正数
data_recv = comm.recv( source=(comm_rank-1)%comm_size )
print( "my rank is %d, I received :" % comm_rank )
print( data_recv )

群体通信

  • 主要有2种:广播(broadcast)和散播(scatter)

广播通信

  • 将一个数据发送给所有的进程,每个进程都会得到这所有的数据
import mpi4py.MPI as MPI

comm = MPI.COMM_WORLD
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()

if comm_rank == 0:
    data = [i for i in range(comm_size)]

data = comm.bcast( data if comm_rank == 0 else None, root=0 )
print( "rank %d, got : " % comm_rank )
print( data )
  • 在上面的代码中,root进程建立了一个列表,然后广播给所有的进程,因此所有的进程都会拥有这个列表(数据),从而实现了数据共享。
  • 上面有一个问题,即这种方法进行数据广播的时间复杂度为O(N),如果要实现O(loh(N))的方法,需要使用规约树广播。具体解释的链接:https://blog.csdn.net/zouxy09/article/details/49031845

散播

  • 将一份数据平分给所有的进程,比如说有10个数据,给2个进程,则每个进程可以得到5个数据。
  • 注意:root也会得到自己散播出去的数据并进行处理
  • mpi4py可以无缝使用numpy,很方便
import mpi4py.MPI as MPI
import numpy as np

comm = MPI.COMM_WORLD
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()

if comm_rank == 0:
    # 一定要确保data的长度是np的数量
    data = np.random.rand(comm_size,3)
    # data = [i for i in range(comm_size)]
    # data = [[1], [2], [3], [4]]
    print( "all data by rank %d : " % comm_rank )
    print( data )
else:
    data = None

local_data = comm.scatter( data , root=0 )
print( "rank %d, got : " % comm_rank )
print( local_data )
all data by rank 0 : 
[[ 0.83198306  0.67017775  0.04115034]]
rank 0, got : 
[ 0.83198306  0.67017775  0.04115034]

收集数据

  • 在求一个数组的最大值的时候,可以让一个进程进行处理,也可以让多个进程同时处理,给每个进程一组特定的数据,求完之后再最最大值回收,可以减小root的数据处理压力,同时减小时间复杂度。
import mpi4py.MPI as MPI
import numpy as np

comm = MPI.COMM_WORLD
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()

if comm_rank == 0:
    data = np.random.rand(comm_size, 2)
    print( data )
else:
    data = None

local_data = comm.scatter( data, root=0 )
local_data = -local_data # 对数据进行处理
print( "rank %d got data and finished dealing." % comm_rank )
print( local_data )

# 由root=0进行数据的收集
# 因为需要进行收集工作,所以是最后执行完的。
combine_data = comm.gather( local_data, root = 0 )
if comm_rank == 0:
    print( combine_data )

规约

  • 在求一个数组的最大值的时候,可以让一个进程进行处理,也可以让多个进程同时处理,给每个进程一组特定的数据,求完之后再最最大值回收,可以减小root的数据处理压力,同时减小时间复杂度。
import mpi4py.MPI as MPI
import numpy as np

comm = MPI.COMM_WORLD
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()

if comm_rank == 0:
    data = np.random.rand(comm_size, 2)
    print( data )
else:
    data = None

local_data = comm.scatter( data, root=0 )
local_data = -local_data # 对数据进行处理
print( "rank %d got data and finished dealing." % comm_rank )
print( local_data )

all_sum = comm.reduce( local_data, root=0, op=MPI.SUM )
if comm_rank == 0:
    print( "sum is : %f", all_sum )
[[ 0.99687077  0.9394709 ]]
rank 0 got data and finished dealing.
[-0.99687077 -0.9394709 ]
sum is : %f [-0.99687077 -0.9394709 ]

实现自定义的处理函数

  • MPI定义的op有SUM,MIN,MAX等函数,但是如果我们希望能够自定义处理函数,则可以自己实现。
  • op的输入参数是2个类型相同的变量,返回一个参数,一个简单的定义如下:my_func
  • code

    import mpi4py.MPI as MPI
    import numpy as np
    
    # 自定义op
    def my_func( a, b ):
        f = a*a + b*b
        return f
    
    comm = MPI.COMM_WORLD
    comm_rank = comm.Get_rank()
    comm_size = comm.Get_size()
    
    if comm_rank == 0:
        data = np.random.rand(comm_size, 1)
        print( data )
    else:
        data = None
    
    local_data = comm.scatter( data, root=0 )
    local_data = -local_data # 对数据进行处理
    print( "rank %d got data and finished dealing." % comm_rank )
    print( local_data )
    
    all_sum = comm.reduce( local_data, root=0, op=my_func )
    if comm_rank == 0:
        print( "sum is : %f", all_sum )