Blame view

3rdparty/boost_1_81_0/libs/mpi/src/broadcast.cpp 4 KB
73ef4ff3   Hu Chunming   提交三方库
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
  // Copyright 2005 Douglas Gregor.
  
  // Use, modification and distribution is subject to the Boost Software
  // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
  // http://www.boost.org/LICENSE_1_0.txt)
  
  // Message Passing Interface 1.1 -- Section 4.4. Broadcast
  
  #include <boost/mpi/config.hpp>
  #include <boost/mpi/collectives/broadcast.hpp>
  #include <boost/mpi/skeleton_and_content.hpp>
  #include <boost/mpi/detail/point_to_point.hpp>
  #include <boost/mpi/environment.hpp>
  #include <cassert>
  
  namespace boost { namespace mpi {
  
  template<>
  void
  broadcast<const packed_oarchive>(const communicator& comm,
                                   const packed_oarchive& oa,
                                   int root)
  {
    // Only the root can broadcast the packed_oarchive
    assert(comm.rank() == root);
  
    int size = comm.size();
    if (size < 2) return;
  
    // Determine maximum tag value
    int tag = environment::collectives_tag();
  
    // Broadcast data to all nodes
    std::vector<request> requests(size-1);
    std::vector<request>::iterator it = requests.begin();
    for (int dest = 0; dest < size; ++dest) {
      if (dest != root) {
        *it++ = detail::packed_archive_isend(comm, dest, tag, oa);
      }
    }
    wait_all(requests.begin(), requests.end());
  }
  
  template<>
  void
  broadcast<packed_oarchive>(const communicator& comm, packed_oarchive& oa,
                             int root)
  {
    broadcast(comm, const_cast<const packed_oarchive&>(oa), root);
  }
  
  template<>
  void
  broadcast<packed_iarchive>(const communicator& comm, packed_iarchive& ia,
                             int root)
  {
    int size = comm.size();
    if (size < 2) return;
  
    // Determine maximum tag value
    int tag = environment::collectives_tag();
  
    // Receive data from the root.
    if (comm.rank() != root) {
      MPI_Status status;
      detail::packed_archive_recv(comm, root, tag, ia, status);
    } else {
      // Broadcast data to all nodes
      std::vector<request> requests(size-1);
      std::vector<request>::iterator it = requests.begin();
      for (int dest = 0; dest < size; ++dest) {
        if (dest != root) {
          *it++ = detail::packed_archive_isend(comm, dest, tag, ia);
        }
      }
      wait_all(requests.begin(), requests.end());
    }
  }
  
  template<>
  void
  broadcast<const packed_skeleton_oarchive>(const communicator& comm,
                                            const packed_skeleton_oarchive& oa,
                                            int root)
  {
    broadcast(comm, oa.get_skeleton(), root);
  }
  
  template<>
  void
  broadcast<packed_skeleton_oarchive>(const communicator& comm,
                                      packed_skeleton_oarchive& oa, int root)
  {
    broadcast(comm, oa.get_skeleton(), root);
  }
  
  template<>
  void
  broadcast<packed_skeleton_iarchive>(const communicator& comm,
                                      packed_skeleton_iarchive& ia, int root)
  {
    broadcast(comm, ia.get_skeleton(), root);
  }
  
  template<>
  void broadcast<content>(const communicator& comm, content& c, int root)
  {
    broadcast(comm, const_cast<const content&>(c), root);
  }
  
  template<>
  void broadcast<const content>(const communicator& comm, const content& c,
                                int root)
  {
  #if defined(BOOST_MPI_BCAST_BOTTOM_WORKS_FINE)
    BOOST_MPI_CHECK_RESULT(MPI_Bcast,
                           (MPI_BOTTOM, 1, c.get_mpi_datatype(),
                            root, comm));
  #else
    if (comm.size() < 2)
      return;
  
    // Some versions of LAM/MPI behave badly when broadcasting using
    // MPI_BOTTOM, so we'll instead use manual send/recv operations.
    if (comm.rank() == root) {
      for (int p = 0; p < comm.size(); ++p) {
        if (p != root) {
          BOOST_MPI_CHECK_RESULT(MPI_Send,
                                 (MPI_BOTTOM, 1, c.get_mpi_datatype(),
                                  p, environment::collectives_tag(), comm));
        }
      }
    } else {
      BOOST_MPI_CHECK_RESULT(MPI_Recv,
                             (MPI_BOTTOM, 1, c.get_mpi_datatype(),
                              root, environment::collectives_tag(),
                              comm, MPI_STATUS_IGNORE));
    }
  #endif
  }
  
  } } // end namespace boost::mpi