@@ -125,7 +125,7 @@ def _vectorized(
125125 raise TypeError ("allow_core_scalar must be literal." )
126126 allow_core_scalar = allow_core_scalar .literal_value
127127
128- batch_ndim = len (input_bc_patterns [0 ])
128+ batch_ndim = len (output_bc_patterns [0 ])
129129 nin = len (constant_inputs_types ) + len (input_types )
130130 nout = len (output_bc_patterns )
131131
@@ -138,13 +138,6 @@ def _vectorized(
138138 if not all (isinstance (input , types .Array ) for input in input_types ):
139139 raise TypingError ("Vectorized inputs must be arrays." )
140140
141- if not all (
142- len (pattern ) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns
143- ):
144- raise TypingError (
145- "Vectorized broadcastable patterns must have the same length."
146- )
147-
148141 core_input_types = []
149142 for input_type , bc_pattern in zip (input_types , input_bc_patterns , strict = True ):
150143 core_ndim = input_type .ndim - len (bc_pattern )
@@ -291,16 +284,21 @@ def compute_itershape(
291284 size : list [ir .Instruction ] | None ,
292285):
293286 one = ir .IntType (64 )(1 )
294- batch_ndim = len (broadcast_pattern [ 0 ] )
287+ batch_ndim = max (( len (p ) for p in broadcast_pattern ), default = 0 )
295288 shape = [None ] * batch_ndim
296289 if size is not None :
297290 shape = size
298291 for i in range (batch_ndim ):
299292 for j , (bc , in_shape ) in enumerate (
300293 zip (broadcast_pattern , in_shapes , strict = True )
301294 ):
302- length = in_shape [i ]
303- if bc [i ]:
295+ # Offset for inputs with fewer dims than batch_ndim
296+ offset = batch_ndim - len (bc )
297+ if i < offset :
298+ # Implicit broadcast dim — no array dim to check
299+ continue
300+ length = in_shape [i - offset ]
301+ if bc [i - offset ]:
304302 with builder .if_then (
305303 builder .icmp_unsigned ("!=" , length , one ), likely = False
306304 ):
@@ -336,8 +334,11 @@ def compute_itershape(
336334 for j , (bc , in_shape ) in enumerate (
337335 zip (broadcast_pattern , in_shapes , strict = True )
338336 ):
339- length = in_shape [i ]
340- if bc [i ]:
337+ offset = batch_ndim - len (bc )
338+ if i < offset :
339+ continue
340+ length = in_shape [i - offset ]
341+ if bc [i - offset ]:
341342 with builder .if_then (
342343 builder .icmp_unsigned ("!=" , length , one ), likely = False
343344 ):
@@ -452,6 +453,7 @@ def make_loop_call(
452453 # output_scope_set = mod.add_metadata([input_scope, output_scope])
453454
454455 zero = ir .Constant (ir .IntType (64 ), 0 )
456+ batch_ndim = len (iter_shape )
455457
456458 # Setup loops and initialize accumulators for outputs
457459 # This part corresponds to opening the loops
@@ -480,9 +482,12 @@ def make_loop_call(
480482 for input , input_type , bc in zip (inputs , input_types , input_bc , strict = True ):
481483 core_ndim = input_type .ndim - len (bc )
482484
483- idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , bc , strict = True )] + [
484- zero
485- ] * core_ndim
485+ # For inputs with fewer batch dims than the loop, skip leading loop indices
486+ offset = batch_ndim - len (bc )
487+ idxs_bc = [
488+ zero if bc_dim else idx
489+ for idx , bc_dim in zip (idxs [offset :], bc , strict = True )
490+ ] + [zero ] * core_ndim
486491 ptr = cgutils .get_item_pointer2 (
487492 context ,
488493 builder ,
0 commit comments