001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.stat.clustering;
019    
020    import java.util.ArrayList;
021    import java.util.Collection;
022    import java.util.List;
023    import java.util.Random;
024    
025    /**
026     * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
027     * @param <T> type of the points to cluster
028     * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
029     * @version $Revision: 811685 $ $Date: 2009-09-05 13:36:48 -0400 (Sat, 05 Sep 2009) $
030     * @since 2.0
031     */
032    public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
033    
034        /** Random generator for choosing initial centers. */
035        private final Random random;
036    
037        /** Build a clusterer.
038         * @param random random generator to use for choosing initial centers
039         */
040        public KMeansPlusPlusClusterer(final Random random) {
041            this.random = random;
042        }
043    
044        /**
045         * Runs the K-means++ clustering algorithm.
046         *
047         * @param points the points to cluster
048         * @param k the number of clusters to split the data into
049         * @param maxIterations the maximum number of iterations to run the algorithm
050         *     for.  If negative, no maximum will be used
051         * @return a list of clusters containing the points
052         */
053        public List<Cluster<T>> cluster(final Collection<T> points,
054                                        final int k, final int maxIterations) {
055            // create the initial clusters
056            List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
057            assignPointsToClusters(clusters, points);
058    
059            // iterate through updating the centers until we're done
060            final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
061            for (int count = 0; count < max; count++) {
062                boolean clusteringChanged = false;
063                List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
064                for (final Cluster<T> cluster : clusters) {
065                    final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
066                    if (!newCenter.equals(cluster.getCenter())) {
067                        clusteringChanged = true;
068                    }
069                    newClusters.add(new Cluster<T>(newCenter));
070                }
071                if (!clusteringChanged) {
072                    return clusters;
073                }
074                assignPointsToClusters(newClusters, points);
075                clusters = newClusters;
076            }
077            return clusters;
078        }
079    
080        /**
081         * Adds the given points to the closest {@link Cluster}.
082         *
083         * @param <T> type of the points to cluster
084         * @param clusters the {@link Cluster}s to add the points to
085         * @param points the points to add to the given {@link Cluster}s
086         */
087        private static <T extends Clusterable<T>> void
088            assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
089            for (final T p : points) {
090                Cluster<T> cluster = getNearestCluster(clusters, p);
091                cluster.addPoint(p);
092            }
093        }
094    
095        /**
096         * Use K-means++ to choose the initial centers.
097         *
098         * @param <T> type of the points to cluster
099         * @param points the points to choose the initial centers from
100         * @param k the number of centers to choose
101         * @param random random generator to use
102         * @return the initial centers
103         */
104        private static <T extends Clusterable<T>> List<Cluster<T>>
105            chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
106    
107            final List<T> pointSet = new ArrayList<T>(points);
108            final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
109    
110            // Choose one center uniformly at random from among the data points.
111            final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
112            resultSet.add(new Cluster<T>(firstPoint));
113    
114            final double[] dx2 = new double[pointSet.size()];
115            while (resultSet.size() < k) {
116                // For each data point x, compute D(x), the distance between x and
117                // the nearest center that has already been chosen.
118                int sum = 0;
119                for (int i = 0; i < pointSet.size(); i++) {
120                    final T p = pointSet.get(i);
121                    final Cluster<T> nearest = getNearestCluster(resultSet, p);
122                    final double d = p.distanceFrom(nearest.getCenter());
123                    sum += d * d;
124                    dx2[i] = sum;
125                }
126    
127                // Add one new data point as a center. Each point x is chosen with
128                // probability proportional to D(x)2
129                final double r = random.nextDouble() * sum;
130                for (int i = 0 ; i < dx2.length; i++) {
131                    if (dx2[i] >= r) {
132                        final T p = pointSet.remove(i);
133                        resultSet.add(new Cluster<T>(p));
134                        break;
135                    }
136                }
137            }
138    
139            return resultSet;
140    
141        }
142    
143        /**
144         * Returns the nearest {@link Cluster} to the given point
145         *
146         * @param <T> type of the points to cluster
147         * @param clusters the {@link Cluster}s to search
148         * @param point the point to find the nearest {@link Cluster} for
149         * @return the nearest {@link Cluster} to the given point
150         */
151        private static <T extends Clusterable<T>> Cluster<T>
152            getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
153            double minDistance = Double.MAX_VALUE;
154            Cluster<T> minCluster = null;
155            for (final Cluster<T> c : clusters) {
156                final double distance = point.distanceFrom(c.getCenter());
157                if (distance < minDistance) {
158                    minDistance = distance;
159                    minCluster = c;
160                }
161            }
162            return minCluster;
163        }
164    
165    }