3
3
# Copyright (C) 2001-2021 NLTK Project
4
4
# Author: Edward Loper <edloper@gmail.com>
5
5
# Steven Bird <stevenbird1@gmail.com>
6
+ # Tom Aarsen <>
6
7
# URL: <https://www.nltk.org/>
7
8
# For license information, see LICENSE.TXT
8
9
@@ -201,6 +202,140 @@ def key(self):
201
202
202
203
return str
203
204
205
+ def recall (self , value ):
206
+ """Given a value in the confusion matrix, return the recall
207
+ that corresponds to this value. The recall is defined as:
208
+
209
+ - *r* = true positive / (true positive + false positive)
210
+
211
+ and can loosely be considered the ratio of how often ``value``
212
+ was predicted correctly relative to how often ``value`` was
213
+ the true result.
214
+
215
+ :param value: value used in the ConfusionMatrix
216
+ :return: the recall corresponding to ``value``.
217
+ :rtype: float
218
+ """
219
+ # Number of times `value` was correct, and also predicted
220
+ TP = self [value , value ]
221
+ # Number of times `value` was correct
222
+ TP_FN = sum (self [value , pred_value ] for pred_value in self ._values )
223
+ if TP_FN == 0 :
224
+ return 0.0
225
+ return TP / TP_FN
226
+
227
+ def precision (self , value ):
228
+ """Given a value in the confusion matrix, return the precision
229
+ that corresponds to this value. The precision is defined as:
230
+
231
+ - *p* = true positive / (true positive + false negative)
232
+
233
+ and can loosely be considered the ratio of how often ``value``
234
+ was predicted correctly relative to the number of predictions
235
+ for ``value``.
236
+
237
+ :param value: value used in the ConfusionMatrix
238
+ :return: the precision corresponding to ``value``.
239
+ :rtype: float
240
+ """
241
+ # Number of times `value` was correct, and also predicted
242
+ TP = self [value , value ]
243
+ # Number of times `value` was predicted
244
+ TP_FP = sum (self [real_value , value ] for real_value in self ._values )
245
+ if TP_FP == 0 :
246
+ return 0.0
247
+ return TP / TP_FP
248
+
249
+ def f_measure (self , value , alpha = 0.5 ):
250
+ """
251
+ Given a value used in the confusion matrix, return the f-measure
252
+ that corresponds to this value. The f-measure is the harmonic mean
253
+ of the ``precision`` and ``recall``, weighted by ``alpha``.
254
+ In particular, given the precision *p* and recall *r* defined by:
255
+
256
+ - *p* = true positive / (true positive + false negative)
257
+ - *r* = true positive / (true positive + false positive)
258
+
259
+ The f-measure is:
260
+
261
+ - *1/(alpha/p + (1-alpha)/r)*
262
+
263
+ With ``alpha = 0.5``, this reduces to:
264
+
265
+ - *2pr / (p + r)*
266
+
267
+ :param value: value used in the ConfusionMatrix
268
+ :param alpha: Ratio of the cost of false negative compared to false
269
+ positives. Defaults to 0.5, where the costs are equal.
270
+ :type alpha: float
271
+ :return: the F-measure corresponding to ``value``.
272
+ :rtype: float
273
+ """
274
+ p = self .precision (value )
275
+ r = self .recall (value )
276
+ if p == 0.0 or r == 0.0 :
277
+ return 0.0
278
+ return 1.0 / (alpha / p + (1 - alpha ) / r )
279
+
280
+ def evaluate (self , alpha = 0.5 , truncate = None , sort_by_count = False ):
281
+ """
282
+ Tabulate the **recall**, **precision** and **f-measure**
283
+ for each value in this confusion matrix.
284
+
285
+ >>> reference = "DET NN VB DET JJ NN NN IN DET NN".split()
286
+ >>> test = "DET VB VB DET NN NN NN IN DET NN".split()
287
+ >>> cm = ConfusionMatrix(reference, test)
288
+ >>> print(cm.evaluate())
289
+ Tag | Prec. | Recall | F-measure
290
+ ----+--------+--------+-----------
291
+ DET | 1.0000 | 1.0000 | 1.0000
292
+ IN | 1.0000 | 1.0000 | 1.0000
293
+ JJ | 0.0000 | 0.0000 | 0.0000
294
+ NN | 0.7500 | 0.7500 | 0.7500
295
+ VB | 0.5000 | 1.0000 | 0.6667
296
+ <BLANKLINE>
297
+
298
+ :param alpha: Ratio of the cost of false negative compared to false
299
+ positives, as used in the f-measure computation. Defaults to 0.5,
300
+ where the costs are equal.
301
+ :type alpha: float
302
+ :param truncate: If specified, then only show the specified
303
+ number of values. Any sorting (e.g., sort_by_count)
304
+ will be performed before truncation. Defaults to None
305
+ :type truncate: int, optional
306
+ :param sort_by_count: Whether to sort the outputs on frequency
307
+ in the reference label. Defaults to False.
308
+ :type sort_by_count: bool, optional
309
+ :return: A tabulated recall, precision and f-measure string
310
+ :rtype: str
311
+ """
312
+ tags = self ._values
313
+
314
+ # Apply keyword parameters
315
+ if sort_by_count :
316
+ tags = sorted (tags , key = lambda v : - sum (self ._confusion [self ._indices [v ]]))
317
+ if truncate :
318
+ tags = tags [:truncate ]
319
+
320
+ tag_column_len = max (max (len (tag ) for tag in tags ), 3 )
321
+
322
+ # Construct the header
323
+ s = (
324
+ f"{ ' ' * (tag_column_len - 3 )} Tag | Prec. | Recall | F-measure\n "
325
+ f"{ '-' * tag_column_len } -+--------+--------+-----------\n "
326
+ )
327
+
328
+ # Construct the body
329
+ for tag in tags :
330
+ s += (
331
+ f"{ tag :>{tag_column_len }} | "
332
+ f"{ self .precision (tag ):<6.4f} | "
333
+ f"{ self .recall (tag ):<6.4f} | "
334
+ f"{ self .f_measure (tag , alpha = alpha ):.4f} \n "
335
+ )
336
+
337
+ return s
338
+
204
339
205
340
def demo ():
206
341
reference = "DET NN VB DET JJ NN NN IN DET NN" .split ()
@@ -211,6 +346,8 @@ def demo():
211
346
print (ConfusionMatrix (reference , test ))
212
347
print (ConfusionMatrix (reference , test ).pretty_format (sort_by_count = True ))
213
348
349
+ print (ConfusionMatrix (reference , test ).recall ("VB" ))
350
+
214
351
215
352
if __name__ == "__main__" :
216
353
demo ()
0 commit comments