From 7d1c6ad4330b9fda4626bd00246634926b3a77e5 Mon Sep 17 00:00:00 2001 From: jalvesz Date: Thu, 15 Aug 2024 11:24:44 +0200 Subject: [PATCH] softmax for ranks from 1 to 4 --- src/stdlib_math_activations.fypp | 119 +++++++++++++++++++++++++++++-- 1 file changed, 112 insertions(+), 7 deletions(-) diff --git a/src/stdlib_math_activations.fypp b/src/stdlib_math_activations.fypp index 4a2dcf70b..3085c4699 100644 --- a/src/stdlib_math_activations.fypp +++ b/src/stdlib_math_activations.fypp @@ -104,14 +104,20 @@ module stdlib_math_activations interface Softmax #:for rk, rt in REAL_KINDS_TYPES - module procedure :: softmax_${rk}$ + module procedure :: Softmax_r1_${rk}$ + module procedure :: Softmax_r2_${rk}$ + module procedure :: Softmax_r3_${rk}$ + module procedure :: Softmax_r4_${rk}$ #:endfor end interface public :: softmax interface Softmax_grad #:for rk, rt in REAL_KINDS_TYPES - module procedure :: Softmax_grad_${rk}$ + module procedure :: Softmax_grad_r1_${rk}$ + module procedure :: Softmax_grad_r2_${rk}$ + module procedure :: Softmax_grad_r3_${rk}$ + module procedure :: Softmax_grad_r4_${rk}$ #:endfor end interface public :: Softmax_grad @@ -315,19 +321,118 @@ end function ! Softmax !================================================== #:for rk, rt in REAL_KINDS_TYPES -pure function Softmax_${rk}$( x ) result( y ) +pure function Softmax_r1_${rk}$( x ) result( y ) ${rt}$, intent(in) :: x(:) ${rt}$ :: y(size(x)) - y(:) = exp(x(:) - maxval(x(:)) ) - y(:) = y(:) / sum(y(:)) + y = exp(x - maxval(x)) + y = y / sum(y) end function -pure function Softmax_grad_${rk}$( x ) result( y ) +pure function Softmax_r2_${rk}$( x , dim ) result( y ) + ${rt}$, intent(in) :: x(:,:) + ${rt}$ :: y(size(x,dim=1),size(x,dim=2)) + + integer, intent(in), optional :: dim + integer :: dim_, j + + dim_ = 1; if(present(dim)) dim_ = dim + + if(dim_==1)then + do j = 1, size(x,dim=2) + y(:,j) = Softmax( x(:,j) ) + end do + else + do j = 1, size(x,dim=1) + y(j,:) = Softmax( x(j,:) ) + end do + end if +end function + +pure function Softmax_r3_${rk}$( x , dim ) result( y ) + ${rt}$, intent(in) :: x(:,:,:) + ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3)) + + integer, intent(in), optional :: dim + integer :: dim_, j + + dim_ = 1; if(present(dim)) dim_ = dim + + if(dim_<=2)then + do j = 1, size(x,dim=3) + y(:,:,j) = Softmax( x(:,:,j) , dim = dim_ ) + end do + else + do j = 1, size(x,dim=1) + y(j,:,:) = Softmax( x(j,:,:) , dim = 2 ) + end do + end if +end function + +pure function Softmax_r4_${rk}$( x , dim ) result( y ) + ${rt}$, intent(in) :: x(:,:,:,:) + ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4)) + + integer, intent(in), optional :: dim + integer :: dim_, j + + dim_ = 1; if(present(dim)) dim_ = dim + + if(dim_<=3)then + do j = 1, size(x,dim=4) + y(:,:,:,j) = Softmax( x(:,:,:,j) , dim = dim_ ) + end do + else + do j = 1, size(x,dim=1) + y(j,:,:,:) = Softmax( x(j,:,:,:) , dim = 3 ) + end do + end if +end function + +pure function Softmax_grad_r1_${rk}$( x ) result( y ) ${rt}$, intent(in) :: x(:) ${rt}$ :: y(size(x)) - y = softmax_${rk}$(x) + y = Softmax(x) + y = y * (1_${rk}$ - y) +end function + +pure function Softmax_grad_r2_${rk}$( x , dim ) result( y ) + ${rt}$, intent(in) :: x(:,:) + ${rt}$ :: y(size(x,dim=1),size(x,dim=2)) + + integer, intent(in), optional :: dim + integer :: dim_ + + dim_ = 1; if(present(dim)) dim_ = dim + + y = Softmax(x,dim_) + y = y * (1_${rk}$ - y) +end function + +pure function Softmax_grad_r3_${rk}$( x , dim ) result( y ) + ${rt}$, intent(in) :: x(:,:,:) + ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3)) + + integer, intent(in), optional :: dim + integer :: dim_ + + dim_ = 1; if(present(dim)) dim_ = dim + + y = Softmax(x,dim_) + y = y * (1_${rk}$ - y) +end function + +pure function Softmax_grad_r4_${rk}$( x , dim ) result( y ) + ${rt}$, intent(in) :: x(:,:,:) + ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4)) + + integer, intent(in), optional :: dim + integer :: dim_ + + dim_ = 1; if(present(dim)) dim_ = dim + + y = Softmax(x,dim_) y = y * (1_${rk}$ - y) end function