diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 61091bb803e49..fbb2a0656e575 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -37,6 +37,7 @@ private[spark] class DTStatsAggregator( val impurityAggregator: ImpurityAggregator = metadata.impurity match { case Gini => new GiniAggregator(metadata.numClasses) case Entropy => new EntropyAggregator(metadata.numClasses) + case ChiSquared => new ChiSquaredAggregator(metadata.numClasses) case Variance => new VarianceAggregator() case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 82e1ed85a0a14..277d4e15f0912 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.impurity.{Impurity, ImpurityCalculator} import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -670,14 +670,32 @@ private[spark] object RandomForest extends Logging { val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() - val leftWeight = leftCount / totalCount.toDouble - val rightWeight = rightCount / totalCount.toDouble + val gain = metadata.impurity match { + case imp if (imp.isTestStatistic) => + // For split quality measures based on a test-statistic, run the test on the + // left and right sub-populations to get a p-value for the null hypothesis + val pval = imp.calculate(leftImpurityCalculator, rightImpurityCalculator) + // Transform the test statistic p-val into a larger-is-better gain value + Impurity.pValToGain(pval) + + case _ => + // Default purity-gain logic: + // measure the weighted decrease in impurity from parent to the left and right + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble + + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + // If the impurity being used is a test statistic p-val, apply a standard transform into + // a larger-is-better gain value for the minimum-gain threshold + val minGain = + if (metadata.impurity.isTestStatistic) Impurity.pValToGain(metadata.minInfoGain) + else metadata.minInfoGain // if information gain doesn't satisfy minimum information gain, // then this split is invalid, return invalid information gain stats. - if (gain < metadata.minInfoGain) { + if (gain < minGain) { return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 3fc3ac58b7795..924d7676eaa46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} +import org.apache.spark.mllib.tree.impurity.{ChiSquared => OldChiSquared, Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -213,7 +213,7 @@ private[ml] trait TreeClassifierParams extends Params { /** * Criterion used for information gain calculation (case-insensitive). - * Supported: "entropy" and "gini". + * Supported: "entropy", "gini", "chisquared". * (default = gini) * @group param */ @@ -240,6 +240,7 @@ private[ml] trait TreeClassifierParams extends Params { getImpurity match { case "entropy" => OldEntropy case "gini" => OldGini + case "chisquared" => OldChiSquared case _ => // Should never happen because of check in setter method. throw new RuntimeException( @@ -251,7 +252,7 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. final val supportedImpurities: Array[String] = - Array("entropy", "gini").map(_.toLowerCase(Locale.ROOT)) + Array("entropy", "gini", "chisquared").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait DecisionTreeClassifierParams diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 58e8f5be7b9f0..f6e736d9bc9e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.impurity.{ChiSquared, Entropy, Gini, Impurity, Variance} /** * Stores all the configuration options for tree construction @@ -140,7 +140,7 @@ class Strategy @Since("1.3.0") ( require(numClasses >= 2, s"DecisionTree Strategy for Classification must have numClasses >= 2," + s" but numClasses = $numClasses.") - require(Set(Gini, Entropy).contains(impurity), + require(Set(Gini, Entropy, ChiSquared).contains(impurity), s"DecisionTree Strategy given invalid impurity for Classification: $impurity." + s" Valid settings: Gini, Entropy") case Regression => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ChiSquared.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ChiSquared.scala new file mode 100644 index 0000000000000..d331ab550911a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/ChiSquared.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} + +/** + * :: Experimental :: + * Class for calculating Chi Squared as a split quality metric during binary classification. + */ +@Since("2.2.0") +@Experimental +object ChiSquared extends Impurity { + private object CSTest extends org.apache.commons.math3.stat.inference.ChiSquareTest() + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + @Since("1.1.0") + def instance: this.type = this + + /** + * :: DeveloperApi :: + * Placeholding definition of classification-based purity. + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return This method will throw an exception for [[ChiSquared]] + */ + @Since("1.1.0") + @DeveloperApi + override def calculate(counts: Array[Double], totalCount: Double): Double = + throw new UnsupportedOperationException("ChiSquared.calculate") + + /** + * :: DeveloperApi :: + * Placeholding definition of regression-based purity. + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return This method will throw an exception for [[ChiSquared]] + */ + @Since("1.0.0") + @DeveloperApi + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("ChiSquared.calculate") + + /** + * :: DeveloperApi :: + * Chi-squared p-values from [[ImpurityCalculator]] for left and right split populations + * @param calcL impurity calculator for the left split population + * @param calcR impurity calculator for the right split population + * @return The p-value for the chi squared null hypothesis; that left and right split populations + * represent the same distribution of categorical values + */ + @Since("2.0.0") + @DeveloperApi + override def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double = { + CSTest.chiSquareTest( + Array( + calcL.stats.map(_.toLong), + calcR.stats.map(_.toLong) + ) + ) + } + + /** + * :: DeveloperApi :: + * Determine if this impurity measure is a test-statistic measure (true for Chi-squared) + * @return For [[ChiSquared]] will return true + */ + @Since("2.0.0") + @DeveloperApi + override def isTestStatistic: Boolean = true +} + +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param numClasses Number of classes for label. + */ +private[spark] class ChiSquaredAggregator(numClasses: Int) + extends ImpurityAggregator(numClasses) with Serializable { + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { + allStats(offset + label.toInt) += instanceWeight + } + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): ChiSquaredCalculator = { + new ChiSquaredCalculator(allStats.view(offset, offset + statsSize).toArray) + } +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * This class stores its own data and is for a specific (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ +private[spark] class ChiSquaredCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: ChiSquaredCalculator = new ChiSquaredCalculator(stats.clone()) + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double = 1.0 + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long = stats.sum.toLong + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double = + if (count == 0) 0 else indexOfLargestArrayElement(stats) + + /** + * Probability of the label given by [[predict]]. + */ + override def prob(label: Double): Double = { + val lbl = label.toInt + require(lbl < stats.length, + s"ChiSquaredCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "ChiSquaredImpurity does not support negative labels") + val cnt = count + if (cnt == 0) 0 else (stats(lbl) / cnt) + } + + /** output in a string format */ + override def toString: String = s"ChiSquaredCalculator(stats = [${stats.mkString(", ")}])" +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index d4448da9eef51..b1dd1006b3151 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -74,6 +74,23 @@ object Entropy extends Impurity { @Since("1.1.0") def instance: this.type = this + /** + * :: DeveloperApi :: + * p-values for test-statistic measures, unsupported for [[Entropy]] + */ + @Since("2.2.0") + @DeveloperApi + def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double = + throw new UnsupportedOperationException("Entropy.calculate") + + /** + * :: DeveloperApi :: + * Determine if this impurity measure is a test-statistic measure + * @return For [[Entropy]] will return false + */ + @Since("2.2.0") + @DeveloperApi + def isTestStatistic: Boolean = false } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index c5e34ffa4f2e5..209abd42d275a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -71,6 +71,23 @@ object Gini extends Impurity { @Since("1.1.0") def instance: this.type = this + /** + * :: DeveloperApi :: + * p-values for test-statistic measures, unsupported for [[Gini]] + */ + @Since("2.2.0") + @DeveloperApi + def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double = + throw new UnsupportedOperationException("Gini.calculate") + + /** + * :: DeveloperApi :: + * Determine if this impurity measure is a test-statistic measure + * @return For [[Gini]] will return false + */ + @Since("2.2.0") + @DeveloperApi + def isTestStatistic: Boolean = false } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala index 9a6452aa13a61..a680f88ab3cf7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala @@ -26,6 +26,7 @@ private[mllib] object Impurities { case "gini" => Gini case "entropy" => Entropy case "variance" => Variance + case "chisquared" => ChiSquared case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 4c7746869dde1..d886c3029d42c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -52,6 +52,49 @@ trait Impurity extends Serializable { @Since("1.0.0") @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double + + /** + * :: DeveloperApi :: + * Compute a test-statistic p-value quality measure from left and right split populations + * @param calcL impurity calculator for the left split population + * @param calcR impurity calculator for the right split population + * @return The p-value for the null hypothesis; that left and right split populations + * represent the same distribution + * @note Unless overridden this method will fail with an exception, for backward compatability + */ + @Since("2.2.0") + @DeveloperApi + def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double + + /** + * :: DeveloperApi :: + * Determine if this impurity measure is a test-statistic measure + * @return True if this is a split quality measure based on a test statistic (i.e. returns a + * p-value) or false otherwise. + * @note Unless overridden this method returns false by default, for backward compatability + */ + @Since("2.2.0") + @DeveloperApi + def isTestStatistic: Boolean +} + +/** + * :: DeveloperApi :: + * Utility functions for Impurity measures + */ +@Since("2.0.0") +@DeveloperApi +object Impurity { + /** + * :: DeveloperApi :: + * Convert a test-statistic p-value into a "larger-is-better" gain value. + * @param pval The test statistic p-value + * @return The negative logarithm of the p-value. Any p-values smaller than 10^-20 are clipped + * to 10^-20 to prevent arithmetic errors + */ + @Since("2.0.0") + @DeveloperApi + def pValToGain(pval: Double): Double = -math.log(math.max(1e-20, pval)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index c9bf0db4de3c2..7951183e0b4b2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -62,6 +62,23 @@ object Variance extends Impurity { @Since("1.0.0") def instance: this.type = this + /** + * :: DeveloperApi :: + * p-values for test-statistic measures, unsupported for [[Variance]] + */ + @Since("2.2.0") + @DeveloperApi + def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double = + throw new UnsupportedOperationException("Variance.calculate") + + /** + * :: DeveloperApi :: + * Determine if this impurity measure is a test-statistic measure + * @return For [[Variance]] will return false + */ + @Since("2.2.0") + @DeveloperApi + def isTestStatistic: Boolean = false } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 918ab27e2730b..cfbd5ba977d9f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -237,6 +237,41 @@ class DecisionTreeClassifierSuite compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) } + test("split quality using chi-squared and minimum gain") { + // Generate a data set where the 1st feature is useful and the others are noise + val features = Vector.fill(200) { + Array.fill(3) { scala.util.Random.nextInt(2).toDouble } + } + val labels = features.map { fv => + LabeledPoint(if (fv(0) == 1.0) 1.0 else 0.0, Vectors.dense(fv)) + } + val rdd = sc.parallelize(labels) + + // two-class learning problem + val numClasses = 2 + // all binary features + val catFeatures = Map(Vector.tabulate(features.head.length) { j => (j, 2) } : _*) + + // Chi-squared split quality with a p-value threshold of 0.01 should allow + // only the first feature to be used since the others are uncorrelated noise + val train: DataFrame = TreeTests.setMetadata(rdd, catFeatures, numClasses) + val dt = new DecisionTreeClassifier() + .setImpurity("chisquared") + .setMaxDepth(5) + .setMinInfoGain(0.01) + val treeModel = dt.fit(train) + + // The tree should use exactly one of the 3 features: feature(0) + val featImps = treeModel.featureImportances + assert(treeModel.depth === 1) + assert(featImps.size === 3) + assert(featImps(0) === 1.0) + assert(featImps(1) === 0.0) + assert(featImps(2) === 0.0) + + compareAPIs(rdd, dt, catFeatures, numClasses) + } + test("predictRaw and predictProbability") { val rdd = continuousDataPointsForMulticlassRDD val dt = new DecisionTreeClassifier() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1793da03a2c3e..391e0a718194e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -139,6 +139,10 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX"), // [SPARK-16240] ML persistence backward compatibility for LDA ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$"), + // [SPARK-15699][ML] Add chi-squared test statistic as a split quality metric + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.ChiSquared.calculate"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.tree.impurity.Impurity.calculate"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.tree.impurity.Impurity.isTestStatistic"), // [SPARK-17717] Add Find and Exists method to Catalog. ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getDatabase"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getTable"), diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4af6f71e19257..525fe94693a02 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -664,7 +664,7 @@ class TreeClassifierParams(object): .. versionadded:: 1.4.0 """ - supportedImpurities = ["entropy", "gini"] + supportedImpurities = ["entropy", "gini", "chisquared"] impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 619fa16d463f5..9fa0838b6ae33 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -162,7 +162,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, indexed from 0: {0, 1, ..., k-1}. :param impurity: Criterion used for information gain calculation. - Supported values: "gini" or "entropy". + Supported values: "gini", "entropy" or "chisquared". (default: "gini") :param maxDepth: Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1