subroutine rda (p,n,nk,x,y,w,pi,fl,al,ga,ipr,sln,sp,dp) c c regularized discriminant analysis. c c Friedman J. H. (1989). Regularized discriminant analysis. J. Amer. Statist. c Assoc., 84, 165-175 (March). c c version of (7/4/88). c c coded and copywrite (c) 1988 by: c c Jerome H. Friedman c Department of Statistics c and c Stanford Linear Accelerator Center c Stanford University. c c c estimate optimal regularized model (regularization parameter values). c c input: c p = number of predictor variables (integer). c n = number of training observations. c nk = number of classes. c x(p,n) = predictor data matrix (covariates). c y(n) = class label for each observation (integer: .ge.1, .le.nk). c w(n) = weight (mass) for each observation. c pi(nk) = prior probability for each class (.gt.0.0, .lt.1.0, sum=1.0). c fl(nk,nk) = misclassification loss matrix. c fl(i,j) = loss for classifying a class i as a class j. c al(4) = search control for covariance mixing parameter (lambda). c al(1) = lower search limit (.ge.0.0, typical value = 0.0). c al(2) = upper search limit (.le.1.0, typical value = 1.0). c al(3) = number of search locations (including end points) c (.ge.1.0, typical value = 5.0). c al(4) = exponent for power transformation to determining c search scale (.gt.0.0, typical value = 1.0). c ga(4) = search control for covariance shrinkage parameter (gamma). c ga(1) = lower search limit (.ge.0.0, typical value = 0.0). c ga(2) = upper search limit (.le.1.0, typical value = 1.0). c ga(3) = number of search locations (including end points) c (.ge.1.0, typical value = 10.0). c ga(4) = exponent for power transformation determining c search scale (.gt.0.0, typical value = 1.0). c ipr = fortran unit number for printed output c (.le.0 => no printed output, typical value = 6). c c output: c sln(3) = solution estimates for optimal regularization values. c sln(1) = covariance mixing parameter (lambda). c sln(2) = covariance shrinkage parameter (gamma). c sln(3) = cross-validated estimate of misclassification risk c for the corresponding model. c c scratch workspace: c sp(nk*(p*(n+nk+1)+5)+3*p) : single precision. c dp(p*(p*(nk+2)+nk+1)) : double precision. c integer p,y(n) real x(p,n),w(n),pi(nk),fl(nk,nk),al(4),ga(4),sln(3) real sp(1) double precision dp(1) data big /1.e30/ i1=1 i2=i1+nk i3=i2+p*nk i4=i3+nk i5=i4+nk i6=i5+p*n*nk i7=i6+p*nk*nk i8=i7+p i9=i8+p i10=i9+p j1=1 j2=j1+p*p*nk j3=j2+p*p j4=j3+p*nk j5=j4+p*p if (ipr.gt.0) write (ipr,140) call altpi (nk,pi,fl,sp(i3)) call covar (p,n,nk,x,y,w,dp(j1),dp(j2),sp(i1),sw,sp(i2),ipr) al1=amax1(0.0,al(1)) al2=amin1(1.0,al(2)) al3=amax1(1.0,al(3)) ga1=amax1(0.0,ga(1)) ga2=amin1(1.0,ga(2)) ga3=amax1(1.0,ga(3)) nlam=al3+0.0001 ngam=ga3+0.0001 if (ipr.gt.0) write (ipr,120) if (nlam.le.1) go to 10 dlam=(al2-al1)/(al3-1.0) go to 20 10 dlam=0.0 20 if (ngam.le.1) go to 30 dga=(ga2-ga1)/(ga3-1.0) go to 40 30 dga=0.0 40 sln(3)=big al4=amax1(0.0001,al(4)) ga4=amax1(0.0001,ga(4)) do 110 lam=1,nlam alamt=dlam*(lam-1)+al1 alam=alamt**(al4) if (lam.ne.1) go to 50 alt=0.0 dlt=alam go to 60 50 dlt=alam-alt 60 call vinc (p,nk,dp(j1),dp(j2),sp(i4),alt,dlt) alt=alam call eigen (p,n,nk,x,sp(i2),dp(j1),dp(j3),sp(i5),sp(i6),dp(j4),dp( 1j5),ierr) if (ierr.ne.0) go to 110 do 100 mag=1,ngam gamt=dga*(mag-1)+ga1 gam=gamt**(ga4) if (mag.ne.1) go to 70 agt=0.0 dgt=gam go to 80 70 dgt=gam-agt 80 call dinc (p,nk,dp(j3),sp(i4),agt,dgt) agt=gam rsb=resub(p,n,nk,alam,pi,fl,sp(i3),y,w,dp(j3),sp(i6),sp(i5),sp(i1) 1,sw,sp(i10)) r=risk(p,n,nk,alam,gam,pi,fl,sp(i3),y,w,dp(j3),sp(i6),sp(i5),sp(i1 1),sw,sp(i7),sp(i8),sp(i9),sdr) if (r.gt.sln(3)) go to 90 sln(3)=r sln(1)=alam sln(2)=gam 90 if (ipr.gt.0) write (ipr,130) alam,gam,rsb,r,sdr 100 continue 110 continue return 120 format(' lambda gamma resub cross-validate 1d') 130 format(' '4g12.4,' +/-'g12.4) 140 format(' regularized discriminant analysis (7/4/88).') end subroutine rule (p,n,nk,x,y,w,pi,fl,alm,gam,sp,dp,ierr) c c constructs classification rule for a specific pair of c regularization parameter values (lambda, gamma). c c input: c p, n, nk, x, y, w, pi, fl = same as in subroutine rda (above). c alm = covariance matrix mixing parameter value (lambda). c gam = covariance matrix shrinkage parameter value (gamma). c c output: c sp, dp : same dimension as in subroutine rda (above). to be used c as input to subroutine clsfy (see below). c ierr = covariance matrix dialgonaization error flag c (.eq.0 => no error; .ne.0 => error, classification rule not valid). c integer p,y(n) real x(p,n),w(n),pi(nk),fl(nk,nk) real sp(1) double precision dp(1) i1=1 i2=i1+nk i3=i2+p*nk i4=i3+nk i5=i4+nk i6=i5+p*n*nk i7=i6+p*nk*nk i8=i7+p i9=i8+p j1=1 j2=j1+p*p*nk j3=j2+p*p j4=j3+p*nk j5=j4+p*p call altpi (nk,pi,fl,sp(i3)) call covar (p,n,nk,x,y,w,dp(j1),dp(j2),sp(i1),sw,sp(i2),0) call vinc (p,nk,dp(j1),dp(j2),sp(i4),0.0,alm) call gave (p,nk,dp(j1),sp(i4),gam) call fnorm (p,nk,alm,dp(j1),sp(i1),sw) call inv (p,nk,dp(j1),dp(j3),sp(i3),dp(j5),ierr) return end subroutine clsfy (p,n,nk,x,y,sp,dp) c c classifies (test) observations according to the classification c rule constructed by subroutine rule (see above). c c input: c p, nk = same as in subroutine rda above. c n = number of observations to be classified. c x(p,n) = predictor variables (covariates) for each (test) observation. c sp, dp = output from subroutine rule (see above). c c output: c y(n) = class identification estimate for each observation (integer). c integer p,y(n) real x(p,n) real sp(1) double precision dp(1) data big /1.e30/ i1=1 i2=i1+nk i3=i2+p*nk i4=i3+nk i5=i4+nk i6=i5+p*n*nk i7=i6+p*nk*nk i8=i7+p i9=i8+p j1=1 j2=j1+p*p*nk j3=j2+p*p j4=j3+p*nk j5=j4+p*p do 20 l=1,n dm=big do 10 k=1,nk dst=dist(p,nk,k,x(1,l),dp(j1),sp(i2),sp(i3),dp(j3),dp(j5)) if (dst.ge.dm) go to 10 dm=dst y(l)=k 10 continue 20 continue return end subroutine altpi (nk,pi,fl,pit) real pi(nk),fl(nk,nk),pit(nk) do 20 k=1,nk a=0.0 do 10 i=1,nk a=a+fl(k,i) 10 continue pit(k)=a*pi(k) 20 continue a=0.0 do 30 k=1,nk a=a+pit(k) 30 continue a=1.0/a do 40 k=1,nk pit(k)=-2.0*alog(a*pit(k)) 40 continue return end subroutine covar (p,n,nk,x,y,w,v,vb,s,sw,xb,ipr) integer p,y(n) real x(p,n), w(n), s(nk),xb(p,nk) double precision v(p,p,nk),vb(p,p) do 30 k=1,nk s(k)=0.0 do 20 i=1,p xb(i,k)=0.0 do 10 j=1,i v(i,j,k)=0.d0 10 continue 20 continue 30 continue do 50 l=1,n k=y(l) s(k)=s(k)+w(l) do 40 i=1,p xb(i,k)=xb(i,k)+w(l)*x(i,l) 40 continue 50 continue ifl=0 do 80 k=1,nk if (s(k).gt.0.0) go to 60 ifl=1 if (ipr.gt.0) write (ipr,160) k go to 80 60 do 70 i=1,p xb(i,k)=xb(i,k)/s(k) 70 continue 80 continue if (ifl.eq.1) stop do 110 l=1,n k=y(l) do 100 i=1,p do 90 j=1,i v(i,j,k)=v(i,j,k)+w(l)*(x(i,l)-xb(i,k))*(x(j,l)-xb(j,k)) 90 continue 100 continue 110 continue sw=0.0 do 120 k=1,nk sw=sw+s(k) 120 continue do 150 i=1,p do 140 j=1,i vb(i,j)=0.d0 do 130 k=1,nk vb(i,j)=vb(i,j)+v(i,j,k) 130 continue 140 continue 150 continue return 160 format(' there are no class('i3,') observations.') end subroutine vinc (p,nk,v,vb,db,al,dl) integer p real db(nk) double precision v(p,p,nk),vb(p,p) data eps /1.e-4/ if (abs(al+dl-1.0).ge.eps) go to 40 do 30 k=1,nk do 20 i=1,p do 10 j=1,i v(i,j,k)=vb(i,j) 10 continue 20 continue 30 continue go to 80 40 a=1.0-al b=a-dl a=1.0/a do 70 k=1,nk do 60 i=1,p do 50 j=1,i v(i,j,k)=(b*v(i,j,k)+dl*vb(i,j))*a 50 continue 60 continue 70 continue 80 do 100 k=1,nk a=0.0 do 90 i=1,p a=a+v(i,i,k) 90 continue db(k)=a/p 100 continue return end subroutine eigen (p,n,nk,x,xb,v,d,z,xm,vt,sc,ierr) integer p real x(p,n),z(p,n,nk),xb(p,nk),xm(p,nk,nk) double precision v(p,p,nk),d(p,nk),vt(p,p),sc(p),s,machep,tol machep=0.5d0**(52) tol=1.d-60 do 70 k=1,nk call tred2 (p,p,p,tol,v(1,1,k),d(1,k),sc,vt) call imtql2 (p,p,p,machep,d(1,k),sc,vt,ierr) if (ierr.ne.0) return do 30 l=1,n do 20 j=1,p s=0.d0 do 10 i=1,p s=s+vt(i,j)*x(i,l) 10 continue z(j,l,k)=s 20 continue 30 continue do 60 m=1,nk do 50 j=1,p s=0.d0 do 40 i=1,p s=s+vt(i,j)*xb(i,m) 40 continue xm(j,m,k)=s 50 continue 60 continue 70 continue return end subroutine dinc (p,nk,d,db,ga,dg) integer p real db(nk) double precision d(p,nk) data eps /1.e-4/ if (abs(ga+dg-1.0).ge.eps) go to 30 do 20 k=1,nk do 10 i=1,p d(i,k)=db(k) 10 continue 20 continue go to 60 30 a=1.0-ga b=a-dg a=1.0/a do 50 k=1,nk do 40 i=1,p d(i,k)=(b*d(i,k)+dg*db(k))*a 40 continue 50 continue 60 return end function resub (p,n,nk,al,pi,fl,pit,y,w,d,xm,z,s,sw,slk) integer p,y(n) real pi(nk),fl(nk,nk),pit(nk),w(n) real xm(p,nk,nk),z(p,n,nk),s(nk),slk(nk,2) double precision d(p,nk),t,rsk,fls,wl,wls data big,sml /1.e30,1.e-15/ do 30 k=1,nk if (d(p,k).ge.sml) go to 10 resub=big return 10 slk(k,1)=(1.0-al)*s(k)+al*sw t=pit(k) do 20 i=1,p t=t+dlog(d(i,k)) 20 continue t=t-p*alog(slk(k,1)) slk(k,2)=t 30 continue rsk=0.d0 wls=rsk do 60 l=1,n dm=big do 50 k=1,nk t=0.d0 do 40 i=1,p t=t+(z(i,l,k)-xm(i,k,k))**2/d(i,k) 40 continue t=slk(k,1)*t+slk(k,2) if (t.ge.dm) go to 50 dm=t ky=k 50 continue j=y(l) fls=fl(j,ky) wl=w(l)*pi(j)/s(j) rsk=rsk+wl*fls wls=wls+wl 60 continue resub=rsk/wls return end function risk (p,n,nk,al,ga,pi,fl,pit,y,w,d,xm,z,s,sw,et,zt,ut,sdr 1) integer p,y(n) real pi(nk),fl(nk,nk),pit(nk),w(n),xm(p,nk,nk),z(p,n,nk),s(nk) real zt(p),ut(p),et(p) double precision d(p,nk),t,tol,dst,rsk,rsq,fls,wl,wls,wlq data tol,big /1.d-5,1.e30/ gs=sqrt(1.0-ga) omal=1.0-al rsk=0.d0 rsq=rsk wls=rsq wlq=wls do 110 l=1,n dm=big j=y(l) bb=s(j)*w(l)/(s(j)-w(l)) do 100 k=1,nk b=bb if (j.ne.k) b=b*al b=sqrt(b) dk=0.0 do 10 i=1,p c=b*(z(i,l,k)-xm(i,j,k)) zt(i)=gs*c dk=dk+c**2 10 continue dk=ga*dk/p if (j.eq.k) go to 20 a=1.0 b=1.0 go to 30 20 a=1.0+w(l)/(s(k)-w(l)) b=s(k)/(s(k)-w(l)) 30 do 40 i=1,p ut(i)=a*z(i,l,k)-b*xm(i,k,k) et(i)=d(i,k)-dk if (et(i).gt.0.0) go to 40 risk=big sdr=risk return 40 continue t=1.d0 do 50 i=1,p t=t-zt(i)**2/et(i) 50 continue if (t.ge.tol) go to 60 risk=big sdr=risk return 60 dst=0.d0 do 70 i=1,p dst=dst+zt(i)*ut(i)/et(i) 70 continue dst=dst**2/t do 80 i=1,p dst=dst+ut(i)**2/et(i) 80 continue a=dlog(t) do 90 i=1,p a=a+alog(et(i)) 90 continue sk=w(l) if (j.ne.k) sk=sk*al sk=omal*s(k)+al*sw-sk dst=sk*dst+a-p*alog(sk)+pit(k) if (dst.ge.dm) go to 100 dm=dst ky=k 100 continue fls=fl(j,ky) wl=w(l)*pi(j)/s(j) rsk=rsk+wl*fls rsq=rsq+wl*fls**2 wls=wls+wl wlq=wlq+wl**2 110 continue rsk=rsk/wls rsq=rsq/wls-rsk**2 risk=rsk sdr=dsqrt(wlq*rsq)/wls return end subroutine gave (p,nk,v,db,ga) integer p real db(nk) double precision v(p,p,nk) ga1=1.0-ga do 30 k=1,nk do 20 i=1,p do 10 j=1,i v(i,j,k)=ga1*v(i,j,k) 10 continue v(i,i,k)=v(i,i,k)+ga*db(k) 20 continue 30 continue return end subroutine fnorm (p,nk,al,v,s,sw) integer p real s(nk) double precision v(p,p,nk),ski do 30 k=1,nk ski=1.0/((1.0-al)*s(k)+al*sw) do 20 i=1,p do 10 j=1,i v(i,j,k)=v(i,j,k)*ski 10 continue 20 continue 30 continue return end subroutine inv (p,nk,v,d,dlt,sc,ierr) integer p real dlt(nk) double precision v(p,p,nk),d(p,nk),sc(p),s,machep,tol,eps data eps /1.d-15/ machep=0.5d0**(52) tol=1.d-60 do 20 k=1,nk call tred2 (p,p,p,tol,v(1,1,k),d(1,k),sc,v(1,1,k)) call imtql2 (p,p,p,machep,d(1,k),sc,v(1,1,k),ierr) if (ierr.ne.0) return s=0.d0 do 10 i=1,p if (d(i,k).lt.eps) d(i,k)=eps s=s+dlog(d(i,k)) d(i,k)=1.d0/d(i,k) 10 continue dlt(k)=dlt(k)+s 20 continue return end function dist (p,nk,k,x,v,xb,pit,d,dm) integer p real x(p),xb(p,nk),pit(nk) double precision v(p,p,nk),d(p,nk),dm(p),s,t do 10 i=1,p dm(i)=x(i)-xb(i,k) 10 continue t=0.d0 do 30 j=1,p s=0.d0 do 20 i=1,p s=s+dm(i)*v(i,j,k) 20 continue t=t+d(j,k)*s**2 30 continue dist=t+pit(k) return end subroutine imtql2 (nt,nm,n,machep,d,e,z,error) implicit real*8 (a-h,o-z) real*8 machep,d(n),e(n),z(nt,nt) integer error error=0 if (n.eq.1) go to 140 do 10 i=2,n 10 e(i-1)=e(i) e(n)=0.0d0 do 90 l=1,n j=0 20 do 30 m=l,n if (m.eq.n) go to 40 if (dabs(e(m)).le.machep*(dabs(d(m))+dabs(d(m+1)))) go to 40 30 continue 40 p=d(l) if (m.eq.l) go to 90 if (j.eq.30) go to 130 j=j+1 g=(d(l+1)-p)/(2.0d0*e(l)) r=dsqrt(1.0d0+g*g) g=d(m)-p+e(l)/(g+dsign(r,g)) s=1.0d0 c=1.0d0 p=d(m) mml=m-l do 80 ii=1,mml i=m-ii f=s*e(i) b=c*e(i) if (dabs(f).lt.dabs(g)) go to 50 c=g/f r=dsqrt(c*c+1.0d0) e(i+1)=f*r s=1.0d0/r c=c/r go to 60 50 c=f/g r=dsqrt(c*c+1.0d0) e(i+1)=g*r s=c/r c=1.0d0/r 60 f=c*d(i)-s*b g=c*b-s*p r=d(i)+p p=c*f-s*g g=s*f+c*g d(i+1)=r-p do 70 ia=1,n f=z(ia,i+1) z(ia,i+1)=s*z(ia,i)+c*f z(ia,i)=c*z(ia,i)-s*f 70 continue 80 continue d(l)=p e(l)=g e(m)=0.0d0 go to 20 90 continue nm1=n-1 do 120 i=1,nm1 k=i p=d(i) ip1=i+1 do 100 j=ip1,n if (d(j).le.p) go to 100 k=j p=d(j) 100 continue if (k.eq.i) go to 120 d(k)=d(i) d(i)=p do 110 j=1,n p=z(j,i) z(j,i)=z(j,k) z(j,k)=p 110 continue 120 continue go to 140 130 error=l 140 return end subroutine tred2 (nt,nm,n,tol,a,d,e,z) implicit real*8 (a-h,o-z) real*8 a(nt,nt),d(n),e(n),z(nt,nt) do 10 i=1,n do 10 j=1,i z(i,j)=a(i,j) 10 continue if (n.eq.1) go to 120 do 110 ii=2,n i=n+2-ii l=i-2 f=z(i,i-1) g=0.0d0 if (l.lt.1) go to 30 do 20 k=1,l 20 g=g+z(i,k)*z(i,k) 30 h=g+f*f if (g.gt.tol) go to 40 e(i)=f h=0.0d0 go to 100 40 l=l+1 g=-dsign(dsqrt(h),f) e(i)=g h=h-f*g z(i,i-1)=f-g f=0.0d0 do 80 j=1,l z(j,i)=z(i,j)/h g=0.0d0 do 50 k=1,j 50 g=g+z(j,k)*z(i,k) jp1=j+1 if (l.lt.jp1) go to 70 do 60 k=jp1,l 60 g=g+z(k,j)*z(i,k) 70 e(j)=g/h f=f+g*z(j,i) 80 continue hh=f/(h+h) do 90 j=1,l f=z(i,j) g=e(j)-hh*f e(j)=g do 90 k=1,j z(j,k)=z(j,k)-f*e(k)-g*z(i,k) 90 continue 100 d(i)=h 110 continue 120 d(1)=0.0d0 e(1)=0.0d0 do 170 i=1,n l=i-1 if (d(i).eq.0.0d0) go to 150 do 140 j=1,l g=0.0d0 do 130 k=1,l 130 g=g+z(i,k)*z(k,j) do 140 k=1,l z(k,j)=z(k,j)-g*z(k,i) 140 continue 150 d(i)=z(i,i) z(i,i)=1.0d0 if (l.lt.1) go to 170 do 160 j=1,l z(i,j)=0.0d0 z(j,i)=0.0d0 160 continue 170 continue return end