/*
 * Copyright  2006-2008 Bernd Rinn
 *
 *  Licensed under the Apache License, Version 2.0 (the "License"); 
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License. 
 *
 */
package ch.rinn.restrictions;

import static ch.rinn.restrictions.Restrictions.outerClass;

import java.io.File;
import java.io.FileFilter;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.Collection;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Set;
import java.util.TreeSet;
import java.util.regex.Pattern;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;

import org.apache.bcel.Repository;
import org.apache.bcel.classfile.ConstantCP;
import org.apache.bcel.classfile.ConstantClass;
import org.apache.bcel.classfile.ConstantFieldref;
import org.apache.bcel.classfile.ConstantInterfaceMethodref;
import org.apache.bcel.classfile.ConstantMethodref;
import org.apache.bcel.classfile.ConstantNameAndType;
import org.apache.bcel.classfile.ConstantPool;
import org.apache.bcel.classfile.ConstantUtf8;
import org.apache.bcel.classfile.DescendingVisitor;
import org.apache.bcel.classfile.EmptyVisitor;
import org.apache.bcel.classfile.JavaClass;
import org.apache.bcel.classfile.Method;
import org.apache.bcel.util.ClassPath;
import org.apache.bcel.util.SyntheticRepository;

/**
 * The class that performs the restriction check.
 * 
 * @author Bernd Rinn
 */
public class RestrictionChecker
{

    private final static String pathSep = System.getProperty("path.separator");

    final private Restrictions restrictions;

    final private Set<String> clazzNames;

    final private Set<String> missingClasses;

    // currently scanned
    private JavaClass clazz;

    private ConstantPool pool;

    private boolean errors;

    private class CheckingVisitor extends EmptyVisitor
    {

        private final boolean missingClassCountsAsError;

        private CheckingVisitor(boolean missingClassCountsAsError)
        {
            this.missingClassCountsAsError = missingClassCountsAsError;
        }

        public void setDescendingVisitor(DescendingVisitor descendingVisitor)
        {
        }

        @Override
        public void visitJavaClass(JavaClass obj)
        {
            final String clazzName = obj.getClassName();
            if (restrictions.isClassPrivileged(clazzName) == false)
            {
                try
                {
                    for (JavaClass jc : obj.getSuperClasses())
                    {
                        final String superClazzName = jc.getClassName();
                        final boolean classIsFinal = restrictions.isClassFinal(superClazzName);
                        final boolean classIsPrivate = restrictions.isClassPrivate(superClazzName);
                        if (classIsFinal
                                || (classIsPrivate && isPrivateContext(superClazzName, clazzName) == false))
                        {
                            errors = true;
                            System.err.println("Error: class " + obj.getClassName() + " extends "
                                    + getModifier(classIsFinal, classIsPrivate) + "class "
                                    + jc.getClassName());
                        }
                    }
                } catch (ClassNotFoundException ex)
                {
                    if (missingClassCountsAsError)
                    {
                        errors = true;
                    }
                }
            }
        }

        private String getModifier(boolean _final, boolean _private)
        {
            String modifier = "";
            if (_final)
            {
                modifier += "@Final ";
            }
            if (_private)
            {
                modifier += "@Private ";
            }
            return modifier;
        }

        @Override
        public void visitConstantClass(ConstantClass obj)
        {
            final String thisClazzName = clazz.getClassName();
            final String clazzName = obj.getBytes(pool).replaceAll("/", ".");
            if (clazzName.startsWith("[") == false)
            {
                final boolean privateContext = isPrivateContext(clazzName, thisClazzName);
                if (restrictions.isClassPrivilegedWithRespectTo(thisClazzName, clazzName) == false
                        && privateContext == false)
                {
                    boolean ignore = restrictions.isSuperClass(thisClazzName, clazzName);
                    if (ignore == false && restrictions.isClassPrivate(clazzName))
                    {
                        errors = true;
                        System.err.println("Error: class " + thisClazzName
                                + " uses @Private class " + clazzName);
                    }
                }
            }
        }

        @Override
        public void visitMethod(Method obj)
        {
            final String thisClazzName = clazz.getClassName();
            final String name = obj.getName() + " " + obj.getSignature().replaceAll("/", ".");
            final String definingClazzName =
                    restrictions.getDefiningClassForMethod(thisClazzName, name);
            final boolean privateContext =
                    (definingClazzName == null)
                            || isPrivateContext(definingClazzName, thisClazzName);
            if (restrictions.isClassPrivilegedWithRespectTo(thisClazzName, definingClazzName) == false
                    && privateContext == false)
            {
                final boolean methodIsFinal = restrictions.isMethodFinal(definingClazzName, name);
                final boolean methodIsPrivate =
                        restrictions.isMethodPrivate(definingClazzName, name);
                if (methodIsFinal || methodIsPrivate)
                {
                    errors = true;
                    System.err.println("Error: class " + thisClazzName + " overrides "
                            + getModifier(methodIsFinal, methodIsPrivate) + "method '" + name
                            + "' of class " + definingClazzName);
                }
            }
        }

        private String getClassName(ConstantCP constant)
        {
            final ConstantClass constantClass =
                    (ConstantClass) pool.getConstant(constant.getClassIndex());
            final String clazzName =
                    ((ConstantUtf8) pool.getConstant(constantClass.getNameIndex())).getBytes();
            return clazzName.replaceAll("/", ".");
        }

        private String getName(ConstantCP constant)
        {
            ConstantNameAndType nameAndType =
                    (org.apache.bcel.classfile.ConstantNameAndType) pool.getConstant(constant
                            .getNameAndTypeIndex());
            final String name =
                    ((ConstantUtf8) pool.getConstant(nameAndType.getNameIndex())).getBytes();
            final String signature =
                    ((ConstantUtf8) pool.getConstant(nameAndType.getSignatureIndex())).getBytes();
            return name + " " + signature.replaceAll("/", ".");
        }

        private boolean isPrivateContext(String definingClazzName, final String thisClazzName)
        {
            return outerClass(thisClazzName).equals(outerClass(definingClazzName));
        }

        @Override
        public void visitConstantFieldref(ConstantFieldref obj)
        {
            final String thisClazzName = clazz.getClassName();
            final String clazzName = getClassName(obj);
            if (clazzName.startsWith("[") == false
                    && restrictions.isClassPrivilegedWithRespectTo(thisClazzName, clazzName) == false)
            {
                String name = getName(obj);
                String definingClazzName = restrictions.getDefiningClassForField(clazzName, name);
                if (definingClazzName != null
                        && isPrivateContext(definingClazzName, thisClazzName) == false
                        && restrictions.isFieldPrivate(definingClazzName, name))
                {
                    errors = true;
                    System.err.println("Error: class " + thisClazzName + " uses @Private field '"
                            + name + "' of class " + definingClazzName);
                }
            }
        }

        @Override
        public void visitConstantMethodref(ConstantMethodref obj)
        {
            primVisitConstantMethodref(obj);
        }

        @Override
        public void visitConstantInterfaceMethodref(ConstantInterfaceMethodref obj)
        {
            primVisitConstantMethodref(obj);
        }

        private void primVisitConstantMethodref(ConstantCP obj)
        {
            final String thisClazzName = clazz.getClassName();
            final String clazzName = getClassName(obj);
            if (clazzName.startsWith("[") == false
                    && restrictions.isClassPrivilegedWithRespectTo(thisClazzName, clazzName) == false)
            {
                final String name = getName(obj);
                final String definingClazzName =
                        restrictions.getDefiningClassForMethod(clazzName, name);
                if (definingClazzName != null
                        && isPrivateContext(definingClazzName, thisClazzName) == false
                        && restrictions.isMethodPrivate(definingClazzName, name))
                {
                    errors = true;
                    System.err.println("Error: class " + thisClazzName + " uses @Private method '"
                            + name + "' of class " + definingClazzName);
                }
            }
        }

    }

    public RestrictionChecker()
    {
        clazzNames = new TreeSet<String>();
        missingClasses = new TreeSet<String>();
        restrictions = new Restrictions(missingClasses);
    }

    private void scanJarFile(Collection<String> clazzNamesFound, String jarfile, String directory)
    {
        try
        {
            final ZipFile zip = new ZipFile(jarfile);
            final Enumeration<? extends ZipEntry> entries = zip.entries();
            while (entries.hasMoreElements())
            {
                final ZipEntry zipEntry = entries.nextElement();
                String p = zipEntry.getName();
                if (zipEntry.isDirectory() == false && p.endsWith(".class")
                        && (directory.equals(".") || p.startsWith(directory)))
                {
                    p =
                            p.substring(0, p.length() - ".class".length()).replaceAll(
                                    File.separator, ".");
                    clazzNamesFound.add(p);
                }
            }
            zip.close();
        } catch (IOException ex)
        {
            throw new IllegalArgumentException("Error reading zip file '" + jarfile + "'.", ex);
        }
    }

    private void scanDirectoriesForClassFiles(final Collection<String> clazzNamesFound,
            final String rootDir, final String prefix, final File... directories)
    {
        for (File directory : directories)
        {
            if (directory.isDirectory())
            {
                File[] classFiles;
                if (prefix == null)
                {
                    classFiles = directory.listFiles(new FilenameFilter()
                        {
                            public boolean accept(File dir, String name)
                            {
                                return name.endsWith(".class");
                            }
                        });
                } else
                {
                    classFiles = directory.listFiles(new FilenameFilter()
                        {
                            public boolean accept(File dir, String name)
                            {
                                return name.startsWith(prefix) && name.endsWith(".class");
                            }
                        });

                }
                for (File f : classFiles)
                {
                    String p = f.getPath();
                    p =
                            p.substring(rootDir.length() + 1, p.length() - ".class".length())
                                    .replaceAll(File.separator, ".");
                    while (p.length() > 0 && p.charAt(0) == '.')
                    {
                        p = p.substring(1);
                    }
                    clazzNamesFound.add(p);
                }
                File[] subDirectories = directory.listFiles(new FileFilter()
                    {
                        public boolean accept(File pathname)
                        {
                            return pathname.isDirectory();
                        }
                    });
                scanDirectoriesForClassFiles(clazzNamesFound, rootDir, prefix, subDirectories);
            }
        }
    }

    private void addInnerClasses(String rootDir, String path)
    {
        final File f = new File(rootDir, path);
        scanDirectoriesForClassFiles(clazzNames, rootDir, f.getName() + "$", f.getParentFile());
    }

    /** Checks all classes that have been added for restriction violations. */
    public void check(boolean missingClassCountsAsError)
    {
        for (String clazzName : clazzNames)
        {
            if (missingClasses.contains(clazzName))
            {
                continue;
            }
            try
            {
                clazz = Repository.lookupClass(clazzName);
                if (clazzName.equals(clazz.getClassName()) == false) // wrong root directory
                {
                    System.err.println("Wrong root directory for class '" + clazz.getClassName()
                            + "' (expected class name '" + clazzName + "')");
                    errors = true;
                    continue;
                }
                errors |= restrictions.scanClass(clazz, clazzNames);
                pool = clazz.getConstantPool();
                CheckingVisitor rvisitor = new CheckingVisitor(missingClassCountsAsError);
                DescendingVisitor visitor = new DescendingVisitor(clazz, rvisitor);
                rvisitor.setDescendingVisitor(visitor);
                visitor.visit();
            } catch (Exception ex)
            {
                System.err.println("An exception occurred");
                ex.printStackTrace();
                errors = true;
            }
        }
    }

    /** Adds all classes in <var>package</var>. */
    public void addClassesFromPackageRecursively(String rootDir, String pkgName)
    {
        addClassesFromDirectoryOrJarfileRecursively(rootDir, pkgName.replaceAll("\\.",
                File.separator));
    }

    /** Adds all classes in <var>directory</var>. */
    public void addClassesFromDirectoryOrJarfileRecursively(String rootDirOrJarfile,
            String pkgNameAsDirectory)
    {
        final int sizeBefore = clazzNames.size();
        final File root = new File(rootDirOrJarfile);
        if (rootDirOrJarfile.endsWith(".jar"))
        {
            if (root.isFile() == false)
            {
                throw new IllegalArgumentException("Jarfile " + root.getAbsolutePath());
            }
            scanJarFile(clazzNames, rootDirOrJarfile, pkgNameAsDirectory);
        } else
        {
            final File directoryToScan = new File(rootDirOrJarfile, pkgNameAsDirectory);
            if (directoryToScan.isDirectory() == false)
            {
                throw new IllegalArgumentException("Directory " + directoryToScan.getAbsolutePath());
            }
            scanDirectoriesForClassFiles(clazzNames, rootDirOrJarfile, null, directoryToScan);
        }
        if (clazzNames.size() == sizeBefore)
        {
            System.err.println("No classes found in directory / jarfile '" + rootDirOrJarfile
                    + "', package directory '" + pkgNameAsDirectory + "'.");
        }
    }

    /** Adds all classes in <var>classes[start:end]</var>. */
    public void addClasses(String rootDir, String[] classes, int start, int end)
    {
        for (int i = start; i < end; ++i)
        {
            final String clazzName = classes[i];
            String path;
            if (clazzName.endsWith(".class"))
            {
                path = clazzName.substring(0, clazzName.length() - ".class".length());
                clazzNames.add(path.replaceAll(File.separator, "."));
            } else
            {
                path = clazzName.replaceAll("\\.", File.separator);
                clazzNames.add(clazzName);
            }
            addInnerClasses(rootDir, path);
        }
    }

    public boolean hasErrors(boolean missingClassesCountAsError)
    {
        return errors || (missingClassesCountAsError && getNumberOfMissingClasses() > 0);
    }

    public Set<String> getMissingClasses()
    {
        return missingClasses;
    }

    public int getNumberOfMissingClasses()
    {
        return missingClasses.size();
    }

    public int getNumberOfClasses()
    {
        return clazzNames.size();
    }

    private static String repairClasspath(String classpath)
    {
        final String wrongPathSet = pathSep.equals(":") ? ";" : ":";
        return classpath.replaceAll(Pattern.quote(wrongPathSet), pathSep);
    }

    private static void setClasspath(StringBuilder classpath, final Set<File> rootDirs,
            boolean verbose)
    {
        for (File dir : rootDirs)
        {
            if (classpath.length() > 0)
            {
                classpath.append(pathSep);
            }
            classpath.append(dir.getAbsoluteFile());
        }
        final String classpathString = repairClasspath(classpath.toString());
        if (verbose)
        {
            System.err.println("Classpath: " + classpathString);
        }
        Repository.setRepository(SyntheticRepository.getInstance(new ClassPath(classpathString)));
    }

    private static void findJarFiles(StringBuilder builder, File dir)
    {
        for (File f : dir.listFiles())
        {
            if (f.isDirectory())
            {
                findJarFiles(builder, f);
            } else
            {
                final String path = getCanonicalPath(f);
                if (path.endsWith(".jar"))
                {
                    if (builder.length() > 0)
                    {
                        builder.append(pathSep);
                    }
                    builder.append(path);
                }
            }
        }
    }

    private static String getCanonicalPath(File dir)
    {
        String path;
        try
        {
            path = dir.getCanonicalPath();
        } catch (IOException ex)
        {
            path = dir.getAbsolutePath();
        }
        return path;
    }

    private static void printHelp()
    {
        System.err.println("RestrictionChecker: No classes to check.");
        System.err.println("Options: -v: verbose mode");
        System.err.println("Options: -m: ignore missing classes");
        System.err.println("Options: -cp <classpath> [...]: add all <classpath> entries to the "
                + "path to search for .class files");
        System.err.println("Options: -jd <jardir> [...]: add all jar files below all <jardir> "
                + "directories to the path to search for .class files");
        System.err.println("Options: -c <classfile> [...]: adds class files to be checked");
        System.err.println("Options: -r <rootdir> [...]: specifies root directories, either for "
                + "following -p option or for all pacakges (default: .)");
        System.err.println("Options: -p <package> [...]: adds all class files in <package> "
                + "(accepts Java package or directory notation), relative to <rootdir> "
                + "recursively");
        System.err.println();
        System.err.println("Examples:");
        System.err.println("jrc -- checks all java classes below the current working directory");
        System.err.println("jrc -r bin -- checks all java classes below the bin/ subdirectory");
        System.err.println("jrc -r bin -p org.apache -- checks all java classes in package "
                + "org.apache below the bin/ subdirectory");
        System.err.println("jrc -r bin -cp utils.jar -- checks all java classes below the "
                + "bin/ subdirectory and uses utils.jar as additional classpath entry "
                + "for searching referenced classes.");
        System.err.println("jrc -r bin -jd libs -- checks all java classes below the "
                + "bin/ subdirectory and uses all jar files in the lib/ subdirectory as "
                + "additional classpath entry for searching referenced classes.");
        System.err.println("jrc -r classes1 utils.jar -- checks all java classes below the "
                + "subdirectories classes1/ and in the jar file utils.jar.");
    }

    /**
     * Some examples for calling this program (program itself abbreviated as jrc):
     * 
     * <pre>
     * jrc -- checks all classes below the current directory
     * jrc -r bin -- checks all classes below the bin/ sub-directory
     * jrc -p org.apache -- checks all classes in the package org.apache  
     * </pre>
     */
    public static void main(String[] args)
    {
        final RestrictionChecker checker = new RestrictionChecker();
        boolean verbose = false;
        boolean missingClassesCountAsError = true;
        long start = System.currentTimeMillis();
        int idx = 0;
        final StringBuilder classpath = new StringBuilder(256);
        final Set<File> rootDirs = new HashSet<File>();
        File rootDir = new File(".");
        boolean defaultRootDir = true;
        boolean classesExplicitelyGiven = false;
        while (idx < args.length)
        {
            if (args[idx].equals("-h"))
            {
                printHelp();
                System.exit(1);
            }
            if (args[idx].equals("-c"))
            {
                int eidx = ++idx;
                while (eidx < args.length && args[eidx].startsWith("-") == false)
                {
                    ++eidx;
                }
                checker.addClasses(getCanonicalPath(rootDir), args, idx, eidx);
                rootDirs.add(rootDir);
                classesExplicitelyGiven = true;
                idx = --eidx;
            } else if (args[idx].equals("-p"))
            {
                int eidx = ++idx;
                while (eidx < args.length && args[eidx].startsWith("-") == false)
                {
                    checker.addClassesFromPackageRecursively(getCanonicalPath(rootDir), args[eidx]);
                    ++eidx;
                }
                rootDirs.add(rootDir);
                classesExplicitelyGiven = true;
                idx = --eidx;
            } else if (args[idx].equals("-r"))
            {
                int eidx = ++idx;
                while (eidx < args.length && args[eidx].startsWith("-") == false)
                {
                    if (classesExplicitelyGiven == false && defaultRootDir == false)
                    {
                        rootDirs.add(rootDir);
                        checker.addClassesFromDirectoryOrJarfileRecursively(
                                getCanonicalPath(rootDir), ".");
                    }
                    rootDir = new File(args[eidx]);
                    ++eidx;
                    classesExplicitelyGiven = false;
                    defaultRootDir = false;
                }
                rootDirs.add(rootDir);
                idx = --eidx;
            } else if (args[idx].equals("-v"))
            {
                verbose = true;
            } else if (args[idx].equals("-m"))
            {
                missingClassesCountAsError = false;
            } else if (args[idx].equals("-cp"))
            {
                int eidx = ++idx;
                while (eidx < args.length && args[eidx].startsWith("-") == false)
                {
                    if (classpath.length() > 0)
                    {
                        classpath.append(pathSep);
                    }
                    classpath.append(args[eidx]);
                    ++eidx;
                }
                idx = --eidx;
            } else if (args[idx].equals("-jd"))
            {
                int eidx = ++idx;
                while (eidx < args.length && args[eidx].startsWith("-") == false)
                {
                    final File cpAsFile = new File(args[eidx]);
                    if (cpAsFile.isDirectory())
                    {
                        findJarFiles(classpath, cpAsFile);
                    } else
                    {
                        throw new IllegalArgumentException("Directory " + cpAsFile.getPath());
                    }
                    ++eidx;
                }
                idx = --eidx;
            } else
            {
                System.err.println("Wrong argument " + args[idx]);
                break;
            }
            ++idx;
        }
        if (classesExplicitelyGiven == false)
        {
            rootDirs.add(rootDir);
            checker.addClassesFromDirectoryOrJarfileRecursively(getCanonicalPath(rootDir), ".");
        }
        if (checker.getNumberOfClasses() == 0)
        {
            printHelp();
            System.exit(1);
        }
        if (verbose)
        {
            System.err.println("Scanning " + checker.getNumberOfClasses() + " classes");
        }
        rootDirs.add(rootDir);
        setClasspath(classpath, rootDirs, verbose);
        checker.check(missingClassesCountAsError);
        final boolean errors = checker.hasErrors(missingClassesCountAsError);
        if (missingClassesCountAsError && checker.getNumberOfMissingClasses() > 0)
        {
            System.err.println(checker.getNumberOfMissingClasses() + " classes where missing");
            if (verbose)
            {
                for (String clazzName : checker.getMissingClasses())
                {
                    System.err.println("  " + clazzName);
                }
            }
        }
        if (verbose)
        {
            System.err.println(errors ? "Check failed" : "Check OK");
            System.err.println((System.currentTimeMillis() - start) / 1000.0 + " s");
        }
        System.exit(errors ? 1 : 0);
    }

}
