Java 获取包中所有的类

刚接触 Java,没想到连这种基础功能都需要自己实现,而且还如此复杂。虽然最后也完成了实现代码,不过最后还是使用了 org.reflections 第三方包。

记录实现代码,万一以后用得到:

javaimport java.io.File;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.net.JarURLConnection;
import java.net.URL;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;


public class ClassAssembly {

  /**
   *
   * @param pkg
   * @param cls
   * @return
   * @throws IOException
   * @throws ClassNotFoundException
   */
  public static Set<Class<?>> getSubclasses(Package pkg, Class<?> cls) throws IOException, ClassNotFoundException {
    Set<Class<?>> classes = new HashSet<>();

    for (Class<?> clazz: getClasses(pkg)) {
      if (cls.isAssignableFrom(clazz) && !cls.equals(clazz)) {
        classes.add(clazz);
      }
    }

    return classes;
  }

  /**
   *
   * @param pkg
   * @param annotation
   * @param <A>
   * @return
   * @throws IOException
   * @throws ClassNotFoundException
   */
  public static <A extends Annotation> Set<Class<?>> getAnnotatedClasses(Package pkg, Class<A> annotation)
    throws IOException, ClassNotFoundException {
    Set<Class<?>> classes = new HashSet<>();

    for (Class<?> clazz: getClasses(pkg)) {
      if (null != clazz.getAnnotation(annotation)) {
        classes.add(clazz);
      }
    }

    return classes;
  }

  /**
   *
   * @param pkg
   * @return
   * @throws IOException
   */
  public static Set<Class<?>> getClasses(Package pkg) throws IOException, ClassNotFoundException {
    return getClasses(pkg.getName());
  }

  /**
   *
   * @param packageName
   * @return
   * @throws IOException
   */
  public static Set<Class<?>> getClasses(String packageName) throws IOException, ClassNotFoundException {
    Set<Class<?>> classes = new HashSet<>();
    String packagePath = packageName.replace('.', '/');
    Enumeration<URL> dirs = Thread.currentThread().getContextClassLoader().getResources(packagePath);

    while (dirs.hasMoreElements()) {
      URL url = dirs.nextElement();
      String protocol = url.getProtocol();

      if (protocol.equals("file")) {
        classes.addAll(getClassesFromDir(new File(url.getPath()), packageName));
      } else if (protocol.equals("jar")) {
        JarFile jarFile = ((JarURLConnection)url.openConnection()).getJarFile();
        classes.addAll(getClassesFromJar(jarFile.entries(), packageName));
      }
    }

    return classes;
  }

  private static List<Class<?>> getClassesFromDir(File dir, String packageName) throws ClassNotFoundException {
    List<Class<?>> classes = new ArrayList<>();
    for (File file : dir.listFiles()) {
      String fileName = file.getName();
      if (file.isDirectory()) {
        classes.addAll(getClassesFromDir(file, packageName + "." + fileName));
      } else {
        if (fileName.endsWith(".class") && !fileName.contains("$")) {
          classes.add(Class.forName(packageName + "." + fileName.substring(0, fileName.length() - 6)));
        }
      }
    }
    return classes;
  }

  private static List<Class<?>> getClassesFromJar(Enumeration<JarEntry> jarEntries, String packageName)
    throws ClassNotFoundException {
    List<Class<?>> classes = new ArrayList<>();

    while (jarEntries.hasMoreElements()) {
      JarEntry jarEntry = jarEntries.nextElement();
      if (jarEntry.isDirectory()) {
        continue;
      }
      String jarEntryName = jarEntry.getName();
      if (jarEntryName.endsWith(".class")) {
        String jarEntryPath = jarEntryName.replace('/', '.');
        if (!jarEntryPath.contains("$") && jarEntryPath.startsWith(packageName)) {
          classes.add(Class.forName(jarEntryPath.substring(0, jarEntryPath.length() - 6)));
        }
      }
    }

    return classes;
  }
}