Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-15699] [ML] Implement a Chi-Squared test statistic option for measuring split quality #13440

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ private[spark] class DTStatsAggregator(
val impurityAggregator: ImpurityAggregator = metadata.impurity match {
case Gini => new GiniAggregator(metadata.numClasses)
case Entropy => new EntropyAggregator(metadata.numClasses)
case ChiSquared => new ChiSquaredAggregator(metadata.numClasses)
case Variance => new VarianceAggregator()
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.impurity.{Impurity, ImpurityCalculator}
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -670,14 +670,32 @@ private[spark] object RandomForest extends Logging {
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()

val leftWeight = leftCount / totalCount.toDouble
val rightWeight = rightCount / totalCount.toDouble
val gain = metadata.impurity match {
case imp if (imp.isTestStatistic) =>
// For split quality measures based on a test-statistic, run the test on the
// left and right sub-populations to get a p-value for the null hypothesis
val pval = imp.calculate(leftImpurityCalculator, rightImpurityCalculator)
// Transform the test statistic p-val into a larger-is-better gain value
Impurity.pValToGain(pval)

case _ =>
// Default purity-gain logic:
// measure the weighted decrease in impurity from parent to the left and right
val leftWeight = leftCount / totalCount.toDouble
val rightWeight = rightCount / totalCount.toDouble

impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
}

val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
// If the impurity being used is a test statistic p-val, apply a standard transform into
// a larger-is-better gain value for the minimum-gain threshold
val minGain =
if (metadata.impurity.isTestStatistic) Impurity.pValToGain(metadata.minInfoGain)
else metadata.minInfoGain
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of a design question here... right now the caller has to switch logic based on what's inside metadata. Can methods like metadata.minInfoGain just implement different logic when the impurity is a test statistic, and so on? push this down towards the impurity implementation? I wonder if isTestStatistic can go away with the right API, but I am not familiar with the details of what that requires.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main issue I recall was that all of the existing metrics assume some kind of "larger is better" gain, and p-values are "smaller is better." I'll take another pass over it and see if I can push that distinction down so it doesn't require exposing new methods.


// if information gain doesn't satisfy minimum information gain,
// then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) {
if (gain < minGain) {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.impurity.{ChiSquared => OldChiSquared, Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}

Expand Down Expand Up @@ -213,7 +213,7 @@ private[ml] trait TreeClassifierParams extends Params {

/**
* Criterion used for information gain calculation (case-insensitive).
* Supported: "entropy" and "gini".
* Supported: "entropy", "gini", "chisquared".
* (default = gini)
* @group param
*/
Expand All @@ -240,6 +240,7 @@ private[ml] trait TreeClassifierParams extends Params {
getImpurity match {
case "entropy" => OldEntropy
case "gini" => OldGini
case "chisquared" => OldChiSquared
case _ =>
// Should never happen because of check in setter method.
throw new RuntimeException(
Expand All @@ -251,7 +252,7 @@ private[ml] trait TreeClassifierParams extends Params {
private[ml] object TreeClassifierParams {
// These options should be lowercase.
final val supportedImpurities: Array[String] =
Array("entropy", "gini").map(_.toLowerCase(Locale.ROOT))
Array("entropy", "gini", "chisquared").map(_.toLowerCase(Locale.ROOT))
}

private[ml] trait DecisionTreeClassifierParams
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
import org.apache.spark.mllib.tree.impurity.{ChiSquared, Entropy, Gini, Impurity, Variance}

/**
* Stores all the configuration options for tree construction
Expand Down Expand Up @@ -140,7 +140,7 @@ class Strategy @Since("1.3.0") (
require(numClasses >= 2,
s"DecisionTree Strategy for Classification must have numClasses >= 2," +
s" but numClasses = $numClasses.")
require(Set(Gini, Entropy).contains(impurity),
require(Set(Gini, Entropy, ChiSquared).contains(impurity),
s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
s" Valid settings: Gini, Entropy")
case Regression =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impurity

import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}

/**
* :: Experimental ::
* Class for calculating Chi Squared as a split quality metric during binary classification.
*/
@Since("2.2.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will have to be 2.5.0 for the moment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update those. 3.0 might be a good target, especially if I can't do this without new isTestStatistic

@Experimental
object ChiSquared extends Impurity {
private object CSTest extends org.apache.commons.math3.stat.inference.ChiSquareTest()

/**
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
@Since("1.1.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd label all these as Since 2.5.0 even if they override a method that existed earlier.

def instance: this.type = this

/**
* :: DeveloperApi ::
* Placeholding definition of classification-based purity.
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return This method will throw an exception for [[ChiSquared]]
*/
@Since("1.1.0")
@DeveloperApi
override def calculate(counts: Array[Double], totalCount: Double): Double =
throw new UnsupportedOperationException("ChiSquared.calculate")

/**
* :: DeveloperApi ::
* Placeholding definition of regression-based purity.
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
* @return This method will throw an exception for [[ChiSquared]]
*/
@Since("1.0.0")
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("ChiSquared.calculate")

/**
* :: DeveloperApi ::
* Chi-squared p-values from [[ImpurityCalculator]] for left and right split populations
* @param calcL impurity calculator for the left split population
* @param calcR impurity calculator for the right split population
* @return The p-value for the chi squared null hypothesis; that left and right split populations
* represent the same distribution of categorical values
*/
@Since("2.0.0")
@DeveloperApi
override def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double = {
CSTest.chiSquareTest(
Array(
calcL.stats.map(_.toLong),
calcR.stats.map(_.toLong)
)
)
}

/**
* :: DeveloperApi ::
* Determine if this impurity measure is a test-statistic measure (true for Chi-squared)
* @return For [[ChiSquared]] will return true
*/
@Since("2.0.0")
@DeveloperApi
override def isTestStatistic: Boolean = true
}

/**
* Class for updating views of a vector of sufficient statistics,
* in order to compute impurity from a sample.
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param numClasses Number of classes for label.
*/
private[spark] class ChiSquaredAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable {

/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
allStats(offset + label.toInt) += instanceWeight
}

/**
* Get an [[ImpurityCalculator]] for a (node, feature, bin).
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): ChiSquaredCalculator = {
new ChiSquaredCalculator(allStats.view(offset, offset + statsSize).toArray)
}
}

/**
* Stores statistics for one (node, feature, bin) for calculating impurity.
* This class stores its own data and is for a specific (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
private[spark] class ChiSquaredCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {

/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
def copy: ChiSquaredCalculator = new ChiSquaredCalculator(stats.clone())

/**
* Calculate the impurity from the stored sufficient statistics.
*/
def calculate(): Double = 1.0

/**
* Number of data points accounted for in the sufficient statistics.
*/
def count: Long = stats.sum.toLong

/**
* Prediction which should be made based on the sufficient statistics.
*/
def predict: Double =
if (count == 0) 0 else indexOfLargestArrayElement(stats)

/**
* Probability of the label given by [[predict]].
*/
override def prob(label: Double): Double = {
val lbl = label.toInt
require(lbl < stats.length,
s"ChiSquaredCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
require(lbl >= 0, "ChiSquaredImpurity does not support negative labels")
val cnt = count
if (cnt == 0) 0 else (stats(lbl) / cnt)
}

/** output in a string format */
override def toString: String = s"ChiSquaredCalculator(stats = [${stats.mkString(", ")}])"
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,23 @@ object Entropy extends Impurity {
@Since("1.1.0")
def instance: this.type = this

/**
* :: DeveloperApi ::
* p-values for test-statistic measures, unsupported for [[Entropy]]
*/
@Since("2.2.0")
@DeveloperApi
def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double =
throw new UnsupportedOperationException("Entropy.calculate")

/**
* :: DeveloperApi ::
* Determine if this impurity measure is a test-statistic measure
* @return For [[Entropy]] will return false
*/
@Since("2.2.0")
@DeveloperApi
def isTestStatistic: Boolean = false
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ object Gini extends Impurity {
@Since("1.1.0")
def instance: this.type = this

/**
* :: DeveloperApi ::
* p-values for test-statistic measures, unsupported for [[Gini]]
*/
@Since("2.2.0")
@DeveloperApi
def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this new method doesn't make sense to implement for existing implementations, only the new one. That kind of suggests to me it isn't part of the generic API for an impurity. Is this really something that belongs inside the logic of the implementations? maybe there's a more general method that needs to be exposed, that can then be specialized for all implementations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll consider if there's a unifying idea here. pval-based metrics require integrating information across the new split children, which I believe was not the case for existing methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that the generalization is closer to my newer signature
val pval = imp.calculate(leftImpurityCalculator, rightImpurityCalculator)
where you have all the context from the left and right nodes. The existing gain-based calculation should fit into this framework, just doing its current weighted average of purity gain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srowen @willb
I cached the design of the metrics back in. In general, Impurity already uses methods that are only defined on certain impurity sub-classes, and so this new method does not change that situation.

My take on the "problem" is that the existing measures are all based on a localized concept of "purity" (or impurity) that can be calculated using only the data at a single node. Splitting based on statistical tests (p-values) breaks that model, since it is making use of a more generalized concept of split quality that requires the sample populations of both children from a candidate split. A maximally general signature would probably involve the parent and both children.

Another kink in the current design is that ImpurityCalculator is essentially parallel to Impurity, and in fact ImpurityCalculator#calculate() is how impurity measures are currently requested. Impurity seems somewhat redundant, and might be factored out in favor of ImpurityCalculator. The current signature calculate() might be generalized into a more inclusive concept of split quality that expects to make use of {parent,left,right}.

Calls to calculate() are not very wide-spread but threading that change through is outside the scope of this particular PR. If people are interested in that kind of refactoring I could look into it in the near future but probably not in the next couple weeks.

That kind of change would also be API breaking and so a good target for 3.0

throw new UnsupportedOperationException("Gini.calculate")

/**
* :: DeveloperApi ::
* Determine if this impurity measure is a test-statistic measure
* @return For [[Gini]] will return false
*/
@Since("2.2.0")
@DeveloperApi
def isTestStatistic: Boolean = false
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ private[mllib] object Impurities {
case "gini" => Gini
case "entropy" => Entropy
case "variance" => Variance
case "chisquared" => ChiSquared
case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,49 @@ trait Impurity extends Serializable {
@Since("1.0.0")
@DeveloperApi
def calculate(count: Double, sum: Double, sumSquares: Double): Double

/**
* :: DeveloperApi ::
* Compute a test-statistic p-value quality measure from left and right split populations
* @param calcL impurity calculator for the left split population
* @param calcR impurity calculator for the right split population
* @return The p-value for the null hypothesis; that left and right split populations
* represent the same distribution
* @note Unless overridden this method will fail with an exception, for backward compatability
*/
@Since("2.2.0")
@DeveloperApi
def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double

/**
* :: DeveloperApi ::
* Determine if this impurity measure is a test-statistic measure
* @return True if this is a split quality measure based on a test statistic (i.e. returns a
* p-value) or false otherwise.
* @note Unless overridden this method returns false by default, for backward compatability
*/
@Since("2.2.0")
@DeveloperApi
def isTestStatistic: Boolean
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding methods to a public trait is technically an API breaking change. This might be considered a Developer API even though it's not labeled that way. Still if we can avoid adding to the API here, it'd be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be customized or extended externally to spark? I'm wondering why it is public.

}

/**
* :: DeveloperApi ::
* Utility functions for Impurity measures
*/
@Since("2.0.0")
@DeveloperApi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no need for this object to be publicly exposed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. I don't recall any specific motivation to keep it private, but historically Spark seems to default things to "minimum visibility." The only method currently defined here is an implementation detail for hacking p-values into the existing 'gain' system, where larger is assumed to be better.

object Impurity {
/**
* :: DeveloperApi ::
* Convert a test-statistic p-value into a "larger-is-better" gain value.
* @param pval The test statistic p-value
* @return The negative logarithm of the p-value. Any p-values smaller than 10^-20 are clipped
* to 10^-20 to prevent arithmetic errors
*/
@Since("2.0.0")
@DeveloperApi
def pValToGain(pval: Double): Double = -math.log(math.max(1e-20, pval))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private to spark?

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ object Variance extends Impurity {
@Since("1.0.0")
def instance: this.type = this

/**
* :: DeveloperApi ::
* p-values for test-statistic measures, unsupported for [[Variance]]
*/
@Since("2.2.0")
@DeveloperApi
def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double =
throw new UnsupportedOperationException("Variance.calculate")

/**
* :: DeveloperApi ::
* Determine if this impurity measure is a test-statistic measure
* @return For [[Variance]] will return false
*/
@Since("2.2.0")
@DeveloperApi
def isTestStatistic: Boolean = false
}

/**
Expand Down
Loading