Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 160 additions & 12 deletions src/NumSharp.Core/Logic/np.find_common_type.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,82 @@

namespace NumSharp
{
// ================================================================================
// TYPE PROMOTION SYSTEM
// ================================================================================
//
// This file implements NumPy-compatible type promotion for arithmetic operations.
// When two arrays (or an array and a scalar) are combined, this system determines
// the result dtype.
//
// ARCHITECTURE
// ============
//
// Four lookup tables are used (two pairs for Type and NPTypeCode access):
//
// _typemap_arr_arr / _nptypemap_arr_arr - Array + Array promotion
// _typemap_arr_scalar / _nptypemap_arr_scalar - Array + Scalar promotion
//
// The tables are FrozenDictionary<(T1, T2), TResult> for O(1) lookup.
//
// WHEN EACH TABLE IS USED
// =======================
//
// The _FindCommonType(NDArray, NDArray) method decides which table to use:
//
// if (both are non-scalar arrays) → _typemap_arr_arr
// if (both are scalar arrays) → _FindCommonScalarType (uses arr_arr rules)
// if (one is array, one is scalar) → _typemap_arr_scalar
//
// This matters because scalar promotion follows different rules than array promotion.
//
// KIND HIERARCHY
// ==============
//
// Types are grouped into "kinds" with a promotion hierarchy:
//
// boolean < integer < floating-point < complex
//
// When operands are of different kinds, the result promotes to the higher kind:
//
// int32 + float32 → float64 (int promotes to float)
// float32 + complex → complex (float promotes to complex)
//
// WITHIN-KIND PROMOTION
// =====================
//
// When operands are the same kind, promotion depends on the operation type:
//
// Array + Array (both non-scalar):
// - Result is the "larger" type that can hold both ranges
// - uint8 + int16 → int16 (int16 can hold uint8 range + negatives)
// - uint32 + int32 → int64 (need 64-bit to hold both ranges)
// - uint64 + int64 → float64 (no integer type can hold both!)
//
// Array + Scalar (NEP 50 behavior):
// - Array dtype wins when scalar is same-kind (e.g., both integers)
// - uint8_array + int32_scalar → uint8 (array wins)
// - float32_array + int32_scalar → float32 (array wins, same effective kind)
//
// EXAMPLES
// ========
//
// var a = np.array(new byte[] {1, 2, 3}); // uint8
// var b = np.array(new int[] {4, 5, 6}); // int32
//
// (a + b).dtype == np.int32 // arr+arr: promotes to int32
// (a + 5).dtype == np.uint8 // arr+scalar: array wins (NEP 50)
// (a + 5.0).dtype == np.float64 // cross-kind: float wins
//
// REFERENCES
// ==========
//
// - NumPy type promotion: https://numpy.org/doc/stable/reference/ufuncs.html#type-casting-rules
// - NEP 50 (scalar promotion): https://numpy.org/neps/nep-0050-scalar-promotion.html
// - Array API type promotion: https://data-apis.org/array-api/latest/API_specification/type_promotion.html
//
// ================================================================================

[SuppressMessage("ReSharper", "StaticMemberInitializerReferesToMemberBelow")]
public static partial class np
{
Expand Down Expand Up @@ -50,6 +126,39 @@ static np()
{
#region arr_arr

// ============================================================================
// ARRAY-ARRAY TYPE PROMOTION TABLE
// ============================================================================
//
// This table defines type promotion when TWO ARRAYS are combined.
// The key is (LeftArrayType, RightArrayType), the value is the result type.
//
// PROMOTION RULES:
//
// 1. Same type: result is that type
// int32 + int32 → int32
//
// 2. Same kind, different size: result is larger type
// int16 + int32 → int32
// float32 + float64 → float64
//
// 3. Signed + Unsigned (same size): result is next-larger signed type
// int16 + uint16 → int32 (need more bits for both ranges)
// int32 + uint32 → int64
// int64 + uint64 → float64 (no larger integer exists!)
//
// 4. Cross-kind: result is the higher kind
// int32 + float32 → float64 (int32 needs float64 precision)
// uint8 + float32 → float32 (uint8 fits in float32)
//
// 5. Complex: absorbs everything
// float32 + complex64 → complex64
// int32 + complex64 → complex128 (int32 needs float64 precision)
//
// This table matches NumPy 2.x arr+arr behavior exactly.
//
// ============================================================================

var typemap_arr_arr = new Dictionary<(Type, Type), Type>(180);
typemap_arr_arr.Add((np.@bool, np.@bool), np.@bool);
typemap_arr_arr.Add((np.@bool, np.uint8), np.uint8);
Expand Down Expand Up @@ -243,6 +352,45 @@ static np()

#region arr_scalar

// ============================================================================
// ARRAY-SCALAR TYPE PROMOTION TABLE
// ============================================================================
//
// This table defines type promotion when an array operates with a scalar value.
// The key is (ArrayType, ScalarType), the value is the result type.
//
// NUMSHARP DESIGN DECISION:
// C# primitive scalars (int, short, long, etc.) are treated as "weakly typed"
// like Python scalars in NumPy 2.x, NOT like NumPy scalars (np.int32, etc.).
//
// This means: np.array(new byte[]{1,2,3}) + 5 → uint8 result (not int32)
//
// WHY: This matches the natural Python/NumPy user experience where `arr + 5`
// preserves the array's dtype when both are integers. This is consistent with
// NumPy 2.x behavior under NEP 50 for Python scalar operands.
//
// NEP 50 (NumPy Enhancement Proposal 50):
// https://numpy.org/neps/nep-0050-scalar-promotion.html
//
// Key rule: When an array operates with a scalar of the same "kind" (e.g., both
// are integers), the array dtype wins. Cross-kind operations (int + float) still
// promote to the higher kind (float).
//
// AFFECTED ENTRIES (12 total - all unsigned array + signed scalar):
//
// | Array Type | Scalar Types | NumPy 1.x Result | NumPy 2.x Result |
// |------------|-------------------|------------------|------------------|
// | uint8 | int16/int32/int64 | int16/int32/int64| uint8 |
// | uint16 | int16/int32/int64 | int32/int32/int64| uint16 |
// | uint32 | int16/int32/int64 | int64/int64/int64| uint32 |
// | uint64 | int16/int32/int64 | float64 (!) | uint64 |
//
// Verified against NumPy 2.4.2:
// >>> (np.array([1,2,3], np.uint8) + 5).dtype
// dtype('uint8')
//
// ============================================================================

var typemap_arr_scalar = new Dictionary<(Type, Type), Type>();
typemap_arr_scalar.Add((np.@bool, np.@bool), np.@bool);
typemap_arr_scalar.Add((np.@bool, np.uint8), np.uint8);
Expand All @@ -259,11 +407,11 @@ static np()
typemap_arr_scalar.Add((np.uint8, np.@bool), np.uint8);
typemap_arr_scalar.Add((np.uint8, np.uint8), np.uint8);
typemap_arr_scalar.Add((np.uint8, np.@char), np.uint8);
typemap_arr_scalar.Add((np.uint8, np.int16), np.int16);
typemap_arr_scalar.Add((np.uint8, np.int16), np.uint8);
typemap_arr_scalar.Add((np.uint8, np.uint16), np.uint8);
typemap_arr_scalar.Add((np.uint8, np.int32), np.int32);
typemap_arr_scalar.Add((np.uint8, np.int32), np.uint8);
typemap_arr_scalar.Add((np.uint8, np.uint32), np.uint8);
typemap_arr_scalar.Add((np.uint8, np.int64), np.int64);
typemap_arr_scalar.Add((np.uint8, np.int64), np.uint8);
typemap_arr_scalar.Add((np.uint8, np.uint64), np.uint8);
typemap_arr_scalar.Add((np.uint8, np.float32), np.float32);
typemap_arr_scalar.Add((np.uint8, np.float64), np.float64);
Expand Down Expand Up @@ -298,11 +446,11 @@ static np()
typemap_arr_scalar.Add((np.uint16, np.@bool), np.uint16);
typemap_arr_scalar.Add((np.uint16, np.uint8), np.uint16);
typemap_arr_scalar.Add((np.uint16, np.@char), np.uint16);
typemap_arr_scalar.Add((np.uint16, np.int16), np.int32);
typemap_arr_scalar.Add((np.uint16, np.int16), np.uint16);
typemap_arr_scalar.Add((np.uint16, np.uint16), np.uint16);
typemap_arr_scalar.Add((np.uint16, np.int32), np.int32);
typemap_arr_scalar.Add((np.uint16, np.int32), np.uint16);
typemap_arr_scalar.Add((np.uint16, np.uint32), np.uint16);
typemap_arr_scalar.Add((np.uint16, np.int64), np.int64);
typemap_arr_scalar.Add((np.uint16, np.int64), np.uint16);
typemap_arr_scalar.Add((np.uint16, np.uint64), np.uint16);
typemap_arr_scalar.Add((np.uint16, np.float32), np.float32);
typemap_arr_scalar.Add((np.uint16, np.float64), np.float64);
Expand All @@ -324,11 +472,11 @@ static np()
typemap_arr_scalar.Add((np.uint32, np.@bool), np.uint32);
typemap_arr_scalar.Add((np.uint32, np.uint8), np.uint32);
typemap_arr_scalar.Add((np.uint32, np.@char), np.uint32);
typemap_arr_scalar.Add((np.uint32, np.int16), np.int64);
typemap_arr_scalar.Add((np.uint32, np.int16), np.uint32);
typemap_arr_scalar.Add((np.uint32, np.uint16), np.uint32);
typemap_arr_scalar.Add((np.uint32, np.int32), np.int64);
typemap_arr_scalar.Add((np.uint32, np.int32), np.uint32);
typemap_arr_scalar.Add((np.uint32, np.uint32), np.uint32);
typemap_arr_scalar.Add((np.uint32, np.int64), np.int64);
typemap_arr_scalar.Add((np.uint32, np.int64), np.uint32);
typemap_arr_scalar.Add((np.uint32, np.uint64), np.uint32);
typemap_arr_scalar.Add((np.uint32, np.float32), np.float64);
typemap_arr_scalar.Add((np.uint32, np.float64), np.float64);
Expand All @@ -350,11 +498,11 @@ static np()
typemap_arr_scalar.Add((np.uint64, np.@bool), np.uint64);
typemap_arr_scalar.Add((np.uint64, np.uint8), np.uint64);
typemap_arr_scalar.Add((np.uint64, np.@char), np.uint64);
typemap_arr_scalar.Add((np.uint64, np.int16), np.float64);
typemap_arr_scalar.Add((np.uint64, np.int16), np.uint64);
typemap_arr_scalar.Add((np.uint64, np.uint16), np.uint64);
typemap_arr_scalar.Add((np.uint64, np.int32), np.float64);
typemap_arr_scalar.Add((np.uint64, np.int32), np.uint64);
typemap_arr_scalar.Add((np.uint64, np.uint32), np.uint64);
typemap_arr_scalar.Add((np.uint64, np.int64), np.float64);
typemap_arr_scalar.Add((np.uint64, np.int64), np.uint64);
typemap_arr_scalar.Add((np.uint64, np.uint64), np.uint64);
typemap_arr_scalar.Add((np.uint64, np.float32), np.float64);
typemap_arr_scalar.Add((np.uint64, np.float64), np.float64);
Expand Down
Loading
Loading