Skip to content

Commit 08df88f

Browse files
committed
feat: mask support in kronecker product for intermediate result matrix
1 parent 2c419fb commit 08df88f

File tree

5 files changed

+442
-6
lines changed

5 files changed

+442
-6
lines changed

Source/kronecker/GB_kron.c

Lines changed: 354 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,68 @@
2323
GB_Matrix_free (&T) ; \
2424
}
2525

26+
#define GBI(Ai,p,avlen) ((Ai == NULL) ? ((p) % (avlen)) : Ai [p])
27+
28+
#define GBB(Ab,p) ((Ab == NULL) ? 1 : Ab [p])
29+
30+
#define GBP(Ap,k,avlen) ((Ap == NULL) ? ((k) * (avlen)) : Ap [k])
31+
32+
#define GBH(Ah,k) ((Ah == NULL) ? (k) : Ah [k])
33+
2634
#include "kronecker/GB_kron.h"
2735
#include "mxm/GB_mxm.h"
2836
#include "transpose/GB_transpose.h"
2937
#include "mask/GB_accum_mask.h"
3038

39+
static bool GB_lookup_xoffset (
40+
GrB_Index* p,
41+
GrB_Matrix A,
42+
GrB_Index row,
43+
GrB_Index col
44+
)
45+
{
46+
GrB_Index vector = A->is_csc ? col : row ;
47+
GrB_Index coord = A->is_csc ? row : col ;
48+
49+
if (A->p == NULL) {
50+
GrB_Index offset = vector * A->vlen + coord ;
51+
if (A->b == NULL || ((int8_t*)A->b)[offset]) {
52+
*p = A->iso ? 0 : offset ;
53+
return true ;
54+
}
55+
return false ;
56+
}
57+
58+
int64_t start, end ;
59+
bool res ;
60+
61+
if (A->h == NULL) {
62+
start = A->p_is_32 ? ((uint32_t*)A->p)[vector] : ((uint64_t*)A->p)[vector] ;
63+
end = A->p_is_32 ? ((uint32_t*)A->p)[vector + 1] : ((uint64_t*)A->p)[vector + 1] ;
64+
end-- ;
65+
if (start > end) return false ;
66+
res = GB_binary_search(coord, A->i, A->i_is_32, &start, &end) ;
67+
if (res) { *p = A->iso ? 0 : start ; }
68+
return res ;
69+
}
70+
else
71+
{
72+
start = 0 ; end = A->plen - 1 ;
73+
res = GB_binary_search(vector, A->h, A->j_is_32, &start, &end) ;
74+
if (!res) return false ;
75+
int64_t k = start ;
76+
start = A->p_is_32 ? ((uint32_t*)A->p)[k] : ((uint64_t*)A->p)[k] ;
77+
end = A->p_is_32 ? ((uint32_t*)A->p)[k+1] : ((uint64_t*)A->p)[k+1] ;
78+
end-- ;
79+
if (start > end) return false ;
80+
res = GB_binary_search(coord, A->i, A->i_is_32, &start, &end) ;
81+
if (res) { *p = A->iso ? 0 : start ; }
82+
return res ;
83+
}
84+
}
85+
86+
#include "emult/GB_emult.h"
87+
3188
GrB_Info GB_kron // C<M> = accum (C, kron(A,B))
3289
(
3390
GrB_Matrix C, // input/output matrix for results
@@ -104,6 +161,302 @@ GrB_Info GB_kron // C<M> = accum (C, kron(A,B))
104161
// quick return if an empty mask is complemented
105162
GB_RETURN_IF_QUICK_MASK (C, C_replace, M, Mask_comp, Mask_struct) ;
106163

164+
// check if it's possible to apply mask immediately in kron
165+
// TODO: make MT of same CSR/CSC format as C
166+
167+
GrB_Matrix MT;
168+
if (M != NULL && !Mask_comp) {
169+
170+
// iterate over mask, count how many elements will be present in MT
171+
// initialize MT->p
172+
173+
GB_MATRIX_WAIT(M);
174+
175+
size_t allocated = 0 ;
176+
bool MT_hypersparse = (A->h != NULL) || (B->h != NULL);
177+
int64_t centries ;
178+
uint64_t nvecs ;
179+
centries = 0 ;
180+
nvecs = 0 ;
181+
182+
uint32_t* MTp32 = NULL ; uint64_t* MTp64 = NULL ;
183+
MTp32 = M->p_is_32 ? GB_calloc_memory (M->vdim + 1, sizeof(uint32_t), &allocated) : NULL ;
184+
MTp64 = M->p_is_32 ? NULL : GB_calloc_memory (M->vdim + 1, sizeof(uint64_t), &allocated) ;
185+
if (MTp32 == NULL && MTp64 == NULL)
186+
{
187+
OUT_OF_MEM_p:
188+
GB_FREE_WORKSPACE ;
189+
return GrB_OUT_OF_MEMORY ;
190+
}
191+
192+
GrB_Type MTtype = op->ztype ;
193+
const size_t MTsize = MTtype->size ;
194+
GB_void MTscalar [GB_VLA(MTsize)] ;
195+
bool MTiso = GB_emult_iso (MTscalar, MTtype, A, B, op) ;
196+
197+
GB_Mp_DECLARE(Mp, ) ;
198+
GB_Mp_PTR(Mp, M) ;
199+
200+
GB_Mh_DECLARE(Mh, ) ;
201+
GB_Mh_PTR(Mh, M) ;
202+
203+
GB_Mi_DECLARE(Mi, ) ;
204+
GB_Mi_PTR(Mi, M) ;
205+
206+
GB_cast_function cast_A = NULL ;
207+
GB_cast_function cast_B = NULL ;
208+
209+
cast_A = GB_cast_factory (op->xtype->code, A->type->code) ;
210+
cast_B = GB_cast_factory (op->ytype->code, B->type->code) ;
211+
212+
int64_t vlen = M->vlen ;
213+
#pragma omp parallel
214+
{
215+
GrB_Index offset ;
216+
217+
#pragma omp for reduction(+:nvecs)
218+
for (GrB_Index k = 0 ; k < M->nvec ; k++)
219+
{
220+
GrB_Index j = Mh32 ? GBH (Mh32, k) : GBH (Mh64, k) ;
221+
222+
int64_t pA_start = Mp32 ? GBP (Mp32, k, vlen) : GBP(Mp64, k, vlen) ;
223+
int64_t pA_end = Mp32 ? GBP (Mp32, k+1, vlen) : GBP(Mp64, k+1, vlen) ;
224+
bool nonempty = false ;
225+
for (GrB_Index p = pA_start ; p < pA_end ; p++)
226+
{
227+
if (!GBB (M->b, p)) continue ;
228+
229+
int64_t i = Mi32 ? GBI (Mi32, p, vlen) : GBI (Mi64, p, vlen) ;
230+
GrB_Index Mrow = M->is_csc ? i : j ; GrB_Index Mcol = M->is_csc ? j : i ;
231+
232+
// extract elements from A and B, increment MTp
233+
234+
if (Mask_struct || (M->iso ? ((int8_t*)M->x)[0] : ((int8_t*)M->x)[p])) {
235+
236+
GrB_Index arow = A_transpose ? (Mcol / bncols) : (Mrow / bnrows);
237+
GrB_Index acol = A_transpose ? (Mrow / bnrows) : (Mcol / bncols);
238+
239+
GrB_Index brow = B_transpose ? (Mcol % bncols) : (Mrow % bnrows);
240+
GrB_Index bcol = B_transpose ? (Mrow % bnrows) : (Mcol % bncols);
241+
242+
bool code = GB_lookup_xoffset(&offset, A, arow, acol) ;
243+
if (!code) {
244+
continue;
245+
}
246+
247+
code = GB_lookup_xoffset(&offset, B, brow, bcol) ;
248+
if (!code) {
249+
continue;
250+
}
251+
252+
if (M->p_is_32)
253+
{
254+
(MTp32[j])++ ;
255+
}
256+
else
257+
{
258+
(MTp64[j])++ ;
259+
}
260+
nonempty = true ;
261+
}
262+
}
263+
if (nonempty) nvecs++ ;
264+
}
265+
}
266+
267+
// GB_cumsum for MT->p
268+
269+
double work = M->vdim ;
270+
int nthreads_max = GB_Context_nthreads_max ( ) ;
271+
double chunk = GB_Context_chunk ( ) ;
272+
int cumsum_threads = GB_nthreads (work, chunk, nthreads_max) ;
273+
M->p_is_32 ? GB_cumsum(MTp32, M->p_is_32, M->vdim, NULL, cumsum_threads, Werk) :
274+
GB_cumsum(MTp64, M->p_is_32, M->vdim, NULL, cumsum_threads, Werk) ;
275+
276+
centries = M->p_is_32 ? MTp32[M->vdim] : MTp64[M->vdim] ;
277+
278+
uint32_t* MTi32 = NULL ; uint64_t* MTi64 = NULL;
279+
MTi32 = M->i_is_32 ? GB_malloc_memory (centries, sizeof(uint32_t), &allocated) : NULL ;
280+
MTi64 = M->i_is_32 ? NULL : GB_malloc_memory (centries, sizeof(uint64_t), &allocated) ;
281+
282+
if (centries > 0 && MTi32 == NULL && MTi64 == NULL) {
283+
OUT_OF_MEM_i:
284+
if (M->p_is_32) { GB_free_memory (&MTp32, (M->vdim + 1) * sizeof(uint32_t)) ; }
285+
else { GB_free_memory (&MTp64, (M->vdim + 1) * sizeof(uint64_t)) ; }
286+
goto OUT_OF_MEM_p ;
287+
}
288+
289+
void* MTx = NULL ;
290+
if (!MTiso)
291+
{
292+
MTx = GB_malloc_memory (centries, op->ztype->size, &allocated) ;
293+
}
294+
else
295+
{
296+
MTx = GB_malloc_memory (1, op->ztype->size, &allocated) ;
297+
if (MTx == NULL) goto OUT_OF_MEM_x ;
298+
memcpy (MTx, MTscalar, MTsize) ;
299+
}
300+
301+
if (centries > 0 && MTx == NULL)
302+
{
303+
OUT_OF_MEM_x:
304+
if (M->i_is_32) { GB_free_memory (&MTi32, centries * sizeof(uint32_t)) ; }
305+
else { GB_free_memory (&MTi64, centries * sizeof (uint64_t)) ; }
306+
goto OUT_OF_MEM_i ;
307+
}
308+
309+
#pragma omp parallel
310+
{
311+
GrB_Index offset ;
312+
GB_void a_elem[op->xtype->size] ;
313+
GB_void b_elem[op->ytype->size] ;
314+
315+
#pragma omp for
316+
for (GrB_Index k = 0 ; k < M->nvec ; k++)
317+
{
318+
GrB_Index j = Mh32 ? GBH (Mh32, k) : GBH (Mh64, k) ;
319+
320+
int64_t pA_start = Mp32 ? GBP (Mp32, k, vlen) : GBP(Mp64, k, vlen) ;
321+
int64_t pA_end = Mp32 ? GBP (Mp32, k+1, vlen) : GBP(Mp64, k+1, vlen) ;
322+
GrB_Index pos = M->p_is_32 ? MTp32[j] : MTp64[j] ;
323+
for (GrB_Index p = pA_start ; p < pA_end ; p++)
324+
{
325+
if (!GBB (M->b, p)) continue ;
326+
327+
int64_t i = Mi32 ? GBI (Mi32, p, vlen) : GBI (Mi64, p, vlen) ;
328+
GrB_Index Mrow = M->is_csc ? i : j ; GrB_Index Mcol = M->is_csc ? j : i ;
329+
330+
// extract elements from A and B,
331+
// initialize offset in MTi and MTx,
332+
// get result of op, place it in MTx
333+
334+
if (Mask_struct || (M->iso ? ((int8_t*)M->x)[0] : ((int8_t*)M->x)[p])) {
335+
336+
GrB_Index arow = A_transpose ? (Mcol / bncols) : (Mrow / bnrows);
337+
GrB_Index acol = A_transpose ? (Mrow / bnrows) : (Mcol / bncols);
338+
339+
GrB_Index brow = B_transpose ? (Mcol % bncols) : (Mrow % bnrows);
340+
GrB_Index bcol = B_transpose ? (Mrow % bnrows) : (Mcol % bncols);
341+
342+
bool code = GB_lookup_xoffset (&offset, A, arow, acol) ;
343+
if (!code) {
344+
continue;
345+
}
346+
if (!MTiso)
347+
cast_A (a_elem, A->x + offset * A->type->size, A->type->size) ;
348+
349+
code = GB_lookup_xoffset (&offset, B, brow, bcol) ;
350+
if (!code) {
351+
continue;
352+
}
353+
if (!MTiso)
354+
cast_B (b_elem, B->x + offset * B->type->size, B->type->size) ;
355+
356+
if (!MTiso)
357+
{
358+
if (op->binop_function) {
359+
op->binop_function (MTx + op->ztype->size * pos, a_elem, b_elem) ;
360+
}
361+
else {
362+
GrB_Index ix, iy, jx, jy ;
363+
ix = A_transpose ? acol : arow ;
364+
iy = A_transpose ? arow : acol ;
365+
jx = B_transpose ? bcol : brow ;
366+
jy = B_transpose ? brow : bcol ;
367+
op->idxbinop_function (MTx + op->ztype->size * pos, a_elem, ix, iy,
368+
b_elem, jx, jy, op->theta) ;
369+
}
370+
}
371+
372+
if (M->i_is_32) { MTi32[pos] = i ; } else { MTi64[pos] = i ; }
373+
pos++ ;
374+
}
375+
}
376+
}
377+
}
378+
379+
#undef GBI
380+
#undef GBB
381+
#undef GBP
382+
#undef GBH
383+
384+
// initialize other fields of MT properly
385+
386+
MT = NULL ;
387+
GrB_Info MTalloc = GB_new_bix (&MT, op->ztype, vlen, M->vdim, GB_ph_null, M->is_csc,
388+
GxB_SPARSE, true, M->hyper_switch, M->vdim, centries, true, MTiso,
389+
M->p_is_32, M->j_is_32, M->i_is_32) ;
390+
if (MTalloc != GrB_SUCCESS) {
391+
if (MTiso) { GB_free_memory (&MTx, op->ztype->size) ; }
392+
else { GB_free_memory (&MTx, centries * op->ztype->size) ; }
393+
goto OUT_OF_MEM_x ;
394+
}
395+
396+
GB_MATRIX_WAIT(MT) ;
397+
398+
GB_free_memory (&MT->i, MT->i_size) ;
399+
GB_free_memory (&MT->x, MT->x_size) ;
400+
401+
MT->p = M->p_is_32 ? (void*)MTp32 : (void*)MTp64 ;
402+
MT->i = M->i_is_32 ? (void*)MTi32 : (void*)MTi64 ;
403+
MT->x = MTx ;
404+
405+
MT->p_size = (M->p_is_32 ? sizeof(uint32_t) : sizeof(uint64_t)) * (M->vdim + 1) ;
406+
MT->i_size = ((M->i_is_32 ? sizeof(uint32_t) : sizeof(uint64_t)) * centries) ;
407+
MT->x_size = MT->iso ? op->ztype->size : op->ztype->size * centries ;
408+
MT->magic = GB_MAGIC ;
409+
MT->nvals = centries ;
410+
MT->nvec_nonempty = nvecs ;
411+
412+
// transpose and convert to hyper if needed
413+
414+
if (MT->is_csc != C->is_csc)
415+
{
416+
GrB_Info MTtranspose = GB_transpose_in_place (MT, true, Werk) ;
417+
if (MTtranspose != GrB_SUCCESS)
418+
{
419+
GB_FREE_WORKSPACE ;
420+
GB_Matrix_free (&MT) ;
421+
return MTtranspose ;
422+
}
423+
}
424+
425+
if (MT_hypersparse) {
426+
uint32_t* MTh32 = NULL ; uint64_t* MTh64 = NULL ;
427+
if (MT->j_is_32) {
428+
MTh32 = GB_malloc_memory (MT->vdim, sizeof(uint32_t), &allocated) ;
429+
}
430+
else {
431+
MTh64 = GB_malloc_memory (MT->vdim, sizeof(uint64_t), &allocated) ;
432+
}
433+
434+
if (MTh32 == NULL && MTh64 == NULL)
435+
{
436+
GB_FREE_WORKSPACE ;
437+
GB_Matrix_free (&MT) ;
438+
return GrB_OUT_OF_MEMORY ;
439+
}
440+
441+
#pragma omp parallel for
442+
for (GrB_Index i = 0; i < MT->vdim; i++) {
443+
if (MT->j_is_32) { MTh32[i] = i ; } else { MTh64[i] = i ; }
444+
}
445+
446+
MT->h = MTh32 ? (void*)MTh32 : (void*)MTh64 ;
447+
448+
GrB_Info MThyperprune = GB_hyper_prune (MT, Werk) ;
449+
if (MThyperprune != GrB_SUCCESS) {
450+
GB_FREE_WORKSPACE ;
451+
GB_Matrix_free (&MT) ;
452+
return MThyperprune ;
453+
}
454+
}
455+
456+
return (GB_accum_mask (C, M, NULL, accum, &MT, C_replace, Mask_comp, Mask_struct, Werk)) ;
457+
}
458+
459+
107460
//--------------------------------------------------------------------------
108461
// transpose A and B if requested
109462
//--------------------------------------------------------------------------
@@ -153,7 +506,7 @@ GrB_Info GB_kron // C<M> = accum (C, kron(A,B))
153506
GB_CLEAR_MATRIX_HEADER (T, &T_header) ;
154507
GB_OK (GB_kroner (T, T_is_csc, op, flipij,
155508
A_transpose ? AT : A, A_is_pattern,
156-
B_transpose ? BT : B, B_is_pattern, Werk)) ;
509+
B_transpose ? BT : B, B_is_pattern, M, Mask_comp, Mask_struct, Werk)) ;
157510

158511
GB_FREE_WORKSPACE ;
159512
ASSERT_MATRIX_OK (T, "T = kron(A,B)", GB0) ;

Source/kronecker/GB_kron.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ GrB_Info GB_kroner // C = kron (A,B)
3737
bool A_is_pattern, // true if values of A are not used
3838
const GrB_Matrix B, // input matrix
3939
bool B_is_pattern, // true if values of B are not used
40+
const GrB_Matrix Mask,
41+
const bool Mask_comp,
42+
const bool Mask_struct,
4043
GB_Werk Werk
4144
) ;
4245

0 commit comments

Comments
 (0)