Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
Eigen  3.4.0
MatrixProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
5// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
12#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
13
14#ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
15#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
16#endif
17
18#include "MatrixProductCommon.h"
19
20// Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX
21#if EIGEN_COMP_LLVM
22#if !defined(EIGEN_ALTIVEC_DISABLE_MMA) && !defined(EIGEN_ALTIVEC_MMA_ONLY)
23#ifdef __MMA__
24#define EIGEN_ALTIVEC_MMA_ONLY
25#else
26#define EIGEN_ALTIVEC_DISABLE_MMA
27#endif
28#endif
29#endif
30
31#ifdef __PPC64__
32#define EIGEN_ALTIVEC_DISABLE_MMA
33#endif
34
35#ifdef __has_builtin
36#if __has_builtin(__builtin_mma_assemble_acc)
37 #define ALTIVEC_MMA_SUPPORT
38#endif
39#endif
40
41#if defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
42 #include "MatrixProductMMA.h"
43#endif
44
45/**************************************************************************************************
46 * TODO *
47 * - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). *
48 * - Check the possibility of transposing as GETREAL and GETIMAG when needed. *
49 **************************************************************************************************/
50namespace Eigen {
51
52namespace internal {
53
54/**************************
55 * Constants and typedefs *
56 **************************/
57template<typename Scalar>
58struct quad_traits
59{
60 typedef typename packet_traits<Scalar>::type vectortype;
61 typedef PacketBlock<vectortype,4> type;
62 typedef vectortype rhstype;
63 enum
64 {
65 vectorsize = packet_traits<Scalar>::size,
66 size = 4,
67 rows = 4
68 };
69};
70
71template<>
72struct quad_traits<double>
73{
74 typedef Packet2d vectortype;
75 typedef PacketBlock<vectortype,4> type;
76 typedef PacketBlock<Packet2d,2> rhstype;
77 enum
78 {
79 vectorsize = packet_traits<double>::size,
80 size = 2,
81 rows = 4
82 };
83};
84
85// MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out
86// to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then
87// are responsible to extract from convert between Eigen's and MatrixProduct approach.
88
89const static Packet16uc p16uc_GETREAL32 = { 0, 1, 2, 3,
90 8, 9, 10, 11,
91 16, 17, 18, 19,
92 24, 25, 26, 27};
93
94const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7,
95 12, 13, 14, 15,
96 20, 21, 22, 23,
97 28, 29, 30, 31};
98const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7,
99 16, 17, 18, 19, 20, 21, 22, 23};
100
101//[a,ai],[b,bi] = [ai,bi]
102const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15,
103 24, 25, 26, 27, 28, 29, 30, 31};
104
105/*********************************************
106 * Single precision real and complex packing *
107 * *******************************************/
108
123template<typename Scalar, typename Index, int StorageOrder>
124EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt)
125{
126 std::complex<Scalar> v;
127 if(i < j)
128 {
129 v.real( dt(j,i).real());
130 v.imag(-dt(j,i).imag());
131 } else if(i > j)
132 {
133 v.real( dt(i,j).real());
134 v.imag( dt(i,j).imag());
135 } else {
136 v.real( dt(i,j).real());
137 v.imag((Scalar)0.0);
138 }
139 return v;
140}
141
142template<typename Scalar, typename Index, int StorageOrder, int N>
143EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
144{
145 const Index depth = k2 + rows;
146 const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> rhs(_rhs, rhsStride);
147 const Index vectorSize = N*quad_traits<Scalar>::vectorsize;
148 const Index vectorDelta = vectorSize * rows;
149 Scalar* blockBf = reinterpret_cast<Scalar *>(blockB);
150
151 Index rir = 0, rii, j = 0;
152 for(; j + vectorSize <= cols; j+=vectorSize)
153 {
154 rii = rir + vectorDelta;
155
156 for(Index i = k2; i < depth; i++)
157 {
158 for(Index k = 0; k < vectorSize; k++)
159 {
160 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j + k, rhs);
161
162 blockBf[rir + k] = v.real();
163 blockBf[rii + k] = v.imag();
164 }
165 rir += vectorSize;
166 rii += vectorSize;
167 }
168
169 rir += vectorDelta;
170 }
171 if (j < cols)
172 {
173 rii = rir + ((cols - j) * rows);
174
175 for(Index i = k2; i < depth; i++)
176 {
177 Index k = j;
178 for(; k < cols; k++)
179 {
180 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, k, rhs);
181
182 blockBf[rir] = v.real();
183 blockBf[rii] = v.imag();
184
185 rir += 1;
186 rii += 1;
187 }
188 }
189 }
190}
191
192template<typename Scalar, typename Index, int StorageOrder>
193EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs, Index lhsStride, Index cols, Index rows)
194{
195 const Index depth = cols;
196 const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> lhs(_lhs, lhsStride);
197 const Index vectorSize = quad_traits<Scalar>::vectorsize;
198 const Index vectorDelta = vectorSize * depth;
199 Scalar* blockAf = (Scalar *)(blockA);
200
201 Index rir = 0, rii, j = 0;
202 for(; j + vectorSize <= rows; j+=vectorSize)
203 {
204 rii = rir + vectorDelta;
205
206 for(Index i = 0; i < depth; i++)
207 {
208 for(Index k = 0; k < vectorSize; k++)
209 {
210 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(j+k, i, lhs);
211
212 blockAf[rir + k] = v.real();
213 blockAf[rii + k] = v.imag();
214 }
215 rir += vectorSize;
216 rii += vectorSize;
217 }
218
219 rir += vectorDelta;
220 }
221
222 if (j < rows)
223 {
224 rii = rir + ((rows - j) * depth);
225
226 for(Index i = 0; i < depth; i++)
227 {
228 Index k = j;
229 for(; k < rows; k++)
230 {
231 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(k, i, lhs);
232
233 blockAf[rir] = v.real();
234 blockAf[rii] = v.imag();
235
236 rir += 1;
237 rii += 1;
238 }
239 }
240 }
241}
242
243template<typename Scalar, typename Index, int StorageOrder, int N>
244EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
245{
246 const Index depth = k2 + rows;
247 const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride);
248 const Index vectorSize = quad_traits<Scalar>::vectorsize;
249
250 Index ri = 0, j = 0;
251 for(; j + N*vectorSize <= cols; j+=N*vectorSize)
252 {
253 Index i = k2;
254 for(; i < depth; i++)
255 {
256 for(Index k = 0; k < N*vectorSize; k++)
257 {
258 if(i <= j+k)
259 blockB[ri + k] = rhs(j+k, i);
260 else
261 blockB[ri + k] = rhs(i, j+k);
262 }
263 ri += N*vectorSize;
264 }
265 }
266
267 if (j < cols)
268 {
269 for(Index i = k2; i < depth; i++)
270 {
271 Index k = j;
272 for(; k < cols; k++)
273 {
274 if(k <= i)
275 blockB[ri] = rhs(i, k);
276 else
277 blockB[ri] = rhs(k, i);
278 ri += 1;
279 }
280 }
281 }
282}
283
284template<typename Scalar, typename Index, int StorageOrder>
285EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
286{
287 const Index depth = cols;
288 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride);
289 const Index vectorSize = quad_traits<Scalar>::vectorsize;
290
291 Index ri = 0, j = 0;
292 for(; j + vectorSize <= rows; j+=vectorSize)
293 {
294 Index i = 0;
295
296 for(; i < depth; i++)
297 {
298 for(Index k = 0; k < vectorSize; k++)
299 {
300 if(i <= j+k)
301 blockA[ri + k] = lhs(j+k, i);
302 else
303 blockA[ri + k] = lhs(i, j+k);
304 }
305 ri += vectorSize;
306 }
307 }
308
309 if (j < rows)
310 {
311 for(Index i = 0; i < depth; i++)
312 {
313 Index k = j;
314 for(; k < rows; k++)
315 {
316 if(i <= k)
317 blockA[ri] = lhs(k, i);
318 else
319 blockA[ri] = lhs(i, k);
320 ri += 1;
321 }
322 }
323 }
324}
325
326template<typename Index, int nr, int StorageOrder>
327struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder>
328{
329 void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
330 {
331 symm_pack_complex_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
332 }
333};
334
335template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
336struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder>
337{
338 void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols, Index rows)
339 {
340 symm_pack_complex_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
341 }
342};
343
344// *********** symm_pack std::complex<float64> ***********
345
346template<typename Index, int nr, int StorageOrder>
347struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder>
348{
349 void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
350 {
351 symm_pack_complex_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
352 }
353};
354
355template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
356struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder>
357{
358 void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols, Index rows)
359 {
360 symm_pack_complex_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
361 }
362};
363
364// *********** symm_pack float32 ***********
365template<typename Index, int nr, int StorageOrder>
366struct symm_pack_rhs<float, Index, nr, StorageOrder>
367{
368 void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
369 {
370 symm_pack_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
371 }
372};
373
374template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
375struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder>
376{
377 void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows)
378 {
379 symm_pack_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
380 }
381};
382
383// *********** symm_pack float64 ***********
384template<typename Index, int nr, int StorageOrder>
385struct symm_pack_rhs<double, Index, nr, StorageOrder>
386{
387 void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
388 {
389 symm_pack_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
390 }
391};
392
393template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
394struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder>
395{
396 void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows)
397 {
398 symm_pack_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
399 }
400};
401
413template<typename Scalar, typename Packet, typename Index>
414EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,4>& block)
415{
416 const Index size = 16 / sizeof(Scalar);
417 pstore<Scalar>(to + (0 * size), block.packet[0]);
418 pstore<Scalar>(to + (1 * size), block.packet[1]);
419 pstore<Scalar>(to + (2 * size), block.packet[2]);
420 pstore<Scalar>(to + (3 * size), block.packet[3]);
421}
422
423template<typename Scalar, typename Packet, typename Index>
424EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,2>& block)
425{
426 const Index size = 16 / sizeof(Scalar);
427 pstore<Scalar>(to + (0 * size), block.packet[0]);
428 pstore<Scalar>(to + (1 * size), block.packet[1]);
429}
430
431// General template for lhs & rhs complex packing.
432template<typename Scalar, typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode, bool UseLhs>
433struct dhs_cpack {
434 EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
435 {
436 const Index vectorSize = quad_traits<Scalar>::vectorsize;
437 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
438 Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
439 Scalar* blockAt = reinterpret_cast<Scalar *>(blockA);
440 Index j = 0;
441
442 for(; j + vectorSize <= rows; j+=vectorSize)
443 {
444 Index i = 0;
445
446 rii = rir + vectorDelta;
447
448 for(; i + vectorSize <= depth; i+=vectorSize)
449 {
450 PacketBlock<Packet,4> blockr, blocki;
451 PacketBlock<PacketC,8> cblock;
452
453 if (UseLhs) {
454 bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, j, i);
455 } else {
456 bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, i, j);
457 }
458
459 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
460 blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32);
461 blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32);
462 blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32);
463
464 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32);
465 blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32);
466 blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32);
467 blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32);
468
469 if(Conjugate)
470 {
471 blocki.packet[0] = -blocki.packet[0];
472 blocki.packet[1] = -blocki.packet[1];
473 blocki.packet[2] = -blocki.packet[2];
474 blocki.packet[3] = -blocki.packet[3];
475 }
476
477 if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs)))
478 {
479 ptranspose(blockr);
480 ptranspose(blocki);
481 }
482
483 storeBlock<Scalar, Packet, Index>(blockAt + rir, blockr);
484 storeBlock<Scalar, Packet, Index>(blockAt + rii, blocki);
485
486 rir += 4*vectorSize;
487 rii += 4*vectorSize;
488 }
489 for(; i < depth; i++)
490 {
491 PacketBlock<Packet,1> blockr, blocki;
492 PacketBlock<PacketC,2> cblock;
493
494 if(((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs)))
495 {
496 if (UseLhs) {
497 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
498 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 2, i);
499 } else {
500 cblock.packet[0] = lhs.template loadPacket<PacketC>(i, j + 0);
501 cblock.packet[1] = lhs.template loadPacket<PacketC>(i, j + 2);
502 }
503 } else {
504 std::complex<Scalar> lhs0, lhs1;
505 if (UseLhs) {
506 lhs0 = lhs(j + 0, i);
507 lhs1 = lhs(j + 1, i);
508 cblock.packet[0] = pload2(&lhs0, &lhs1);
509 lhs0 = lhs(j + 2, i);
510 lhs1 = lhs(j + 3, i);
511 cblock.packet[1] = pload2(&lhs0, &lhs1);
512 } else {
513 lhs0 = lhs(i, j + 0);
514 lhs1 = lhs(i, j + 1);
515 cblock.packet[0] = pload2(&lhs0, &lhs1);
516 lhs0 = lhs(i, j + 2);
517 lhs1 = lhs(i, j + 3);
518 cblock.packet[1] = pload2(&lhs0, &lhs1);
519 }
520 }
521
522 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32);
523 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32);
524
525 if(Conjugate)
526 {
527 blocki.packet[0] = -blocki.packet[0];
528 }
529
530 pstore<Scalar>(blockAt + rir, blockr.packet[0]);
531 pstore<Scalar>(blockAt + rii, blocki.packet[0]);
532
533 rir += vectorSize;
534 rii += vectorSize;
535 }
536
537 rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
538 }
539
540 if (j < rows)
541 {
542 if(PanelMode) rir += (offset*(rows - j - vectorSize));
543 rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
544
545 for(Index i = 0; i < depth; i++)
546 {
547 Index k = j;
548 for(; k < rows; k++)
549 {
550 if (UseLhs) {
551 blockAt[rir] = lhs(k, i).real();
552
553 if(Conjugate)
554 blockAt[rii] = -lhs(k, i).imag();
555 else
556 blockAt[rii] = lhs(k, i).imag();
557 } else {
558 blockAt[rir] = lhs(i, k).real();
559
560 if(Conjugate)
561 blockAt[rii] = -lhs(i, k).imag();
562 else
563 blockAt[rii] = lhs(i, k).imag();
564 }
565
566 rir += 1;
567 rii += 1;
568 }
569 }
570 }
571 }
572};
573
574// General template for lhs & rhs packing.
575template<typename Scalar, typename Index, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
576struct dhs_pack{
577 EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
578 {
579 const Index vectorSize = quad_traits<Scalar>::vectorsize;
580 Index ri = 0, j = 0;
581
582 for(; j + vectorSize <= rows; j+=vectorSize)
583 {
584 Index i = 0;
585
586 if(PanelMode) ri += vectorSize*offset;
587
588 for(; i + vectorSize <= depth; i+=vectorSize)
589 {
590 PacketBlock<Packet,4> block;
591
592 if (UseLhs) {
593 bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, j, i);
594 } else {
595 bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, i, j);
596 }
597 if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
598 {
599 ptranspose(block);
600 }
601
602 storeBlock<Scalar, Packet, Index>(blockA + ri, block);
603
604 ri += 4*vectorSize;
605 }
606 for(; i < depth; i++)
607 {
608 if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
609 {
610 if (UseLhs) {
611 blockA[ri+0] = lhs(j+0, i);
612 blockA[ri+1] = lhs(j+1, i);
613 blockA[ri+2] = lhs(j+2, i);
614 blockA[ri+3] = lhs(j+3, i);
615 } else {
616 blockA[ri+0] = lhs(i, j+0);
617 blockA[ri+1] = lhs(i, j+1);
618 blockA[ri+2] = lhs(i, j+2);
619 blockA[ri+3] = lhs(i, j+3);
620 }
621 } else {
622 Packet lhsV;
623 if (UseLhs) {
624 lhsV = lhs.template loadPacket<Packet>(j, i);
625 } else {
626 lhsV = lhs.template loadPacket<Packet>(i, j);
627 }
628 pstore<Scalar>(blockA + ri, lhsV);
629 }
630
631 ri += vectorSize;
632 }
633
634 if(PanelMode) ri += vectorSize*(stride - offset - depth);
635 }
636
637 if (j < rows)
638 {
639 if(PanelMode) ri += offset*(rows - j);
640
641 for(Index i = 0; i < depth; i++)
642 {
643 Index k = j;
644 for(; k < rows; k++)
645 {
646 if (UseLhs) {
647 blockA[ri] = lhs(k, i);
648 } else {
649 blockA[ri] = lhs(i, k);
650 }
651 ri += 1;
652 }
653 }
654 }
655 }
656};
657
658// General template for lhs packing, float64 specialization.
659template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
660struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, true>
661{
662 EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
663 {
664 const Index vectorSize = quad_traits<double>::vectorsize;
665 Index ri = 0, j = 0;
666
667 for(; j + vectorSize <= rows; j+=vectorSize)
668 {
669 Index i = 0;
670
671 if(PanelMode) ri += vectorSize*offset;
672
673 for(; i + vectorSize <= depth; i+=vectorSize)
674 {
675 PacketBlock<Packet2d,2> block;
676 if(StorageOrder == RowMajor)
677 {
678 block.packet[0] = lhs.template loadPacket<Packet2d>(j + 0, i);
679 block.packet[1] = lhs.template loadPacket<Packet2d>(j + 1, i);
680
681 ptranspose(block);
682 } else {
683 block.packet[0] = lhs.template loadPacket<Packet2d>(j, i + 0);
684 block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
685 }
686
687 storeBlock<double, Packet2d, Index>(blockA + ri, block);
688
689 ri += 2*vectorSize;
690 }
691 for(; i < depth; i++)
692 {
693 if(StorageOrder == RowMajor)
694 {
695 blockA[ri+0] = lhs(j+0, i);
696 blockA[ri+1] = lhs(j+1, i);
697 } else {
698 Packet2d lhsV = lhs.template loadPacket<Packet2d>(j, i);
699 pstore<double>(blockA + ri, lhsV);
700 }
701
702 ri += vectorSize;
703 }
704
705 if(PanelMode) ri += vectorSize*(stride - offset - depth);
706 }
707
708 if (j < rows)
709 {
710 if(PanelMode) ri += offset*(rows - j);
711
712 for(Index i = 0; i < depth; i++)
713 {
714 Index k = j;
715 for(; k < rows; k++)
716 {
717 blockA[ri] = lhs(k, i);
718 ri += 1;
719 }
720 }
721 }
722 }
723};
724
725// General template for rhs packing, float64 specialization.
726template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode>
727struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, false>
728{
729 EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
730 {
731 const Index vectorSize = quad_traits<double>::vectorsize;
732 Index ri = 0, j = 0;
733
734 for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
735 {
736 Index i = 0;
737
738 if(PanelMode) ri += offset*(2*vectorSize);
739
740 for(; i + vectorSize <= depth; i+=vectorSize)
741 {
742 PacketBlock<Packet2d,4> block;
743 if(StorageOrder == ColMajor)
744 {
745 PacketBlock<Packet2d,2> block1, block2;
746 block1.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 0);
747 block1.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 1);
748 block2.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 2);
749 block2.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 3);
750
751 ptranspose(block1);
752 ptranspose(block2);
753
754 pstore<double>(blockB + ri , block1.packet[0]);
755 pstore<double>(blockB + ri + 2, block2.packet[0]);
756 pstore<double>(blockB + ri + 4, block1.packet[1]);
757 pstore<double>(blockB + ri + 6, block2.packet[1]);
758 } else {
759 block.packet[0] = rhs.template loadPacket<Packet2d>(i + 0, j + 0); //[a1 a2]
760 block.packet[1] = rhs.template loadPacket<Packet2d>(i + 0, j + 2); //[a3 a4]
761 block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
762 block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
763
764 storeBlock<double, Packet2d, Index>(blockB + ri, block);
765 }
766
767 ri += 4*vectorSize;
768 }
769 for(; i < depth; i++)
770 {
771 if(StorageOrder == ColMajor)
772 {
773 blockB[ri+0] = rhs(i, j+0);
774 blockB[ri+1] = rhs(i, j+1);
775
776 ri += vectorSize;
777
778 blockB[ri+0] = rhs(i, j+2);
779 blockB[ri+1] = rhs(i, j+3);
780 } else {
781 Packet2d rhsV = rhs.template loadPacket<Packet2d>(i, j);
782 pstore<double>(blockB + ri, rhsV);
783
784 ri += vectorSize;
785
786 rhsV = rhs.template loadPacket<Packet2d>(i, j + 2);
787 pstore<double>(blockB + ri, rhsV);
788 }
789 ri += vectorSize;
790 }
791
792 if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
793 }
794
795 if (j < cols)
796 {
797 if(PanelMode) ri += offset*(cols - j);
798
799 for(Index i = 0; i < depth; i++)
800 {
801 Index k = j;
802 for(; k < cols; k++)
803 {
804 blockB[ri] = rhs(i, k);
805 ri += 1;
806 }
807 }
808 }
809 }
810};
811
812// General template for lhs complex packing, float64 specialization.
813template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
814struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true>
815{
816 EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
817 {
818 const Index vectorSize = quad_traits<double>::vectorsize;
819 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
820 Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii;
821 double* blockAt = reinterpret_cast<double *>(blockA);
822 Index j = 0;
823
824 for(; j + vectorSize <= rows; j+=vectorSize)
825 {
826 Index i = 0;
827
828 rii = rir + vectorDelta;
829
830 for(; i + vectorSize <= depth; i+=vectorSize)
831 {
832 PacketBlock<Packet,2> blockr, blocki;
833 PacketBlock<PacketC,4> cblock;
834
835 if(StorageOrder == ColMajor)
836 {
837 cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0); //[a1 a1i]
838 cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1); //[b1 b1i]
839
840 cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0); //[a2 a2i]
841 cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i]
842
843 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2]
844 blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
845
846 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64);
847 blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64);
848 } else {
849 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); //[a1 a1i]
850 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i); //[a2 a2i]
851
852 cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1); //[b1 b1i]
853 cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i
854
855 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2]
856 blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
857
858 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
859 blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
860 }
861
862 if(Conjugate)
863 {
864 blocki.packet[0] = -blocki.packet[0];
865 blocki.packet[1] = -blocki.packet[1];
866 }
867
868 storeBlock<double, Packet, Index>(blockAt + rir, blockr);
869 storeBlock<double, Packet, Index>(blockAt + rii, blocki);
870
871 rir += 2*vectorSize;
872 rii += 2*vectorSize;
873 }
874 for(; i < depth; i++)
875 {
876 PacketBlock<Packet,1> blockr, blocki;
877 PacketBlock<PacketC,2> cblock;
878
879 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
880 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i);
881
882 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
883 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
884
885 if(Conjugate)
886 {
887 blocki.packet[0] = -blocki.packet[0];
888 }
889
890 pstore<double>(blockAt + rir, blockr.packet[0]);
891 pstore<double>(blockAt + rii, blocki.packet[0]);
892
893 rir += vectorSize;
894 rii += vectorSize;
895 }
896
897 rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
898 }
899
900 if (j < rows)
901 {
902 if(PanelMode) rir += (offset*(rows - j - vectorSize));
903 rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
904
905 for(Index i = 0; i < depth; i++)
906 {
907 Index k = j;
908 for(; k < rows; k++)
909 {
910 blockAt[rir] = lhs(k, i).real();
911
912 if(Conjugate)
913 blockAt[rii] = -lhs(k, i).imag();
914 else
915 blockAt[rii] = lhs(k, i).imag();
916
917 rir += 1;
918 rii += 1;
919 }
920 }
921 }
922 }
923};
924
925// General template for rhs complex packing, float64 specialization.
926template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
927struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false>
928{
929 EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
930 {
931 const Index vectorSize = quad_traits<double>::vectorsize;
932 const Index vectorDelta = 2*vectorSize * ((PanelMode) ? stride : depth);
933 Index rir = ((PanelMode) ? (2*vectorSize*offset) : 0), rii;
934 double* blockBt = reinterpret_cast<double *>(blockB);
935 Index j = 0;
936
937 for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
938 {
939 Index i = 0;
940
941 rii = rir + vectorDelta;
942
943 for(; i < depth; i++)
944 {
945 PacketBlock<PacketC,4> cblock;
946 PacketBlock<Packet,2> blockr, blocki;
947
948 bload<DataMapper, PacketC, Index, 2, 0, ColMajor>(cblock, rhs, i, j);
949
950 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
951 blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
952
953 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
954 blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
955
956 if(Conjugate)
957 {
958 blocki.packet[0] = -blocki.packet[0];
959 blocki.packet[1] = -blocki.packet[1];
960 }
961
962 storeBlock<double, Packet, Index>(blockBt + rir, blockr);
963 storeBlock<double, Packet, Index>(blockBt + rii, blocki);
964
965 rir += 2*vectorSize;
966 rii += 2*vectorSize;
967 }
968
969 rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
970 }
971
972 if (j < cols)
973 {
974 if(PanelMode) rir += (offset*(cols - j - 2*vectorSize));
975 rii = rir + (((PanelMode) ? stride : depth) * (cols - j));
976
977 for(Index i = 0; i < depth; i++)
978 {
979 Index k = j;
980 for(; k < cols; k++)
981 {
982 blockBt[rir] = rhs(i, k).real();
983
984 if(Conjugate)
985 blockBt[rii] = -rhs(i, k).imag();
986 else
987 blockBt[rii] = rhs(i, k).imag();
988
989 rir += 1;
990 rii += 1;
991 }
992 }
993 }
994 }
995};
996
997/**************
998 * GEMM utils *
999 **************/
1000
1001// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
1002template<typename Packet, bool NegativeAccumulate>
1003EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,4>* acc, const Packet& lhsV, const Packet* rhsV)
1004{
1005 if(NegativeAccumulate)
1006 {
1007 acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
1008 acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
1009 acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
1010 acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
1011 } else {
1012 acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
1013 acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
1014 acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
1015 acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
1016 }
1017}
1018
1019template<typename Packet, bool NegativeAccumulate>
1020EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,1>* acc, const Packet& lhsV, const Packet* rhsV)
1021{
1022 if(NegativeAccumulate)
1023 {
1024 acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
1025 } else {
1026 acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
1027 }
1028}
1029
1030template<int N, typename Scalar, typename Packet, bool NegativeAccumulate>
1031EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
1032{
1033 Packet lhsV = pload<Packet>(lhs);
1034
1035 pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
1036}
1037
1038template<typename Scalar, typename Packet, typename Index>
1039EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, Index remaining_rows)
1040{
1041#ifdef _ARCH_PWR9
1042 lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
1043#else
1044 Index i = 0;
1045 do {
1046 lhsV[i] = lhs[i];
1047 } while (++i < remaining_rows);
1048#endif
1049}
1050
1051template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate>
1052EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows)
1053{
1054 Packet lhsV;
1055 loadPacketRemaining<Scalar, Packet, Index>(lhs, lhsV, remaining_rows);
1056
1057 pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
1058}
1059
1060// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real.
1061template<int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1062EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi)
1063{
1064 pger_common<Packet, false>(accReal, lhsV, rhsV);
1065 if(LhsIsReal)
1066 {
1067 pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
1068 EIGEN_UNUSED_VARIABLE(lhsVi);
1069 } else {
1070 if (!RhsIsReal) {
1071 pger_common<Packet, ConjugateLhs == ConjugateRhs>(accReal, lhsVi, rhsVi);
1072 pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
1073 } else {
1074 EIGEN_UNUSED_VARIABLE(rhsVi);
1075 }
1076 pger_common<Packet, ConjugateLhs>(accImag, lhsVi, rhsV);
1077 }
1078}
1079
1080template<int N, typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1081EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
1082{
1083 Packet lhsV = ploadLhs<Scalar, Packet>(lhs_ptr);
1084 Packet lhsVi;
1085 if(!LhsIsReal) lhsVi = ploadLhs<Scalar, Packet>(lhs_ptr_imag);
1086 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1087
1088 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1089}
1090
1091template<typename Scalar, typename Packet, typename Index, bool LhsIsReal>
1092EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi, Index remaining_rows)
1093{
1094#ifdef _ARCH_PWR9
1095 lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar));
1096 if(!LhsIsReal) lhsVi = vec_xl_len((Scalar *)lhs_ptr_imag, remaining_rows * sizeof(Scalar));
1097 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1098#else
1099 Index i = 0;
1100 do {
1101 lhsV[i] = lhs_ptr[i];
1102 if(!LhsIsReal) lhsVi[i] = lhs_ptr_imag[i];
1103 } while (++i < remaining_rows);
1104 if(LhsIsReal) EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1105#endif
1106}
1107
1108template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1109EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi, Index remaining_rows)
1110{
1111 Packet lhsV, lhsVi;
1112 loadPacketRemaining<Scalar, Packet, Index, LhsIsReal>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi, remaining_rows);
1113
1114 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1115}
1116
1117template<typename Scalar, typename Packet>
1118EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs)
1119{
1120 return ploadu<Packet>(lhs);
1121}
1122
1123// Zero the accumulator on PacketBlock.
1124template<typename Scalar, typename Packet>
1125EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,4>& acc)
1126{
1127 acc.packet[0] = pset1<Packet>((Scalar)0);
1128 acc.packet[1] = pset1<Packet>((Scalar)0);
1129 acc.packet[2] = pset1<Packet>((Scalar)0);
1130 acc.packet[3] = pset1<Packet>((Scalar)0);
1131}
1132
1133template<typename Scalar, typename Packet>
1134EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,1>& acc)
1135{
1136 acc.packet[0] = pset1<Packet>((Scalar)0);
1137}
1138
1139// Scale the PacketBlock vectors by alpha.
1140template<typename Packet>
1141EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
1142{
1143 acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
1144 acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
1145 acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
1146 acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
1147}
1148
1149template<typename Packet>
1150EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
1151{
1152 acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
1153}
1154
1155template<typename Packet>
1156EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
1157{
1158 acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
1159 acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
1160 acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
1161 acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
1162}
1163
1164template<typename Packet>
1165EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
1166{
1167 acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
1168}
1169
1170// Complex version of PacketBlock scaling.
1171template<typename Packet, int N>
1172EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag)
1173{
1174 bscalec_common<Packet>(cReal, aReal, bReal);
1175
1176 bscalec_common<Packet>(cImag, aImag, bReal);
1177
1178 pger_common<Packet, true>(&cReal, bImag, aImag.packet);
1179
1180 pger_common<Packet, false>(&cImag, bImag, aReal.packet);
1181}
1182
1183template<typename Packet>
1184EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,4>& acc, const Packet& pMask)
1185{
1186 acc.packet[0] = pand(acc.packet[0], pMask);
1187 acc.packet[1] = pand(acc.packet[1], pMask);
1188 acc.packet[2] = pand(acc.packet[2], pMask);
1189 acc.packet[3] = pand(acc.packet[3], pMask);
1190}
1191
1192template<typename Packet>
1193EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,4>& aReal, PacketBlock<Packet,4>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,4>& cReal, PacketBlock<Packet,4>& cImag, const Packet& pMask)
1194{
1195 band<Packet>(aReal, pMask);
1196 band<Packet>(aImag, pMask);
1197
1198 bscalec<Packet,4>(aReal, aImag, bReal, bImag, cReal, cImag);
1199}
1200
1201// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
1202template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
1203EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col)
1204{
1205 if (StorageOrder == RowMajor) {
1206 acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
1207 acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
1208 acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
1209 acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
1210 } else {
1211 acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
1212 acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
1213 acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
1214 acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
1215 }
1216}
1217
1218// An overload of bload when you have a PacketBLock with 8 vectors.
1219template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
1220EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col)
1221{
1222 if (StorageOrder == RowMajor) {
1223 acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
1224 acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
1225 acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
1226 acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
1227 acc.packet[4] = res.template loadPacket<Packet>(row + 0, col + (N+1)*accCols);
1228 acc.packet[5] = res.template loadPacket<Packet>(row + 1, col + (N+1)*accCols);
1229 acc.packet[6] = res.template loadPacket<Packet>(row + 2, col + (N+1)*accCols);
1230 acc.packet[7] = res.template loadPacket<Packet>(row + 3, col + (N+1)*accCols);
1231 } else {
1232 acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
1233 acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
1234 acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
1235 acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
1236 acc.packet[4] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
1237 acc.packet[5] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 1);
1238 acc.packet[6] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 2);
1239 acc.packet[7] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 3);
1240 }
1241}
1242
1243template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
1244EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,2>& acc, const DataMapper& res, Index row, Index col)
1245{
1246 acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
1247 acc.packet[1] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
1248}
1249
1250const static Packet4i mask41 = { -1, 0, 0, 0 };
1251const static Packet4i mask42 = { -1, -1, 0, 0 };
1252const static Packet4i mask43 = { -1, -1, -1, 0 };
1253
1254const static Packet2l mask21 = { -1, 0 };
1255
1256template<typename Packet>
1257EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows)
1258{
1259 if (remaining_rows == 0) {
1260 return pset1<Packet>(float(0.0)); // Not used
1261 } else {
1262 switch (remaining_rows) {
1263 case 1: return Packet(mask41);
1264 case 2: return Packet(mask42);
1265 default: return Packet(mask43);
1266 }
1267 }
1268}
1269
1270template<>
1271EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const int remaining_rows)
1272{
1273 if (remaining_rows == 0) {
1274 return pset1<Packet2d>(double(0.0)); // Not used
1275 } else {
1276 return Packet2d(mask21);
1277 }
1278}
1279
1280template<typename Packet>
1281EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha, const Packet& pMask)
1282{
1283 band<Packet>(accZ, pMask);
1284
1285 bscale<Packet>(acc, accZ, pAlpha);
1286}
1287
1288template<typename Packet>
1289EIGEN_ALWAYS_INLINE void pbroadcast4_old(const __UNPACK_TYPE__(Packet)* a, Packet& a0, Packet& a1, Packet& a2, Packet& a3)
1290{
1291 pbroadcast4<Packet>(a, a0, a1, a2, a3);
1292}
1293
1294template<>
1295EIGEN_ALWAYS_INLINE void pbroadcast4_old<Packet2d>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
1296{
1297 a1 = pload<Packet2d>(a);
1298 a3 = pload<Packet2d>(a + 2);
1299 a0 = vec_splat(a1, 0);
1300 a1 = vec_splat(a1, 1);
1301 a2 = vec_splat(a3, 0);
1302 a3 = vec_splat(a3, 1);
1303}
1304
1305// PEEL loop factor.
1306#define PEEL 7
1307
1308template<typename Scalar, typename Packet, typename Index>
1309EIGEN_ALWAYS_INLINE void MICRO_EXTRA_COL(
1310 const Scalar* &lhs_ptr,
1311 const Scalar* &rhs_ptr,
1312 PacketBlock<Packet,1> &accZero,
1313 Index remaining_rows,
1314 Index remaining_cols)
1315{
1316 Packet rhsV[1];
1317 rhsV[0] = pset1<Packet>(rhs_ptr[0]);
1318 pger<1,Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
1319 lhs_ptr += remaining_rows;
1320 rhs_ptr += remaining_cols;
1321}
1322
1323template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
1324EIGEN_STRONG_INLINE void gemm_extra_col(
1325 const DataMapper& res,
1326 const Scalar* lhs_base,
1327 const Scalar* rhs_base,
1328 Index depth,
1329 Index strideA,
1330 Index offsetA,
1331 Index row,
1332 Index col,
1333 Index remaining_rows,
1334 Index remaining_cols,
1335 const Packet& pAlpha)
1336{
1337 const Scalar* rhs_ptr = rhs_base;
1338 const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
1339 PacketBlock<Packet,1> accZero;
1340
1341 bsetzero<Scalar, Packet>(accZero);
1342
1343 Index remaining_depth = (depth & -accRows);
1344 Index k = 0;
1345 for(; k + PEEL <= remaining_depth; k+= PEEL)
1346 {
1347 EIGEN_POWER_PREFETCH(rhs_ptr);
1348 EIGEN_POWER_PREFETCH(lhs_ptr);
1349 for (int l = 0; l < PEEL; l++) {
1350 MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
1351 }
1352 }
1353 for(; k < remaining_depth; k++)
1354 {
1355 MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
1356 }
1357 for(; k < depth; k++)
1358 {
1359 Packet rhsV[1];
1360 rhsV[0] = pset1<Packet>(rhs_ptr[0]);
1361 pger<1, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
1362 lhs_ptr += remaining_rows;
1363 rhs_ptr += remaining_cols;
1364 }
1365
1366 accZero.packet[0] = vec_mul(pAlpha, accZero.packet[0]);
1367 for(Index i = 0; i < remaining_rows; i++) {
1368 res(row + i, col) += accZero.packet[0][i];
1369 }
1370}
1371
1372template<typename Scalar, typename Packet, typename Index, const Index accRows>
1373EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
1374 const Scalar* &lhs_ptr,
1375 const Scalar* &rhs_ptr,
1376 PacketBlock<Packet,4> &accZero,
1377 Index remaining_rows)
1378{
1379 Packet rhsV[4];
1380 pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1381 pger<4, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
1382 lhs_ptr += remaining_rows;
1383 rhs_ptr += accRows;
1384}
1385
1386template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
1387EIGEN_STRONG_INLINE void gemm_extra_row(
1388 const DataMapper& res,
1389 const Scalar* lhs_base,
1390 const Scalar* rhs_base,
1391 Index depth,
1392 Index strideA,
1393 Index offsetA,
1394 Index row,
1395 Index col,
1396 Index rows,
1397 Index cols,
1398 Index remaining_rows,
1399 const Packet& pAlpha,
1400 const Packet& pMask)
1401{
1402 const Scalar* rhs_ptr = rhs_base;
1403 const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
1404 PacketBlock<Packet,4> accZero, acc;
1405
1406 bsetzero<Scalar, Packet>(accZero);
1407
1408 Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
1409 Index k = 0;
1410 for(; k + PEEL <= remaining_depth; k+= PEEL)
1411 {
1412 EIGEN_POWER_PREFETCH(rhs_ptr);
1413 EIGEN_POWER_PREFETCH(lhs_ptr);
1414 for (int l = 0; l < PEEL; l++) {
1415 MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
1416 }
1417 }
1418 for(; k < remaining_depth; k++)
1419 {
1420 MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
1421 }
1422
1423 if ((remaining_depth == depth) && (rows >= accCols))
1424 {
1425 for(Index j = 0; j < 4; j++) {
1426 acc.packet[j] = res.template loadPacket<Packet>(row, col + j);
1427 }
1428 bscale<Packet>(acc, accZero, pAlpha, pMask);
1429 res.template storePacketBlock<Packet,4>(row, col, acc);
1430 } else {
1431 for(; k < depth; k++)
1432 {
1433 Packet rhsV[4];
1434 pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1435 pger<4, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
1436 lhs_ptr += remaining_rows;
1437 rhs_ptr += accRows;
1438 }
1439
1440 for(Index j = 0; j < 4; j++) {
1441 accZero.packet[j] = vec_mul(pAlpha, accZero.packet[j]);
1442 }
1443 for(Index j = 0; j < 4; j++) {
1444 for(Index i = 0; i < remaining_rows; i++) {
1445 res(row + i, col + j) += accZero.packet[j][i];
1446 }
1447 }
1448 }
1449}
1450
1451#define MICRO_UNROLL(func) \
1452 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1453
1454#define MICRO_UNROLL_WORK(func, func2, peel) \
1455 MICRO_UNROLL(func2); \
1456 func(0,peel) func(1,peel) func(2,peel) func(3,peel) \
1457 func(4,peel) func(5,peel) func(6,peel) func(7,peel)
1458
1459#define MICRO_LOAD_ONE(iter) \
1460 if (unroll_factor > iter) { \
1461 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \
1462 lhs_ptr##iter += accCols; \
1463 } else { \
1464 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
1465 }
1466
1467#define MICRO_WORK_ONE(iter, peel) \
1468 if (unroll_factor > iter) { \
1469 pger_common<Packet, false>(&accZero##iter, lhsV##iter, rhsV##peel); \
1470 }
1471
1472#define MICRO_TYPE_PEEL4(func, func2, peel) \
1473 if (PEEL > peel) { \
1474 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
1475 pbroadcast4<Packet>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1476 MICRO_UNROLL_WORK(func, func2, peel) \
1477 } else { \
1478 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1479 }
1480
1481#define MICRO_TYPE_PEEL1(func, func2, peel) \
1482 if (PEEL > peel) { \
1483 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
1484 rhsV##peel[0] = pset1<Packet>(rhs_ptr[remaining_cols * peel]); \
1485 MICRO_UNROLL_WORK(func, func2, peel) \
1486 } else { \
1487 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1488 }
1489
1490#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
1491 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \
1492 func(func1,func2,0); func(func1,func2,1); \
1493 func(func1,func2,2); func(func1,func2,3); \
1494 func(func1,func2,4); func(func1,func2,5); \
1495 func(func1,func2,6); func(func1,func2,7); \
1496 func(func1,func2,8); func(func1,func2,9);
1497
1498#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
1499 Packet rhsV0[M]; \
1500 func(func1,func2,0);
1501
1502#define MICRO_ONE_PEEL4 \
1503 MICRO_UNROLL_TYPE_PEEL(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
1504 rhs_ptr += (accRows * PEEL);
1505
1506#define MICRO_ONE4 \
1507 MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
1508 rhs_ptr += accRows;
1509
1510#define MICRO_ONE_PEEL1 \
1511 MICRO_UNROLL_TYPE_PEEL(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
1512 rhs_ptr += (remaining_cols * PEEL);
1513
1514#define MICRO_ONE1 \
1515 MICRO_UNROLL_TYPE_ONE(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
1516 rhs_ptr += remaining_cols;
1517
1518#define MICRO_DST_PTR_ONE(iter) \
1519 if (unroll_factor > iter) { \
1520 bsetzero<Scalar, Packet>(accZero##iter); \
1521 } else { \
1522 EIGEN_UNUSED_VARIABLE(accZero##iter); \
1523 }
1524
1525#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
1526
1527#define MICRO_SRC_PTR_ONE(iter) \
1528 if (unroll_factor > iter) { \
1529 lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
1530 } else { \
1531 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
1532 }
1533
1534#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
1535
1536#define MICRO_PREFETCH_ONE(iter) \
1537 if (unroll_factor > iter) { \
1538 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \
1539 }
1540
1541#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
1542
1543#define MICRO_STORE_ONE(iter) \
1544 if (unroll_factor > iter) { \
1545 acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
1546 acc.packet[1] = res.template loadPacket<Packet>(row + iter*accCols, col + 1); \
1547 acc.packet[2] = res.template loadPacket<Packet>(row + iter*accCols, col + 2); \
1548 acc.packet[3] = res.template loadPacket<Packet>(row + iter*accCols, col + 3); \
1549 bscale<Packet>(acc, accZero##iter, pAlpha); \
1550 res.template storePacketBlock<Packet,4>(row + iter*accCols, col, acc); \
1551 }
1552
1553#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
1554
1555#define MICRO_COL_STORE_ONE(iter) \
1556 if (unroll_factor > iter) { \
1557 acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
1558 bscale<Packet>(acc, accZero##iter, pAlpha); \
1559 res.template storePacketBlock<Packet,1>(row + iter*accCols, col, acc); \
1560 }
1561
1562#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE)
1563
1564template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
1565EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
1566 const DataMapper& res,
1567 const Scalar* lhs_base,
1568 const Scalar* rhs_base,
1569 Index depth,
1570 Index strideA,
1571 Index offsetA,
1572 Index& row,
1573 Index col,
1574 const Packet& pAlpha)
1575{
1576 const Scalar* rhs_ptr = rhs_base;
1577 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
1578 PacketBlock<Packet,4> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
1579 PacketBlock<Packet,4> acc;
1580
1581 MICRO_SRC_PTR
1582 MICRO_DST_PTR
1583
1584 Index k = 0;
1585 for(; k + PEEL <= depth; k+= PEEL)
1586 {
1587 EIGEN_POWER_PREFETCH(rhs_ptr);
1588 MICRO_PREFETCH
1589 MICRO_ONE_PEEL4
1590 }
1591 for(; k < depth; k++)
1592 {
1593 MICRO_ONE4
1594 }
1595 MICRO_STORE
1596
1597 row += unroll_factor*accCols;
1598}
1599
1600template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
1601EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration(
1602 const DataMapper& res,
1603 const Scalar* lhs_base,
1604 const Scalar* rhs_base,
1605 Index depth,
1606 Index strideA,
1607 Index offsetA,
1608 Index& row,
1609 Index col,
1610 Index remaining_cols,
1611 const Packet& pAlpha)
1612{
1613 const Scalar* rhs_ptr = rhs_base;
1614 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
1615 PacketBlock<Packet,1> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
1616 PacketBlock<Packet,1> acc;
1617
1618 MICRO_SRC_PTR
1619 MICRO_DST_PTR
1620
1621 Index k = 0;
1622 for(; k + PEEL <= depth; k+= PEEL)
1623 {
1624 EIGEN_POWER_PREFETCH(rhs_ptr);
1625 MICRO_PREFETCH
1626 MICRO_ONE_PEEL1
1627 }
1628 for(; k < depth; k++)
1629 {
1630 MICRO_ONE1
1631 }
1632 MICRO_COL_STORE
1633
1634 row += unroll_factor*accCols;
1635}
1636
1637template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
1638EIGEN_STRONG_INLINE void gemm_unrolled_col(
1639 const DataMapper& res,
1640 const Scalar* lhs_base,
1641 const Scalar* rhs_base,
1642 Index depth,
1643 Index strideA,
1644 Index offsetA,
1645 Index& row,
1646 Index rows,
1647 Index col,
1648 Index remaining_cols,
1649 const Packet& pAlpha)
1650{
1651#define MAX_UNROLL 6
1652 while(row + MAX_UNROLL*accCols <= rows) {
1653 gemm_unrolled_col_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
1654 }
1655 switch( (rows-row)/accCols ) {
1656#if MAX_UNROLL > 7
1657 case 7:
1658 gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
1659 break;
1660#endif
1661#if MAX_UNROLL > 6
1662 case 6:
1663 gemm_unrolled_col_iteration<6, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
1664 break;
1665#endif
1666#if MAX_UNROLL > 5
1667 case 5:
1668 gemm_unrolled_col_iteration<5, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
1669 break;
1670#endif
1671#if MAX_UNROLL > 4
1672 case 4:
1673 gemm_unrolled_col_iteration<4, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
1674 break;
1675#endif
1676#if MAX_UNROLL > 3
1677 case 3:
1678 gemm_unrolled_col_iteration<3, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
1679 break;
1680#endif
1681#if MAX_UNROLL > 2
1682 case 2:
1683 gemm_unrolled_col_iteration<2, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
1684 break;
1685#endif
1686#if MAX_UNROLL > 1
1687 case 1:
1688 gemm_unrolled_col_iteration<1, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
1689 break;
1690#endif
1691 default:
1692 break;
1693 }
1694#undef MAX_UNROLL
1695}
1696
1697/****************
1698 * GEMM kernels *
1699 * **************/
1700template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
1701EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
1702{
1703 const Index remaining_rows = rows % accCols;
1704 const Index remaining_cols = cols % accRows;
1705
1706 if( strideA == -1 ) strideA = depth;
1707 if( strideB == -1 ) strideB = depth;
1708
1709 const Packet pAlpha = pset1<Packet>(alpha);
1710 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
1711
1712 Index col = 0;
1713 for(; col + accRows <= cols; col += accRows)
1714 {
1715 const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
1716 const Scalar* lhs_base = blockA;
1717 Index row = 0;
1718
1719#define MAX_UNROLL 6
1720 while(row + MAX_UNROLL*accCols <= rows) {
1721 gemm_unrolled_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
1722 }
1723 switch( (rows-row)/accCols ) {
1724#if MAX_UNROLL > 7
1725 case 7:
1726 gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
1727 break;
1728#endif
1729#if MAX_UNROLL > 6
1730 case 6:
1731 gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
1732 break;
1733#endif
1734#if MAX_UNROLL > 5
1735 case 5:
1736 gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
1737 break;
1738#endif
1739#if MAX_UNROLL > 4
1740 case 4:
1741 gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
1742 break;
1743#endif
1744#if MAX_UNROLL > 3
1745 case 3:
1746 gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
1747 break;
1748#endif
1749#if MAX_UNROLL > 2
1750 case 2:
1751 gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
1752 break;
1753#endif
1754#if MAX_UNROLL > 1
1755 case 1:
1756 gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
1757 break;
1758#endif
1759 default:
1760 break;
1761 }
1762#undef MAX_UNROLL
1763
1764 if(remaining_rows > 0)
1765 {
1766 gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
1767 }
1768 }
1769
1770 if(remaining_cols > 0)
1771 {
1772 const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
1773 const Scalar* lhs_base = blockA;
1774
1775 for(; col < cols; col++)
1776 {
1777 Index row = 0;
1778
1779 gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
1780
1781 if (remaining_rows > 0)
1782 {
1783 gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
1784 }
1785 rhs_base++;
1786 }
1787 }
1788}
1789
1790#define accColsC (accCols / 2)
1791#define advanceRows ((LhsIsReal) ? 1 : 2)
1792#define advanceCols ((RhsIsReal) ? 1 : 2)
1793
1794// PEEL_COMPLEX loop factor.
1795#define PEEL_COMPLEX 3
1796
1797template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1798EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_COL(
1799 const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
1800 const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
1801 PacketBlock<Packet,1> &accReal, PacketBlock<Packet,1> &accImag,
1802 Index remaining_rows,
1803 Index remaining_cols)
1804{
1805 Packet rhsV[1], rhsVi[1];
1806 rhsV[0] = pset1<Packet>(rhs_ptr_real[0]);
1807 if(!RhsIsReal) rhsVi[0] = pset1<Packet>(rhs_ptr_imag[0]);
1808 pgerc<1, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
1809 lhs_ptr_real += remaining_rows;
1810 if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
1811 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1812 rhs_ptr_real += remaining_cols;
1813 if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
1814 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
1815}
1816
1817template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1818EIGEN_STRONG_INLINE void gemm_complex_extra_col(
1819 const DataMapper& res,
1820 const Scalar* lhs_base,
1821 const Scalar* rhs_base,
1822 Index depth,
1823 Index strideA,
1824 Index offsetA,
1825 Index strideB,
1826 Index row,
1827 Index col,
1828 Index remaining_rows,
1829 Index remaining_cols,
1830 const Packet& pAlphaReal,
1831 const Packet& pAlphaImag)
1832{
1833 const Scalar* rhs_ptr_real = rhs_base;
1834 const Scalar* rhs_ptr_imag;
1835 if(!RhsIsReal) rhs_ptr_imag = rhs_base + remaining_cols*strideB;
1836 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
1837 const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
1838 const Scalar* lhs_ptr_imag;
1839 if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
1840 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1841 PacketBlock<Packet,1> accReal, accImag;
1842 PacketBlock<Packet,1> taccReal, taccImag;
1843 PacketBlock<Packetc,1> acc0, acc1;
1844
1845 bsetzero<Scalar, Packet>(accReal);
1846 bsetzero<Scalar, Packet>(accImag);
1847
1848 Index remaining_depth = (depth & -accRows);
1849 Index k = 0;
1850 for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX)
1851 {
1852 EIGEN_POWER_PREFETCH(rhs_ptr_real);
1853 if(!RhsIsReal) {
1854 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
1855 }
1856 EIGEN_POWER_PREFETCH(lhs_ptr_real);
1857 if(!LhsIsReal) {
1858 EIGEN_POWER_PREFETCH(lhs_ptr_imag);
1859 }
1860 for (int l = 0; l < PEEL_COMPLEX; l++) {
1861 MICRO_COMPLEX_EXTRA_COL<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols);
1862 }
1863 }
1864 for(; k < remaining_depth; k++)
1865 {
1866 MICRO_COMPLEX_EXTRA_COL<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols);
1867 }
1868
1869 for(; k < depth; k++)
1870 {
1871 Packet rhsV[1], rhsVi[1];
1872 rhsV[0] = pset1<Packet>(rhs_ptr_real[0]);
1873 if(!RhsIsReal) rhsVi[0] = pset1<Packet>(rhs_ptr_imag[0]);
1874 pgerc<1, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows);
1875 lhs_ptr_real += remaining_rows;
1876 if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
1877 rhs_ptr_real += remaining_cols;
1878 if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
1879 }
1880
1881 bscalec<Packet,1>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag);
1882 bcouple_common<Packet, Packetc>(taccReal, taccImag, acc0, acc1);
1883
1884 if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
1885 {
1886 res(row + 0, col + 0) += pfirst<Packetc>(acc0.packet[0]);
1887 } else {
1888 acc0.packet[0] += res.template loadPacket<Packetc>(row + 0, col + 0);
1889 res.template storePacketBlock<Packetc,1>(row + 0, col + 0, acc0);
1890 if(remaining_rows > accColsC) {
1891 res(row + accColsC, col + 0) += pfirst<Packetc>(acc1.packet[0]);
1892 }
1893 }
1894}
1895
1896template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1897EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
1898 const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
1899 const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
1900 PacketBlock<Packet,4> &accReal, PacketBlock<Packet,4> &accImag,
1901 Index remaining_rows)
1902{
1903 Packet rhsV[4], rhsVi[4];
1904 pbroadcast4_old<Packet>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1905 if(!RhsIsReal) pbroadcast4_old<Packet>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
1906 pgerc<4, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
1907 lhs_ptr_real += remaining_rows;
1908 if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
1909 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1910 rhs_ptr_real += accRows;
1911 if(!RhsIsReal) rhs_ptr_imag += accRows;
1912 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
1913}
1914
1915template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1916EIGEN_STRONG_INLINE void gemm_complex_extra_row(
1917 const DataMapper& res,
1918 const Scalar* lhs_base,
1919 const Scalar* rhs_base,
1920 Index depth,
1921 Index strideA,
1922 Index offsetA,
1923 Index strideB,
1924 Index row,
1925 Index col,
1926 Index rows,
1927 Index cols,
1928 Index remaining_rows,
1929 const Packet& pAlphaReal,
1930 const Packet& pAlphaImag,
1931 const Packet& pMask)
1932{
1933 const Scalar* rhs_ptr_real = rhs_base;
1934 const Scalar* rhs_ptr_imag;
1935 if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB;
1936 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
1937 const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
1938 const Scalar* lhs_ptr_imag;
1939 if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
1940 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1941 PacketBlock<Packet,4> accReal, accImag;
1942 PacketBlock<Packet,4> taccReal, taccImag;
1943 PacketBlock<Packetc,4> acc0, acc1;
1944 PacketBlock<Packetc,8> tRes;
1945
1946 bsetzero<Scalar, Packet>(accReal);
1947 bsetzero<Scalar, Packet>(accImag);
1948
1949 Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
1950 Index k = 0;
1951 for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX)
1952 {
1953 EIGEN_POWER_PREFETCH(rhs_ptr_real);
1954 if(!RhsIsReal) {
1955 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
1956 }
1957 EIGEN_POWER_PREFETCH(lhs_ptr_real);
1958 if(!LhsIsReal) {
1959 EIGEN_POWER_PREFETCH(lhs_ptr_imag);
1960 }
1961 for (int l = 0; l < PEEL_COMPLEX; l++) {
1962 MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows);
1963 }
1964 }
1965 for(; k < remaining_depth; k++)
1966 {
1967 MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows);
1968 }
1969
1970 if ((remaining_depth == depth) && (rows >= accCols))
1971 {
1972 bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row, col);
1973 bscalec<Packet>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
1974 bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1);
1975 res.template storePacketBlock<Packetc,4>(row + 0, col, acc0);
1976 res.template storePacketBlock<Packetc,4>(row + accColsC, col, acc1);
1977 } else {
1978 for(; k < depth; k++)
1979 {
1980 Packet rhsV[4], rhsVi[4];
1981 pbroadcast4_old<Packet>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
1982 if(!RhsIsReal) pbroadcast4_old<Packet>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
1983 pgerc<4, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows);
1984 lhs_ptr_real += remaining_rows;
1985 if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
1986 rhs_ptr_real += accRows;
1987 if(!RhsIsReal) rhs_ptr_imag += accRows;
1988 }
1989
1990 bscalec<Packet,4>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag);
1991 bcouple_common<Packet, Packetc>(taccReal, taccImag, acc0, acc1);
1992
1993 if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
1994 {
1995 for(Index j = 0; j < 4; j++) {
1996 res(row + 0, col + j) += pfirst<Packetc>(acc0.packet[j]);
1997 }
1998 } else {
1999 for(Index j = 0; j < 4; j++) {
2000 PacketBlock<Packetc,1> acc2;
2001 acc2.packet[0] = res.template loadPacket<Packetc>(row + 0, col + j) + acc0.packet[j];
2002 res.template storePacketBlock<Packetc,1>(row + 0, col + j, acc2);
2003 if(remaining_rows > accColsC) {
2004 res(row + accColsC, col + j) += pfirst<Packetc>(acc1.packet[j]);
2005 }
2006 }
2007 }
2008 }
2009}
2010
2011#define MICRO_COMPLEX_UNROLL(func) \
2012 func(0) func(1) func(2) func(3) func(4)
2013
2014#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2015 MICRO_COMPLEX_UNROLL(func2); \
2016 func(0,peel) func(1,peel) func(2,peel) func(3,peel) func(4,peel)
2017
2018#define MICRO_COMPLEX_LOAD_ONE(iter) \
2019 if (unroll_factor > iter) { \
2020 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
2021 lhs_ptr_real##iter += accCols; \
2022 if(!LhsIsReal) { \
2023 lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
2024 lhs_ptr_imag##iter += accCols; \
2025 } else { \
2026 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
2027 } \
2028 } else { \
2029 EIGEN_UNUSED_VARIABLE(lhsV##iter); \
2030 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
2031 }
2032
2033#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
2034 if (unroll_factor > iter) { \
2035 pgerc_common<4, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
2036 }
2037
2038#define MICRO_COMPLEX_WORK_ONE1(iter, peel) \
2039 if (unroll_factor > iter) { \
2040 pgerc_common<1, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
2041 }
2042
2043#define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
2044 if (PEEL_COMPLEX > peel) { \
2045 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
2046 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
2047 pbroadcast4_old<Packet>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
2048 if(!RhsIsReal) { \
2049 pbroadcast4_old<Packet>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
2050 } else { \
2051 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2052 } \
2053 MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2054 } else { \
2055 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2056 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2057 }
2058
2059#define MICRO_COMPLEX_TYPE_PEEL1(func, func2, peel) \
2060 if (PEEL_COMPLEX > peel) { \
2061 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
2062 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
2063 rhsV##peel[0] = pset1<Packet>(rhs_ptr_real[remaining_cols * peel]); \
2064 if(!RhsIsReal) { \
2065 rhsVi##peel[0] = pset1<Packet>(rhs_ptr_imag[remaining_cols * peel]); \
2066 } else { \
2067 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2068 } \
2069 MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2070 } else { \
2071 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2072 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2073 }
2074
2075#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2076 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \
2077 Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M], rhsVi4[M], rhsVi5[M], rhsVi6[M], rhsVi7[M], rhsVi8[M], rhsVi9[M]; \
2078 func(func1,func2,0); func(func1,func2,1); \
2079 func(func1,func2,2); func(func1,func2,3); \
2080 func(func1,func2,4); func(func1,func2,5); \
2081 func(func1,func2,6); func(func1,func2,7); \
2082 func(func1,func2,8); func(func1,func2,9);
2083
2084#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
2085 Packet rhsV0[M], rhsVi0[M];\
2086 func(func1,func2,0);
2087
2088#define MICRO_COMPLEX_ONE_PEEL4 \
2089 MICRO_COMPLEX_UNROLL_TYPE_PEEL(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
2090 rhs_ptr_real += (accRows * PEEL_COMPLEX); \
2091 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX);
2092
2093#define MICRO_COMPLEX_ONE4 \
2094 MICRO_COMPLEX_UNROLL_TYPE_ONE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \
2095 rhs_ptr_real += accRows; \
2096 if(!RhsIsReal) rhs_ptr_imag += accRows;
2097
2098#define MICRO_COMPLEX_ONE_PEEL1 \
2099 MICRO_COMPLEX_UNROLL_TYPE_PEEL(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \
2100 rhs_ptr_real += (remaining_cols * PEEL_COMPLEX); \
2101 if(!RhsIsReal) rhs_ptr_imag += (remaining_cols * PEEL_COMPLEX);
2102
2103#define MICRO_COMPLEX_ONE1 \
2104 MICRO_COMPLEX_UNROLL_TYPE_ONE(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \
2105 rhs_ptr_real += remaining_cols; \
2106 if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
2107
2108#define MICRO_COMPLEX_DST_PTR_ONE(iter) \
2109 if (unroll_factor > iter) { \
2110 bsetzero<Scalar, Packet>(accReal##iter); \
2111 bsetzero<Scalar, Packet>(accImag##iter); \
2112 } else { \
2113 EIGEN_UNUSED_VARIABLE(accReal##iter); \
2114 EIGEN_UNUSED_VARIABLE(accImag##iter); \
2115 }
2116
2117#define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
2118
2119#define MICRO_COMPLEX_SRC_PTR_ONE(iter) \
2120 if (unroll_factor > iter) { \
2121 lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
2122 if(!LhsIsReal) { \
2123 lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
2124 } else { \
2125 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
2126 } \
2127 } else { \
2128 EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
2129 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
2130 }
2131
2132#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
2133
2134#define MICRO_COMPLEX_PREFETCH_ONE(iter) \
2135 if (unroll_factor > iter) { \
2136 EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
2137 if(!LhsIsReal) { \
2138 EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
2139 } \
2140 }
2141
2142#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
2143
2144#define MICRO_COMPLEX_STORE_ONE(iter) \
2145 if (unroll_factor > iter) { \
2146 bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row + iter*accCols, col); \
2147 bscalec<Packet,4>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
2148 bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); \
2149 res.template storePacketBlock<Packetc,4>(row + iter*accCols + 0, col, acc0); \
2150 res.template storePacketBlock<Packetc,4>(row + iter*accCols + accColsC, col, acc1); \
2151 }
2152
2153#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
2154
2155#define MICRO_COMPLEX_COL_STORE_ONE(iter) \
2156 if (unroll_factor > iter) { \
2157 bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row + iter*accCols, col); \
2158 bscalec<Packet,1>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
2159 bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); \
2160 res.template storePacketBlock<Packetc,1>(row + iter*accCols + 0, col, acc0); \
2161 res.template storePacketBlock<Packetc,1>(row + iter*accCols + accColsC, col, acc1); \
2162 }
2163
2164#define MICRO_COMPLEX_COL_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_COL_STORE_ONE)
2165
2166template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2167EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration(
2168 const DataMapper& res,
2169 const Scalar* lhs_base,
2170 const Scalar* rhs_base,
2171 Index depth,
2172 Index strideA,
2173 Index offsetA,
2174 Index strideB,
2175 Index& row,
2176 Index col,
2177 const Packet& pAlphaReal,
2178 const Packet& pAlphaImag)
2179{
2180 const Scalar* rhs_ptr_real = rhs_base;
2181 const Scalar* rhs_ptr_imag;
2182 if(!RhsIsReal) {
2183 rhs_ptr_imag = rhs_base + accRows*strideB;
2184 } else {
2185 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
2186 }
2187 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
2188 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
2189 const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
2190 PacketBlock<Packet,4> accReal0, accImag0, accReal1, accImag1;
2191 PacketBlock<Packet,4> accReal2, accImag2, accReal3, accImag3;
2192 PacketBlock<Packet,4> accReal4, accImag4;
2193 PacketBlock<Packet,4> taccReal, taccImag;
2194 PacketBlock<Packetc,4> acc0, acc1;
2195 PacketBlock<Packetc,8> tRes;
2196
2197 MICRO_COMPLEX_SRC_PTR
2198 MICRO_COMPLEX_DST_PTR
2199
2200 Index k = 0;
2201 for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX)
2202 {
2203 EIGEN_POWER_PREFETCH(rhs_ptr_real);
2204 if(!RhsIsReal) {
2205 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
2206 }
2207 MICRO_COMPLEX_PREFETCH
2208 MICRO_COMPLEX_ONE_PEEL4
2209 }
2210 for(; k < depth; k++)
2211 {
2212 MICRO_COMPLEX_ONE4
2213 }
2214 MICRO_COMPLEX_STORE
2215
2216 row += unroll_factor*accCols;
2217}
2218
2219template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2220EIGEN_STRONG_INLINE void gemm_complex_unrolled_col_iteration(
2221 const DataMapper& res,
2222 const Scalar* lhs_base,
2223 const Scalar* rhs_base,
2224 Index depth,
2225 Index strideA,
2226 Index offsetA,
2227 Index strideB,
2228 Index& row,
2229 Index col,
2230 Index remaining_cols,
2231 const Packet& pAlphaReal,
2232 const Packet& pAlphaImag)
2233{
2234 const Scalar* rhs_ptr_real = rhs_base;
2235 const Scalar* rhs_ptr_imag;
2236 if(!RhsIsReal) {
2237 rhs_ptr_imag = rhs_base + remaining_cols*strideB;
2238 } else {
2239 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
2240 }
2241 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
2242 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
2243 const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
2244 PacketBlock<Packet,1> accReal0, accImag0, accReal1, accImag1;
2245 PacketBlock<Packet,1> accReal2, accImag2, accReal3, accImag3;
2246 PacketBlock<Packet,1> accReal4, accImag4;
2247 PacketBlock<Packet,1> taccReal, taccImag;
2248 PacketBlock<Packetc,1> acc0, acc1;
2249 PacketBlock<Packetc,2> tRes;
2250
2251 MICRO_COMPLEX_SRC_PTR
2252 MICRO_COMPLEX_DST_PTR
2253
2254 Index k = 0;
2255 for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX)
2256 {
2257 EIGEN_POWER_PREFETCH(rhs_ptr_real);
2258 if(!RhsIsReal) {
2259 EIGEN_POWER_PREFETCH(rhs_ptr_imag);
2260 }
2261 MICRO_COMPLEX_PREFETCH
2262 MICRO_COMPLEX_ONE_PEEL1
2263 }
2264 for(; k < depth; k++)
2265 {
2266 MICRO_COMPLEX_ONE1
2267 }
2268 MICRO_COMPLEX_COL_STORE
2269
2270 row += unroll_factor*accCols;
2271}
2272
2273template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2274EIGEN_STRONG_INLINE void gemm_complex_unrolled_col(
2275 const DataMapper& res,
2276 const Scalar* lhs_base,
2277 const Scalar* rhs_base,
2278 Index depth,
2279 Index strideA,
2280 Index offsetA,
2281 Index strideB,
2282 Index& row,
2283 Index rows,
2284 Index col,
2285 Index remaining_cols,
2286 const Packet& pAlphaReal,
2287 const Packet& pAlphaImag)
2288{
2289#define MAX_COMPLEX_UNROLL 3
2290 while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
2291 gemm_complex_unrolled_col_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
2292 }
2293 switch( (rows-row)/accCols ) {
2294#if MAX_COMPLEX_UNROLL > 4
2295 case 4:
2296 gemm_complex_unrolled_col_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
2297 break;
2298#endif
2299#if MAX_COMPLEX_UNROLL > 3
2300 case 3:
2301 gemm_complex_unrolled_col_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
2302 break;
2303#endif
2304#if MAX_COMPLEX_UNROLL > 2
2305 case 2:
2306 gemm_complex_unrolled_col_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
2307 break;
2308#endif
2309#if MAX_COMPLEX_UNROLL > 1
2310 case 1:
2311 gemm_complex_unrolled_col_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
2312 break;
2313#endif
2314 default:
2315 break;
2316 }
2317#undef MAX_COMPLEX_UNROLL
2318}
2319
2320template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2321EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
2322{
2323 const Index remaining_rows = rows % accCols;
2324 const Index remaining_cols = cols % accRows;
2325
2326 if( strideA == -1 ) strideA = depth;
2327 if( strideB == -1 ) strideB = depth;
2328
2329 const Packet pAlphaReal = pset1<Packet>(alpha.real());
2330 const Packet pAlphaImag = pset1<Packet>(alpha.imag());
2331 const Packet pMask = bmask<Packet>((const int)(remaining_rows));
2332
2333 const Scalar* blockA = (Scalar *) blockAc;
2334 const Scalar* blockB = (Scalar *) blockBc;
2335
2336 Index col = 0;
2337 for(; col + accRows <= cols; col += accRows)
2338 {
2339 const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
2340 const Scalar* lhs_base = blockA;
2341 Index row = 0;
2342
2343#define MAX_COMPLEX_UNROLL 3
2344 while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
2345 gemm_complex_unrolled_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
2346 }
2347 switch( (rows-row)/accCols ) {
2348#if MAX_COMPLEX_UNROLL > 4
2349 case 4:
2350 gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
2351 break;
2352#endif
2353#if MAX_COMPLEX_UNROLL > 3
2354 case 3:
2355 gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
2356 break;
2357#endif
2358#if MAX_COMPLEX_UNROLL > 2
2359 case 2:
2360 gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
2361 break;
2362#endif
2363#if MAX_COMPLEX_UNROLL > 1
2364 case 1:
2365 gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
2366 break;
2367#endif
2368 default:
2369 break;
2370 }
2371#undef MAX_COMPLEX_UNROLL
2372
2373 if(remaining_rows > 0)
2374 {
2375 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
2376 }
2377 }
2378
2379 if(remaining_cols > 0)
2380 {
2381 const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
2382 const Scalar* lhs_base = blockA;
2383
2384 for(; col < cols; col++)
2385 {
2386 Index row = 0;
2387
2388 gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
2389
2390 if (remaining_rows > 0)
2391 {
2392 gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
2393 }
2394 rhs_base++;
2395 }
2396 }
2397}
2398
2399#undef accColsC
2400#undef advanceCols
2401#undef advanceRows
2402
2403/************************************
2404 * ppc64le template specializations *
2405 * **********************************/
2406template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2407struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2408{
2409 void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2410};
2411
2412template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2413void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2414 ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2415{
2416 dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, true> pack;
2417 pack(blockA, lhs, depth, rows, stride, offset);
2418}
2419
2420template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2421struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2422{
2423 void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2424};
2425
2426template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2427void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2428 ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2429{
2430 dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, true> pack;
2431 pack(blockA, lhs, depth, rows, stride, offset);
2432}
2433
2434#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
2435template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2436struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2437{
2438 void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2439};
2440
2441template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2442void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2443 ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2444{
2445 dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, false> pack;
2446 pack(blockB, rhs, depth, cols, stride, offset);
2447}
2448
2449template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2450struct gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2451{
2452 void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2453};
2454
2455template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2456void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2457 ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2458{
2459 dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
2460 pack(blockB, rhs, depth, cols, stride, offset);
2461}
2462#endif
2463
2464template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2465struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2466{
2467 void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2468};
2469
2470template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2471void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2472 ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2473{
2474 dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, true> pack;
2475 pack(blockA, lhs, depth, rows, stride, offset);
2476}
2477
2478template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2479struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2480{
2481 void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2482};
2483
2484template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2485void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2486 ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2487{
2488 dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
2489 pack(blockA, lhs, depth, rows, stride, offset);
2490}
2491
2492template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2493struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2494{
2495 void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2496};
2497
2498template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2499void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2500 ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2501{
2502 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack;
2503 pack(blockA, lhs, depth, rows, stride, offset);
2504}
2505
2506template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2507struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2508{
2509 void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2510};
2511
2512template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2513void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2514 ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2515{
2516 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack;
2517 pack(blockA, lhs, depth, rows, stride, offset);
2518}
2519
2520#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
2521template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2522struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2523{
2524 void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2525};
2526
2527template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2528void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2529 ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2530{
2531 dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, false> pack;
2532 pack(blockB, rhs, depth, cols, stride, offset);
2533}
2534
2535template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2536struct gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2537{
2538 void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2539};
2540
2541template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2542void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2543 ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2544{
2545 dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
2546 pack(blockB, rhs, depth, cols, stride, offset);
2547}
2548#endif
2549
2550template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2551struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2552{
2553 void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2554};
2555
2556template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2557void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2558 ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2559{
2560 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack;
2561 pack(blockB, rhs, depth, cols, stride, offset);
2562}
2563
2564template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2565struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2566{
2567 void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2568};
2569
2570template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2571void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2572 ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2573{
2574 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack;
2575 pack(blockB, rhs, depth, cols, stride, offset);
2576}
2577
2578template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2579struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2580{
2581 void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2582};
2583
2584template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2585void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
2586 ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2587{
2588 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack;
2589 pack(blockA, lhs, depth, rows, stride, offset);
2590}
2591
2592template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2593struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2594{
2595 void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
2596};
2597
2598template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
2599void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
2600 ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
2601{
2602 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack;
2603 pack(blockA, lhs, depth, rows, stride, offset);
2604}
2605
2606template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2607struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2608{
2609 void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2610};
2611
2612template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2613void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
2614 ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2615{
2616 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack;
2617 pack(blockB, rhs, depth, cols, stride, offset);
2618}
2619
2620template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2621struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2622{
2623 void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
2624};
2625
2626template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
2627void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
2628 ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
2629{
2630 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack;
2631 pack(blockB, rhs, depth, cols, stride, offset);
2632}
2633
2634// ********* gebp specializations *********
2635template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2636struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2637{
2638 typedef typename quad_traits<float>::vectortype Packet;
2639 typedef typename quad_traits<float>::rhstype RhsPacket;
2640
2641 void operator()(const DataMapper& res, const float* blockA, const float* blockB,
2642 Index rows, Index depth, Index cols, float alpha,
2643 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2644};
2645
2646template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2647void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2648 ::operator()(const DataMapper& res, const float* blockA, const float* blockB,
2649 Index rows, Index depth, Index cols, float alpha,
2650 Index strideA, Index strideB, Index offsetA, Index offsetB)
2651 {
2652 const Index accRows = quad_traits<float>::rows;
2653 const Index accCols = quad_traits<float>::size;
2654 void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index);
2655
2656 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2657 //generate with MMA only
2658 gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2659 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2660 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2661 gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2662 }
2663 else{
2664 gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2665 }
2666 #else
2667 gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2668 #endif
2669 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2670 }
2671
2672template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2673struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2674{
2675 typedef Packet4f Packet;
2676 typedef Packet2cf Packetc;
2677 typedef Packet4f RhsPacket;
2678
2679 void operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
2680 Index rows, Index depth, Index cols, std::complex<float> alpha,
2681 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2682};
2683
2684template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2685void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2686 ::operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
2687 Index rows, Index depth, Index cols, std::complex<float> alpha,
2688 Index strideA, Index strideB, Index offsetA, Index offsetB)
2689 {
2690 const Index accRows = quad_traits<float>::rows;
2691 const Index accCols = quad_traits<float>::size;
2692 void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*,
2693 Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
2694
2695 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2696 //generate with MMA only
2697 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2698 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2699 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2700 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2701 }
2702 else{
2703 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2704 }
2705 #else
2706 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2707 #endif
2708 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2709 }
2710
2711template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2712struct gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2713{
2714 typedef Packet4f Packet;
2715 typedef Packet2cf Packetc;
2716 typedef Packet4f RhsPacket;
2717
2718 void operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
2719 Index rows, Index depth, Index cols, std::complex<float> alpha,
2720 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2721};
2722
2723template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2724void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2725 ::operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB,
2726 Index rows, Index depth, Index cols, std::complex<float> alpha,
2727 Index strideA, Index strideB, Index offsetA, Index offsetB)
2728 {
2729 const Index accRows = quad_traits<float>::rows;
2730 const Index accCols = quad_traits<float>::size;
2731 void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*,
2732 Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
2733 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2734 //generate with MMA only
2735 gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2736 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2737 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2738 gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2739 }
2740 else{
2741 gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2742 }
2743 #else
2744 gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2745 #endif
2746 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2747 }
2748
2749template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2750struct gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2751{
2752 typedef Packet4f Packet;
2753 typedef Packet2cf Packetc;
2754 typedef Packet4f RhsPacket;
2755
2756 void operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
2757 Index rows, Index depth, Index cols, std::complex<float> alpha,
2758 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2759};
2760
2761template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2762void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2763 ::operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB,
2764 Index rows, Index depth, Index cols, std::complex<float> alpha,
2765 Index strideA, Index strideB, Index offsetA, Index offsetB)
2766 {
2767 const Index accRows = quad_traits<float>::rows;
2768 const Index accCols = quad_traits<float>::size;
2769 void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*,
2770 Index, Index, Index, std::complex<float>, Index, Index, Index, Index);
2771 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2772 //generate with MMA only
2773 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2774 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2775 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2776 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2777 }
2778 else{
2779 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2780 }
2781 #else
2782 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2783 #endif
2784 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2785 }
2786
2787template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2788struct gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2789{
2790 typedef typename quad_traits<double>::vectortype Packet;
2791 typedef typename quad_traits<double>::rhstype RhsPacket;
2792
2793 void operator()(const DataMapper& res, const double* blockA, const double* blockB,
2794 Index rows, Index depth, Index cols, double alpha,
2795 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2796};
2797
2798template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2799void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2800 ::operator()(const DataMapper& res, const double* blockA, const double* blockB,
2801 Index rows, Index depth, Index cols, double alpha,
2802 Index strideA, Index strideB, Index offsetA, Index offsetB)
2803 {
2804 const Index accRows = quad_traits<double>::rows;
2805 const Index accCols = quad_traits<double>::size;
2806 void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index);
2807
2808 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2809 //generate with MMA only
2810 gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2811 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2812 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2813 gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2814 }
2815 else{
2816 gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2817 }
2818 #else
2819 gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>;
2820 #endif
2821 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2822 }
2823
2824template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2825struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2826{
2827 typedef quad_traits<double>::vectortype Packet;
2828 typedef Packet1cd Packetc;
2829 typedef quad_traits<double>::rhstype RhsPacket;
2830
2831 void operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
2832 Index rows, Index depth, Index cols, std::complex<double> alpha,
2833 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2834};
2835
2836template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2837void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2838 ::operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
2839 Index rows, Index depth, Index cols, std::complex<double> alpha,
2840 Index strideA, Index strideB, Index offsetA, Index offsetB)
2841 {
2842 const Index accRows = quad_traits<double>::rows;
2843 const Index accCols = quad_traits<double>::size;
2844 void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*,
2845 Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
2846 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2847 //generate with MMA only
2848 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2849 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2850 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2851 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2852 }
2853 else{
2854 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2855 }
2856 #else
2857 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>;
2858 #endif
2859 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2860 }
2861
2862template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2863struct gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2864{
2865 typedef quad_traits<double>::vectortype Packet;
2866 typedef Packet1cd Packetc;
2867 typedef quad_traits<double>::rhstype RhsPacket;
2868
2869 void operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
2870 Index rows, Index depth, Index cols, std::complex<double> alpha,
2871 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2872};
2873
2874template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2875void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2876 ::operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB,
2877 Index rows, Index depth, Index cols, std::complex<double> alpha,
2878 Index strideA, Index strideB, Index offsetA, Index offsetB)
2879 {
2880 const Index accRows = quad_traits<double>::rows;
2881 const Index accCols = quad_traits<double>::size;
2882 void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*,
2883 Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
2884 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2885 //generate with MMA only
2886 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2887 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2888 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2889 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2890 }
2891 else{
2892 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2893 }
2894 #else
2895 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>;
2896 #endif
2897 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2898 }
2899
2900template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2901struct gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2902{
2903 typedef quad_traits<double>::vectortype Packet;
2904 typedef Packet1cd Packetc;
2905 typedef quad_traits<double>::rhstype RhsPacket;
2906
2907 void operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
2908 Index rows, Index depth, Index cols, std::complex<double> alpha,
2909 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0);
2910};
2911
2912template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
2913void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
2914 ::operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB,
2915 Index rows, Index depth, Index cols, std::complex<double> alpha,
2916 Index strideA, Index strideB, Index offsetA, Index offsetB)
2917 {
2918 const Index accRows = quad_traits<double>::rows;
2919 const Index accCols = quad_traits<double>::size;
2920 void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*,
2921 Index, Index, Index, std::complex<double>, Index, Index, Index, Index);
2922 #ifdef EIGEN_ALTIVEC_MMA_ONLY
2923 //generate with MMA only
2924 gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2925 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA)
2926 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){
2927 gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2928 }
2929 else{
2930 gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2931 }
2932 #else
2933 gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>;
2934 #endif
2935 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
2936 }
2937} // end namespace internal
2938
2939} // end namespace Eigen
2940
2941#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H
@ ColMajor
Definition: Constants.h:319
@ RowMajor
Definition: Constants.h:321
Namespace containing all symbols from the Eigen library.
Definition: Core:141
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_real_op< typename Derived::Scalar >, const Derived > real(const Eigen::ArrayBase< Derived > &x)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_imag_op< typename Derived::Scalar >, const Derived > imag(const Eigen::ArrayBase< Derived > &x)