escript Revision_
MUMPS.h
Go to the documentation of this file.
1
2/*****************************************************************************
3*
4* Copyright (c) 2003-2020 by The University of Queensland
5* http://www.uq.edu.au
6*
7* Primary Business: Queensland, Australia
8* Licensed under the Apache License, version 2.0
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12* Development 2012-2013 by School of Earth Sciences
13* Development from 2014 by Centre for Geoscience Computing (GeoComp)
14*
15*****************************************************************************/
16
17
18/****************************************************************************/
19
20/* Paso: interface to the MUMPS library */
21
22/****************************************************************************/
23
24#ifndef __PASO_MUMPS_H__
25#define __PASO_MUMPS_H__
26
27#include "SparseMatrix.h"
28#include "Options.h"
29#include "PasoException.h"
30
31#ifdef ESYS_HAVE_MUMPS
32// TODO: is this needed? #pragma push_macro("MPI_COMM_WORLD")
33#if defined(MPI_COMM_WORLD)
34#undef MPI_COMM_WORLD // breaks mumps_mpi.h, defined in escriptcore/src/EsysMPI.h
35#endif
36#include <mumps_mpi.h>
37// TODO: is this needed? #pragma pop_macro("MPI_COMM_WORLD")
38// #include <zmumps_c.h>
39#include <dmumps_c.h>
40#include <zmumps_c.h>
41#define MUMPS_JOB_INIT -1
42#define MUMPS_JOB_END -2
43#define MUMPS_USE_COMM_WORLD -987654
44#define ICNTL(I) icntl[(I)-1] // macro s.t. indices match documentation
45
46#ifdef _WIN32
47#define NOMINMAX
48#include <windows.h>
49#undef NOMINMAX
50#endif
51#endif // ESYS_HAVE_MUMPS
52
53namespace paso {
54
56 bool verbose;
57 std::stringstream ssExceptMsg;
58#ifdef ESYS_HAVE_MUMPS
59 MUMPS_INT myid;
60#ifdef _WIN32 // workaround for d/zmumps dll clash
61 HINSTANCE h_mumps_c_dll;
62#endif
63#endif // ESYS_HAVE_MUMPS
64};
65
66template <typename T>
70
71template <typename T>
73
74template <typename T>
75void MUMPS_solve(SparseMatrix_ptr<T> A, T* out, T* in, dim_t numRefinements, bool verbose);
76
77template <typename T>
78void MUMPS_print_list(const char* name, const T* vals, const int n, const int max_n=100);
79
80std::ostream& operator<<(std::ostream& os, const cplx_t& c);
81
82template <>
84 double* rhs;
85#ifdef ESYS_HAVE_MUMPS
86 DMUMPS_STRUC_C id;
87 typedef double mumps_t;
88#ifdef _WIN32 // workaround for d/zmumps dll clash
89 typedef HRESULT (CALLBACK* MUMPS_C_FUNC_PTR)(DMUMPS_STRUC_C*);
90 MUMPS_C_FUNC_PTR mumps_c;
91 const char* mumps_lib = "libdmumps";
92 const char* mumps_proc = "dmumps_c";
93#else
94 void (*mumps_c)(DMUMPS_STRUC_C*) = &dmumps_c;
95#endif
96#endif // ESYS_HAVE_MUMPS
97};
98
99template <>
101 cplx_t* rhs;
102#ifdef ESYS_HAVE_MUMPS
103 ZMUMPS_STRUC_C id;
104 typedef ZMUMPS_COMPLEX mumps_t;
105#ifdef _WIN32 // workaround for dmumps/zdmumps dll clash
106 typedef HRESULT (CALLBACK* MUMPS_C_FUNC_PTR)(ZMUMPS_STRUC_C*);
107 MUMPS_C_FUNC_PTR mumps_c;
108 const char* mumps_lib = "libzmumps";
109 const char* mumps_proc = "zmumps_c";
110#else
111 void (*mumps_c)(ZMUMPS_STRUC_C*) = &zmumps_c;
112#endif
113#endif // ESYS_HAVE_MUMPS
114};
115
117template <typename T>
119{
120 if (A && A->solver_p) {
121#ifdef ESYS_HAVE_MUMPS
122 // Terminate instance.
123 auto pt = static_cast<MUMPS_Handler<T>*>(A->solver_p);
124 delete[] pt->rhs;
125 pt->id.job = MUMPS_JOB_END;
126 pt->mumps_c(&pt->id);
127#ifdef _WIN32
128 FreeLibrary(pt->h_mumps_c_dll);
129#endif
130 if (pt->myid == 0) {
131 std::string message = pt->ssExceptMsg.str();
132 if (!message.empty()) {
133 // terminating with solve error message
134 throw PasoException(message);
135 }
136 }
137 MUMPS_INT ierr = MPI_Finalize();
138 if (pt->verbose) {
139 std::cout << "MUMPS: instance terminated." << std::endl;
140 }
141 delete pt;
142#endif
143 A->solver_p = NULL;
144 }
145}
146
148template <typename T>
149void MUMPS_solve(SparseMatrix_ptr<T> A, T* out, T* in, dim_t numRefinements, bool verbose)
150{
151#ifdef ESYS_HAVE_MUMPS
152 if (! (A->type & (MATRIX_FORMAT_OFFSET1 + MATRIX_FORMAT_BLK1)) ) {
153 throw PasoException("Paso: MUMPS requires CSR format with index offset 1 and block size 1.");
154 }
155
156 auto pt = reinterpret_cast<MUMPS_Handler<T>*>(A->solver_p);
157 if (pt == NULL) {
158 pt = new MUMPS_Handler<T>;
159#ifdef _WIN32
160 pt->h_mumps_c_dll = LoadLibrary(pt->mumps_lib);
161 if (pt->h_mumps_c_dll == NULL) {
162 std::stringstream ss;
163 ss << "Paso: MUMPS LoadLibrary failed - \"" << pt->mumps_lib << "\".";
164 throw PasoException(ss.str());
165 }
166 pt->mumps_c = (MUMPS_Handler<T>::MUMPS_C_FUNC_PTR)GetProcAddress(pt->h_mumps_c_dll, pt->mumps_proc);
167 if (pt->mumps_c == NULL) {
168 std::stringstream ss;
169 ss << "Paso: MUMPS GetProcAddress failed - \"" << pt->mumps_proc << "\".";
170 throw PasoException(ss.str());
171 }
172#endif
173 A->solver_p = (void*) pt;
174 A->solver_package = PASO_MUMPS;
175 double time0 = escript::gettime();
176
177 A->pattern->csrToHB(); // generate Harwell-Boeing format needed for MUMPS from CSR
178 MUMPS_INT n = A->numRows; // matrix order
179 MUMPS_INT8 nnz = A->pattern->len; // number non-zeros
180 MUMPS_INT* irn = reinterpret_cast<MUMPS_INT*>(A->pattern->hb_row); // row indices array
181 MUMPS_INT* jcn = reinterpret_cast<MUMPS_INT*>(A->pattern->hb_col); // col indices array
182 pt->verbose = verbose;
183 if (pt->verbose) {
184 std::cout << "MUMPS in ===>" << std::endl;
185 std::cout << "n = " << n << std::endl;
186 std::cout << "nnz = " << nnz << std::endl;
187 MUMPS_print_list("val", A->val, nnz);
188 MUMPS_print_list("in", in, n);
189 MUMPS_print_list("ptr", A->pattern->ptr, n+1);
190 MUMPS_print_list("index", A->pattern->index, nnz);
191 MUMPS_print_list("hb_row", A->pattern->hb_row, nnz);
192 MUMPS_print_list("hb_col", A->pattern->hb_col, nnz);
193 }
194 pt->rhs = new T[n];
195 std::memcpy(pt->rhs, in, n*sizeof(T));
196 MUMPS_INT ierr;
197 ierr = MPI_Init(NULL, NULL);
198 ierr = MPI_Comm_rank(MPI_COMM_WORLD, &pt->myid);
199
200 // Initialize a MUMPS instance. Use MPI_COMM_WORLD
201 pt->id.comm_fortran = MUMPS_USE_COMM_WORLD;
202 pt->id.par = 1; pt->id.sym = 0;
203 pt->id.job = MUMPS_JOB_INIT;
204 pt->mumps_c(&pt->id);
205 // Define the problem on the host
206 if (pt->myid == 0) {
207 pt->id.n = n; pt->id.nnz = nnz;
208 pt->id.irn = irn; pt->id.jcn = jcn;
209 pt->id.a = reinterpret_cast<typename MUMPS_Handler<T>::mumps_t*>(A->val);
210 pt->id.rhs = reinterpret_cast<typename MUMPS_Handler<T>::mumps_t*>(pt->rhs);
211 }
212 if (!pt->verbose) {
213 // No outputs
214 pt->id.ICNTL(1)=-1; pt->id.ICNTL(2)=-1; pt->id.ICNTL(3)=-1; pt->id.ICNTL(4)=0;
215 }
216
217 // Call the MUMPS package (analyse, factorization and solve).
218 pt->id.job = 6;
219 pt->mumps_c(&pt->id);
220 if (pt->id.infog[0] < 0) {
221 pt->ssExceptMsg << "(PROC " << pt->myid << ") MUMPS ERROR: INFOG(1)=" << pt->id.infog[0]
222 << ", INFOG(2)=" << pt->id.infog[1];
223 } else {
224 std::memcpy(out, reinterpret_cast<T*>(pt->rhs), n*sizeof(T));
225 if (pt->id.infog[0] > 0) {
226 std::cout << "(PROC " << pt->myid << ") MUMPS WARNING: INFOG(1)=" << pt->id.infog[0]
227 << ", INFOG(2)=" << pt->id.infog[1];
228 }
229 if (pt->verbose) {
230 std::cout << "MUMPS out ===>" << std::endl;
231 MUMPS_print_list("out", out, n);
232 std::cout << "MUMPS: factorization and solve completed (time = "
233 << escript::gettime()-time0 << ")." << std::endl;
234 }
235 }
236 }
237#else // ESYS_HAVE_MUMPS
238 throw PasoException("Paso: Not compiled with MUMPS.");
239#endif
240}
241
242// output array data for debugging solver
243// array length limit is 100 by default, use 0 for no limit
244template <typename T>
245void MUMPS_print_list(const char* name, const T* vals, const int n, const int max_n)
246{
247 std::cout << name << " = [ ";
248 for (int i=0; i<n; i++) {
249 if (i > 0) {
250 std::cout << ", ";
251 }
252 std::cout << vals[i];
253 if (max_n > 0) {
254 if (i > max_n) {
255 std::cout << ", ...";
256 break;
257 }
258 }
259 }
260 std::cout << " ]" << std::endl;
261}
262
263} // namespace paso
264
265#endif // __PASO_MUMPS_H__
266
#define MPI_COMM_WORLD
Definition EsysMPI.h:50
#define PASO_MUMPS
Definition Options.h:57
#define MATRIX_FORMAT_BLK1
Definition Paso.h:63
#define MATRIX_FORMAT_OFFSET1
Definition Paso.h:64
PasoException exception class.
Definition PasoException.h:34
double gettime()
returns the current ticks for timing
Definition EsysMPI.h:192
Definition BiCGStab.cpp:25
void MUMPS_print_list(const char *name, const T *vals, const int n, const int max_n=100)
Definition MUMPS.h:245
boost::shared_ptr< SparseMatrix< T > > SparseMatrix_ptr
Definition SparseMatrix.h:37
std::ostream & operator<<(std::ostream &os, const cplx_t &c)
Definition MUMPS.cpp:34
void MUMPS_solve(SparseMatrix_ptr< T > A, T *out, T *in, dim_t numRefinements, bool verbose)
calls the solver
Definition MUMPS.h:149
void MUMPS_free(SparseMatrix< T > *A)
frees any MUMPS related data from the matrix
Definition MUMPS.h:118
Definition blocktools.h:70
cplx_t * rhs
Definition MUMPS.h:101
double * rhs
Definition MUMPS.h:84
Definition MUMPS.h:55
std::stringstream ssExceptMsg
Definition MUMPS.h:57
bool verbose
Definition MUMPS.h:56
Definition MUMPS.h:67
T * rhs
Definition MUMPS.h:68
Definition SparseMatrix.h:45
void * solver_p
pointer to data needed by a solver
Definition SparseMatrix.h:177