Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.

Commit 4c46d25

Browse files
committed
rename linkage, use debug_assert
1 parent 406cce0 commit 4c46d25

2 files changed

Lines changed: 44 additions & 44 deletions

File tree

src/learning/agglomerative.rs

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
//! # Usage
66
//!
77
//! ```
8-
//! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics};
8+
//! use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Linkage};
99
//! use rusty_machine::learning::SupModel;
1010
//! use rusty_machine::linalg::{Matrix, Vector};
1111
//!
1212
//! let inputs = Matrix::new(4, 2, vec![1., 3., 2., 3., 4., 3., 5., 3.]);
13-
//! let mut agg = AgglomerativeClustering::new(2, Metrics::Single);
13+
//! let mut agg = AgglomerativeClustering::new(2, Linkage::Single);
1414
//!
1515
//! // Train the model and get the clustering result
1616
//! let res = agg.train(&inputs).unwrap();
@@ -27,7 +27,7 @@ use learning::{LearningResult};
2727

2828
/// Agglomerative clustering distances
2929
#[derive(Debug)]
30-
pub enum Metrics {
30+
pub enum Linkage {
3131
/// Single linkage clustering
3232
Single,
3333
/// Complete linkage clustering
@@ -49,7 +49,7 @@ pub enum Metrics {
4949
Ward2,
5050
}
5151

52-
impl Metrics {
52+
impl Linkage {
5353

5454
// calculate distance using Lance-Williams algorithm
5555
fn dist(&self, ci: &Cluster, cj: &Cluster, ck: &Cluster, dmat: &DistanceMatrix) -> f64 {
@@ -58,38 +58,38 @@ impl Metrics {
5858
let djk = dmat.get(ck.id, cj.id);
5959

6060
match self {
61-
&Metrics::Single => {
61+
&Linkage::Single => {
6262
// 0.5 * dik + 0.5 * djk + 0. * dij - 0.5 * (dik - djk).abs()
6363
dik.min(djk)
6464
},
65-
&Metrics::Complete => {
65+
&Linkage::Complete => {
6666
// 0.5 * dik + 0.5 * djk + 0. * dij + 0.5 * (dik - djk).abs()
6767
dik.max(djk)
6868
},
69-
&Metrics::Average => {
69+
&Linkage::Average => {
7070
let s = ci.size + cj.size;
7171
ci.size / s * dik + cj.size / s * djk
7272
},
73-
&Metrics::Centroid => {
73+
&Linkage::Centroid => {
7474
let s = ci.size + cj.size;
7575
let ai = ci.size / s;
7676
let aj = cj.size / s;
7777
let dij = dmat.get(ci.id, cj.id);
7878
ai * dik + aj * djk - ai * aj * dij
7979
},
80-
&Metrics::Median => {
80+
&Linkage::Median => {
8181
let dij = dmat.get(ci.id, cj.id);
8282
0.5 * dik + 0.5 * djk - 0.25 * dij
8383
},
84-
&Metrics::Ward1 => {
84+
&Linkage::Ward1 => {
8585
let s = ci.size + cj.size + ck.size;
8686
let dij = dmat.get(ci.id, cj.id);
8787
(ci.size + ck.size) / s * dik + (cj.size + ck.size) / s * djk - ck.size / s * dij
8888
},
89-
&Metrics::Ward | &Metrics::Ward2 => {
89+
&Linkage::Ward | &Linkage::Ward2 => {
9090
let s = ci.size + cj.size + ck.size;
9191
let dij = dmat.get(ci.id, cj.id);
92-
((ci.size + ck.size) / s * dik.powf(2.) + (cj.size + ck.size) / s * djk.powf(2.) - ck.size / s * dij.powf(2.)).sqrt()
92+
((ci.size + ck.size) / s * dik * dik + (cj.size + ck.size) / s * djk * djk - ck.size / s * dij * dij).sqrt()
9393
}
9494
}
9595
}
@@ -147,11 +147,11 @@ impl DistanceMatrix {
147147

148148
unsafe {
149149
for i in 0..n {
150-
for j in i..inputs.rows() {
150+
for j in (i + 1)..inputs.rows() {
151151
let mut val = 0.;
152152
for k in 0..inputs.cols() {
153-
val += (inputs.get_unchecked([i, k]) -
154-
inputs.get_unchecked([j, k])).abs().powf(2.);
153+
let d = inputs.get_unchecked([i, k]) - inputs.get_unchecked([j, k]);
154+
val += d * d;
155155
}
156156
val = val.sqrt();
157157
data.insert((i, j), val);
@@ -177,13 +177,13 @@ impl DistanceMatrix {
177177
/// Add distance between i-th and j-th item
178178
/// i must be smaller than j
179179
fn insert(&mut self, i: usize, j: usize, dist: f64) {
180-
assert!(i < j, "i must be smaller than j");
180+
debug_assert!(i < j, "i must be smaller than j");
181181
self.data.insert((i, j), dist);
182182
}
183183

184184
/// Delete distance between i-th and j-th item
185185
fn delete(&mut self, i: usize, j: usize) {
186-
assert!(i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0");
186+
debug_assert!(i != j, "DistanceMatrix doesn't store distance when i == j, because it is 0.0");
187187
if i > j {
188188
self.data.remove(&(j, i));
189189
} else {
@@ -196,7 +196,7 @@ impl DistanceMatrix {
196196
#[derive(Debug)]
197197
pub struct AgglomerativeClustering {
198198
n: usize,
199-
method: Metrics,
199+
linkage: Linkage,
200200

201201
// internally stores distances / merged history (currently for testing)
202202
distances: Option<Vec<f64>>,
@@ -208,19 +208,19 @@ impl AgglomerativeClustering {
208208
/// Constructs an untrained Decision Tree with specified
209209
///
210210
/// - `n` - Number of clusters
211-
/// - `method` - Distance metrics
211+
/// - `linkage` - Linkage method
212212
///
213213
/// # Examples
214214
///
215215
/// ```
216-
/// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Metrics};
216+
/// use rusty_machine::learning::agglomerative::{AgglomerativeClustering, Linkage};
217217
///
218-
/// let _ = AgglomerativeClustering::new(3, Metrics::Single);
218+
/// let _ = AgglomerativeClustering::new(3, Linkage::Single);
219219
/// ```
220-
pub fn new(n: usize, method: Metrics) -> Self {
220+
pub fn new(n: usize, linkage: Linkage) -> Self {
221221
AgglomerativeClustering {
222222
n: n,
223-
method: method,
223+
linkage: linkage,
224224

225225
distances: None,
226226
merged: None
@@ -269,7 +269,7 @@ impl AgglomerativeClustering {
269269

270270
// update distances using Lance Williams algorithm
271271
for ck in clusters.iter() {
272-
let d = self.method.dist(&ci, &cj, ck, &dmat);
272+
let d = self.linkage.dist(&ci, &cj, ck, &dmat);
273273
dmat.insert(ck.id, id, d);
274274

275275
// remove unnecessary distances
@@ -301,7 +301,7 @@ impl AgglomerativeClustering {
301301
#[cfg(test)]
302302
mod tests {
303303

304-
use super::{AgglomerativeClustering, DistanceMatrix, Metrics};
304+
use super::{AgglomerativeClustering, DistanceMatrix, Linkage};
305305

306306
#[test]
307307
fn test_distance_matrix() {
@@ -348,66 +348,66 @@ mod tests {
348348
55., 65., 80., 75., 85.;
349349
90., 85., 88., 92., 95.];
350350

351-
let mut hclust = AgglomerativeClustering::new(1, Metrics::Single);
351+
let mut hclust = AgglomerativeClustering::new(1, Linkage::Single);
352352
let _ = hclust.train(&data);
353353
let exp = vec![12.409673645990857, 21.307275752662516, 28.478061731796284,
354354
38.1051177665153, 47.10626285325551, 54.31390245600108];
355355
assert_eq!(hclust.distances.unwrap(), exp);
356356
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
357357
assert_eq!(hclust.merged.unwrap(), exp);
358358

359-
let mut hclust = AgglomerativeClustering::new(1, Metrics::Complete);
359+
let mut hclust = AgglomerativeClustering::new(1, Linkage::Complete);
360360
let _ = hclust.train(&data);
361361
let exp = vec![12.409673645990857, 21.307275752662516, 33.77869150810907,
362362
45.58508528016593, 60.13318551349163, 91.53141537199127];
363363
assert_eq!(hclust.distances.unwrap(), exp);
364364
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
365365
assert_eq!(hclust.merged.unwrap(), exp);
366366

367-
let mut hclust = AgglomerativeClustering::new(1, Metrics::Average);
367+
let mut hclust = AgglomerativeClustering::new(1, Linkage::Average);
368368
let _ = hclust.train(&data);
369369
let exp = vec![12.409673645990857, 21.307275752662516, 31.128376619952675,
370370
41.84510152334062, 53.305905710336944, 69.92295649225116];
371371
assert_eq!(hclust.distances.unwrap(), exp);
372372
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
373373
assert_eq!(hclust.merged.unwrap(), exp);
374374

375-
let mut hclust = AgglomerativeClustering::new(1, Metrics::Centroid);
375+
let mut hclust = AgglomerativeClustering::new(1, Linkage::Centroid);
376376
let _ = hclust.train(&data);
377377
let exp = vec![12.409673645990857, 21.307275752662516, 25.801557681787045,
378378
38.7426831118429, 44.021013600051624, 44.02758328256392];
379379
assert_eq!(hclust.distances.unwrap(), exp);
380380
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
381381
assert_eq!(hclust.merged.unwrap(), exp);
382382

383-
let mut hclust = AgglomerativeClustering::new(1, Metrics::Median);
383+
let mut hclust = AgglomerativeClustering::new(1, Linkage::Median);
384384
let _ = hclust.train(&data);
385385
let exp = vec![12.409673645990857, 21.307275752662516, 25.801557681787045,
386386
38.7426831118429, 45.898926771596045, 45.42216730738696];
387387
assert_eq!(hclust.distances.unwrap(), exp);
388388
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
389389
assert_eq!(hclust.merged.unwrap(), exp);
390390

391-
let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward1);
391+
let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward1);
392392
let _ = hclust.train(&data);
393393
let exp = vec![12.409673645990857, 21.307275752662516, 34.4020769090494,
394394
51.65691081579053, 66.03152040007744, 150.95171411164773];
395395
assert_eq!(hclust.distances.unwrap(), exp);
396396
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
397397
assert_eq!(hclust.merged.unwrap(), exp);
398398

399-
let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward2);
399+
let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward2);
400400
let _ = hclust.train(&data);
401401
let exp = vec![12.409673645990857, 21.307275752662516, 33.911649915626334,
402-
47.97916214358062, 62.481997407253225, 115.91869071527186];
402+
47.97916214358062, 62.48199740725323, 115.91869071527186];
403403
assert_eq!(hclust.distances.unwrap(), exp);
404404
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
405405
assert_eq!(hclust.merged.unwrap(), exp);
406406

407-
let mut hclust = AgglomerativeClustering::new(1, Metrics::Ward);
407+
let mut hclust = AgglomerativeClustering::new(1, Linkage::Ward);
408408
let _ = hclust.train(&data);
409409
let exp = vec![12.409673645990857, 21.307275752662516, 33.911649915626334,
410-
47.97916214358062, 62.481997407253225, 115.91869071527186];
410+
47.97916214358062, 62.48199740725323, 115.91869071527186];
411411
assert_eq!(hclust.distances.unwrap(), exp);
412412
let exp = vec![(1, 5), (2, 4), (0, 8), (6, 7), (3, 9), (10, 11)];
413413
assert_eq!(hclust.merged.unwrap(), exp);

tests/learning/agglomerative.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use rm::linalg::{Matrix, Vector};
2-
use rm::learning::agglomerative::{AgglomerativeClustering, Metrics};
2+
use rm::learning::agglomerative::{AgglomerativeClustering, Linkage};
33

44
#[test]
55
fn test_cluster() {
@@ -11,42 +11,42 @@ fn test_cluster() {
1111
55., 65., 80., 75., 85.,
1212
90., 85., 88., 92., 95.]);
1313

14-
let mut hclust = AgglomerativeClustering::new(3, Metrics::Single);
14+
let mut hclust = AgglomerativeClustering::new(3, Linkage::Single);
1515
let res = hclust.train(&data);
1616
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
1717
assert_eq!(res.unwrap(), exp);
1818

19-
let mut hclust = AgglomerativeClustering::new(3, Metrics::Complete);
19+
let mut hclust = AgglomerativeClustering::new(3, Linkage::Complete);
2020
let res = hclust.train(&data);
2121
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
2222
assert_eq!(res.unwrap(), exp);
2323

24-
let mut hclust = AgglomerativeClustering::new(3, Metrics::Average);
24+
let mut hclust = AgglomerativeClustering::new(3, Linkage::Average);
2525
let res = hclust.train(&data);
2626
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
2727
assert_eq!(res.unwrap(), exp);
2828

29-
let mut hclust = AgglomerativeClustering::new(3, Metrics::Centroid);
29+
let mut hclust = AgglomerativeClustering::new(3, Linkage::Centroid);
3030
let res = hclust.train(&data);
3131
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
3232
assert_eq!(res.unwrap(), exp);
3333

34-
let mut hclust = AgglomerativeClustering::new(3, Metrics::Median);
34+
let mut hclust = AgglomerativeClustering::new(3, Linkage::Median);
3535
let res = hclust.train(&data);
3636
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
3737
assert_eq!(res.unwrap(), exp);
3838

39-
let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward1);
39+
let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward1);
4040
let res = hclust.train(&data);
4141
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
4242
assert_eq!(res.unwrap(), exp);
4343

44-
let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward2);
44+
let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward2);
4545
let res = hclust.train(&data);
4646
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
4747
assert_eq!(res.unwrap(), exp);
4848

49-
let mut hclust = AgglomerativeClustering::new(3, Metrics::Ward);
49+
let mut hclust = AgglomerativeClustering::new(3, Linkage::Ward);
5050
let res = hclust.train(&data);
5151
let exp = Vector::new(vec![1, 2, 1, 0, 1, 2, 2]);
5252
assert_eq!(res.unwrap(), exp);

0 commit comments

Comments
 (0)