Skip to content

Commit c61f92e

Browse files
Fixed experiment indexing
1 parent 38b431c commit c61f92e

File tree

2 files changed

+1138
-1115
lines changed

2 files changed

+1138
-1115
lines changed

attribution/experiment_logger.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,9 @@ def score_to_color(self, score, vmin=-1, vmax=1):
192192
def print_text_total_attribution(
193193
self, exp_id: Optional[int] = None, score_agg: Literal["mean", "sum", "last"] = "mean"
194194
):
195-
if exp_id == -1:
196-
exp_id = self.df_experiments["exp_id"].max()
195+
196+
if exp_id is not None and exp_id < 0:
197+
exp_id = self.df_experiments["exp_id"].max() + 1 + exp_id
197198

198199
token_attrs_df = (
199200
self.df_input_token_attribution.groupby(["exp_id", "attribution_strategy"])
@@ -241,8 +242,8 @@ def print_text_total_attribution(
241242
self.pretty_print(df)
242243

243244
def print_text_attribution_matrix(self, exp_id: int = -1):
244-
if exp_id == -1:
245-
exp_id = self.df_experiments["exp_id"].max()
245+
if exp_id is not None and exp_id < 0:
246+
exp_id = self.df_experiments["exp_id"].max() + 1 + exp_id
246247

247248
matrix = self.get_attribution_matrix(exp_id)
248249

@@ -285,8 +286,8 @@ def print_total_attribution(
285286
self, exp_id: Optional[int] = None, score_agg: Literal["mean", "last"] = "mean"
286287
):
287288
totals = []
288-
if exp_id == -1:
289-
exp_id = self.df_experiments["exp_id"].max()
289+
if exp_id is not None and exp_id < 0:
290+
exp_id = self.df_experiments["exp_id"].max() + 1 + exp_id
290291

291292
token_attrs_df = (
292293
self.df_input_token_attribution.groupby(["exp_id", "attribution_strategy"])
@@ -333,8 +334,8 @@ def print_attribution_matrix(
333334
show_debug_cols: bool = False,
334335
score_agg: Literal["mean", "last"] = "mean",
335336
):
336-
if exp_id == -1:
337-
exp_id = self.df_experiments["exp_id"].max()
337+
if exp_id is not None and exp_id < 0:
338+
exp_id = self.df_experiments["exp_id"].max() + 1 + exp_id
338339
matrix = self.get_attribution_matrix(
339340
exp_id, attribution_strategy, show_debug_cols, score_agg
340341
)
@@ -357,8 +358,8 @@ def get_attribution_matrix(
357358
show_debug_cols: bool = False,
358359
score_agg: Literal["mean", "sum", "last"] = "mean",
359360
):
360-
if exp_id == -1:
361-
exp_id = self.df_experiments["exp_id"].max()
361+
if exp_id is not None and exp_id < 0:
362+
exp_id = self.df_experiments["exp_id"].max() + 1 + exp_id
362363

363364
if attribution_strategy is None:
364365
strategies = self.df_token_attribution_matrix[

0 commit comments

Comments
 (0)