-
Notifications
You must be signed in to change notification settings - Fork 220
Add support for bfloat16 type #1544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
585b02d
Fix bfloat16 tensor printing (issue #1469)
alinpahontu2912 33d3754
Add tests for BFloat16 and Float16 tensor printing
alinpahontu2912 0fa8769
Add standalone BFloat16 test scripts
alinpahontu2912 854aaac
Remove standalone BFloat16 test scripts
alinpahontu2912 18ea647
Add full BFloat16 managed type support
alinpahontu2912 c899f63
Address PR review comments for BFloat16 support
alinpahontu2912 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. | ||
| using System; | ||
| using System.Globalization; | ||
| using System.Runtime.InteropServices; | ||
|
|
||
| #nullable enable | ||
| namespace TorchSharp | ||
| { | ||
| /// <summary> | ||
| /// Represents a 16-bit brain floating-point number (BFloat16). | ||
| /// Binary layout: 1 sign bit, 8 exponent bits, 7 mantissa bits — the upper 16 bits of IEEE 754 float32. | ||
| /// Binary-compatible with c10::BFloat16 in LibTorch. | ||
| /// </summary> | ||
| [StructLayout(LayoutKind.Sequential)] | ||
| public readonly struct BFloat16 : IComparable<BFloat16>, IEquatable<BFloat16>, IComparable, IFormattable | ||
| { | ||
| internal readonly ushort value; | ||
|
|
||
| internal BFloat16(ushort rawValue, bool _) | ||
| { | ||
| value = rawValue; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Creates a BFloat16 from a float value using round-to-nearest-even (matching PyTorch c10::BFloat16). | ||
| /// </summary> | ||
| public BFloat16(float f) | ||
| { | ||
| value = FloatToBFloat16Bits(f); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Creates a BFloat16 from the raw 16-bit representation. | ||
| /// </summary> | ||
| public static BFloat16 FromRawValue(ushort rawValue) => new BFloat16(rawValue, false); | ||
|
|
||
| // --- Conversion to/from float --- | ||
|
|
||
| private static unsafe ushort FloatToBFloat16Bits(float f) | ||
| { | ||
| uint bits = *(uint*)&f; | ||
| // NaN: preserve payload, just truncate | ||
| if ((bits & 0x7F800000u) == 0x7F800000u && (bits & 0x007FFFFFu) != 0) | ||
| return (ushort)(bits >> 16 | 0x0040u); // quiet NaN | ||
| // Round-to-nearest-even (matching PyTorch c10::BFloat16) | ||
| uint lsb = (bits >> 16) & 1u; | ||
| uint roundingBias = 0x7FFFu + lsb; | ||
| bits += roundingBias; | ||
| return (ushort)(bits >> 16); | ||
| } | ||
|
|
||
| private static unsafe float BFloat16BitsToFloat(ushort raw) | ||
| { | ||
| int bits = raw << 16; | ||
| return *(float*)&bits; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Converts this BFloat16 to a float. | ||
| /// </summary> | ||
| public float ToSingle() => BFloat16BitsToFloat(value); | ||
|
|
||
| // --- Conversion operators --- | ||
|
|
||
| public static explicit operator float(BFloat16 bf) => bf.ToSingle(); | ||
| public static explicit operator double(BFloat16 bf) => bf.ToSingle(); | ||
| public static explicit operator BFloat16(float f) => new BFloat16(f); | ||
| public static explicit operator BFloat16(double d) => new BFloat16((float)d); | ||
|
|
||
| // --- Arithmetic operators (promote to float, truncate back) --- | ||
|
|
||
| public static BFloat16 operator +(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() + b.ToSingle()); | ||
| public static BFloat16 operator -(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() - b.ToSingle()); | ||
| public static BFloat16 operator *(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() * b.ToSingle()); | ||
| public static BFloat16 operator /(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() / b.ToSingle()); | ||
| public static BFloat16 operator %(BFloat16 a, BFloat16 b) => new BFloat16(a.ToSingle() % b.ToSingle()); | ||
| public static BFloat16 operator -(BFloat16 a) => new BFloat16(-a.ToSingle()); | ||
|
|
||
| // --- Comparison operators --- | ||
|
|
||
| public static bool operator ==(BFloat16 a, BFloat16 b) => a.ToSingle() == b.ToSingle(); | ||
| public static bool operator !=(BFloat16 a, BFloat16 b) => a.ToSingle() != b.ToSingle(); | ||
| public static bool operator <(BFloat16 a, BFloat16 b) => a.ToSingle() < b.ToSingle(); | ||
| public static bool operator >(BFloat16 a, BFloat16 b) => a.ToSingle() > b.ToSingle(); | ||
| public static bool operator <=(BFloat16 a, BFloat16 b) => a.ToSingle() <= b.ToSingle(); | ||
| public static bool operator >=(BFloat16 a, BFloat16 b) => a.ToSingle() >= b.ToSingle(); | ||
|
|
||
| // --- IEquatable / IComparable --- | ||
|
|
||
| public bool Equals(BFloat16 other) => ToSingle() == other.ToSingle(); | ||
| public override bool Equals(object? obj) => obj is BFloat16 other && Equals(other); | ||
| public override int GetHashCode() => ToSingle().GetHashCode(); | ||
|
|
||
| public int CompareTo(BFloat16 other) => ToSingle().CompareTo(other.ToSingle()); | ||
| public int CompareTo(object? obj) | ||
| { | ||
| if (obj is null) return 1; | ||
| if (obj is BFloat16 other) return CompareTo(other); | ||
| throw new ArgumentException("Object must be of type BFloat16."); | ||
| } | ||
|
|
||
| // --- Formatting --- | ||
|
|
||
| public override string ToString() => ToSingle().ToString(); | ||
| public string ToString(string? format, IFormatProvider? formatProvider) => ToSingle().ToString(format, formatProvider); | ||
|
|
||
| // --- Constants --- | ||
|
|
||
| public static readonly BFloat16 Zero = FromRawValue(0x0000); | ||
| public static readonly BFloat16 One = FromRawValue(0x3F80); | ||
| public static readonly BFloat16 NaN = FromRawValue(0x7FC0); | ||
| public static readonly BFloat16 PositiveInfinity = FromRawValue(0x7F80); | ||
| public static readonly BFloat16 NegativeInfinity = FromRawValue(0xFF80); | ||
| public static readonly BFloat16 MaxValue = FromRawValue(0x7F7F); // ~3.39e+38 | ||
| public static readonly BFloat16 MinValue = FromRawValue(0xFF7F); // ~-3.39e+38 | ||
| public static readonly BFloat16 SmallestSubnormal = FromRawValue(0x0001); // smallest positive (subnormal) | ||
| public static readonly BFloat16 Epsilon = SmallestSubnormal; // .NET-style epsilon: smallest positive > 0 | ||
| public static readonly BFloat16 MinNormal = FromRawValue(0x0080); // smallest normal | ||
|
|
||
| // --- Static helpers --- | ||
|
|
||
| public static bool IsNaN(BFloat16 bf) => float.IsNaN(bf.ToSingle()); | ||
| public static bool IsInfinity(BFloat16 bf) => float.IsInfinity(bf.ToSingle()); | ||
| public static bool IsPositiveInfinity(BFloat16 bf) => float.IsPositiveInfinity(bf.ToSingle()); | ||
| public static bool IsNegativeInfinity(BFloat16 bf) => float.IsNegativeInfinity(bf.ToSingle()); | ||
| public static bool IsFinite(BFloat16 bf) => !IsInfinity(bf) && !IsNaN(bf); | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. | ||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.Diagnostics.Contracts; | ||
| using System.Linq; | ||
|
|
||
| #nullable enable | ||
| namespace TorchSharp | ||
| { | ||
| public static partial class torch | ||
| { | ||
| /// <summary> | ||
| /// Create a scalar tensor from a single value | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16 scalar, ScalarType? dtype = null, Device? device = null, bool requires_grad = false) | ||
| { | ||
| return _tensor_generic(new BFloat16[] { scalar }, stackalloc long[] { }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, shaping it based on the shape passed in. | ||
| /// </summary> | ||
| /// <remarks>The Torch runtime does not take ownership of the data, so there is no device argument.</remarks> | ||
| [Pure] | ||
| public static Tensor tensor(IList<BFloat16> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray.ToArray(), dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, false, names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, shaping it based on the shape passed in. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16[] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, stackalloc long[] { rawArray.LongLength }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, shaping it based on the shape passed in. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16[] rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a 1-D tensor from an array of values, shaping it based on the input array. | ||
| /// </summary> | ||
| /// <remarks>The Torch runtime does not take ownership of the data, so there is no device argument.</remarks> | ||
| [Pure] | ||
| public static Tensor tensor(IList<BFloat16> rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return tensor(rawArray, stackalloc long[] { (long)rawArray.Count }, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, organizing it as a two-dimensional tensor. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// The Torch runtime does not take ownership of the data, so there is no device argument. | ||
| /// The input array must have rows * columns elements. | ||
| /// </remarks> | ||
| [Pure] | ||
| public static Tensor tensor(IList<BFloat16> rawArray, long rows, long columns, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return tensor(rawArray, stackalloc long[] { rows, columns }, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, organizing it as a three-dimensional tensor. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// The Torch runtime does not take ownership of the data, so there is no device argument. | ||
| /// The input array must have dim0*dim1*dim2 elements. | ||
| /// </remarks> | ||
| [Pure] | ||
| public static Tensor tensor(IList<BFloat16> rawArray, long dim0, long dim1, long dim2, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2 }, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, organizing it as a four-dimensional tensor. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// The Torch runtime does not take ownership of the data, so there is no device argument. | ||
| /// The input array must have dim0*dim1*dim2*dim3 elements. | ||
| /// </remarks> | ||
| [Pure] | ||
| public static Tensor tensor(IList<BFloat16> rawArray, long dim0, long dim1, long dim2, long dim3, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return tensor(rawArray, stackalloc long[] { dim0, dim1, dim2, dim3 }, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a two-dimensional tensor from a two-dimensional array of values. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16[,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a three-dimensional tensor from a three-dimensional array of values. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16[,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a four-dimensional tensor from a four-dimensional array of values. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(BFloat16[,,,] rawArray, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Create a tensor from an array of values, shaping it based on the shape passed in. | ||
| /// </summary> | ||
| [Pure] | ||
| public static Tensor tensor(Memory<BFloat16> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null) | ||
| { | ||
| return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.BFloat16, dtype, device, requires_grad, names: names); | ||
| } | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this
BFloat16different fromSystem.Numerics.BFloat16that's now available in .NET 11?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to check, but the problem is we are now building with .net 8 only and I don't think directly jumping to 11 would be okay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. But we could, only for .NET 11, exclude this and use the built-in type. And eventually removing this custom code completely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will leave a note in releasenotes for that