1+ #define CATCH_CONFIG_MAIN
2+ #include < catch2/catch_all.hpp>
3+ #include < flucoma/algorithms/public/NMF.hpp>
4+ #include < flucoma/data/FluidTensor.hpp>
5+ #include < algorithm>
6+ #include < iostream>
7+ #include < vector>
8+
9+ namespace fluid {
10+
11+ TEST_CASE (" NMF is repeatable with user-supplied random seed" )
12+ {
13+
14+ using algorithm::NMF;
15+ using Tensor = FluidTensor<double , 2 >;
16+ NMF algo;
17+
18+ Tensor input{{1 , 2 , 3 }, {4 , 5 , 6 }, {7 , 8 , 9 }};
19+
20+ std::vector Vs (4 , Tensor (3 , 3 ));
21+ std::vector Ws (4 , Tensor (2 , 3 ));
22+ std::vector Hs (4 , Tensor (3 , 2 ));
23+
24+ algo.process (input, Ws[0 ], Hs[0 ], Vs[0 ], 2 , 1 , true , true , 42 );
25+ algo.process (input, Ws[1 ], Hs[1 ], Vs[1 ], 2 , 1 , true , true , 42 );
26+ algo.process (input, Ws[2 ], Hs[2 ], Vs[2 ], 2 , 1 , true , true , 5063 );
27+ algo.process (input, Ws[3 ], Hs[3 ], Vs[3 ], 2 , 1 , true , true , 5063 );
28+
29+ using Catch::Matchers::RangeEquals;
30+
31+ SECTION (" Calls with the same seed have the same output" )
32+ {
33+ REQUIRE_THAT (Ws[1 ], RangeEquals (Ws[0 ]));
34+ REQUIRE_THAT (Hs[1 ], RangeEquals (Hs[0 ]));
35+ REQUIRE_THAT (Vs[1 ], RangeEquals (Vs[0 ]));
36+ REQUIRE_THAT (Ws[3 ], RangeEquals (Ws[2 ]));
37+ REQUIRE_THAT (Hs[3 ], RangeEquals (Hs[2 ]));
38+ REQUIRE_THAT (Vs[3 ], RangeEquals (Vs[2 ]));
39+ }
40+ SECTION (" Calls with different seeds have different outputs" )
41+ {
42+ REQUIRE_THAT (Ws[1 ], !RangeEquals (Ws[2 ]));
43+ REQUIRE_THAT (Hs[1 ], !RangeEquals (Hs[2 ]));
44+ REQUIRE_THAT (Vs[1 ], !RangeEquals (Vs[2 ]));
45+ }
46+ }
47+
48+ TEST_CASE (" NMF processFrame() is repeatable with user-supplied random seed" )
49+ {
50+ using fluid::algorithm::NMF;
51+ using Tensor = fluid::FluidTensor<double , 2 >;
52+ using Vector = fluid::FluidTensor<double , 1 >;
53+ NMF algo;
54+
55+ Vector input{{1 , 0 , 1 , 0 }};
56+ Tensor bases{{0 , 0 , 1 , 0 }, {1 , 0 , 0 , 0 }};
57+ Vector v (4 );
58+
59+ std::vector outputs (3 , Vector (2 ));
60+
61+ index nIter{0 };
62+ algo.processFrame (input, bases, outputs[0 ], nIter, v, 42 ,
63+ FluidDefaultAllocator ());
64+ algo.processFrame (input, bases, outputs[1 ], nIter, v, 42 ,
65+ FluidDefaultAllocator ());
66+ algo.processFrame (input, bases, outputs[2 ], nIter, v, 7863 ,
67+ FluidDefaultAllocator ());
68+
69+ using Catch::Matchers::RangeEquals;
70+
71+ REQUIRE_THAT (outputs[1 ], RangeEquals (outputs[0 ]));
72+ REQUIRE_THAT (outputs[1 ], !RangeEquals (outputs[2 ]));
73+ }
74+ } // namespace fluid
0 commit comments