diff --git a/dart/packages/fory-test/test/datatype_test/bfloat16_test.dart b/dart/packages/fory-test/test/datatype_test/bfloat16_test.dart new file mode 100644 index 0000000000..c82957030a --- /dev/null +++ b/dart/packages/fory-test/test/datatype_test/bfloat16_test.dart @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import 'package:test/test.dart'; +import 'package:fory/fory.dart'; + +void main() { + group('BFloat16', () { + test('converts from and to bits', () { + var bf = BFloat16.fromBits(0x3f80); // 1.0 in bfloat16 + expect(bf.toBits(), 0x3f80); + expect(bf.toFloat32(), 1.0); + expect(bf.value, 1.0); + }); + + test('converts from float32', () { + var bf = BFloat16.fromFloat32(1.0); + expect(bf.toBits(), 0x3f80); + expect(bf.toFloat32(), 1.0); + + var bf2 = BFloat16.fromFloat32(-1.0); + expect(bf2.toFloat32(), -1.0); + + var bf3 = BFloat16.fromFloat32(0.0); + expect(bf3.toFloat32(), 0.0); + }); + + test('equality and hashcode', () { + var bf1 = BFloat16.fromBits(0x3f80); + var bf2 = BFloat16.fromFloat32(1.0); + var bf3 = BFloat16.fromFloat32(2.0); + + expect(bf1 == bf2, isTrue); + expect(bf1 == bf3, isFalse); + expect(bf1.hashCode == bf2.hashCode, isTrue); + }); + }); + + group('BFloat16Array', () { + test('creates from length', () { + var arr = BFloat16Array.fromLength(5); + expect(arr.length, 5); + expect(arr.raw.length, 5); + }); + + test('creates from list and sets values', () { + var arr = BFloat16Array.fromList([1.0, 2.0, BFloat16.fromFloat32(3.0)]); + expect(arr.length, 3); + expect(arr.get(0).toFloat32(), 1.0); + expect(arr.get(1).toFloat32(), 2.0); + expect(arr.get(2).toFloat32(), 3.0); + + arr.set(0, 4.0); + expect(arr.get(0).toFloat32(), 4.0); + + arr.set(1, BFloat16.fromFloat32(5.0)); + expect(arr.get(1).toFloat32(), 5.0); + }); + }); +} diff --git a/dart/packages/fory/lib/fory.dart b/dart/packages/fory/lib/fory.dart index ba5526ca23..55ff44bccf 100644 --- a/dart/packages/fory/lib/fory.dart +++ b/dart/packages/fory/lib/fory.dart @@ -61,7 +61,9 @@ export 'src/datatype/int32.dart'; export 'src/datatype/uint8.dart'; export 'src/datatype/uint16.dart'; export 'src/datatype/uint32.dart'; +export 'src/datatype/float16.dart'; export 'src/datatype/float32.dart'; +export 'src/datatype/bfloat16.dart'; export 'src/datatype/local_date.dart'; export 'src/datatype/timestamp.dart'; diff --git a/dart/packages/fory/lib/src/const/dart_type.dart b/dart/packages/fory/lib/src/const/dart_type.dart index 85c37fcf11..956671b3f5 100644 --- a/dart/packages/fory/lib/src/const/dart_type.dart +++ b/dart/packages/fory/lib/src/const/dart_type.dart @@ -21,6 +21,7 @@ import 'dart:collection'; import 'dart:typed_data'; import 'package:collection/collection.dart' show BoolList; import 'package:decimal/decimal.dart'; +import 'package:fory/src/datatype/bfloat16.dart'; import 'package:fory/src/datatype/float16.dart'; import 'package:fory/src/datatype/float32.dart'; import 'package:fory/src/datatype/int16.dart'; @@ -60,6 +61,8 @@ enum DartTypeEnum { ObjType.FLOAT32, true, 'dart:core@Float32'), FLOAT16(Float16, true, 'Float16', 'package', 'fory/src/datatype/float16.dart', ObjType.FLOAT16, true, 'dart:core@Float16'), + BFLOAT16(BFloat16, true, 'BFloat16', 'package', 'fory/src/datatype/bfloat16.dart', + ObjType.BFLOAT16, true, 'dart:core@BFloat16'), DOUBLE(double, true, 'double', 'dart', 'core', ObjType.FLOAT64, true, 'dart:core@double'), STRING(String, true, 'String', 'dart', 'core', ObjType.STRING, true, diff --git a/dart/packages/fory/lib/src/datatype/bfloat16.dart b/dart/packages/fory/lib/src/datatype/bfloat16.dart new file mode 100644 index 0000000000..ff25f8ccb7 --- /dev/null +++ b/dart/packages/fory/lib/src/datatype/bfloat16.dart @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import 'dart:typed_data'; + +import 'fory_fixed_num.dart'; + +/// BFloat16: 16-bit brain floating point type. +/// Wraps a 16-bit integer representing the bfloat16 format. +final class BFloat16 extends FixedNum { + /// The raw 16-bit integer storage. + final int _bits; + + /// Internal constructor from raw bits. + const BFloat16._(this._bits); + + /// Creates a [BFloat16] from a number. + factory BFloat16(num value) => BFloat16.fromFloat32(value.toDouble()); + + /// Returns the raw 16-bit integer representation. + int toBits() { + return _bits; + } + + /// Creates a [BFloat16] from a raw 16-bit integer. + factory BFloat16.fromBits(int bits) { + return BFloat16._(bits & 0xffff); + } + + /// Creates a [BFloat16] by converting a standard 32-bit floating-point number. + factory BFloat16.fromFloat32(double f32) { + var float32View = Float32List(1); + var uint32View = Uint32List.view(float32View.buffer); + + float32View[0] = f32; + var bits = uint32View[0]; + var exponent = (bits >> 23) & 0xff; + if (exponent == 255) { + return BFloat16.fromBits((bits >> 16) & 0xffff); + } + var remainder = bits & 0x1ffff; + var u = (bits + 0x8000) >> 16; + if (remainder == 0x8000 && (u & 1) != 0) { + u--; + } + return BFloat16.fromBits(u & 0xffff); + } + + /// Converts this [BFloat16] to a standard 32-bit floating-point number. + double toFloat32() { + var float32View = Float32List(1); + var uint32View = Uint32List.view(float32View.buffer); + + float32View[0] = 0.0; + uint32View[0] = _bits << 16; + return float32View[0]; + } + + /// Gets the numeric value as a double. + @override + num get value => toFloat32(); + + @override + String toString() => 'BFloat16($value)'; + + @override + bool operator ==(Object other) => + identical(this, other) || + other is BFloat16 && + runtimeType == other.runtimeType && + _bits == other._bits; + + @override + int get hashCode => _bits.hashCode; +} + +/// A fixed-length list of [BFloat16] values, backed by a [Uint16List]. +class BFloat16Array { + /// The underlying raw 16-bit storage. + final Uint16List _data; + + /// Creates an array of the given [length], initialized to zero. + BFloat16Array.fromLength(int length) : _data = Uint16List(length); + + /// Creates an array backed by a copy of the given [Uint16List]. + BFloat16Array.fromUint16List(Uint16List source) + : _data = Uint16List.fromList(source); + + /// Creates an array from a list of [BFloat16] or [double] values. + BFloat16Array.fromList(List source) + : _data = Uint16List(source.length) { + for (int i = 0; i < source.length; i++) { + var v = source[i]; + _data[i] = v is BFloat16 ? v.toBits() : BFloat16.fromFloat32(v as double).toBits(); + } + } + + /// The length of the array. + int get length => _data.length; + + /// Retrieves the [BFloat16] at the given [index]. + BFloat16 get(int index) => BFloat16.fromBits(_data[index]); + + /// Sets the [BFloat16] value at the given [index]. + /// Will automatically convert a [double] to a [BFloat16]. + void set(int index, dynamic value) { + _data[index] = value is BFloat16 + ? value.toBits() + : BFloat16.fromFloat32(value as double).toBits(); + } + + /// Exposes the underlying [Uint16List] used for storage. + Uint16List get raw => _data; + + /// Creates a [BFloat16Array] initialized using the provided raw [Uint16List]. + factory BFloat16Array.fromRaw(Uint16List data) { + return BFloat16Array.fromUint16List(data); + } +} diff --git a/dart/packages/fory/lib/src/datatype/fory_fixed_num.dart b/dart/packages/fory/lib/src/datatype/fory_fixed_num.dart index 2e1fead364..e4128ba93b 100644 --- a/dart/packages/fory/lib/src/datatype/fory_fixed_num.dart +++ b/dart/packages/fory/lib/src/datatype/fory_fixed_num.dart @@ -22,8 +22,9 @@ import 'float32.dart'; import 'int16.dart'; import 'int32.dart'; import 'int8.dart'; +import 'bfloat16.dart'; -enum NumType { int8, int16, int32, float16, float32 } +enum NumType { int8, int16, int32, float16, float32, bfloat16 } /// Base abstract class for fixed-size numeric types abstract base class FixedNum implements Comparable { @@ -44,6 +45,8 @@ abstract base class FixedNum implements Comparable { return Float16(value); case NumType.float32: return Float32(value); + case NumType.bfloat16: + return BFloat16(value); } } diff --git a/dart/packages/fory/lib/src/memory/byte_reader.dart b/dart/packages/fory/lib/src/memory/byte_reader.dart index b8531d7826..04792511bf 100644 --- a/dart/packages/fory/lib/src/memory/byte_reader.dart +++ b/dart/packages/fory/lib/src/memory/byte_reader.dart @@ -20,6 +20,7 @@ import 'dart:typed_data'; import 'package:meta/meta.dart'; import 'package:fory/src/datatype/float16.dart'; +import 'package:fory/src/datatype/bfloat16.dart'; import 'package:fory/src/memory/byte_reader_impl.dart'; abstract base class ByteReader { @@ -70,6 +71,9 @@ abstract base class ByteReader { /// Reads a 16-bit floating point number from the stream. Float16 readFloat16(); + /// Reads a 16-bit brain float point number from the stream. + BFloat16 readBFloat16(); + int readVarUint36Small(); int readVarInt32(); diff --git a/dart/packages/fory/lib/src/memory/byte_reader_impl.dart b/dart/packages/fory/lib/src/memory/byte_reader_impl.dart index 571a9d2e3f..18ce592ba8 100644 --- a/dart/packages/fory/lib/src/memory/byte_reader_impl.dart +++ b/dart/packages/fory/lib/src/memory/byte_reader_impl.dart @@ -20,6 +20,7 @@ import 'dart:typed_data'; import 'package:fory/src/dev_annotation/optimize.dart'; import 'package:fory/src/datatype/float16.dart'; +import 'package:fory/src/datatype/bfloat16.dart'; import 'package:fory/src/memory/byte_reader.dart'; final class ByteReaderImpl extends ByteReader { @@ -126,6 +127,13 @@ final class ByteReaderImpl extends ByteReader { return Float16.fromBits(value); } + @override + BFloat16 readBFloat16() { + int value = _bd.getUint16(_offset, endian); + _offset += 2; + return BFloat16.fromBits(value); + } + @override Uint8List readBytesView(int length) { // create a view of the original list diff --git a/dart/packages/fory/lib/src/memory/byte_writer.dart b/dart/packages/fory/lib/src/memory/byte_writer.dart index fb556afd02..93ede7475d 100644 --- a/dart/packages/fory/lib/src/memory/byte_writer.dart +++ b/dart/packages/fory/lib/src/memory/byte_writer.dart @@ -21,6 +21,7 @@ import 'dart:typed_data'; import 'package:meta/meta.dart'; import 'package:fory/src/memory/byte_writer_impl.dart'; import 'package:fory/src/datatype/float16.dart'; +import 'package:fory/src/datatype/bfloat16.dart'; abstract base class ByteWriter { @protected @@ -47,6 +48,7 @@ abstract base class ByteWriter { void writeFloat32(double value); void writeFloat64(double value); void writeFloat16(Float16 value); + void writeBFloat16(BFloat16 value); void writeBytes(List bytes); diff --git a/dart/packages/fory/lib/src/memory/byte_writer_impl.dart b/dart/packages/fory/lib/src/memory/byte_writer_impl.dart index 56d7ecbfd0..5a168ed7a6 100644 --- a/dart/packages/fory/lib/src/memory/byte_writer_impl.dart +++ b/dart/packages/fory/lib/src/memory/byte_writer_impl.dart @@ -20,6 +20,7 @@ import 'dart:typed_data'; import 'package:fory/src/dev_annotation/optimize.dart'; import 'package:fory/src/datatype/float16.dart'; +import 'package:fory/src/datatype/bfloat16.dart'; import 'package:fory/src/memory/byte_writer.dart'; final class ByteWriterImpl extends ByteWriter { @@ -142,6 +143,13 @@ final class ByteWriterImpl extends ByteWriter { writeUint16(value.toBits()); } + /// Append a `BFloat16` (2 bytes, Little Endian) to the buffer + @inline + @override + void writeBFloat16(BFloat16 value) { + writeUint16(value.toBits()); + } + /// Append a list of bytes to the buffer @override @inline diff --git a/dart/packages/fory/lib/src/serializer/primitive_type_serializer.dart b/dart/packages/fory/lib/src/serializer/primitive_type_serializer.dart index fdbc5ea94f..8075d92ada 100644 --- a/dart/packages/fory/lib/src/serializer/primitive_type_serializer.dart +++ b/dart/packages/fory/lib/src/serializer/primitive_type_serializer.dart @@ -34,6 +34,7 @@ import 'package:fory/src/memory/byte_writer.dart'; import 'package:fory/src/serializer/serializer.dart'; import 'package:fory/src/serializer/serializer_cache.dart'; import 'package:fory/src/serialization_context.dart'; +import 'package:fory/src/datatype/bfloat16.dart'; abstract base class PrimitiveSerializerCache extends SerializerCache { const PrimitiveSerializerCache(); @@ -308,6 +309,23 @@ final class Float16Serializer extends Serializer { } } +final class BFloat16Serializer extends Serializer { + static const SerializerCache cache = _BFloat16SerializerCache(); + + BFloat16Serializer._(bool writeRef) : super(ObjType.BFLOAT16, writeRef); + + @override + BFloat16 read(ByteReader br, int refId, DeserializationContext pack) { + return br.readBFloat16(); + } + + @override + void write(ByteWriter bw, covariant BFloat16 v, SerializationContext pack) { + // No checks are performed here + bw.writeBFloat16(v); + } +} + final class _Float64SerializerCache extends PrimitiveSerializerCache { static Float64Serializer? serializerWithRef; static Float64Serializer? serializerWithoutRef; @@ -573,3 +591,21 @@ final class TaggedUInt64Serializer extends Serializer { bw.writeVarInt64(v); } } + +final class _BFloat16SerializerCache extends PrimitiveSerializerCache { + static BFloat16Serializer? serializerWithRef; + static BFloat16Serializer? serializerWithoutRef; + + const _BFloat16SerializerCache(); + + @override + Serializer getSerializerWithRef(bool writeRef) { + if (writeRef) { + serializerWithRef ??= BFloat16Serializer._(true); + return serializerWithRef!; + } else { + serializerWithoutRef ??= BFloat16Serializer._(false); + return serializerWithoutRef!; + } + } +} diff --git a/dart/packages/fory/lib/src/serializer/serializer_pool.dart b/dart/packages/fory/lib/src/serializer/serializer_pool.dart index 6bf2cee007..986ec25166 100644 --- a/dart/packages/fory/lib/src/serializer/serializer_pool.dart +++ b/dart/packages/fory/lib/src/serializer/serializer_pool.dart @@ -23,8 +23,9 @@ import 'package:collection/collection.dart'; import 'package:fory/src/config/fory_config.dart'; import 'package:fory/src/const/dart_type.dart'; import 'package:fory/src/const/types.dart'; -import 'package:fory/src/datatype/float16.dart'; import 'package:fory/src/datatype/float32.dart'; +import 'package:fory/src/datatype/float16.dart'; +import 'package:fory/src/datatype/bfloat16.dart'; import 'package:fory/src/datatype/int16.dart'; import 'package:fory/src/datatype/int32.dart'; import 'package:fory/src/datatype/int8.dart'; @@ -82,6 +83,8 @@ class SerializerPool { Float32Serializer.cache.getSerializer(conf); typeToTypeInfo[Float16]!.serializer = Float16Serializer.cache.getSerializer(conf); + typeToTypeInfo[BFloat16]!.serializer = + BFloat16Serializer.cache.getSerializer(conf); typeToTypeInfo[String]!.serializer = StringSerializer.cache.getSerializer(conf);