/*
 * Copyright  2006-2008 Bernd Rinn
 *
 *  Licensed under the Apache License, Version 2.0 (the "License"); 
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License. 
 *
 */
package ch.rinn.restrictions;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

import org.apache.bcel.Constants;
import org.apache.bcel.Repository;
import org.apache.bcel.classfile.AnnotationEntry;
import org.apache.bcel.classfile.ArrayElementValue;
import org.apache.bcel.classfile.ConstantClass;
import org.apache.bcel.classfile.ConstantPool;
import org.apache.bcel.classfile.DescendingVisitor;
import org.apache.bcel.classfile.ElementValue;
import org.apache.bcel.classfile.ElementValuePair;
import org.apache.bcel.classfile.EmptyVisitor;
import org.apache.bcel.classfile.Field;
import org.apache.bcel.classfile.InnerClass;
import org.apache.bcel.classfile.JavaClass;
import org.apache.bcel.classfile.Method;

/**
 * A class representing the restrictions.
 * 
 * @author Bernd Rinn
 */
public class Restrictions
{

    private final String PRIVATE_ANNOTATION_NAME = getNameInClassFile(Private.class);

    private final String FINAL_ANNOTATION_NAME = getNameInClassFile(Final.class);

    private final String PRIVILEGED_ANNOTATION_NAME = getNameInClassFile(Privileged.class);

    private final String FRIEND_ANNOTATION_NAME = getNameInClassFile(Friend.class);

    final private Map<String, ClassRestriction> restrictions =
            new HashMap<String, ClassRestriction>();

    final private Map<String, Set<String>> superClassSet = new HashMap<String, Set<String>>();

    final private Set<String> missingClasses;

    private boolean errors = false;

    private static String getNameInClassFile(Class<?> clazz)
    {
        return "L" + clazz.getCanonicalName().replaceAll("\\.", "/") + ";";
    }
    
    static String getClassNameFromException(ClassNotFoundException ex)
    {
        final String prefix = "Exception while looking for class ";
        String msg = ex.getMessage();
        if (msg.startsWith(prefix))
        {
            int endIdx = msg.indexOf(':');
            if (endIdx > 0)
            {
                return msg.substring(prefix.length(), endIdx);
            }
        }
        return null;
    }

    static String outerClass(String someClass)
    {
        int posDollar = someClass.indexOf('$');
        if (posDollar > 0)
        {
            return someClass.substring(0, posDollar);
        } else
        {
            return someClass;
        }
    }

    private class Restriction
    {
        boolean _private;

        boolean _final;

        Restriction()
        {
        }

        Restriction(boolean _final, boolean _private)
        {
            this._final = _final;
            this._private = _private;
        }

        void setPrivate(boolean _private)
        {
            this._private = _private;
        }

        void setFinal(boolean _final)
        {
            this._final = _final;
        }

        void setPrivileged(boolean _test)
        {
            throw new UnsupportedOperationException();
        }

        void addFriend(String clazzName)
        {
            throw new UnsupportedOperationException();
        }

        boolean isPrivate()
        {
            return _private;
        }

        boolean isFinal()
        {
            return _final;
        }

        boolean isPrivileged()
        {
            return false;
        }

        boolean isFriend(String clazzName)
        {
            return false;
        }

        @Override
        public String toString()
        {
            return "[private=" + _private + ",final=" + _final + "]";
        }

    }

    private class ClassRestriction extends Restriction
    {
        boolean privileged;

        Set<String> friends;

        Map<String, Restriction> fieldRestrictions;

        Map<String, Restriction> methodRestrictions;

        ClassRestriction()
        {
            this.friends = new HashSet<String>();
            this.fieldRestrictions = new HashMap<String, Restriction>();
            this.methodRestrictions = new HashMap<String, Restriction>();
        }

        @Override
        void setPrivileged(boolean privileged)
        {
            this.privileged = privileged;
        }

        @Override
        boolean isPrivileged()
        {
            return privileged;
        }

        @Override
        void addFriend(String clazzName)
        {
            // remove the leading 'L' and the trailing ';' and replace '/' with '.'
            friends.add(clazzName.substring(1, clazzName.length() - 1).replaceAll("/", "."));
        }

        @Override
        boolean isFriend(String clazzName)
        {
            return friends.contains(clazzName);
        }

        Restriction getRestriction(Object entity)
        {
            Restriction restriction = null;
            if (entity instanceof Method)
            {
                Method method = (Method) entity;
                String methodName =
                        method.getName() + " " + method.getSignature().replaceAll("/", ".");
                restriction = getMethodRestriction(methodName);
                if (restriction == null)
                {
                    restriction = new Restriction();
                    methodRestrictions.put(methodName, restriction);
                }
            } else if (entity instanceof Field)
            {
                Field field = (Field) entity;
                String fieldName =
                        field.getName() + " " + field.getSignature().replaceAll("/", ".");
                restriction = getFieldRestriction(fieldName);
                if (restriction == null)
                {
                    restriction = new Restriction();
                    fieldRestrictions.put(fieldName, restriction);
                }
            } else if (entity instanceof JavaClass)
            {
                restriction = this;
            } else
            {
                throw new IllegalArgumentException("Unknown entity " + entity.getClass());
            }
            return restriction;
        }

        public Restriction getFieldRestriction(String fieldName)
        {
            return fieldRestrictions.get(fieldName);
        }

        public Restriction getMethodRestriction(String methodName)
        {
            return methodRestrictions.get(methodName);
        }

        @Override
        public String toString()
        {
            return "[private=" + _private + ",final=" + _final + ", test=" + privileged + "]";
        }
    }

    private class RestrictionVisitor extends EmptyVisitor
    {

        private Set<String> dependantClasses;

        protected JavaClass clazz;

        private DescendingVisitor descendingVisitor;

        RestrictionVisitor(JavaClass clazz, Set<String> dependantClasses)
        {
            this.clazz = clazz;
            this.dependantClasses = dependantClasses;
        }

        void setDescendingVisitor(DescendingVisitor descendingVisitor)
        {
            this.descendingVisitor = descendingVisitor;
        }

        protected void registerClassRestriction()
        {
            String clazzName = clazz.getClassName();
            ClassRestriction classRestriction = restrictions.get(clazzName);
            if (classRestriction == null)
            {
                classRestriction = new ClassRestriction();
                restrictions.put(clazzName, classRestriction);
            }
        }

        protected ClassRestriction getClassRestriction()
        {
            ClassRestriction classRestriction = restrictions.get(clazz.getClassName());
            if (classRestriction == null)
            {
                throw new RuntimeException("class " + clazz.getClassName() + " not yet registered");
            }
            return classRestriction;
        }

        protected void addClassToScan(String clazzName)
        {
            ClassRestriction classRestriction = restrictions.get(clazzName);
            if (classRestriction == null)
            {
                dependantClasses.add(clazzName);
            }
        }

        protected void addClassesToScan(JavaClass... classes)
        {
            for (JavaClass classToAdd : classes)
            {
                addClassToScan(classToAdd.getClassName());
            }
        }

        @Override
        public void visitJavaClass(JavaClass obj)
        {
            registerClassRestriction();
            JavaClass[] superClasses = null;
            JavaClass[] allInterfaces = null;
            try
            {
                superClasses = obj.getSuperClasses();
                addClassesToScan(superClasses);
                Set<String> superClassNames = new LinkedHashSet<String>();
                superClassSet.put(obj.getClassName(), superClassNames);
                for (JavaClass jc : superClasses)
                {
                    String clazzName = jc.getClassName();
                    superClassNames.add(clazzName);
                    if (clazzName.equals("junit.framework.TestCase"))
                    {
                        getClassRestriction().setPrivileged(true);
                    }
                }
            } catch (ClassNotFoundException ex)
            {
                missingClasses.add(getClassNameFromException(ex));
            }
            try
            {
                allInterfaces = obj.getAllInterfaces();
                addClassesToScan(allInterfaces);
            } catch (ClassNotFoundException ex)
            {
                missingClasses.add(getClassNameFromException(ex));
            }
        }

        @Override
        public void visitAnnotationEntry(AnnotationEntry obj)
        {
            Object entity = descendingVisitor.predecessor(1);
            if (entity == null)
            {
                entity = clazz;
            }
            if (obj.getAnnotationType().equals(FINAL_ANNOTATION_NAME))
            {
                if (entity instanceof Method && clazz.isClass() == false)
                {
                    Method method = (Method) entity;
                    String name = method.getName() + " " + method.getSignature();
                    errors = true;
                    System.err.println("Error: @Final method '" + name + "' on interface "
                            + clazz.getClassName());
                }
                getClassRestriction().getRestriction(entity).setFinal(true);
            }
            if (obj.getAnnotationType().equals(PRIVATE_ANNOTATION_NAME))
            {
                if (entity instanceof Method)
                {
                    Method method = (Method) entity;
                    JavaClass definingSuperClass = getDefiningSuperClassForMethod(clazz, method);
                    if (definingSuperClass != null)
                    {
                        String name = method.getName() + " " + method.getSignature();
                        errors = true;
                        System.err.println("Error: @Private method '" + name + "' of class "
                                + clazz.getClassName()
                                + " tries to reduce visibility of super method in class "
                                + definingSuperClass.getClassName());
                    }
                }
                getClassRestriction().getRestriction(entity).setPrivate(true);
            }
            if (obj.getAnnotationType().equals(PRIVILEGED_ANNOTATION_NAME))
            {
                getClassRestriction().getRestriction(entity).setPrivileged(true);
            }
            if (obj.getAnnotationType().equals(FRIEND_ANNOTATION_NAME))
            {
                for (ElementValuePair e : obj.getElementValuePairs())
                {
                    if (e.getNameString().equals("toClasses")
                            && e.getValue().getElementValueType() == ElementValue.ARRAY)
                    {
                        final ArrayElementValue valueArray = (ArrayElementValue) e.getValue();
                        for (ElementValue value : valueArray.getElementValuesArray())
                        {
                            getClassRestriction().getRestriction(entity)
                                    .addFriend(value.toString());
                        }
                    }
                }
            }
        }

    }

    private class RestrictionVisitorRegisterUsedClasses extends RestrictionVisitor
    {

        private ConstantPool pool;

        RestrictionVisitorRegisterUsedClasses(JavaClass clazz, Set<String> dependantClasses)
        {
            super(clazz, dependantClasses);
            pool = clazz.getConstantPool();
        }

        private void addDependantClassesToScan(String clazzName)
        {
            JavaClass classToScan = null;
            try
            {
                classToScan = Repository.lookupClass(clazzName);
                addClassesToScan(classToScan.getSuperClasses());
            } catch (ClassNotFoundException ex)
            {
                missingClasses.add(getClassNameFromException(ex));
            }
            if (classToScan != null)
            {
                try
                {
                    addClassesToScan(classToScan.getAllInterfaces());
                } catch (ClassNotFoundException ex)
                {
                    missingClasses.add(getClassNameFromException(ex));
                }
            }
        }

        @Override
        public void visitConstantClass(ConstantClass obj)
        {
            String clazzName = obj.getBytes(pool).replaceAll("/", ".");
            if (clazzName.startsWith("[") == false)
            {
                addClassToScan(clazzName);
                addDependantClassesToScan(clazzName);
            }
        }

        @Override
        public void visitInnerClass(InnerClass obj)
        {
            String innerClazzName =
                    pool.getConstantString(obj.getInnerClassIndex(), Constants.CONSTANT_Class)
                            .replaceAll("/", ".");
            addClassToScan(innerClazzName);
            addDependantClassesToScan(innerClazzName);
        }

    }

    public Restrictions(Set<String> missingClasses)
    {
        this.missingClasses = missingClasses;
    }

    private ClassRestriction getClassRestriction(String className)
    {
        ClassRestriction restriction = restrictions.get(className);
        return restriction == null ? new ClassRestriction() : restriction;
    }

    public boolean isClassPrivate(String className)
    {
        if (primIsClassPrivate(className) == false)
        {
            return primIsClassPrivate(outerClass(className));
        } else
        {
            return true;
        }
    }

    private boolean primIsClassPrivate(String className)
    {
        return getClassRestriction(className).isPrivate();
    }

    public boolean isClassFinal(String className)
    {
        return getClassRestriction(className).isFinal();
    }

    public boolean isClassPrivilegedWithRespectTo(String classNameToCheck, String classNameToAccess)
    {
        return isClassPrivileged(classNameToCheck)
                || isClassFriendOf(classNameToCheck, classNameToAccess);
    }

    public boolean isClassPrivileged(String className)
    {
        if (primIsClassPrivileged(className) == false)
        {
            return primIsClassPrivileged(outerClass(className));
        } else
        {
            return true;
        }
    }

    private boolean primIsClassPrivileged(String className)
    {
        return getClassRestriction(className).isPrivileged();
    }

    public boolean isClassFriendOf(String classNameToCheck, String classNameOfFriend)
    {
        if (primIsClassFriendOf(classNameToCheck, classNameOfFriend) == false)
        {
            return primIsClassFriendOf(outerClass(classNameToCheck), classNameOfFriend);
        } else
        {
            return true;
        }
    }

    private boolean primIsClassFriendOf(String classNameToCheck, String classNameOfFriend)
    {
        return getClassRestriction(classNameToCheck).isFriend(classNameOfFriend);
    }

    public boolean isSuperClass(String className, String classToTest)
    {
        Set<String> superClasses = superClassSet.get(className);
        return superClasses != null && superClasses.contains(classToTest);
    }

    private Restriction getMethodRestriction(String className, String methodName)
    {
        return getClassRestriction(className).getMethodRestriction(methodName);
    }

    private Restriction getFieldRestriction(String className, String fieldName)
    {
        return getClassRestriction(className).getFieldRestriction(fieldName);
    }

    private JavaClass getDefiningSuperClassForMethod(JavaClass clazz, Method method)
    {
        if (method.isStatic())
        {
            return null;
        }
        try
        {
            JavaClass superClass =
                    primGetDefiningSuperClassForMethod(clazz.getAllInterfaces(), method);
            // if clazz is an interface, clazz.getAllInterfaces() will contain itself
            if (superClass != null && clazz.equals(superClass) == false)
            {
                return superClass;
            }
        } catch (ClassNotFoundException ex)
        {
            missingClasses.add(getClassNameFromException(ex));
        }
        try
        {
            JavaClass superClass =
                    primGetDefiningSuperClassForMethod(clazz.getSuperClasses(), method);
            if (superClass != null)
            {
                return superClass;
            }
        } catch (ClassNotFoundException ex)
        {
            missingClasses.add(getClassNameFromException(ex));
        }
        return null;
    }

    private JavaClass primGetDefiningSuperClassForMethod(JavaClass[] superClasses, Method method)
    {
        if (superClasses != null)
        {
            for (JavaClass clazz2 : superClasses)
            {
                for (Method method2 : clazz2.getMethods())
                {
                    if (method2.equals(method))
                    {
                        return clazz2;
                    }
                }
            }
        }
        return null;
    }

    public String getDefiningClassForMethod(String className, String methodName)
    {
        if (isMethodDefined(className, methodName) == false)
        {
            return getDefiningSuperClassForMethod(className, methodName);
        } else
        {
            return className;
        }
    }

    private String getDefiningSuperClassForMethod(String className, String methodName)
    {
        Set<String> superClasses = superClassSet.get(className);
        if (superClasses != null)
        {
            for (String clazzName : superClasses)
            {
                if (isMethodDefined(clazzName, methodName))
                {
                    return clazzName;
                }
            }
        }
        return null;
    }

    private boolean isMethodDefined(String className, String methodName)
    {
        return (getMethodRestriction(className, methodName) != null);
    }

    public String getDefiningClassForField(String className, String fieldName)
    {
        if (isFieldDefined(className, fieldName) == false)
        {
            Set<String> superClasses = superClassSet.get(className);
            if (superClasses != null)
            {
                for (String clazzName : superClasses)
                {
                    if (isFieldDefined(clazzName, fieldName))
                    {
                        return clazzName;
                    }
                }
            }
            return null;
        } else
        {
            return className;
        }
    }

    private boolean isFieldDefined(String className, String fieldName)
    {
        return (getFieldRestriction(className, fieldName) != null);
    }

    public boolean isMethodPrivate(String className, String methodName)
    {
        Restriction restriction = getMethodRestriction(className, methodName);
        return (restriction == null) ? false : restriction.isPrivate();
    }

    public boolean isMethodFinal(String className, String methodName)
    {
        Restriction restriction = getMethodRestriction(className, methodName);
        return (restriction == null) ? false : restriction.isFinal();
    }

    public boolean isFieldPrivate(String className, String fieldName)
    {
        Restriction restriction = getFieldRestriction(className, fieldName);
        return (restriction == null) ? false : restriction.isPrivate();
    }

    public boolean scanClass(JavaClass clazz, Set<String> allClasses)
    {
        if (restrictions.containsKey(clazz.getClassName()))
        {
            return false;
        }
        Set<String> dependantClasses = new HashSet<String>();
        RestrictionVisitorRegisterUsedClasses rvisitor =
                new RestrictionVisitorRegisterUsedClasses(clazz, dependantClasses);
        DescendingVisitor visitor = new DescendingVisitor(clazz, rvisitor);
        rvisitor.setDescendingVisitor(visitor);
        visitor.visit();
        while (dependantClasses.isEmpty() == false)
        {
            Iterator<String> it = dependantClasses.iterator();
            String clazzName = it.next();
            it.remove();
            if (missingClasses.contains(clazzName))
            {
                continue;
            }
            try
            {
                final JavaClass clazzFromRepos = Repository.lookupClass(clazzName);
                RestrictionVisitor bvisitor;
                if (allClasses.contains(clazzName))
                {
                    bvisitor =
                            new RestrictionVisitorRegisterUsedClasses(clazzFromRepos,
                                    dependantClasses);
                } else
                {
                    bvisitor = new RestrictionVisitor(clazzFromRepos, dependantClasses);
                }
                visitor = new DescendingVisitor(clazzFromRepos, bvisitor);
                bvisitor.setDescendingVisitor(visitor);
                visitor.visit();
            } catch (ClassNotFoundException ex)
            {
                missingClasses.add(getClassNameFromException(ex));
            }
        }
        return errors;
    }

}
