Skip to content

Commit 212f74d

Browse files
authored
Merge pull request #329 from flucoma/feature/nmf-random-seeding
BufNMF Random Seeding
2 parents 35b63f5 + 50478ae commit 212f74d

6 files changed

Lines changed: 108 additions & 23 deletions

File tree

include/flucoma/algorithms/public/NMF.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ under the European Union’s Horizon 2020 research and innovation programme
1111
#pragma once
1212

1313
#include "../util/AlgorithmUtils.hpp"
14+
#include "../util/EigenRandom.hpp"
1415
#include "../util/FluidEigenMappings.hpp"
1516
#include "../../data/FluidIndex.hpp"
16-
#include "../../data/TensorTypes.hpp"
1717
#include "../../data/FluidMemory.hpp"
18+
#include "../../data/TensorTypes.hpp"
1819
#include <Eigen/Core>
1920
#include <vector>
2021

@@ -42,17 +43,16 @@ class NMF
4243

4344
// processFrame computes activations of a dictionary W in a given frame
4445
void processFrame(const RealVectorView x, const RealMatrixView W0,
45-
RealVectorView out, index nIterations,
46-
RealVectorView v, Allocator& alloc)
46+
RealVectorView out, index nIterations, RealVectorView v,
47+
index randomSeed, Allocator& alloc)
4748
{
4849
using namespace Eigen;
4950
using namespace _impl;
5051
index rank = W0.extent(0);
5152
FluidEigenMap<Matrix> W = asEigen<Matrix>(W0);
52-
53+
5354
ScopedEigenMap<VectorXd> h(rank, alloc);
54-
h = VectorXd::Random(rank) * 0.5 + VectorXd::Constant(rank, 0.5);
55-
55+
h = EigenRandom<VectorXd>(rank, RandomSeed{randomSeed}, Range{0.0, 1.0});
5656
ScopedEigenMap<VectorXd> v0(x.size(), alloc);
5757
v0 = asEigen<Matrix>(x);
5858
W = W.array().max(epsilon).matrix();
@@ -90,7 +90,7 @@ class NMF
9090

9191
void process(const RealMatrixView X, RealMatrixView W1, RealMatrixView H1,
9292
RealMatrixView V1, index rank, index nIterations, bool updateW,
93-
bool updateH = false,
93+
bool updateH = false, index randomSeed = -1,
9494
RealMatrixView W0 = RealMatrixView(nullptr, 0, 0, 0),
9595
RealMatrixView H0 = RealMatrixView(nullptr, 0, 0, 0))
9696
{
@@ -101,8 +101,8 @@ class NMF
101101
MatrixXd W;
102102
if (W0.extent(0) == 0 && W0.extent(1) == 0)
103103
{
104-
W = MatrixXd::Random(nBins, rank) * 0.5 +
105-
MatrixXd::Constant(nBins, rank, 0.5);
104+
W = EigenRandom<MatrixXd>(nBins, rank, RandomSeed{randomSeed},
105+
Range{0.0, 1.0});
106106
}
107107
else
108108
{
@@ -113,8 +113,8 @@ class NMF
113113
MatrixXd H;
114114
if (H0.extent(0) == 0 && H0.extent(1) == 0)
115115
{
116-
H = MatrixXd::Random(rank, nFrames) * 0.5 +
117-
MatrixXd::Constant(rank, nFrames, 0.5);
116+
H = EigenRandom<MatrixXd>(rank, nFrames, RandomSeed{randomSeed},
117+
Range{0.0, 1.0});
118118
}
119119
else
120120
{

include/flucoma/clients/nrt/NMFClient.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ under the European Union’s Horizon 2020 research and innovation programme
1818
#include "../../algorithms/public/NMF.hpp"
1919
#include "../../algorithms/public/RatioMask.hpp"
2020
#include "../../algorithms/public/STFT.hpp"
21-
#include "../../data/FluidTensor.hpp"
2221
#include "../../data/FluidMemory.hpp"
22+
#include "../../data/FluidTensor.hpp"
2323
#include <algorithm> //for max_element
2424
#include <cassert>
2525
#include <sstream> //for ostringstream
@@ -47,6 +47,7 @@ enum NMFParamIndex {
4747
kEnvelopesUpdate,
4848
kRank,
4949
kIterations,
50+
kRandomSeed,
5051
kFFT
5152
};
5253

@@ -57,7 +58,7 @@ constexpr auto BufNMFParams = defineParameters(
5758
LongParam("startChan", "Start Channel", 0, Min(0)),
5859
LongParam("numChans", "Number Channels", -1),
5960
BufferParam("resynth", "Resynthesis Buffer"),
60-
LongParam("resynthMode","Resynthesise components", 0,Min(0),Max(1)),
61+
LongParam("resynthMode", "Resynthesise components", 0, Min(0), Max(1)),
6162
BufferParam("bases", "Bases Buffer"),
6263
EnumParam("basesMode", "Bases Buffer Update Mode", 0, "None", "Seed",
6364
"Fixed"),
@@ -66,6 +67,7 @@ constexpr auto BufNMFParams = defineParameters(
6667
"Fixed"),
6768
LongParam("components", "Number of Components", 1, Min(1)),
6869
LongParam("iterations", "Number of Iterations", 100, Min(1)),
70+
LongParam("seed", "Random Seed", -1),
6971
FFTParam("fftSettings", "FFT Settings", 1024, -1, -1));
7072

7173
class NMFClient : public FluidBaseClient, public OfflineIn, public OfflineOut
@@ -98,7 +100,7 @@ class NMFClient : public FluidBaseClient, public OfflineIn, public OfflineOut
98100
index nFrames = get<kNumFrames>();
99101
index nChannels = get<kNumChans>();
100102
auto rangeCheck = bufferRangeCheck(get<kSource>().get(), get<kOffset>(),
101-
nFrames, get<kStartChan>(), nChannels);
103+
nFrames, get<kStartChan>(), nChannels);
102104

103105
if (!rangeCheck.ok()) return rangeCheck;
104106

@@ -264,8 +266,9 @@ class NMFClient : public FluidBaseClient, public OfflineIn, public OfflineOut
264266
: true;
265267
});
266268
nmf.process(magnitude, outputFilters, outputEnvelopes, outputMags,
267-
get<kRank>(), get<kIterations>() * needsAnalysis, !fixFilters, !fixEnvelopes,
268-
seededFilters, seededEnvelopes);
269+
get<kRank>(), get<kIterations>() * needsAnalysis, !fixFilters,
270+
!fixEnvelopes, get<kRandomSeed>(), seededFilters,
271+
seededEnvelopes);
269272

270273
if (c.task() && c.task()->cancelled())
271274
return {Result::Status::kCancelled, ""};

include/flucoma/clients/rt/NMFFilterClient.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ namespace fluid {
2222
namespace client {
2323
namespace nmffilter {
2424

25-
enum NMFFilterIndex { kFilterbuf, kMaxRank, kIterations, kFFT };
25+
enum NMFFilterIndex { kFilterbuf, kMaxRank, kIterations, kRandomSeed, kFFT };
2626

2727
constexpr auto NMFFilterParams = defineParameters(
2828
InputBufferParam("bases", "Bases Buffer"),
2929
LongParamRuntimeMax<Primary>("maxComponents", "Maximum Number of Components", 20, Min(1)),
3030
LongParam("iterations", "Number of Iterations", 10, Min(1)),
31+
LongParam("seed", "Random Seed", -1),
3132
FFTParam("fftSettings", "FFT Settings", 1024, -1, -1));
3233

3334
class NMFFilterClient : public FluidBaseClient, public AudioIn, public AudioOut
@@ -103,7 +104,8 @@ class NMFFilterClient : public FluidBaseClient, public AudioIn, public AudioOut
103104
[&](ComplexMatrixView in, ComplexMatrixView out) {
104105
algorithm::STFT::magnitude(in, tmpMagnitude);
105106
mNMF.processFrame(tmpMagnitude.row(0), tmpFilt, tmpOut,
106-
get<kIterations>(), tmpEstimate.row(0), c.allocator());
107+
get<kIterations>(), tmpEstimate.row(0),
108+
get<kRandomSeed>(), c.allocator());
107109
mMask.init(tmpEstimate);
108110
for (index i = 0; i < rank; ++i)
109111
{

include/flucoma/clients/rt/NMFMatchClient.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ enum NMFMatchParamIndex {
2525
kFilterbuf,
2626
kMaxRank,
2727
kIterations,
28+
kRandomSeed,
2829
kFFT
2930
};
3031

@@ -33,6 +34,7 @@ constexpr auto NMFMatchParams = defineParameters(
3334
LongParamRuntimeMax<Primary>("maxComponents", "Maximum Number of Components", 20,
3435
Min(1)),
3536
LongParam("iterations", "Number of Iterations", 10, Min(1)),
37+
LongParam("seed", "Random Seed", -1),
3638
FFTParam("fftSettings", "FFT Settings", 1024, -1, -1));
3739

3840
class NMFMatchClient : public FluidBaseClient, public AudioIn, public ControlOut
@@ -104,14 +106,16 @@ class NMFMatchClient : public FluidBaseClient, public AudioIn, public ControlOut
104106
for (index i = 0; i < filter.rows(); ++i)
105107
filter.row(i) <<= filterBuffer.samps(i);
106108

107-
mSTFTProcessor.processInput(get<kFFT>(), input, c, [&](ComplexMatrixView in) {
108-
algorithm::STFT::magnitude(in, mags);
109-
mNMF.processFrame(mags.row(0), filter, activations,
110-
10, FluidTensorView<double,1>{nullptr, 0, 0}, c.allocator());
111-
});
112109

113110
output[0](Slice(0,rank)) <<= activations;
114111
output[0](Slice(rank,get<kMaxRank>().max() - rank)).fill(0);
112+
mSTFTProcessor.processInput(
113+
get<kFFT>(), input, c, [&](ComplexMatrixView in) {
114+
algorithm::STFT::magnitude(in, mags);
115+
mNMF.processFrame(mags.row(0), filter, activations, 10,
116+
FluidTensorView<double, 1>{nullptr, 0, 0},
117+
get<kRandomSeed>(), c.allocator());
118+
});
115119
}
116120
}
117121

tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ add_test_executable(TestTransientSlice algorithms/public/TestTransientSlice.cpp)
116116

117117
add_test_executable(TestMLP algorithms/public/TestMLP.cpp)
118118
add_test_executable(TestKMeans algorithms/public/TestKMeans.cpp)
119+
add_test_executable(TestNMF algorithms/public/TestNMF.cpp)
119120
add_test_executable(TestUMAP algorithms/public/TestUMAP.cpp)
120121

121122
add_test_executable(TestDataSampler data/detail/TestDataSampler.cpp)
@@ -157,6 +158,7 @@ catch_discover_tests(TestBufferedProcess WORKING_DIRECTORY "${CMAKE_BINARY_DIR}"
157158
catch_discover_tests(TestMLP WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
158159
catch_discover_tests(TestKMeans WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
159160
catch_discover_tests(TestEigenRandom WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
161+
catch_discover_tests(TestNMF WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
160162
catch_discover_tests(TestRTPGHI WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
161163
catch_discover_tests(TestUMAP WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
162164

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#define CATCH_CONFIG_MAIN
2+
#include <catch2/catch_all.hpp>
3+
#include <flucoma/algorithms/public/NMF.hpp>
4+
#include <flucoma/data/FluidTensor.hpp>
5+
#include <algorithm>
6+
#include <iostream>
7+
#include <vector>
8+
9+
namespace fluid {
10+
11+
TEST_CASE("NMF is repeatable with user-supplied random seed")
12+
{
13+
14+
using algorithm::NMF;
15+
using Tensor = FluidTensor<double, 2>;
16+
NMF algo;
17+
18+
Tensor input{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
19+
20+
std::vector Vs(4, Tensor(3, 3));
21+
std::vector Ws(4, Tensor(2, 3));
22+
std::vector Hs(4, Tensor(3, 2));
23+
24+
algo.process(input, Ws[0], Hs[0], Vs[0], 2, 1, true, true, 42);
25+
algo.process(input, Ws[1], Hs[1], Vs[1], 2, 1, true, true, 42);
26+
algo.process(input, Ws[2], Hs[2], Vs[2], 2, 1, true, true, 5063);
27+
algo.process(input, Ws[3], Hs[3], Vs[3], 2, 1, true, true, 5063);
28+
29+
using Catch::Matchers::RangeEquals;
30+
31+
SECTION("Calls with the same seed have the same output")
32+
{
33+
REQUIRE_THAT(Ws[1], RangeEquals(Ws[0]));
34+
REQUIRE_THAT(Hs[1], RangeEquals(Hs[0]));
35+
REQUIRE_THAT(Vs[1], RangeEquals(Vs[0]));
36+
REQUIRE_THAT(Ws[3], RangeEquals(Ws[2]));
37+
REQUIRE_THAT(Hs[3], RangeEquals(Hs[2]));
38+
REQUIRE_THAT(Vs[3], RangeEquals(Vs[2]));
39+
}
40+
SECTION("Calls with different seeds have different outputs")
41+
{
42+
REQUIRE_THAT(Ws[1], !RangeEquals(Ws[2]));
43+
REQUIRE_THAT(Hs[1], !RangeEquals(Hs[2]));
44+
REQUIRE_THAT(Vs[1], !RangeEquals(Vs[2]));
45+
}
46+
}
47+
48+
TEST_CASE("NMF processFrame() is repeatable with user-supplied random seed")
49+
{
50+
using fluid::algorithm::NMF;
51+
using Tensor = fluid::FluidTensor<double, 2>;
52+
using Vector = fluid::FluidTensor<double, 1>;
53+
NMF algo;
54+
55+
Vector input{{1, 0, 1, 0}};
56+
Tensor bases{{0, 0, 1, 0}, {1, 0, 0, 0}};
57+
Vector v(4);
58+
59+
std::vector outputs(3, Vector(2));
60+
61+
index nIter{0};
62+
algo.processFrame(input, bases, outputs[0], nIter, v, 42,
63+
FluidDefaultAllocator());
64+
algo.processFrame(input, bases, outputs[1], nIter, v, 42,
65+
FluidDefaultAllocator());
66+
algo.processFrame(input, bases, outputs[2], nIter, v, 7863,
67+
FluidDefaultAllocator());
68+
69+
using Catch::Matchers::RangeEquals;
70+
71+
REQUIRE_THAT(outputs[1], RangeEquals(outputs[0]));
72+
REQUIRE_THAT(outputs[1], !RangeEquals(outputs[2]));
73+
}
74+
} // namespace fluid

0 commit comments

Comments
 (0)