-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use concept-erasure implementation of LEACE and SAL (#252)
* Use concept-erasure implementation of LEACE and SAL * fix parameter name in ccs * Fix test failures * Be picky about the concept-erasure version * Refactor to support concept-erasure v0.1 * Fix test failure --------- Co-authored-by: Walter Laurito <lauritowal@yahoo.com>
- Loading branch information
1 parent
6f975ff
commit a88c01a
Showing
17 changed files
with
208 additions
and
519 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,15 @@ | ||
from .ccs_reporter import CcsReporter, CcsReporterConfig | ||
from .ccs_reporter import CcsConfig, CcsReporter | ||
from .classifier import Classifier | ||
from .concept_eraser import ConceptEraser | ||
from .eigen_reporter import EigenReporter, EigenReporterConfig | ||
from .reporter import Reporter, ReporterConfig | ||
from .common import FitterConfig | ||
from .eigen_reporter import EigenFitter, EigenFitterConfig | ||
from .platt_scaling import PlattMixin | ||
|
||
__all__ = [ | ||
"CcsReporter", | ||
"CcsReporterConfig", | ||
"CcsConfig", | ||
"Classifier", | ||
"ConceptEraser", | ||
"EigenReporter", | ||
"EigenReporterConfig", | ||
"Reporter", | ||
"ReporterConfig", | ||
"EigenFitter", | ||
"EigenFitterConfig", | ||
"FitterConfig", | ||
"PlattMixin", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""An ELK reporter network.""" | ||
|
||
from dataclasses import dataclass | ||
|
||
from concept_erasure import LeaceEraser | ||
from simple_parsing.helpers import Serializable | ||
from torch import Tensor, nn | ||
|
||
from .platt_scaling import PlattMixin | ||
|
||
|
||
@dataclass | ||
class FitterConfig(Serializable, decode_into_subclasses=True): | ||
seed: int = 42 | ||
"""The random seed to use.""" | ||
|
||
|
||
@dataclass | ||
class Reporter(PlattMixin): | ||
weight: Tensor | ||
eraser: LeaceEraser | ||
|
||
def __post_init__(self): | ||
# Platt scaling parameters | ||
self.bias = nn.Parameter(self.weight.new_zeros(1)) | ||
self.scale = nn.Parameter(self.weight.new_ones(1)) | ||
|
||
def __call__(self, hiddens: Tensor) -> Tensor: | ||
"""Return the predicted log odds on input `x`.""" | ||
raw_scores = self.eraser(hiddens) @ self.weight.mT | ||
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.