# Copyright (C) 1997-2000  Adrian Trapletti
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Library General Public
# License as published by the Free Software Foundation; either
# version 2 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Library General Public License for more details.
#
# You should have received a copy of the GNU Library General Public
# License along with this library; if not, write to the Free
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#
# model selection tests for ffnet 
#


ffnet.terasvirta.test <- function (obj, ...) { UseMethod("ffnet.terasvirta.test") }

ffnet.terasvirta.test.default <- function (x, y, nhid = 0, alpha = 0.9, scale = TRUE,
                                           errfunc, outtype, init = c("null", "train"),
                                           trace = c("PRINT","NO","PLOT"), ...) 
{
  det <- function (x)
    prod(diag(qr(x)$qr))*(-1)^(NCOL(x)-1)

  cov.mle <- function (x)
    t(x)%*%x/NROW(x)
  
  DNAME <- paste(deparse(substitute(x)), "and", deparse(substitute(y)))
  if (any(is.na(x))) stop("NAs in x")
  if (any(is.na(y))) stop("NAs in y")
  if ((alpha <= 0) || (alpha >= 1)) stop("alpha is not in (0,1)")
  if (nhid < 0) stop ("nhid must be >= 0")
  if (!require(mva, quietly = TRUE)) stop("Package mva is needed. Stopping")
  trace <- match.arg(trace)
  init <- match.arg(init)
  x <- as.matrix(x)
  y <- as.matrix(y)
  if (nrow(x) != nrow(y)) stop("number of rows of x and y must match")
  if (nrow(x) <= 0) stop("no observations in x and y")
  if (ncol(x) < 1) stop ("invalid x")
  t <- nrow(x)
  k <- ncol(x)
  if (missing(errfunc))
  {
    if (ncol(y) > 1) errfunc <- "GSSE"
    else errfunc <- "SSE"
  }
  if (scale)
  {
    x <- scale(x)
    y <- scale(y)
  }
  m <- k
  ga <- kronecker(1:m,rep(1,m))
  gb <-  kronecker(rep(1,m),1:m)
  guniq <- (ga <= gb)
  xx <- x[,ga[guniq]]*x[,gb[guniq]] 
  ga <- kronecker(1:m,rep(1,m*m))
  gb <- rep(kronecker(1:m,rep(1,m)),m)
  gc <- kronecker(rep(1,m*m),1:m)
  guniq <- ((ga <= gb) & (gb <= gc))
  xx <- cbind(xx,x[,ga[guniq]]*x[,gb[guniq]]*x[,gc[guniq]])
  prc <- prcomp(xx, scale = TRUE)
  q <- min(which(cumsum(prc$sdev) > alpha*sum(prc$sdev)))
  xx <- prc$x[,1:q]
  rr0 <- ffnet(y~x, nhid=nhid, errfunc=errfunc, outtype=outtype, trace=trace, ...)  
  rr0 <- ffnet(y~x, nhid=nhid, errfunc=errfunc, outtype=outtype, trace=trace, wts=rr0$wts)
  nw <- (rr0$nin+1+q)*nhid+(nhid+1)*rr0$nout+(rr0$nin+q)*rr0$nout
  wts <- rnorm(nw)
  if (nhid > 0)
    fwts <- rep(seq(rr0$nin+1,l=nhid,by=rr0$nin+q+1),rep(q,nhid))+rep(seq(1,q),nhid)
  else
    fwts <- numeric(0)
  wts[fwts] <- 0
  if (init == "null")
  {
    fwts2 <- rep(seq((rr0$nin+q+1)*nhid+(nhid+1)*rr0$nout+rr0$nin,l=rr0$nout,by=rr0$nin+q),
                 rep(q,rr0$nout))+rep(seq(1,q),rr0$nout)
    wts[-c(fwts,fwts2)] <- rr0$wts
    rr1 <- ffnet(y~x+xx, nhid=nhid, errfunc=errfunc, outtype=outtype, 
                 trace=trace, wts=wts, fwts=fwts)  
  }
  else if (init == "train")
  {
    rr1 <- ffnet(y~x+xx, nhid=nhid, errfunc=errfunc, outtype=outtype, 
                 trace=trace, wts=wts, fwts=fwts, ...)  
    rr1 <- ffnet(y~x+xx, nhid=nhid, errfunc=errfunc, outtype=outtype, 
                 trace=trace, wts=rr1$wts, fwts=fwts)
  }
  if ((errfunc == "SSE") || (errfunc == "GSSE"))
    STAT <- t*log(det(cov.mle(residuals(rr0)))/det(cov.mle(residuals(rr1))))
  else if (errfunc == "MAD")
    STAT <- t*log(det(cov.mle(residuals(rr0)))/det(cov.mle(residuals(rr1))))
  else if (errfunc == "ENTROPY")
    STAT <- t*log(det(cov.mle(residuals(rr0)))/det(cov.mle(residuals(rr1))))
  else stop ("invalid errfunc")
  PVAL <- 1-pchisq (STAT, q*k)
  PARAMETER <- c(q*k, nhid)
  names(STAT) <- "X-squared"
  names(PARAMETER) <- c("df", "nhid")
  METHOD <- "ffnet Teraesvirta Regression Test"
  ARG <- c(alpha, scale, rr0$value, rr1$value)
  names(ARG) <- c("alpha", "scale", "null fitted value", "alternative fitted value")
  structure(list(statistic = STAT, parameter = PARAMETER, p.value = PVAL, 
                 method = METHOD, data.name = DNAME, arguments = ARG), 
            class = "htest")
}

ffnet.terasvirta.test.ts <- function (y, x, lag = 1, lagx = 1, nhid = 0, alpha = 0.9,
                                      scale = TRUE, errfunc, init = c("null", "train"),
                                      trace = c("PRINT","NO","PLOT"), ...) 
{
  det <- function (x)
    prod(diag(qr(x)$qr))*(-1)^(NCOL(x)-1)
  
  cov.mle <- function (x)
    t(x)%*%x/NROW(x)
  
  if (!inherits(y, "ts")) stop ("method is only for tseries objects")
  if (any(is.na(y))) stop("NAs in y")
  mx <- missing(x)
  if (!mx)
  {
    if (!is.ts(x)) stop ("x is not a ts object")
    if (any(is.na(x))) stop("NAs in x")
    if ((start(x) != start(y)) || (end(x) != end(y)) || (frequency(x) != frequency(y)))
      stop ("ts attributes of y and x do not match")
    if (lagx < 1) stop ("wrong lagx") 
  }
  if (lag < 1) stop("minimum lag is 1")
  if ((alpha <= 0) || (alpha >= 1)) stop("alpha is not in (0,1)")
  if (nhid < 0) stop ("nhid must be >= 0")
  if (!require(mva, quietly = TRUE)) stop("Package mva is needed. Stopping")
  DNAME <- deparse(substitute(y))
  trace <- match.arg(trace)
  init <- match.arg(init)
  t <- NROW(y)
  k <- NCOL(y)
  if (missing(errfunc))
  {
    if (k > 1) errfunc <- "GSSE"
    else errfunc <- "SSE"
  }
  if (errfunc == "ENTROPY") stop ("entropy is not supported for time series modelling")
  if (scale) y <- scale(y)
  y <- embed (y, lag+1)
  if (!mx)
  {
    p <- NCOL(x)
    if (scale) x <- scale(x)
    x <- embed (x, lagx+1)
    nx <- NROW(x) 
    ny <- NROW(y) 
    nxy <- min(nx,ny) 
    x <- cbind(as.matrix(y[(ny-nxy+1):ny,-(1:k)]),as.matrix(x[(nx-nxy+1):nx,-(1:p)]))
    y <- as.matrix(y[(ny-nxy+1):ny,1:k])
    m <- k*lag+p*lagx
  }
  else
  {
    x <- as.matrix(y[,-(1:k)])
    y <- as.matrix(y[,1:k])
    m <- k*lag
  }
  ga <- kronecker(1:m,rep(1,m))
  gb <-  kronecker(rep(1,m),1:m)
  guniq <- (ga <= gb)
  xx <- x[,ga[guniq]]*x[,gb[guniq]] 
  ga <- kronecker(1:m,rep(1,m*m))
  gb <- rep(kronecker(1:m,rep(1,m)),m)
  gc <- kronecker(rep(1,m*m),1:m)
  guniq <- ((ga <= gb) & (gb <= gc))
  xx <- cbind(xx,x[,ga[guniq]]*x[,gb[guniq]]*x[,gc[guniq]])
  prc <- prcomp(xx, scale = TRUE)
  q <- min(which(cumsum(prc$sdev) > alpha*sum(prc$sdev)))
  xx <- prc$x[,1:q]
  rr0 <- ffnet(y~x, nhid=nhid, errfunc=errfunc, trace=trace, ...)  
  rr0 <- ffnet(y~x, nhid=nhid, errfunc=errfunc, trace=trace, wts=rr0$wts)
  nw <- (rr0$nin+1+q)*nhid+(nhid+1)*rr0$nout+(rr0$nin+q)*rr0$nout
  wts <- rnorm(nw)
  if (nhid > 0)
    fwts <- rep(seq(rr0$nin+1,l=nhid,by=rr0$nin+q+1),rep(q,nhid))+rep(seq(1,q),nhid)
  else
    fwts <- numeric(0)
  wts[fwts] <- 0
  if (init == "null")
  {
    fwts2 <- rep(seq((rr0$nin+q+1)*nhid+(nhid+1)*rr0$nout+rr0$nin,l=rr0$nout,by=rr0$nin+q),
                 rep(q,rr0$nout))+rep(seq(1,q),rr0$nout)
    wts[-c(fwts,fwts2)] <- rr0$wts
    rr1 <- ffnet(y~x+xx, nhid=nhid, errfunc=errfunc, 
                 trace=trace, wts=wts, fwts=fwts)  
  }
  else if (init == "train")
  {
    rr1 <- ffnet(y~x+xx, nhid=nhid, errfunc=errfunc,
                 trace=trace, wts=wts, fwts=fwts, ...)  
    rr1 <- ffnet(y~x+xx, nhid=nhid, errfunc=errfunc,
                 trace=trace, wts=rr1$wts, fwts=fwts)
  }
  if ((errfunc == "SSE") || (errfunc == "GSSE"))
    STAT <- t*log(det(cov.mle(residuals(rr0)))/det(cov.mle(residuals(rr1))))
  else if (errfunc == "MAD")
    STAT <- t*log(det(cov.mle(residuals(rr0)))/det(cov.mle(residuals(rr1))))
  else stop ("invalid errfunc")
  names(STAT) <- "X-squared"
  PVAL <- 1-pchisq (STAT, q*k)
  if (!mx)
  {
    PARAMETER <- c(q*k, nhid, lag, lagx)
    names(PARAMETER) <- c("df", "nhid", "lag", "lagx")
  }
  else
  {
    PARAMETER <- c(q*k, nhid, lag)
    names(PARAMETER) <- c("df", "nhid", "lag")
  }
  METHOD <- "ffnet Teraesvirta Time Series Model Test"
  ARG <- c(alpha, scale, rr0$value, rr1$value)
  names(ARG) <- c("alpha", "scale", "null fitted value", "alternative fitted value")
  structure(list(statistic = STAT, parameter = PARAMETER, p.value = PVAL, 
                 method = METHOD, data.name = DNAME, arguments = ARG), 
            class = "htest")
}

ffnet.terasvirta.test.formula <- function (formula, data = NULL, errfunc, outtype, ...)
{
  if (!inherits(formula, "formula")) stop("method is only for formula objects")
  m <- match.call(expand = FALSE)
  if (is.matrix(eval(m$data, sys.frame(sys.parent())))) 
    m$data <- as.data.frame(data)
  m$... <- m$errfunc <- m$outtype <- NULL
  m[[1]] <- as.name("model.frame")
  m <- eval(m, sys.frame(sys.parent()))
  Terms <- attr(m, "terms")
  attr(Terms, "intercept") <- 0
  x <- model.matrix(Terms, m)
  y <- model.extract(m, response)
  if (is.factor(y))
  {
    lev <- levels(y)
    counts <- table(y)
    if (any(counts == 0))
    {
      warning(paste("group(s)", paste(lev[counts == 0], collapse = " "), "are empty"))
      y <- factor(y, levels = lev[counts > 0])
    }
    y <- as.cl.code(y)
    if (missing(errfunc)) errfunc <- "ENTROPY"
    if (missing(outtype)) outtype <- "SOFT"
    ffnet.terasvirta.test.default (x, y, errfunc=errfunc, outtype=outtype, scale=F, ...)
  }
  else
  {
    if (missing(errfunc))
    {
      if (NCOL(y) > 1) errfunc <- "GSSE"
      else errfunc <- "SSE"
    } 
    if (missing(outtype)) outtype <- "LIN"
    ffnet.terasvirta.test.default (x, y, errfunc=errfunc, outtype=outtype, ...)
  }
}

prune <- function (nn)
{
  if (!inherits(nn, "ffnet")) stop ("method is only for ffnet objects")
  if (any(is.na(summary(nn)$coef[,4]))) stop ("NaNs in the p-values")
  insign <- which(names(coef(nn))==names(which(max(summary(nn)$coef[,4])==summary(nn)$coef[,4])))
  fwts <- c(insign, nn$fwts)
  wts <- nn$wts
  wts[insign] <- 0
  return (list(wts=wts,fwts=fwts))
}
