java中用djl加载和运行预训练模型只需三步:添加依赖(如djl-api、pytorch-engine等)、选择模型(url/本地路径/模型id)、构建predictor执行推理;djl自动适配pytorch等引擎,无需编写底层计算逻辑。

Java中用DJL加载和运行预训练模型,核心是三步:添加依赖、选择模型(本地或远程)、构建Predictor执行推理。不需要写底层计算逻辑,DJL自动处理引擎适配(如PyTorch、TensorFlow、ONNX Runtime)。
1. 添加DJL依赖(Maven)
DJL支持多引擎,推荐从PyTorch开始(生态成熟、模型丰富)。在pom.xml中引入:
-
核心API:
djl-api -
PyTorch引擎:
model-zoo+pytorch-engine -
预编译本地库(免编译):
pytorch-native-auto(自动匹配系统架构)
示例依赖片段:
<dependency> <groupId>ai.djl</groupId> <artifactId>api</artifactId> <version>0.27.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.27.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-auto</artifactId> <version>2.1.2</version> </dependency>
2. 加载预训练模型(支持URL/本地路径/模型ID)
DJL内置ModelZoo,可直接用HuggingFace ID或DJL Model Zoo地址加载。例如加载bert-base-uncased文本分类模型:
立即学习“Java免费学习笔记(深入)”;
- 用
Criteria声明输入输出类型、模型来源、设备(CPU/GPU) - 调用
ModelLoader.loadModel()获得Model实例 - 注意:首次加载会自动下载模型权重到本地缓存(
~/.djl.ai/cache)
代码示例:
Criteria<String, Classifications> criteria = Criteria.builder()
.setTypes(String.class, Classifications.class)
.optModelUrls("https://resources.djl.ai/test-models/pytorch/transformers/bert-base-uncased.zip")
.optEngine("PyTorch")
.optTranslator(new BertTranslator())
.build();
Model model = Model.newInstance("bert");
model = ModelLoader.loadModel(criteria);3. 构建Predictor并运行推理
Predictor是执行推理的入口,封装了预处理、前向传播、后处理。创建后调用predict()即可:
-
Translator负责输入转NDArray、输出转业务对象(如Classifications) - 支持批量输入(List),也支持单条字符串
- 用完记得
close()释放资源(推荐try-with-resources)
完整推理示例:
try (Predictor<String, Classifications> predictor = model.newPredictor(new BertTranslator())) {
Classifications result = predictor.predict("I love DJL!");
System.out.println(result);
// 输出类似:positive: 0.982, negative: 0.018
}4. 常见问题与建议
实际使用时容易卡在环境或格式上,注意以下几点:
- GPU支持需安装CUDA驱动+cuDNN,并用
pytorch-native-cu118等对应版本依赖 - 模型输入必须匹配
Translator定义(如Bert要tokenize,CNN图像要resize+normalize) - 自定义模型:把
model.pt和synset.txt等放在同目录,用optModelPath(Paths.get("models/my-model")) - 性能优化:启用
setLimit(1)限制线程数,或用Model.setBlock()手动指定计算图










