Skip to content

Commit

Permalink
Version 2.0.0 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Jun 26, 2023
1 parent 877be8e commit 248f927
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 8 deletions.
1 change: 1 addition & 0 deletions cca_zoo/classical/_iterative/_altmaxvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
convergence_checking=convergence_checking,
track=track,
verbose=verbose,
trainer_kwargs={"accelerator": "cpu"}
)
self.tau = tau
self.proximal = proximal
Expand Down
1 change: 1 addition & 0 deletions cca_zoo/classical/_iterative/_elasticnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
patience=0,
track=track,
verbose=verbose,
trainer_kwargs={"accelerator": "cpu"}
)

def _check_params(self):
Expand Down
3 changes: 1 addition & 2 deletions cca_zoo/classical/_iterative/_pls_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(
learning_rate=1,
initialization: Union[str, callable] = "random",
callbacks=None,
trainer_kwargs=None,
):
super().__init__(
latent_dimensions,
Expand All @@ -75,7 +74,7 @@ def __init__(
learning_rate=learning_rate,
initialization=initialization,
callbacks=callbacks,
trainer_kwargs=trainer_kwargs,
trainer_kwargs={"accelerator": "cpu"}
)

def _get_module(self, weights=None, k=None):
Expand Down
9 changes: 5 additions & 4 deletions cca_zoo/classical/_iterative/_pmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
patience=0,
track=track,
verbose=verbose,
trainer_kwargs={"accelerator": "cpu"}
)
self.tau = tau
self.positive = positive
Expand All @@ -113,7 +114,7 @@ def _get_module(self, weights=None, k=None):
k=k,
tau=self.tau,
tol=self.tol,
track=self.track,
tracking=self.track,
convergence_checking=self.convergence_checking,
)

Expand All @@ -128,13 +129,13 @@ def __init__(
k=None,
tau=None,
tol=1e-3,
track=False,
tracking=False,
convergence_checking=False,
):
super().__init__(
weights=weights,
k=k,
tracking=track,
tracking=tracking,
convergence_checking=convergence_checking,
)
self.tau = tau
Expand Down Expand Up @@ -163,7 +164,7 @@ def training_step(self, batch, batch_idx):
f"All result weights are zero in view {view_index}. "
"Try less regularisation or another initialisation"
)
# if track or convergence_checking is enabled, compute the objective function
# if tracking or convergence_checking is enabled, compute the objective function
if self.tracking or self.convergence_checking:
objective = self.objective(batch["views"])
# check that the maximum change in weights is smaller than the tolerance times the maximum absolute value of the weights
Expand Down
1 change: 1 addition & 0 deletions cca_zoo/classical/_iterative/_scca_parkhomenko.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
patience=patience,
track=track,
verbose=verbose,
trainer_kwargs={"accelerator": "cpu"}
)

def _check_params(self):
Expand Down
1 change: 1 addition & 0 deletions cca_zoo/classical/_iterative/_scca_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
random_state=random_state,
deflation=deflation,
verbose=verbose,
trainer_kwargs={"accelerator": "cpu"}
)
self.tau = tau
self.regularisation = regularisation
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ numpy
scipy
scikit-learn
scikit-prox
pytest
matplotlib
pandas
seaborn
tensorly
joblib
mvlearn
tqdm
setuptools
pytorch
pytorch-lightning

0 comments on commit 248f927

Please sign in to comment.