55//! # Usage
66//!
77//! ```
8- //! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics };
8+ //! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Linkage };
99//! use rusty_machine::learning::SupModel;
1010//! use rusty_machine::linalg::{Matrix, Vector};
1111//!
1212//! let inputs = Matrix::new(4, 2, vec![1., 3., 2., 3., 4., 3., 5., 3.]);
13- //! let mut agg = AgglomerativeClustering::new(2, Metrics ::Single);
13+ //! let mut agg = AgglomerativeClustering::new(2, Linkage ::Single);
1414//!
1515//! // Train the model and get the clustering result
1616//! let res = agg.train(&inputs).unwrap();
@@ -27,7 +27,7 @@ use learning::{LearningResult};
2727
2828/// Agglomerative clustering distances
2929#[ derive( Debug ) ]
30- pub enum Metrics {
30+ pub enum Linkage {
3131 /// Single linkage clustering
3232 Single ,
3333 /// Complete linkage clustering
@@ -49,7 +49,7 @@ pub enum Metrics {
4949 Ward2 ,
5050}
5151
52- impl Metrics {
52+ impl Linkage {
5353
5454 // calculate distance using Lance-Williams algorithm
5555 fn dist ( & self , ci : & Cluster , cj : & Cluster , ck : & Cluster , dmat : & DistanceMatrix ) -> f64 {
@@ -58,38 +58,38 @@ impl Metrics {
5858 let djk = dmat. get ( ck. id , cj. id ) ;
5959
6060 match self {
61- & Metrics :: Single => {
61+ & Linkage :: Single => {
6262 // 0.5 * dik + 0.5 * djk + 0. * dij - 0.5 * (dik - djk).abs()
6363 dik. min ( djk)
6464 } ,
65- & Metrics :: Complete => {
65+ & Linkage :: Complete => {
6666 // 0.5 * dik + 0.5 * djk + 0. * dij + 0.5 * (dik - djk).abs()
6767 dik. max ( djk)
6868 } ,
69- & Metrics :: Average => {
69+ & Linkage :: Average => {
7070 let s = ci. size + cj. size ;
7171 ci. size / s * dik + cj. size / s * djk
7272 } ,
73- & Metrics :: Centroid => {
73+ & Linkage :: Centroid => {
7474 let s = ci. size + cj. size ;
7575 let ai = ci. size / s;
7676 let aj = cj. size / s;
7777 let dij = dmat. get ( ci. id , cj. id ) ;
7878 ai * dik + aj * djk - ai * aj * dij
7979 } ,
80- & Metrics :: Median => {
80+ & Linkage :: Median => {
8181 let dij = dmat. get ( ci. id , cj. id ) ;
8282 0.5 * dik + 0.5 * djk - 0.25 * dij
8383 } ,
84- & Metrics :: Ward1 => {
84+ & Linkage :: Ward1 => {
8585 let s = ci. size + cj. size + ck. size ;
8686 let dij = dmat. get ( ci. id , cj. id ) ;
8787 ( ci. size + ck. size ) / s * dik + ( cj. size + ck. size ) / s * djk - ck. size / s * dij
8888 } ,
89- & Metrics :: Ward | & Metrics :: Ward2 => {
89+ & Linkage :: Ward | & Linkage :: Ward2 => {
9090 let s = ci. size + cj. size + ck. size ;
9191 let dij = dmat. get ( ci. id , cj. id ) ;
92- ( ( ci. size + ck. size ) / s * dik. powf ( 2. ) + ( cj. size + ck. size ) / s * djk. powf ( 2. ) - ck. size / s * dij. powf ( 2. ) ) . sqrt ( )
92+ ( ( ci. size + ck. size ) / s * dik * dik + ( cj. size + ck. size ) / s * djk * djk - ck. size / s * dij * dij ) . sqrt ( )
9393 }
9494 }
9595 }
@@ -147,11 +147,11 @@ impl DistanceMatrix {
147147
148148 unsafe {
149149 for i in 0 ..n {
150- for j in i ..inputs. rows ( ) {
150+ for j in ( i + 1 ) ..inputs. rows ( ) {
151151 let mut val = 0. ;
152152 for k in 0 ..inputs. cols ( ) {
153- val += ( inputs. get_unchecked ( [ i, k] ) -
154- inputs . get_unchecked ( [ j , k ] ) ) . abs ( ) . powf ( 2. ) ;
153+ let d = inputs. get_unchecked ( [ i, k] ) - inputs . get_unchecked ( [ j , k ] ) ;
154+ val += d * d ;
155155 }
156156 val = val. sqrt ( ) ;
157157 data. insert ( ( i, j) , val) ;
@@ -177,13 +177,13 @@ impl DistanceMatrix {
177177 /// Add distance between i-th and j-th item
178178 /// i must be smaller than j
179179 fn insert ( & mut self , i : usize , j : usize , dist : f64 ) {
180- assert ! ( i < j, "i must be smaller than j" ) ;
180+ debug_assert ! ( i < j, "i must be smaller than j" ) ;
181181 self . data . insert ( ( i, j) , dist) ;
182182 }
183183
184184 /// Delete distance between i-th and j-th item
185185 fn delete ( & mut self , i : usize , j : usize ) {
186- assert ! ( i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0" ) ;
186+ debug_assert ! ( i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0" ) ;
187187 if i > j {
188188 self . data . remove ( & ( j, i) ) ;
189189 } else {
@@ -196,7 +196,7 @@ impl DistanceMatrix {
196196#[ derive( Debug ) ]
197197pub struct AgglomerativeClustering {
198198 n : usize ,
199- method : Metrics ,
199+ linkage : Linkage ,
200200
201201 // internally stores distances / merged history (currently for testing)
202202 distances : Option < Vec < f64 > > ,
@@ -208,19 +208,19 @@ impl AgglomerativeClustering {
208208 /// Constructs an untrained Decision Tree with specified
209209 ///
210210 /// - `n` - Number of clusters
211- /// - `method ` - Distance metrics
211+ /// - `linkage ` - Linkage method
212212 ///
213213 /// # Examples
214214 ///
215215 /// ```
216- /// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics };
216+ /// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Linkage };
217217 ///
218- /// let _ = AgglomerativeClustering::new(3, Metrics ::Single);
218+ /// let _ = AgglomerativeClustering::new(3, Linkage ::Single);
219219 /// ```
220- pub fn new ( n : usize , method : Metrics ) -> Self {
220+ pub fn new ( n : usize , linkage : Linkage ) -> Self {
221221 AgglomerativeClustering {
222222 n : n,
223- method : method ,
223+ linkage : linkage ,
224224
225225 distances : None ,
226226 merged : None
@@ -269,7 +269,7 @@ impl AgglomerativeClustering {
269269
270270 // update distances using Lance Williams algorithm
271271 for ck in clusters. iter ( ) {
272- let d = self . method . dist ( & ci, & cj, ck, & dmat) ;
272+ let d = self . linkage . dist ( & ci, & cj, ck, & dmat) ;
273273 dmat. insert ( ck. id , id, d) ;
274274
275275 // remove unnecessary distances
@@ -301,7 +301,7 @@ impl AgglomerativeClustering {
301301#[ cfg( test) ]
302302mod tests {
303303
304- use super :: { AgglomerativeClustering , DistanceMatrix , Metrics } ;
304+ use super :: { AgglomerativeClustering , DistanceMatrix , Linkage } ;
305305
306306 #[ test]
307307 fn test_distance_matrix ( ) {
@@ -348,66 +348,66 @@ mod tests {
348348 55. , 65. , 80. , 75. , 85. ;
349349 90. , 85. , 88. , 92. , 95. ] ;
350350
351- let mut hclust = AgglomerativeClustering :: new ( 1 , Metrics :: Single ) ;
351+ let mut hclust = AgglomerativeClustering :: new ( 1 , Linkage :: Single ) ;
352352 let _ = hclust. train ( & data) ;
353353 let exp = vec ! [ 12.409673645990857 , 21.307275752662516 , 28.478061731796284 ,
354354 38.1051177665153 , 47.10626285325551 , 54.31390245600108 ] ;
355355 assert_eq ! ( hclust. distances. unwrap( ) , exp) ;
356356 let exp = vec ! [ ( 1 , 5 ) , ( 2 , 4 ) , ( 0 , 8 ) , ( 6 , 7 ) , ( 3 , 9 ) , ( 10 , 11 ) ] ;
357357 assert_eq ! ( hclust. merged. unwrap( ) , exp) ;
358358
359- let mut hclust = AgglomerativeClustering :: new ( 1 , Metrics :: Complete ) ;
359+ let mut hclust = AgglomerativeClustering :: new ( 1 , Linkage :: Complete ) ;
360360 let _ = hclust. train ( & data) ;
361361 let exp = vec ! [ 12.409673645990857 , 21.307275752662516 , 33.77869150810907 ,
362362 45.58508528016593 , 60.13318551349163 , 91.53141537199127 ] ;
363363 assert_eq ! ( hclust. distances. unwrap( ) , exp) ;
364364 let exp = vec ! [ ( 1 , 5 ) , ( 2 , 4 ) , ( 0 , 8 ) , ( 6 , 7 ) , ( 3 , 9 ) , ( 10 , 11 ) ] ;
365365 assert_eq ! ( hclust. merged. unwrap( ) , exp) ;
366366
367- let mut hclust = AgglomerativeClustering :: new ( 1 , Metrics :: Average ) ;
367+ let mut hclust = AgglomerativeClustering :: new ( 1 , Linkage :: Average ) ;
368368 let _ = hclust. train ( & data) ;
369369 let exp = vec ! [ 12.409673645990857 , 21.307275752662516 , 31.128376619952675 ,
370370 41.84510152334062 , 53.305905710336944 , 69.92295649225116 ] ;
371371 assert_eq ! ( hclust. distances. unwrap( ) , exp) ;
372372 let exp = vec ! [ ( 1 , 5 ) , ( 2 , 4 ) , ( 0 , 8 ) , ( 6 , 7 ) , ( 3 , 9 ) , ( 10 , 11 ) ] ;
373373 assert_eq ! ( hclust. merged. unwrap( ) , exp) ;
374374
375- let mut hclust = AgglomerativeClustering :: new ( 1 , Metrics :: Centroid ) ;
375+ let mut hclust = AgglomerativeClustering :: new ( 1 , Linkage :: Centroid ) ;
376376 let _ = hclust. train ( & data) ;
377377 let exp = vec ! [ 12.409673645990857 , 21.307275752662516 , 25.801557681787045 ,
378378 38.7426831118429 , 44.021013600051624 , 44.02758328256392 ] ;
379379 assert_eq ! ( hclust. distances. unwrap( ) , exp) ;
380380 let exp = vec ! [ ( 1 , 5 ) , ( 2 , 4 ) , ( 0 , 8 ) , ( 6 , 7 ) , ( 3 , 9 ) , ( 10 , 11 ) ] ;
381381 assert_eq ! ( hclust. merged. unwrap( ) , exp) ;
382382
383- let mut hclust = AgglomerativeClustering :: new ( 1 , Metrics :: Median ) ;
383+ let mut hclust = AgglomerativeClustering :: new ( 1 , Linkage :: Median ) ;
384384 let _ = hclust. train ( & data) ;
385385 let exp = vec ! [ 12.409673645990857 , 21.307275752662516 , 25.801557681787045 ,
386386 38.7426831118429 , 45.898926771596045 , 45.42216730738696 ] ;
387387 assert_eq ! ( hclust. distances. unwrap( ) , exp) ;
388388 let exp = vec ! [ ( 1 , 5 ) , ( 2 , 4 ) , ( 0 , 8 ) , ( 6 , 7 ) , ( 3 , 9 ) , ( 10 , 11 ) ] ;
389389 assert_eq ! ( hclust. merged. unwrap( ) , exp) ;
390390
391- let mut hclust = AgglomerativeClustering :: new ( 1 , Metrics :: Ward1 ) ;
391+ let mut hclust = AgglomerativeClustering :: new ( 1 , Linkage :: Ward1 ) ;
392392 let _ = hclust. train ( & data) ;
393393 let exp = vec ! [ 12.409673645990857 , 21.307275752662516 , 34.4020769090494 ,
394394 51.65691081579053 , 66.03152040007744 , 150.95171411164773 ] ;
395395 assert_eq ! ( hclust. distances. unwrap( ) , exp) ;
396396 let exp = vec ! [ ( 1 , 5 ) , ( 2 , 4 ) , ( 0 , 8 ) , ( 6 , 7 ) , ( 3 , 9 ) , ( 10 , 11 ) ] ;
397397 assert_eq ! ( hclust. merged. unwrap( ) , exp) ;
398398
399- let mut hclust = AgglomerativeClustering :: new ( 1 , Metrics :: Ward2 ) ;
399+ let mut hclust = AgglomerativeClustering :: new ( 1 , Linkage :: Ward2 ) ;
400400 let _ = hclust. train ( & data) ;
401401 let exp = vec ! [ 12.409673645990857 , 21.307275752662516 , 33.911649915626334 ,
402- 47.97916214358062 , 62.481997407253225 , 115.91869071527186 ] ;
402+ 47.97916214358062 , 62.48199740725323 , 115.91869071527186 ] ;
403403 assert_eq ! ( hclust. distances. unwrap( ) , exp) ;
404404 let exp = vec ! [ ( 1 , 5 ) , ( 2 , 4 ) , ( 0 , 8 ) , ( 6 , 7 ) , ( 3 , 9 ) , ( 10 , 11 ) ] ;
405405 assert_eq ! ( hclust. merged. unwrap( ) , exp) ;
406406
407- let mut hclust = AgglomerativeClustering :: new ( 1 , Metrics :: Ward ) ;
407+ let mut hclust = AgglomerativeClustering :: new ( 1 , Linkage :: Ward ) ;
408408 let _ = hclust. train ( & data) ;
409409 let exp = vec ! [ 12.409673645990857 , 21.307275752662516 , 33.911649915626334 ,
410- 47.97916214358062 , 62.481997407253225 , 115.91869071527186 ] ;
410+ 47.97916214358062 , 62.48199740725323 , 115.91869071527186 ] ;
411411 assert_eq ! ( hclust. distances. unwrap( ) , exp) ;
412412 let exp = vec ! [ ( 1 , 5 ) , ( 2 , 4 ) , ( 0 , 8 ) , ( 6 , 7 ) , ( 3 , 9 ) , ( 10 , 11 ) ] ;
413413 assert_eq ! ( hclust. merged. unwrap( ) , exp) ;
0 commit comments