@@ -16,7 +16,7 @@ import {scalar, Tensor, tensor, tensor1d, tensor2d} from '@tensorflow/tfjs-core'
1616
1717import { setEpsilon } from './backend/common' ;
1818import * as tfl from './index' ;
19- import { binaryAccuracy , categoricalAccuracy , get , getLossOrMetricName } from './metrics' ;
19+ import { binaryAccuracy , categoricalAccuracy , get , getLossOrMetricName , r2Score } from './metrics' ;
2020import { LossOrMetricFn } from './types' ;
2121import { describeMathCPUAndGPU , describeMathCPUAndWebGL2 , expectTensorsClose } from './utils/test_utils' ;
2222
@@ -283,6 +283,27 @@ describeMathCPUAndGPU('recall metric', () => {
283283 } ) ;
284284} ) ;
285285
286+ describeMathCPUAndGPU ( 'r2Score' , ( ) => {
287+ it ( '1D' , ( ) => {
288+ const yTrue = tensor1d ( [ 3 , - 0.5 , 2 , 7 , 4.2 , 8.5 , 1.3 , 2.8 , 6.7 , 9.0 ] ) ;
289+ const yPred = tensor1d ( [ 2.5 , 0.0 , 2.1 , 7.8 , 4.0 , 8.2 , 1.4 , 2.9 , 6.5 , 9.1 ] ) ;
290+ const score = r2Score ( yTrue , yPred ) ;
291+ expectTensorsClose ( score , scalar ( 0.985 ) ) ;
292+ } ) ;
293+ it ( '2D' , ( ) => {
294+ const yTrue = tensor2d ( [
295+ [ 3 , 2.5 ] , [ - 0.5 , 3.2 ] , [ 2 , 1.9 ] , [ 7 , 5.1 ] , [ 4.2 , 3.8 ] , [ 8.5 , 7.4 ] ,
296+ [ 1.3 , 0.6 ] , [ 2.8 , 2.1 ] , [ 6.7 , 5.3 ] , [ 9.0 , 8.7 ]
297+ ] ) ;
298+ const yPred = tensor2d ( [
299+ [ 2.7 , 2.3 ] , [ 0.0 , 3.1 ] , [ 2.1 , 1.8 ] , [ 6.8 , 5.0 ] , [ 4.1 , 3.7 ] , [ 8.4 , 7.2 ] ,
300+ [ 1.4 , 0.7 ] , [ 2.9 , 2.2 ] , [ 6.6 , 5.2 ] , [ 9.2 , 8.9 ]
301+ ] ) ;
302+ const score = r2Score ( yTrue , yPred ) ;
303+ expectTensorsClose ( score , scalar ( 0.995 ) ) ;
304+ } ) ;
305+ } ) ;
306+
286307describe ( 'metrics.get' , ( ) => {
287308 it ( 'valid name, not alias' , ( ) => {
288309 expect ( get ( 'binaryAccuracy' ) === get ( 'categoricalAccuracy' ) ) . toEqual ( false ) ;
0 commit comments