/*
 * Copyright (C) 2005 Alfresco, Inc.
 *
 * Licensed under the Mozilla Public License version 1.1 
 * with a permitted attribution clause. You may obtain a
 * copy of the License at
 *
 *   http://www.alfresco.org/legal/license.txt
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
 * either express or implied. See the License for the specific
 * language governing permissions and limitations under the
 * License.
 */
package org.alfresco.repo.audit;

import java.io.Serializable;
import java.lang.reflect.Method;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Date;
import java.util.List;

import org.alfresco.repo.security.authentication.AuthenticationUtil;
import org.alfresco.repo.transaction.AlfrescoTransactionSupport;
import org.alfresco.service.Auditable;
import org.alfresco.service.NotAuditable;
import org.alfresco.service.cmr.audit.AuditInfo;
import org.alfresco.service.cmr.repository.NodeRef;
import org.alfresco.service.cmr.repository.StoreRef;
import org.aopalliance.intercept.MethodInvocation;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * The default audit component implementation. TODO: Implement before, after and exception filtering. At the moment these filters are ignired. TODO: Respect audit internal - at the
 * moment audit internal is fixed to false.
 * 
 * @author Andy Hind
 */
public class AuditComponentImpl implements AuditComponent
{
    /**
     * The application name to use for audit entries generated by method interception around public services.
     */
    private static final String SYSTEM_APPLICATION = "SystemMethodInterceptor";

    /**
     * Logging
     */
    private static Log s_logger = LogFactory.getLog(AuditComponentImpl.class);

    /**
     * Suspend resume auditing
     */
    private static ThreadLocal<Boolean> auditFlag = new ThreadLocal<Boolean>();

    /**
     * IOC
     */
    private PublicServiceIdentifier publicServiceIdentifier;

    private AuditConfiguration auditConfiguration;

    private AuditDAO auditDAO;

    private AuditDAO auditFailedDAO;

    private AuditModel auditModel;

    /**
     * Keep hold of the host where the audit occurs. TODO: Check that we get the correct address ...
     */

    private InetAddress auditHost;

    public AuditComponentImpl()
    {
        super();
        // Initialise the host address
        try
        {
            auditHost = InetAddress.getLocalHost();
        }
        catch (UnknownHostException e)
        {
            s_logger.error("Failed to get local host address", e);
        }
    }

    /*
     * IOC property setters
     */

    public void setAuditDAO(AuditDAO auditDAO)
    {
        this.auditDAO = auditDAO;
    }

    public void setAuditFailedDAO(AuditDAO auditFailedDAO)
    {
        this.auditFailedDAO = auditFailedDAO;
    }

    public void setAuditConfiguration(AuditConfiguration auditConfiguration)
    {
        this.auditConfiguration = auditConfiguration;
    }

    public void setPublicServiceIdentifier(PublicServiceIdentifier publicServiceIdentifier)
    {
        this.publicServiceIdentifier = publicServiceIdentifier;
    }

    public void setAuditModel(AuditModel auditModel)
    {
        this.auditModel = auditModel;
    }

    public Object audit(MethodInvocation mi) throws Throwable
    {
        if ((auditFlag.get() == null) || (!auditFlag.get().booleanValue()))
        {
            try
            {
                auditFlag.set(Boolean.TRUE);

                Method method = mi.getMethod();
                String methodName = method.getName();
                String serviceName = publicServiceIdentifier.getPublicServiceName(mi);
                if (method.isAnnotationPresent(Auditable.class))
                {

                    if (serviceName != null)
                    {
                        if (s_logger.isDebugEnabled())
                        {
                            s_logger.debug("Auditing - " + serviceName + "." + methodName);
                        }
                        return auditImpl(mi);
                    }
                    else
                    {
                        if (s_logger.isDebugEnabled())
                        {
                            s_logger.debug("UnknownService." + methodName);
                        }
                        return auditImpl(mi);
                    }

                }
                else if (method.isAnnotationPresent(NotAuditable.class))
                {
                    if (s_logger.isDebugEnabled())
                    {
                        s_logger.debug("Not Audited. " + serviceName + "." + methodName);
                    }
                    return mi.proceed();
                }
                else
                {
                    if (s_logger.isDebugEnabled())
                    {
                        s_logger.debug("Unannotated service method " + serviceName + "." + methodName);
                    }
                    throw new RuntimeException("Unannotated service method " + serviceName + "." + methodName);
                }
            }
            finally
            {
                auditFlag.set(Boolean.FALSE);
            }
        }
        else
        {
            return mi.proceed();
        }
    }

    /**
     * Audit a method invocation
     */
    public Object auditImpl(MethodInvocation mi) throws Throwable
    {
        AuditState auditInfo = new AuditState(auditConfiguration);
        // RecordOptions recordOptions = auditModel.getAuditRecordOptions(mi);
        AuditMode auditMode = AuditMode.UNSET;
        try
        {
            auditMode = beforeInvocation(auditMode, auditInfo, mi);
            Object o = mi.proceed();
            auditMode = postInvocation(auditMode, auditInfo, mi, o);
            if ((auditMode == AuditMode.ALL) || (auditMode == AuditMode.SUCCESS))
            {
                auditDAO.audit(auditInfo);
            }
            return o;
        }
        catch (Throwable t)
        {
            auditMode = onError(auditMode, auditInfo, mi, t);
            if ((auditMode == AuditMode.ALL) || (auditMode == AuditMode.FAIL))
            {
                try
                {
                    auditFailedDAO.audit(auditInfo);
                }
                catch (Throwable tt)
                {
                    throw new AuditException("Failed to audit exception", new Object[] { tt }, t);
                }
            }
            throw t;
        }
    }

    /**
     * Helper method to set auditable properties and to determine if auditing is required when an exception is caught in the audited method.
     * 
     * @param auditMode
     * @param auditInfo
     * @param t
     * @return
     */
    private AuditMode onError(AuditMode auditMode, AuditState auditInfo, MethodInvocation mi, Throwable t)
    {
        if ((auditMode == AuditMode.ALL) || (auditMode == AuditMode.FAIL))
        {
            auditInfo.setFail(true);
            auditInfo.setThrowable(t);
        }

        return auditMode;
    }

    /**
     * Helper method to set audited information after method invocation and to determine if auditing should take place based on the method return value.
     * 
     * @param auditMode
     * @param auditInfo
     * @param mi
     * @param returnObject
     * @return
     */
    private AuditMode postInvocation(AuditMode auditMode, AuditState auditInfo, MethodInvocation mi, Object returnObject)
    {
        if (returnObject == null)
        {
            auditInfo.setReturnObject(null);
        }
        else if (returnObject instanceof Serializable)
        {
            auditInfo.setReturnObject((Serializable) returnObject);
        }
        else
        {
            auditInfo.setReturnObject(returnObject.toString());
        }

        Auditable auditable = mi.getMethod().getAnnotation(Auditable.class);
        if (auditable.key() == Auditable.Key.RETURN)
        {
            if (returnObject != null)
            {
                if (returnObject instanceof NodeRef)
                {
                    NodeRef key = (NodeRef) returnObject;
                    auditInfo.setKeyStore(key.getStoreRef());
                    auditInfo.setKeyGUID(key.getId());
                }
                else if (returnObject instanceof StoreRef)
                {
                    auditInfo.setKeyStore((StoreRef)returnObject);
                }
            }
        }

        // If the user name is not set, try and set it after the method call.
        // This covers authentication when the user is only known after the call.

        if (auditInfo.getUserIdentifier() == null)
        {
            auditInfo.setUserIdentifier(AuthenticationUtil.getCurrentUserName());
        }

        return auditMode;
    }

    /**
     * Set auditable information and determine if auditing is required before method invocation. This would normally be based on the method arguments.
     * 
     * @param auditMode
     * @param auditInfo
     * @param mi
     * @return
     */
    private AuditMode beforeInvocation(AuditMode auditMode, AuditState auditInfo, MethodInvocation mi)
    {
        AuditMode effectiveAuditMode = auditModel.beforeExecution(auditMode, mi);

        if (auditMode != AuditMode.NONE)
        {
            String methodName = mi.getMethod().getName();
            String serviceName = publicServiceIdentifier.getPublicServiceName(mi);
            auditInfo.setAuditApplication(SYSTEM_APPLICATION);
            auditInfo.setAuditConfiguration(auditConfiguration);
            auditInfo.setAuditMethod(methodName);
            auditInfo.setAuditService(serviceName);
            auditInfo.setClientAddress(null);
            auditInfo.setDate(new Date());
            auditInfo.setFail(false);
            auditInfo.setFiltered(false);
            auditInfo.setHostAddress(auditHost);
            Auditable auditable = mi.getMethod().getAnnotation(Auditable.class);
            Object key = null;
            switch (auditable.key())
            {
            case ARG_0:
                key = mi.getArguments()[0];
                break;
            case ARG_1:
                key = mi.getArguments()[1];
                break;
            case ARG_2:
                key = mi.getArguments()[2];
                break;
            case ARG_3:
                key = mi.getArguments()[3];
                break;
            case ARG_4:
                key = mi.getArguments()[4];
                break;
            case ARG_5:
                key = mi.getArguments()[5];
                break;
            case ARG_6:
                key = mi.getArguments()[6];
                break;
            case ARG_7:
                key = mi.getArguments()[7];
                break;
            case ARG_8:
                key = mi.getArguments()[8];
                break;
            case ARG_9:
                key = mi.getArguments()[9];
                break;
            case NO_KEY:
            default:
                break;
            }
            if (key != null)
            {
                if (key instanceof NodeRef)
                {
                    auditInfo.setKeyStore(((NodeRef) key).getStoreRef());
                    auditInfo.setKeyGUID(((NodeRef) key).getId());
                }
                else if (key instanceof StoreRef)
                {
                    auditInfo.setKeyStore((StoreRef) key);
                }
            }
            auditInfo.setKeyPropertiesAfter(null);
            auditInfo.setKeyPropertiesBefore(null);
            auditInfo.setMessage(null);
            if (mi.getArguments() != null)
            {
                Serializable[] serArgs = new Serializable[mi.getArguments().length];
                for (int i = 0; i < mi.getArguments().length; i++)
                {
                    if ((auditable.recordable() == null)
                            || (auditable.recordable().length <= i) || auditable.recordable()[i])
                    {
                        if (mi.getArguments()[i] == null)
                        {
                            serArgs[i] = null;
                        }
                        else if (mi.getArguments()[i] instanceof Serializable)
                        {
                            serArgs[i] = (Serializable) mi.getArguments()[i];
                        }
                        else
                        {
                            serArgs[i] = mi.getArguments()[i].toString();
                        }
                    }
                    else
                    {
                        serArgs[i] = "********";
                    }
                }
                auditInfo.setMethodArguments(serArgs);
            }
            auditInfo.setPath(null);
            auditInfo.setReturnObject(null);
            auditInfo.setSessionId(null);
            auditInfo.setThrowable(null);
            auditInfo.setTxId(AlfrescoTransactionSupport.getTransactionId());
            auditInfo.setUserIdentifier(AuthenticationUtil.getCurrentUserName());
        }

        return effectiveAuditMode;
    }

    /**
     * A simple audit entry Currently we ignore filtering here.
     */
    public void audit(String source, String description, NodeRef key, Object... args)
    {
        AuditState auditInfo = new AuditState(auditConfiguration);
        // RecordOptions recordOptions = auditModel.getAuditRecordOptions(mi);
        AuditMode auditMode = AuditMode.UNSET;
        try
        {
            auditMode = onApplicationAudit(auditMode, auditInfo, source, description, key, args);
            if ((auditMode == AuditMode.ALL) || (auditMode == AuditMode.SUCCESS))
            {
                auditDAO.audit(auditInfo);
            }
        }
        catch (Throwable t)
        {
            auditMode = onError(auditMode, auditInfo, t, source, description, key, args);
            if ((auditMode == AuditMode.ALL) || (auditMode == AuditMode.FAIL))
            {
                try
                {
                    auditFailedDAO.audit(auditInfo);
                }
                catch (Throwable tt)
                {
                    throw new AuditException("Failed to audit exception", new Object[] { tt }, t);
                }
            }
            throw new AuditException("Application audit failed", t);
        }
    }

    public List<AuditInfo> getAuditTrail(NodeRef nodeRef)
    {
        return auditDAO.getAuditTrail(nodeRef);
    }

    private AuditMode onApplicationAudit(AuditMode auditMode, AuditState auditInfo, String source, String description,
            NodeRef key, Object... args)
    {
        AuditMode effectiveAuditMode = auditModel.beforeExecution(auditMode, source, description, key, args);

        if (auditMode != AuditMode.NONE)
        {
            if (source.equals(SYSTEM_APPLICATION))
            {
                throw new AuditException("Application audit can not use the reserved identifier " + SYSTEM_APPLICATION);
            }

            auditInfo.setAuditApplication(source);
            auditInfo.setAuditConfiguration(auditConfiguration);
            auditInfo.setAuditMethod(null);
            auditInfo.setAuditService(null);
            auditInfo.setClientAddress(null);
            auditInfo.setDate(new Date());
            auditInfo.setFail(false);
            auditInfo.setFiltered(false);
            auditInfo.setHostAddress(auditHost);
            if (key != null)
            {
                auditInfo.setKeyStore(key.getStoreRef());
                auditInfo.setKeyGUID(key.getId());
            }
            auditInfo.setKeyPropertiesAfter(null);
            auditInfo.setKeyPropertiesBefore(null);
            auditInfo.setMessage(description);
            if (args != null)
            {
                Serializable[] serArgs = new Serializable[args.length];
                for (int i = 0; i < args.length; i++)
                {
                    if (args[i] == null)
                    {
                        serArgs[i] = null;
                    }
                    else if (args[i] instanceof Serializable)
                    {
                        serArgs[i] = (Serializable) args[i];
                    }
                    else
                    {
                        serArgs[i] = args[i].toString();
                    }
                }
                auditInfo.setMethodArguments(serArgs);
            }
            auditInfo.setPath(null);
            auditInfo.setReturnObject(null);
            auditInfo.setSessionId(null);
            auditInfo.setThrowable(null);
            auditInfo.setTxId(AlfrescoTransactionSupport.getTransactionId());
            auditInfo.setUserIdentifier(AuthenticationUtil.getCurrentUserName());
        }

        return effectiveAuditMode;
    }

    private AuditMode onError(AuditMode auditMode, AuditState auditInfo, Throwable t, String source,
            String description, NodeRef key, Object... args)
    {
        if ((auditMode == AuditMode.ALL) || (auditMode == AuditMode.FAIL))
        {
            auditInfo.setFail(true);
            auditInfo.setThrowable(t);
        }

        return auditMode;

    }
}