Skip to content

Commit b8bc6aa

Browse files
Add Tensor.to_sparse() API for sparse COO tensor conversion
Add to_sparse() and to_sparse(int sparse_dim) methods to convert dense tensors to sparse COO format. This enables GCN and GAT graph neural network examples that require sparse matrix operations. Changes across all 4 binding layers: - THSTensor.h: declarations - THSTensor.cpp: implementations - LibTorchSharp.THSTensor.cs: P/Invoke - Tensor.cs: managed C# methods
1 parent bdc2bcb commit b8bc6aa

4 files changed

Lines changed: 42 additions & 0 deletions

File tree

src/Native/LibTorchSharp/THSTensor.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,16 @@ Tensor THSTensor_to_dense(Tensor tensor)
18661866
CATCH_TENSOR(tensor->to_dense());
18671867
}
18681868

1869+
Tensor THSTensor_to_sparse(Tensor tensor)
1870+
{
1871+
CATCH_TENSOR(tensor->to_sparse());
1872+
}
1873+
1874+
Tensor THSTensor_to_sparse_with_dims(Tensor tensor, const int64_t sparse_dim)
1875+
{
1876+
CATCH_TENSOR(tensor->to_sparse(sparse_dim));
1877+
}
1878+
18691879
void THSTensor_set_(Tensor tensor, const Tensor source)
18701880
{
18711881
CATCH(tensor->set_(*source););

src/Native/LibTorchSharp/THSTensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,9 @@ EXPORT_API(Tensor) THSTensor_trapezoid_dx(const Tensor y, const double dx, int64
13781378

13791379
EXPORT_API(Tensor) THSTensor_to_dense(Tensor tensor);
13801380

1381+
EXPORT_API(Tensor) THSTensor_to_sparse(Tensor tensor);
1382+
EXPORT_API(Tensor) THSTensor_to_sparse_with_dims(Tensor tensor, const int64_t sparse_dim);
1383+
13811384
EXPORT_API(Tensor) THSTensor_to_device(const Tensor tensor, const int device_type, const int device_index, const bool copy, const bool non_blocking);
13821385

13831386
EXPORT_API(Tensor) THSTensor_to_type(const Tensor tensor, int8_t scalar_type, const bool copy, const bool non_blocking);

src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,12 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
376376
[DllImport("LibTorchSharp")]
377377
internal static extern IntPtr THSTensor_to_dense(IntPtr handle);
378378

379+
[DllImport("LibTorchSharp")]
380+
internal static extern IntPtr THSTensor_to_sparse(IntPtr handle);
381+
382+
[DllImport("LibTorchSharp")]
383+
internal static extern IntPtr THSTensor_to_sparse_with_dims(IntPtr handle, long sparse_dim);
384+
379385
[DllImport("LibTorchSharp")]
380386
internal static extern IntPtr THSTensor_clone(IntPtr handle);
381387

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,29 @@ public Tensor to_dense()
12871287
return new Tensor(res);
12881288
}
12891289

1290+
/// <summary>
1291+
/// Converts a dense tensor to a sparse COO tensor.
1292+
/// </summary>
1293+
public Tensor to_sparse()
1294+
{
1295+
var res = NativeMethods.THSTensor_to_sparse(Handle);
1296+
if (res == IntPtr.Zero)
1297+
CheckForErrors();
1298+
return new Tensor(res);
1299+
}
1300+
1301+
/// <summary>
1302+
/// Converts a dense tensor to a sparse COO tensor with the specified number of sparse dimensions.
1303+
/// </summary>
1304+
/// <param name="sparse_dim">The number of sparse dimensions.</param>
1305+
public Tensor to_sparse(int sparse_dim)
1306+
{
1307+
var res = NativeMethods.THSTensor_to_sparse_with_dims(Handle, sparse_dim);
1308+
if (res == IntPtr.Zero)
1309+
CheckForErrors();
1310+
return new Tensor(res);
1311+
}
1312+
12901313
/// <summary>
12911314
/// Returns a copy of the tensor input.
12921315
/// </summary>

0 commit comments

Comments
 (0)