SpringBoot + 批处理分片 + 分布式协调:千万级数据分片并行处理,避免单点瓶颈

引言

最近在处理一个用户数据迁移项目时,遇到了一个棘手的问题:需要将千万级的用户数据从旧系统迁移到新系统。如果用传统的单线程批处理方式,预计需要几天时间才能完成,而且一旦某个环节出错,整个迁移过程就得重来。

有没有一种方式能让批处理任务像分布式系统一样,把大任务拆分成小任务并行处理呢?答案是肯定的,今天就来聊聊SpringBoot如何通过批处理分片和分布式协调来解决千万级数据处理的性能瓶颈问题。

为什么需要批处理分片?

传统批处理的痛点

让我们先看看传统批处理方式的局限性:

单点处理瓶颈

// 传统的单线程批处理
@Service
public class UserDataMigrationService {
    
    public void migrateAllUsers() {
        List<User> allUsers = userMapper.selectAllUsers(); // 1000万条数据
        
        for (User user : allUsers) {
            // 逐条处理,耗时巨大
            processUser(user);
        }
    }
}

这种处理方式存在明显问题:

  1. 处理时间长 - 单线程处理千万级数据需要数天时间
  2. 资源利用率低 - CPU和内存资源无法充分利用
  3. 容错性差 - 一旦出错需要重新开始整个过程
  4. 扩展性差 - 无法利用多台机器的计算能力

分片处理的价值

分片处理能带来这些好处:

性能提升

  • 将大任务拆分成多个小任务并行处理
  • 充分利用多核CPU和多台机器的计算能力
  • 处理时间从天级缩短到小时级

可靠性增强

  • 单个分片失败不影响其他分片
  • 支持断点续传和重试机制
  • 每个分片都有独立的进度跟踪

扩展性强

  • 可以动态增加处理节点
  • 支持水平扩展
  • 负载均衡自动分配

核心架构设计

我们的批处理分片架构:

┌─────────────────┐    ┌──────────────────┐    ┌─────────────────┐
│   调度中心      │───▶│   分片协调器     │───▶│   执行节点      │
│  (Master)       │    │ (PartitionHandler)│    │  (Worker)       │
└─────────────────┘    └──────────────────┘    └─────────────────┘
        │                        │                       │
        │ 分配分片任务           │                       │
        │───────────────────────▶│                       │
        │                        │ 分发具体分片          │
        │                        │──────────────────────▶│
        │                        │                       │
        │                        │                       │
        │                        │  返回处理结果         │
        │                        │◀──────────────────────│
        │                        │                       │
        │ 汇总所有结果           │                       │
        │◀───────────────────────│                       │
        │                        │                       │

核心设计要点

1. 分片策略设计

// 分片策略接口
public interface PartitionStrategy<T> {
    List<Partition<T>> createPartitions(T dataSource, int partitionCount);
    Partition<T> getPartition(T dataSource, int partitionIndex);
}

// 基于ID范围的分片策略
@Component
public class IdRangePartitionStrategy implements PartitionStrategy<User> {
    
    @Override
    public List<Partition<User>> createPartitions(User dataSource, int partitionCount) {
        List<Partition<User>> partitions = new ArrayList<>();
        
        // 获取数据总量
        long totalCount = userMapper.countAllUsers();
        long partitionSize = totalCount / partitionCount;
        
        // 创建分片
        for (int i = 0; i < partitionCount; i++) {
            long startId = i * partitionSize + 1;
            long endId = (i == partitionCount - 1) ? totalCount : (i + 1) * partitionSize;
            
            Partition<User> partition = Partition.<User>builder()
                .partitionId(i)
                .startId(startId)
                .endId(endId)
                .totalCount(endId - startId + 1)
                .build();
                
            partitions.add(partition);
        }
        
        return partitions;
    }
    
    @Override
    public Partition<User> getPartition(User dataSource, int partitionIndex) {
        // 根据索引获取特定分片
        List<Partition<User>> partitions = createPartitions(dataSource, getPartitionCount());
        return partitions.get(partitionIndex);
    }
}

// 基于哈希的分片策略
@Component
public class HashPartitionStrategy implements PartitionStrategy<User> {
    
    @Override
    public List<Partition<User>> createPartitions(User dataSource, int partitionCount) {
        List<Partition<User>> partitions = new ArrayList<>();
        
        for (int i = 0; i < partitionCount; i++) {
            Partition<User> partition = Partition.<User>builder()
                .partitionId(i)
                .hashMod(partitionCount)
                .hashValue(i)
                .build();
                
            partitions.add(partition);
        }
        
        return partitions;
    }
    
    @Override
    public Partition<User> getPartition(User dataSource, int partitionIndex) {
        User user = (User) dataSource;
        int hashValue = user.getId().hashCode() % getPartitionCount();
        return createPartitions(dataSource, getPartitionCount()).get(hashValue);
    }
}

2. 分布式协调器

// 分布式协调器
@Component
public class DistributedPartitionCoordinator {
    
    private final RedisTemplate<String, Object> redisTemplate;
    private final PartitionStrategy<User> partitionStrategy;
    private final TaskRegistry taskRegistry;
    
    // 注册批处理任务
    public String registerBatchTask(BatchTaskConfig config) {
        String taskId = UUID.randomUUID().toString();
        
        // 创建分片
        List<Partition<User>> partitions = partitionStrategy.createPartitions(
            config.getDataSource(), config.getPartitionCount());
        
        // 存储任务信息到Redis
        BatchTask task = BatchTask.builder()
            .taskId(taskId)
            .taskName(config.getTaskName())
            .totalPartitions(partitions.size())
            .createdTime(System.currentTimeMillis())
            .status(TaskStatus.PENDING)
            .build();
            
        redisTemplate.opsForValue().set("batch:task:" + taskId, task);
        
        // 注册分片信息
        for (int i = 0; i < partitions.size(); i++) {
            Partition<User> partition = partitions.get(i);
            PartitionTask partitionTask = PartitionTask.builder()
                .taskId(taskId)
                .partitionId(i)
                .partitionData(partition)
                .status(PartitionStatus.PENDING)
                .build();
                
            redisTemplate.opsForValue().set(
                "batch:partition:" + taskId + ":" + i, partitionTask);
                
            // 添加到待处理队列
            redisTemplate.opsForList().leftPush(
                "batch:queue:" + taskId, i);
        }
        
        return taskId;
    }
    
    // 获取待处理的分片
    public PartitionTask claimPartition(String workerId) {
        // 从队列中获取待处理的分片
        Object partitionIndex = redisTemplate.opsForList().rightPop("batch:queue:*");
        if (partitionIndex == null) {
            return null;
        }
        
        // 解析taskId和partitionId
        String[] parts = partitionIndex.toString().split(":");
        String taskId = parts[2];
        int partitionId = Integer.parseInt(parts[3]);
        
        // 获取分片任务
        String partitionKey = "batch:partition:" + taskId + ":" + partitionId;
        PartitionTask partitionTask = (PartitionTask) redisTemplate.opsForValue().get(partitionKey);
        
        if (partitionTask != null && partitionTask.getStatus() == PartitionStatus.PENDING) {
            // 更新分片状态为处理中
            partitionTask.setStatus(PartitionStatus.PROCESSING);
            partitionTask.setWorkerId(workerId);
            partitionTask.setStartTime(System.currentTimeMillis());
            redisTemplate.opsForValue().set(partitionKey, partitionTask);
            
            return partitionTask;
        }
        
        return null;
    }
    
    // 完成分片处理
    public void completePartition(String taskId, int partitionId, PartitionResult result) {
        String partitionKey = "batch:partition:" + taskId + ":" + partitionId;
        PartitionTask partitionTask = (PartitionTask) redisTemplate.opsForValue().get(partitionKey);
        
        if (partitionTask != null) {
            // 更新分片状态
            partitionTask.setStatus(PartitionStatus.COMPLETED);
            partitionTask.setResult(result);
            partitionTask.setEndTime(System.currentTimeMillis());
            redisTemplate.opsForValue().set(partitionKey, partitionTask);
            
            // 检查整个任务是否完成
            checkTaskCompletion(taskId);
        }
    }
    
    // 检查任务完成状态
    private void checkTaskCompletion(String taskId) {
        String taskKey = "batch:task:" + taskId;
        BatchTask task = (BatchTask) redisTemplate.opsForValue().get(taskKey);
        
        if (task != null) {
            long completedCount = redisTemplate.opsForValue().getOperations()
                .keys("batch:partition:" + taskId + ":*")
                .stream()
                .map(key -> (PartitionTask) redisTemplate.opsForValue().get(key))
                .filter(pt -> pt.getStatus() == PartitionStatus.COMPLETED)
                .count();
                
            if (completedCount == task.getTotalPartitions()) {
                task.setStatus(TaskStatus.COMPLETED);
                task.setEndTime(System.currentTimeMillis());
                redisTemplate.opsForValue().set(taskKey, task);
                
                // 触发任务完成回调
                taskRegistry.invokeCompletionCallback(taskId);
            }
        }
    }
}

3. Spring Batch分片配置

// Spring Batch分片配置
@Configuration
@EnableBatchProcessing
public class BatchPartitionConfig {
    
    @Autowired
    private JobBuilderFactory jobBuilderFactory;
    
    @Autowired
    private StepBuilderFactory stepBuilderFactory;
    
    @Autowired
    private PartitionStrategy<User> partitionStrategy;
    
    // 主步骤配置
    @Bean
    public Step masterStep() {
        return stepBuilderFactory.get("masterStep")
            .partitioner("slaveStep", partitioner())
            .partitionHandler(partitionHandler())
            .build();
    }
    
    // 分片器配置
    @Bean
    public Partitioner partitioner() {
        return new Partitioner() {
            @Override
            public Map<String, ExecutionContext> partition(int gridSize) {
                Map<String, ExecutionContext> partitions = new HashMap<>();
                
                // 创建分片
                List<Partition<User>> partitionList = partitionStrategy.createPartitions(
                    new User(), gridSize);
                
                for (int i = 0; i < partitionList.size(); i++) {
                    ExecutionContext context = new ExecutionContext();
                    Partition<User> partition = partitionList.get(i);
                    
                    context.putLong("startId", partition.getStartId());
                    context.putLong("endId", partition.getEndId());
                    context.putInt("partitionId", i);
                    context.putString("partitionName", "partition-" + i);
                    
                    partitions.put("partition-" + i, context);
                }
                
                return partitions;
            }
        };
    }
    
    // 分片处理器配置
    @Bean
    public PartitionHandler partitionHandler() {
        TaskExecutorPartitionHandler handler = new TaskExecutorPartitionHandler();
        handler.setStep(slaveStep());
        handler.setTaskExecutor(taskExecutor());
        handler.setGridSize(10); // 10个并行分片
        return handler;
    }
    
    // 从步骤配置
    @Bean
    public Step slaveStep() {
        return stepBuilderFactory.get("slaveStep")
            .<User, User>chunk(1000)
            .reader(userItemReader(null))
            .processor(userItemProcessor())
            .writer(userItemWriter())
            .build();
    }
    
    // 用户数据读取器
    @Bean
    @StepScope
    public ItemReader<User> userItemReader(@Value("#{stepExecutionContext[startId]}") Long startId,
                                         @Value("#{stepExecutionContext[endId]}") Long endId) {
        return new JdbcPagingItemReaderBuilder<User>()
            .name("userItemReader")
            .dataSource(dataSource)
            .selectClause("SELECT *")
            .fromClause("FROM users")
            .whereClause("WHERE id >= " + startId + " AND id <= " + endId)
            .sortKeys(Collections.singletonMap("id", Order.ASCENDING))
            .rowMapper(new UserRowMapper())
            .pageSize(1000)
            .build();
    }
    
    // 用户数据处理器
    @Bean
    public ItemProcessor<User, User> userItemProcessor() {
        return user -> {
            // 处理用户数据
            user.setProcessedTime(new Date());
            return user;
        };
    }
    
    // 用户数据写入器
    @Bean
    public ItemWriter<User> userItemWriter() {
        return users -> {
            // 批量写入处理后的用户数据
            userMapper.batchInsert(users);
        };
    }
    
    // 线程池配置
    @Bean
    public TaskExecutor taskExecutor() {
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        executor.setCorePoolSize(10);
        executor.setMaxPoolSize(20);
        executor.setQueueCapacity(100);
        executor.setThreadNamePrefix("batch-worker-");
        executor.initialize();
        return executor;
    }
    
    // 作业配置
    @Bean
    public Job userMigrationJob() {
        return jobBuilderFactory.get("userMigrationJob")
            .start(masterStep())
            .build();
    }
}

关键实现细节

1. 分片数据读取优化

// 优化的分片数据读取器
@Component
public class OptimizedPartitionItemReader implements ItemReader<User> {
    
    private final JdbcTemplate jdbcTemplate;
    private final PartitionContext partitionContext;
    
    private List<User> currentBatch;
    private int currentIndex = 0;
    
    public OptimizedPartitionItemReader(JdbcTemplate jdbcTemplate, 
                                      PartitionContext partitionContext) {
        this.jdbcTemplate = jdbcTemplate;
        this.partitionContext = partitionContext;
    }
    
    @Override
    public User read() throws Exception {
        // 如果当前批次数据已读完,加载下一批次
        if (currentBatch == null || currentIndex >= currentBatch.size()) {
            loadNextBatch();
            currentIndex = 0;
        }
        
        // 返回当前记录
        if (currentBatch != null && currentIndex < currentBatch.size()) {
            return currentBatch.get(currentIndex++);
        }
        
        return null;
    }
    
    private void loadNextBatch() {
        String sql = "SELECT * FROM users WHERE id >= ? AND id <= ? LIMIT ? OFFSET ?";
        
        currentBatch = jdbcTemplate.query(sql, 
            new Object[]{
                partitionContext.getStartId(),
                partitionContext.getEndId(),
                partitionContext.getBatchSize(),
                partitionContext.getCurrentOffset()
            },
            new UserRowMapper());
            
        // 更新偏移量
        partitionContext.incrementOffset(partitionContext.getBatchSize());
    }
}

// 分片上下文管理
@Data
@Component
public class PartitionContext {
    private long startId;
    private long endId;
    private int batchSize = 1000;
    private int currentOffset = 0;
    private int partitionId;
    
    public void incrementOffset(int increment) {
        this.currentOffset += increment;
    }
    
    public boolean isWithinRange(long id) {
        return id >= startId && id <= endId;
    }
}

2. 错误处理和重试机制

// 分片错误处理器
@Component
public class PartitionErrorHandler {
    
    private final RedisTemplate<String, Object> redisTemplate;
    private final AlertService alertService;
    
    // 处理分片执行异常
    public void handlePartitionError(String taskId, int partitionId, Exception exception) {
        String partitionKey = "batch:partition:" + taskId + ":" + partitionId;
        PartitionTask partitionTask = (PartitionTask) redisTemplate.opsForValue().get(partitionKey);
        
        if (partitionTask != null) {
            // 记录错误信息
            partitionTask.setStatus(PartitionStatus.FAILED);
            partitionTask.setErrorMessage(exception.getMessage());
            partitionTask.setErrorTime(System.currentTimeMillis());
            redisTemplate.opsForValue().set(partitionKey, partitionTask);
            
            // 判断是否需要重试
            if (shouldRetry(partitionTask)) {
                retryPartition(taskId, partitionId, partitionTask.getRetryCount() + 1);
            } else {
                // 触发告警
                alertService.sendAlert("批处理分片失败", 
                    String.format("任务 %s 的分片 %d 处理失败: %s", 
                        taskId, partitionId, exception.getMessage()));
                        
                // 标记整个任务为失败
                markTaskAsFailed(taskId);
            }
        }
    }
    
    private boolean shouldRetry(PartitionTask partitionTask) {
        return partitionTask.getRetryCount() < 3; // 最多重试3次
    }
    
    private void retryPartition(String taskId, int partitionId, int retryCount) {
        String partitionKey = "batch:partition:" + taskId + ":" + partitionId;
        PartitionTask partitionTask = (PartitionTask) redisTemplate.opsForValue().get(partitionKey);
        
        if (partitionTask != null) {
            // 更新重试次数
            partitionTask.setRetryCount(retryCount);
            partitionTask.setStatus(PartitionStatus.PENDING);
            partitionTask.setErrorMessage(null);
            redisTemplate.opsForValue().set(partitionKey, partitionTask);
            
            // 重新加入处理队列
            redisTemplate.opsForList().leftPush("batch:queue:" + taskId, partitionId);
        }
    }
    
    private void markTaskAsFailed(String taskId) {
        String taskKey = "batch:task:" + taskId;
        BatchTask task = (BatchTask) redisTemplate.opsForValue().get(taskKey);
        
        if (task != null) {
            task.setStatus(TaskStatus.FAILED);
            task.setEndTime(System.currentTimeMillis());
            redisTemplate.opsForValue().set(taskKey, task);
        }
    }
}

3. 进度监控和统计

// 批处理进度监控器
@Service
public class BatchProgressMonitor {
    
    private final RedisTemplate<String, Object> redisTemplate;
    private final MeterRegistry meterRegistry;
    
    // 监控指标注册
    @PostConstruct
    public void registerMetrics() {
        Gauge.builder("batch.task.count")
            .register(meterRegistry, this, BatchProgressMonitor::getTotalTaskCount);
            
        Gauge.builder("batch.completed.task.count")
            .register(meterRegistry, this, BatchProgressMonitor::getCompletedTaskCount);
            
        Gauge.builder("batch.failed.task.count")
            .register(meterRegistry, this, BatchProgressMonitor::getFailedTaskCount);
    }
    
    // 获取任务进度
    public TaskProgress getTaskProgress(String taskId) {
        String taskKey = "batch:task:" + taskId;
        BatchTask task = (BatchTask) redisTemplate.opsForValue().get(taskKey);
        
        if (task == null) {
            return null;
        }
        
        // 统计各状态分片数量
        List<PartitionTask> partitionTasks = getAllPartitionTasks(taskId);
        
        long completedCount = partitionTasks.stream()
            .filter(pt -> pt.getStatus() == PartitionStatus.COMPLETED)
            .count();
            
        long processingCount = partitionTasks.stream()
            .filter(pt -> pt.getStatus() == PartitionStatus.PROCESSING)
            .count();
            
        long pendingCount = partitionTasks.stream()
            .filter(pt -> pt.getStatus() == PartitionStatus.PENDING)
            .count();
            
        long failedCount = partitionTasks.stream()
            .filter(pt -> pt.getStatus() == PartitionStatus.FAILED)
            .count();
        
        // 计算进度百分比
        double progress = (double) completedCount / task.getTotalPartitions() * 100;
        
        return TaskProgress.builder()
            .taskId(taskId)
            .taskName(task.getTaskName())
            .totalPartitions(task.getTotalPartitions())
            .completedPartitions(completedCount)
            .processingPartitions(processingCount)
            .pendingPartitions(pendingCount)
            .failedPartitions(failedCount)
            .progressPercentage(progress)
            .status(task.getStatus())
            .startTime(task.getCreatedTime())
            .build();
    }
    
    // 获取所有分片任务
    private List<PartitionTask> getAllPartitionTasks(String taskId) {
        Set<String> keys = redisTemplate.opsForValue().getOperations()
            .keys("batch:partition:" + taskId + ":*");
            
        return keys.stream()
            .map(key -> (PartitionTask) redisTemplate.opsForValue().get(key))
            .collect(Collectors.toList());
    }
    
    // 实时进度推送
    @Scheduled(fixedRate = 5000) // 每5秒推送一次
    public void pushProgressUpdates() {
        Set<String> taskKeys = redisTemplate.opsForValue().getOperations()
            .keys("batch:task:*");
            
        taskKeys.forEach(taskKey -> {
            String taskId = taskKey.substring("batch:task:".length());
            TaskProgress progress = getTaskProgress(taskId);
            
            if (progress != null) {
                // 推送到WebSocket或消息队列
                pushProgressToClients(progress);
            }
        });
    }
    
    private void pushProgressToClients(TaskProgress progress) {
        // 通过WebSocket或消息队列推送给前端
        // 这里简化处理,实际项目中需要集成具体的消息推送机制
        log.info("任务进度更新: {} - {:.2f}%", progress.getTaskName(), progress.getProgressPercentage());
    }
}

业务场景应用

1. 数据迁移场景

// 用户数据迁移服务
@Service
public class UserMigrationService {
    
    @Autowired
    private DistributedPartitionCoordinator coordinator;
    
    @Autowired
    private BatchProgressMonitor progressMonitor;
    
    // 启动用户数据迁移
    public String startUserMigration(UserMigrationConfig config) {
        BatchTaskConfig taskConfig = BatchTaskConfig.builder()
            .taskName("用户数据迁移")
            .dataSource(new User())
            .partitionCount(config.getWorkerCount())
            .partitionStrategy(config.getPartitionStrategy())
            .build();
            
        String taskId = coordinator.registerBatchTask(taskConfig);
        
        log.info("用户数据迁移任务已启动,任务ID: {}", taskId);
        return taskId;
    }
    
    // 获取迁移进度
    public MigrationProgress getMigrationProgress(String taskId) {
        TaskProgress taskProgress = progressMonitor.getTaskProgress(taskId);
        
        if (taskProgress == null) {
            return null;
        }
        
        return MigrationProgress.builder()
            .taskId(taskId)
            .migratedCount(taskProgress.getCompletedPartitions() * 1000L) // 假设每分片1000条记录
            .totalCount(taskProgress.getTotalPartitions() * 1000L)
            .progressPercentage(taskProgress.getProgressPercentage())
            .status(convertStatus(taskProgress.getStatus()))
            .estimatedTimeRemaining(calculateEstimatedTime(taskProgress))
            .build();
    }
    
    private MigrationStatus convertStatus(TaskStatus status) {
        switch (status) {
            case PENDING: return MigrationStatus.NOT_STARTED;
            case PROCESSING: return MigrationStatus.IN_PROGRESS;
            case COMPLETED: return MigrationStatus.COMPLETED;
            case FAILED: return MigrationStatus.FAILED;
            default: return MigrationStatus.UNKNOWN;
        }
    }
    
    private long calculateEstimatedTime(TaskProgress progress) {
        if (progress.getCompletedPartitions() == 0) {
            return -1; // 无法估算
        }
        
        long elapsed = System.currentTimeMillis() - progress.getStartTime();
        double rate = (double) progress.getCompletedPartitions() / elapsed;
        
        long remainingPartitions = progress.getPendingPartitions() + progress.getProcessingPartitions();
        return (long) (remainingPartitions / rate);
    }
}

2. 配置管理

# application.yml
batch:
  partition:
    enabled: true
    default-partition-count: 10
    max-retry-count: 3
    batch-size: 1000
    queue-capacity: 10000
    
  monitoring:
    enabled: true
    progress-update-interval: 5000  # 5秒更新一次进度
    metrics-export: true
    
  strategies:
    - name: id-range
      class: com.example.batch.partition.IdRangePartitionStrategy
      description: 基于ID范围的分片策略
    - name: hash
      class: com.example.batch.partition.HashPartitionStrategy
      description: 基于哈希的分片策略
    - name: modulo
      class: com.example.batch.partition.ModuloPartitionStrategy
      description: 基于模运算的分片策略

3. 监控面板

// 批处理监控控制器
@RestController
@RequestMapping("/api/batch")
public class BatchMonitorController {
    
    @Autowired
    private BatchProgressMonitor progressMonitor;
    
    @Autowired
    private UserMigrationService migrationService;
    
    // 获取所有批处理任务
    @GetMapping("/tasks")
    public ResponseEntity<List<BatchTaskSummary>> getAllTasks() {
        List<BatchTaskSummary> summaries = progressMonitor.getAllTaskSummaries();
        return ResponseEntity.ok(summaries);
    }
    
    // 获取特定任务详情
    @GetMapping("/tasks/{taskId}")
    public ResponseEntity<TaskDetail> getTaskDetail(@PathVariable String taskId) {
        TaskProgress progress = progressMonitor.getTaskProgress(taskId);
        if (progress == null) {
            return ResponseEntity.notFound().build();
        }
        
        TaskDetail detail = TaskDetail.builder()
            .taskId(taskId)
            .taskName(progress.getTaskName())
            .status(progress.getStatus())
            .progress(progress.getProgressPercentage())
            .totalPartitions(progress.getTotalPartitions())
            .completedPartitions(progress.getCompletedPartitions())
            .failedPartitions(progress.getFailedPartitions())
            .startTime(new Date(progress.getStartTime()))
            .build();
            
        return ResponseEntity.ok(detail);
    }
    
    // 启动数据迁移
    @PostMapping("/migration/start")
    public ResponseEntity<MigrationResponse> startMigration(@RequestBody MigrationRequest request) {
        try {
            String taskId = migrationService.startUserMigration(
                UserMigrationConfig.builder()
                    .workerCount(request.getWorkerCount())
                    .partitionStrategy(request.getStrategy())
                    .build());
            
            MigrationResponse response = MigrationResponse.builder()
                .taskId(taskId)
                .message("数据迁移任务已启动")
                .build();
                
            return ResponseEntity.ok(response);
        } catch (Exception e) {
            return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
                .body(MigrationResponse.builder()
                    .message("启动迁移任务失败: " + e.getMessage())
                    .build());
        }
    }
    
    // 获取迁移进度
    @GetMapping("/migration/progress/{taskId}")
    public ResponseEntity<MigrationProgress> getMigrationProgress(@PathVariable String taskId) {
        MigrationProgress progress = migrationService.getMigrationProgress(taskId);
        if (progress == null) {
            return ResponseEntity.notFound().build();
        }
        return ResponseEntity.ok(progress);
    }
}

最佳实践建议

1. 分片策略选择

@Component
public class SmartPartitionStrategySelector {
    
    private final List<PartitionStrategy<User>> strategies;
    
    public PartitionStrategy<User> selectStrategy(User dataSource, int dataSize, int workerCount) {
        // 根据数据特征选择最优策略
        if (dataSize > 1000000 && hasSequentialId(dataSource)) {
            return strategies.stream()
                .filter(s -> s instanceof IdRangePartitionStrategy)
                .findFirst()
                .orElseThrow(() -> new IllegalStateException("No suitable strategy found"));
        } else if (dataSize > 100000 && hasGoodHashDistribution(dataSource)) {
            return strategies.stream()
                .filter(s -> s instanceof HashPartitionStrategy)
                .findFirst()
                .orElseThrow(() -> new IllegalStateException("No suitable strategy found"));
        } else {
            // 默认使用模运算策略
            return strategies.stream()
                .filter(s -> s instanceof ModuloPartitionStrategy)
                .findFirst()
                .orElseThrow(() -> new IllegalStateException("No suitable strategy found"));
        }
    }
    
    private boolean hasSequentialId(User dataSource) {
        // 检查数据是否具有连续的ID
        return true; // 简化实现
    }
    
    private boolean hasGoodHashDistribution(User dataSource) {
        // 检查数据的哈希分布是否均匀
        return true; // 简化实现
    }
}

2. 性能优化建议

@Configuration
public class BatchPerformanceConfig {
    
    // 优化的线程池配置
    @Bean
    public TaskExecutor optimizedBatchTaskExecutor() {
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        executor.setCorePoolSize(20);
        executor.setMaxPoolSize(50);
        executor.setQueueCapacity(1000);
        executor.setKeepAliveSeconds(300);
        executor.setThreadNamePrefix("batch-opt-");
        executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
        executor.setWaitForTasksToCompleteOnShutdown(true);
        executor.setAwaitTerminationSeconds(300);
        executor.initialize();
        return executor;
    }
    
    // 数据库连接池优化
    @Bean
    public DataSource batchDataSource() {
        HikariDataSource dataSource = new HikariDataSource();
        dataSource.setJdbcUrl("jdbc:mysql://localhost:3306/batch_db");
        dataSource.setUsername("username");
        dataSource.setPassword("password");
        dataSource.setMaximumPoolSize(50);  // 批处理需要更多连接
        dataSource.setMinimumIdle(10);
        dataSource.setConnectionTimeout(30000);
        dataSource.setIdleTimeout(600000);
        dataSource.setMaxLifetime(1800000);
        return dataSource;
    }
    
    // 批处理优化配置
    @Bean
    public JdbcTemplate batchJdbcTemplate(DataSource batchDataSource) {
        JdbcTemplate jdbcTemplate = new JdbcTemplate(batchDataSource);
        jdbcTemplate.setFetchSize(1000);  // 增大批量读取大小
        return jdbcTemplate;
    }
}

3. 故障恢复机制

@Component
public class BatchRecoveryManager {
    
    private final RedisTemplate<String, Object> redisTemplate;
    private final DistributedPartitionCoordinator coordinator;
    
    // 检测并恢复失败的任务
    @Scheduled(fixedRate = 60000) // 每分钟检查一次
    public void detectAndRecoverFailedTasks() {
        Set<String> taskKeys = redisTemplate.opsForValue().getOperations()
            .keys("batch:task:*");
            
        taskKeys.forEach(taskKey -> {
            String taskId = taskKey.substring("batch:task:".length());
            BatchTask task = (BatchTask) redisTemplate.opsForValue().get(taskKey);
            
            if (task != null && task.getStatus() == TaskStatus.PROCESSING) {
                // 检查是否超时
                if (isTaskTimedOut(task)) {
                    recoverTask(taskId);
                }
            }
        });
    }
    
    private boolean isTaskTimedOut(BatchTask task) {
        return System.currentTimeMillis() - task.getCreatedTime() > 3600000; // 1小时超时
    }
    
    private void recoverTask(String taskId) {
        // 重新分配未完成的分片
        List<Integer> incompletePartitions = getIncompletePartitions(taskId);
        
        incompletePartitions.forEach(partitionId -> {
            String partitionKey = "batch:partition:" + taskId + ":" + partitionId;
            PartitionTask partitionTask = (PartitionTask) redisTemplate.opsForValue().get(partitionKey);
            
            if (partitionTask != null && partitionTask.getStatus() == PartitionStatus.PROCESSING) {
                // 将处理中超时的分片重新置为待处理
                partitionTask.setStatus(PartitionStatus.PENDING);
                partitionTask.setWorkerId(null);
                redisTemplate.opsForValue().set(partitionKey, partitionTask);
                
                // 重新加入处理队列
                redisTemplate.opsForList().leftPush("batch:queue:" + taskId, partitionId);
            }
        });
        
        log.info("任务 {} 已恢复,重新分配 {} 个分片", taskId, incompletePartitions.size());
    }
    
    private List<Integer> getIncompletePartitions(String taskId) {
        Set<String> partitionKeys = redisTemplate.opsForValue().getOperations()
            .keys("batch:partition:" + taskId + ":*");
            
        return partitionKeys.stream()
            .map(key -> {
                PartitionTask task = (PartitionTask) redisTemplate.opsForValue().get(key);
                return task != null && task.getStatus() != PartitionStatus.COMPLETED ? 
                    task.getPartitionId() : null;
            })
            .filter(Objects::nonNull)
            .collect(Collectors.toList());
    }
}

预期效果

通过这套批处理分片和分布式协调方案,我们可以实现:

  • 性能提升:处理时间从天级缩短到小时级,提升10-50倍性能
  • 资源利用率:充分利用多核CPU和多台机器资源
  • 可靠性增强:支持断点续传和故障自动恢复
  • 扩展性提升:可动态增减处理节点
  • 运维友好:完整的监控和告警机制

这套方案让批处理从"单机串行"变成了"分布式并行",是处理大规模数据的重要利器。


欢迎关注公众号"服务端技术精选",获取更多技术干货!
欢迎大家加群交流


标题:SpringBoot + 批处理分片 + 分布式协调:千万级数据分片并行处理,避免单点瓶颈
作者:jiangyi
地址:http://jiangyi.space/articles/2026/02/13/1770787208712.html
公众号:服务端技术精选
    评论
    0 评论
avatar

取消