-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-15699] [ML] Implement a Chi-Squared test statistic option for measuring split quality #13440
Changes from all commits
acbc515
ae8f7ea
cb76359
345ba6a
11e6ed5
bb2f660
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.impurity | ||
|
||
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} | ||
|
||
/** | ||
* :: Experimental :: | ||
* Class for calculating Chi Squared as a split quality metric during binary classification. | ||
*/ | ||
@Since("2.2.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will have to be 2.5.0 for the moment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll update those. 3.0 might be a good target, especially if I can't do this without new |
||
@Experimental | ||
object ChiSquared extends Impurity { | ||
private object CSTest extends org.apache.commons.math3.stat.inference.ChiSquareTest() | ||
|
||
/** | ||
* Get this impurity instance. | ||
* This is useful for passing impurity parameters to a Strategy in Java. | ||
*/ | ||
@Since("1.1.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I'd label all these as Since 2.5.0 even if they override a method that existed earlier. |
||
def instance: this.type = this | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Placeholding definition of classification-based purity. | ||
* @param counts Array[Double] with counts for each label | ||
* @param totalCount sum of counts for all labels | ||
* @return This method will throw an exception for [[ChiSquared]] | ||
*/ | ||
@Since("1.1.0") | ||
@DeveloperApi | ||
override def calculate(counts: Array[Double], totalCount: Double): Double = | ||
throw new UnsupportedOperationException("ChiSquared.calculate") | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Placeholding definition of regression-based purity. | ||
* @param count number of instances | ||
* @param sum sum of labels | ||
* @param sumSquares summation of squares of the labels | ||
* @return This method will throw an exception for [[ChiSquared]] | ||
*/ | ||
@Since("1.0.0") | ||
@DeveloperApi | ||
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = | ||
throw new UnsupportedOperationException("ChiSquared.calculate") | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Chi-squared p-values from [[ImpurityCalculator]] for left and right split populations | ||
* @param calcL impurity calculator for the left split population | ||
* @param calcR impurity calculator for the right split population | ||
* @return The p-value for the chi squared null hypothesis; that left and right split populations | ||
* represent the same distribution of categorical values | ||
*/ | ||
@Since("2.0.0") | ||
@DeveloperApi | ||
override def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double = { | ||
CSTest.chiSquareTest( | ||
Array( | ||
calcL.stats.map(_.toLong), | ||
calcR.stats.map(_.toLong) | ||
) | ||
) | ||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Determine if this impurity measure is a test-statistic measure (true for Chi-squared) | ||
* @return For [[ChiSquared]] will return true | ||
*/ | ||
@Since("2.0.0") | ||
@DeveloperApi | ||
override def isTestStatistic: Boolean = true | ||
} | ||
|
||
/** | ||
* Class for updating views of a vector of sufficient statistics, | ||
* in order to compute impurity from a sample. | ||
* Note: Instances of this class do not hold the data; they operate on views of the data. | ||
* @param numClasses Number of classes for label. | ||
*/ | ||
private[spark] class ChiSquaredAggregator(numClasses: Int) | ||
extends ImpurityAggregator(numClasses) with Serializable { | ||
|
||
/** | ||
* Update stats for one (node, feature, bin) with the given label. | ||
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. | ||
* @param offset Start index of stats for this (node, feature, bin). | ||
*/ | ||
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { | ||
allStats(offset + label.toInt) += instanceWeight | ||
} | ||
|
||
/** | ||
* Get an [[ImpurityCalculator]] for a (node, feature, bin). | ||
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. | ||
* @param offset Start index of stats for this (node, feature, bin). | ||
*/ | ||
def getCalculator(allStats: Array[Double], offset: Int): ChiSquaredCalculator = { | ||
new ChiSquaredCalculator(allStats.view(offset, offset + statsSize).toArray) | ||
} | ||
} | ||
|
||
/** | ||
* Stores statistics for one (node, feature, bin) for calculating impurity. | ||
* This class stores its own data and is for a specific (node, feature, bin). | ||
* @param stats Array of sufficient statistics for a (node, feature, bin). | ||
*/ | ||
private[spark] class ChiSquaredCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { | ||
|
||
/** | ||
* Make a deep copy of this [[ImpurityCalculator]]. | ||
*/ | ||
def copy: ChiSquaredCalculator = new ChiSquaredCalculator(stats.clone()) | ||
|
||
/** | ||
* Calculate the impurity from the stored sufficient statistics. | ||
*/ | ||
def calculate(): Double = 1.0 | ||
|
||
/** | ||
* Number of data points accounted for in the sufficient statistics. | ||
*/ | ||
def count: Long = stats.sum.toLong | ||
|
||
/** | ||
* Prediction which should be made based on the sufficient statistics. | ||
*/ | ||
def predict: Double = | ||
if (count == 0) 0 else indexOfLargestArrayElement(stats) | ||
|
||
/** | ||
* Probability of the label given by [[predict]]. | ||
*/ | ||
override def prob(label: Double): Double = { | ||
val lbl = label.toInt | ||
require(lbl < stats.length, | ||
s"ChiSquaredCalculator.prob given invalid label: $lbl (should be < ${stats.length}") | ||
require(lbl >= 0, "ChiSquaredImpurity does not support negative labels") | ||
val cnt = count | ||
if (cnt == 0) 0 else (stats(lbl) / cnt) | ||
} | ||
|
||
/** output in a string format */ | ||
override def toString: String = s"ChiSquaredCalculator(stats = [${stats.mkString(", ")}])" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,6 +71,23 @@ object Gini extends Impurity { | |
@Since("1.1.0") | ||
def instance: this.type = this | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* p-values for test-statistic measures, unsupported for [[Gini]] | ||
*/ | ||
@Since("2.2.0") | ||
@DeveloperApi | ||
def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like this new method doesn't make sense to implement for existing implementations, only the new one. That kind of suggests to me it isn't part of the generic API for an impurity. Is this really something that belongs inside the logic of the implementations? maybe there's a more general method that needs to be exposed, that can then be specialized for all implementations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll consider if there's a unifying idea here. pval-based metrics require integrating information across the new split children, which I believe was not the case for existing methods. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect that the generalization is closer to my newer signature There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @srowen @willb My take on the "problem" is that the existing measures are all based on a localized concept of "purity" (or impurity) that can be calculated using only the data at a single node. Splitting based on statistical tests (p-values) breaks that model, since it is making use of a more generalized concept of split quality that requires the sample populations of both children from a candidate split. A maximally general signature would probably involve the parent and both children. Another kink in the current design is that Calls to That kind of change would also be API breaking and so a good target for 3.0 |
||
throw new UnsupportedOperationException("Gini.calculate") | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Determine if this impurity measure is a test-statistic measure | ||
* @return For [[Gini]] will return false | ||
*/ | ||
@Since("2.2.0") | ||
@DeveloperApi | ||
def isTestStatistic: Boolean = false | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,49 @@ trait Impurity extends Serializable { | |
@Since("1.0.0") | ||
@DeveloperApi | ||
def calculate(count: Double, sum: Double, sumSquares: Double): Double | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Compute a test-statistic p-value quality measure from left and right split populations | ||
* @param calcL impurity calculator for the left split population | ||
* @param calcR impurity calculator for the right split population | ||
* @return The p-value for the null hypothesis; that left and right split populations | ||
* represent the same distribution | ||
* @note Unless overridden this method will fail with an exception, for backward compatability | ||
*/ | ||
@Since("2.2.0") | ||
@DeveloperApi | ||
def calculate(calcL: ImpurityCalculator, calcR: ImpurityCalculator): Double | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Determine if this impurity measure is a test-statistic measure | ||
* @return True if this is a split quality measure based on a test statistic (i.e. returns a | ||
* p-value) or false otherwise. | ||
* @note Unless overridden this method returns false by default, for backward compatability | ||
*/ | ||
@Since("2.2.0") | ||
@DeveloperApi | ||
def isTestStatistic: Boolean | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding methods to a public trait is technically an API breaking change. This might be considered a Developer API even though it's not labeled that way. Still if we can avoid adding to the API here, it'd be better. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be customized or extended externally to spark? I'm wondering why it is public. |
||
} | ||
|
||
/** | ||
* :: DeveloperApi :: | ||
* Utility functions for Impurity measures | ||
*/ | ||
@Since("2.0.0") | ||
@DeveloperApi | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is no need for this object to be publicly exposed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. I don't recall any specific motivation to keep it private, but historically Spark seems to default things to "minimum visibility." The only method currently defined here is an implementation detail for hacking p-values into the existing 'gain' system, where larger is assumed to be better. |
||
object Impurity { | ||
/** | ||
* :: DeveloperApi :: | ||
* Convert a test-statistic p-value into a "larger-is-better" gain value. | ||
* @param pval The test statistic p-value | ||
* @return The negative logarithm of the p-value. Any p-values smaller than 10^-20 are clipped | ||
* to 10^-20 to prevent arithmetic errors | ||
*/ | ||
@Since("2.0.0") | ||
@DeveloperApi | ||
def pValToGain(pval: Double): Double = -math.log(math.max(1e-20, pval)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. private to spark? |
||
} | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kind of a design question here... right now the caller has to switch logic based on what's inside metadata. Can methods like metadata.minInfoGain just implement different logic when the impurity is a test statistic, and so on? push this down towards the impurity implementation? I wonder if
isTestStatistic
can go away with the right API, but I am not familiar with the details of what that requires.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main issue I recall was that all of the existing metrics assume some kind of "larger is better" gain, and p-values are "smaller is better." I'll take another pass over it and see if I can push that distinction down so it doesn't require exposing new methods.