基于PyTorch的图像分类系统,支持多种模型切换、实时推理和结果可视化。
项目简介
深度学习图像识别系统是一个基于PyTorch开发的Web应用,支持ResNet、VGG、EfficientNet等多种模型,提供在线推理和模型管理功能。
系统架构
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| ┌─────────────────────────────────────────────────┐ │ Web Frontend │ │ (Vue.js + Element UI) │ └────────────────────┬────────────────────────────┘ │ REST API ┌────────────────────▼────────────────────────────┐ │ Flask Backend │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────┐ │ │ │ Inference │ │ Model │ │ Auth │ │ │ │ Engine │ │ Manager │ │ Service │ │ │ └─────────────┘ └─────────────┘ └─────────┘ │ └────────────────────┬────────────────────────────┘ │ ┌────────────────────▼────────────────────────────┐ │ PyTorch Models │ │ ResNet50 / VGG16 / EfficientNet-B0 │ └─────────────────────────────────────────────────┘
|
功能特性
- 模型在线切换
- 实时推理
- 结果可视化
- API接口
- 批量预测
- 模型对比
核心代码
1. 模型定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| import torch import torch.nn as nn from torchvision import models
class ImageClassifier(nn.Module): def __init__(self, model_name='resnet50', num_classes=1000, pretrained=True): super().__init__() if model_name == 'resnet50': self.backbone = models.resnet50(pretrained=pretrained) self.backbone.fc = nn.Linear(2048, num_classes) elif model_name == 'vgg16': self.backbone = models.vgg16(pretrained=pretrained) self.backbone.classifier[-1] = nn.Linear(4096, num_classes) elif model_name == 'efficientnet_b0': self.backbone = models.efficientnet_b0(pretrained=pretrained) self.backbone.classifier[-1] = nn.Linear(1280, num_classes) def forward(self, x): return self.backbone(x)
|
2. 数据预处理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| from torchvision import transforms
def get_transforms(): return transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ])
def preprocess_image(image_path, transform): from PIL import Image image = Image.open(image_path).convert('RGB') return transform(image).unsqueeze(0)
|
3. 推理引擎
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| import torch.nn.functional as F
class InferenceEngine: def __init__(self, model_path, device='cuda'): self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.model = torch.jit.load(model_path) self.model.to(self.device) self.model.eval() @torch.no_grad() def predict(self, image_tensor, top_k=5): image_tensor = image_tensor.to(self.device) outputs = self.model(image_tensor) probs = F.softmax(outputs, dim=1) top_probs, top_indices = torch.topk(probs, top_k, dim=1) return top_probs.cpu().numpy(), top_indices.cpu().numpy()
|
4. Flask API
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| from flask import Flask, request, jsonify from werkzeug.utils import secure_filename
app = Flask(__name__) engine = InferenceEngine('models/resnet50.pth')
@app.route('/predict', methods=['POST']) def predict(): if 'image' not in request.files: return jsonify({'error': 'No image provided'}), 400 file = request.files['image'] filename = secure_filename(file.filename) file.save(f'/tmp/{filename}') image = preprocess_image(f'/tmp/{filename}', get_transforms()) probs, indices = engine.predict(image) return jsonify({ 'predictions': [ {'class': idx, 'probability': float(prob)} for prob, idx in zip(probs[0], indices[0]) ] })
@app.route('/switch_model', methods=['POST']) def switch_model(): data = request.json model_name = data.get('model_name') engine.load_model(f'models/{model_name}.pth') return jsonify({'status': 'success', 'model': model_name})
|
5. 前端展示
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| export default { data() { return { image: null, predictions: [], loading: false, currentModel: 'resnet50' } }, methods: { async uploadImage(event) { const file = event.target.files[0]; this.image = URL.createObjectURL(file); const formData = new FormData(); formData.append('image', file); this.loading = true; const response = await fetch('/predict', { method: 'POST', body: formData }); this.predictions = (await response.json()).predictions; this.loading = false; }, async switchModel(model) { await fetch('/switch_model', { method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({model_name: model}) }); this.currentModel = model; } } }
|
部署配置
Docker
1 2 3 4 5 6 7 8 9 10 11 12
| FROM python:3.9-slim
WORKDIR /app
COPY requirements.txt . RUN pip install -r requirements.txt
COPY . .
EXPOSE 5000
CMD ["python", "app.py"]
|
requirements.txt
1 2 3 4 5
| torch==2.0.0 torchvision==0.15.0 flask==2.3.0 Pillow==9.5.0 numpy==1.24.0
|
项目成果
- 支持3种主流模型切换
- 推理速度 < 100ms (GPU)
- 准确率: ImageNet Top-1 76.5%
总结
深度学习图像识别系统展示了PyTorch在实际应用中的潜力,通过模块化设计实现了灵活的模型管理和高效推理能力。