跳到主要内容

原子类与 CAS 机制详解

🎯 学习目标

通过本章学习,你将掌握:

  • 理解为什么需要原子类及其解决的核心问题
  • 掌握 CAS 机制的工作原理和局限性
  • 熟练使用各种原子类及其典型应用场景
  • 了解 ABA 问题及其解决方案
  • 掌握原子类在面试中的常见考点

📚 目录

  1. 从问题出发:为什么需要原子类?
  2. CAS 机制深度解析
  3. 原子类家族完全指南
  4. 基础原子类型实战
  5. 引用类型原子类详解
  6. 解决 ABA 问题
  7. 高性能计数器:LongAdder
  8. 原子字段更新器
  9. 实战应用案例
  10. 面试必备考点

从问题出发:为什么需要原子类?

❌ 传统方式的问题

让我们先看一个简单的计数器例子:

public class UnsafeCounter {
private int count = 0;

public void increment() {
count++; // 看起来简单,实则不安全
}

public int getCount() {
return count;
}
}

问题分析count++ 操作实际上包含三个步骤:

  1. 读取 count 的当前值
  2. 将值加 1
  3. 将新值写回 count

在多线程环境下,这些步骤可能被交错执行,导致数据不一致:

// 线程A和线程B同时执行 increment()
// 初始值: count = 5

// 理想情况:
// 线程A: 读取5 -> 计算6 -> 写入6
// 线程B: 读取6 -> 计算7 -> 写入7
// 结果: count = 7 ✓

// 实际可能的情况:
// 线程A: 读取5
// 线程B: 读取5 (都在A写入前读取)
// 线程A: 计算6 -> 写入6
// 线程B: 计算6 -> 写入6
// 结果: count = 6 ✗ (丢失了一次递增)

✅ 解决方案对比

方案1:使用 synchronized

public class SynchronizedCounter {
private int count = 0;

public synchronized void increment() {
count++;
}

public synchronized int getCount() {
return count;
}
}

优点

  • 简单直观,容易理解
  • 保证操作的原子性

缺点

  • 线程阻塞涉及内核态切换,开销较大
  • 锁竞争激烈时性能下降明显
  • 可能导致死锁(虽然这个例子中不会)

方案2:使用 AtomicInteger

public class AtomicCounter {
private final AtomicInteger count = new AtomicInteger(0);

public void increment() {
count.incrementAndGet(); // 原子操作
}

public int getCount() {
return count.get();
}
}

优点

  • 无锁算法,避免线程阻塞
  • 在竞争不激烈时性能优秀
  • JVM 层面优化,支持硬件级别的原子操作

适用场景

  • 高频的简单操作(如计数器)
  • 对性能要求较高的场景
  • 单个变量的原子操作

CAS 机制深度解析

🔍 什么是 CAS?

CAS(Compare-And-Swap,比较并交换)是一种无锁算法的核心操作。它包含三个操作数:

  • 内存位置 V:要更新的变量的内存地址
  • 预期原值 A:期望的当前值
  • 新值 B:要更新的新值

操作逻辑:只有当内存位置 V 的值等于预期原值 A 时,才将该位置更新为新值 B,否则什么都不做。

🧠 CAS 工作原理图解

// CAS 的伪代码实现
boolean compareAndSwap(int* memoryAddress, int expectedValue, int newValue) {
if (*memoryAddress == expectedValue) {
*memoryAddress = newValue;
return true; // 更新成功
}
return false; // 更新失败
}

🔄 实际执行流程

AtomicInteger.incrementAndGet() 为例:

public final int incrementAndGet() {
return unsafe.getAndAddInt(this, valueOffset, 1) + 1;
}

// getAndAddInt 的实现逻辑:
public final int getAndAddInt(Object o, long offset, int delta) {
int v;
do {
v = getIntVolatile(o, offset); // 1. 读取当前值
} while (!compareAndSwapInt(o, offset, v, v + delta)); // 2. CAS 操作
return v; // 返回旧值
}

执行步骤详解

  1. 读取当前值v = getIntVolatile(o, offset)

    • 从内存中读取 AtomicInteger 对象的当前值
    • 使用 volatile 语义保证可见性
  2. 尝试更新compareAndSwapInt(o, offset, v, v + delta)

    • 比较内存中的值是否还是 v
    • 如果是,更新为 v + delta 并返回 true
    • 如果不是(其他线程已经修改过),返回 false
  3. 自旋重试while (!...)

    • 如果 CAS 失败,回到步骤1重新读取当前值
    • 重复这个过程直到成功为止

⚡ CAS 的优势

  1. 无阻塞:线程不会进入阻塞状态,避免内核态切换
  2. 高并发性能:在竞争不激烈时性能优异
  3. 避免死锁:没有锁的获取和释放过程

⚠️ CAS 的局限性

1. ABA 问题

// 初始状态:共享变量 value = A

// 时间线:
// T1: 读取 value = A,准备 CAS(A, B)
// T2: 修改 value = B
// T3: 修改 value = A (回到原始值)
// T1: 执行 CAS(A, B) - 成功!但这可能不是我们想要的

2. 自旋消耗 CPU

在高并发场景下,多个线程同时进行 CAS 操作,大部分都会失败,导致 CPU 空转。

3. 复合操作困难

CAS 只能保证单个操作的原子性,无法处理需要多个步骤的复合操作。


原子类家族完全指南

Java 的原子类主要位于 java.util.concurrent.atomic 包中,可以分为以下几大类:

📊 原子类分类表

分类主要类适用场景特点
基础类型AtomicInteger, AtomicLong, AtomicBoolean简单计数器、状态标志直接操作基本类型
引用类型AtomicReference<V>, AtomicStampedReference<V>, AtomicMarkableReference<V>对象引用更新、缓存支持泛型,解决 ABA
数组类型AtomicIntegerArray, AtomicLongArray, AtomicReferenceArray<T>数组元素的原子操作对数组中单个元素操作
字段更新器AtomicIntegerFieldUpdater<T>, AtomicLongFieldUpdater<T>, AtomicReferenceFieldUpdater<T,V>对象字段原子更新无需创建额外对象
累加器LongAdder, LongAccumulator, DoubleAdder, DoubleAccumulator高并发统计分段累加,高性能

🎯 选择指南


基础原子类型实战

AtomicInteger 详解

核心方法介绍

// 创建实例
AtomicInteger atomicInt = new AtomicInteger(0);
AtomicInteger atomicInt2 = new AtomicInteger(10); // 指定初始值

// 基础操作
int get() // 获取当前值
void set(int newValue) // 设置值(立即生效)
void lazySet(int newValue) // 延迟设置(可能稍后生效,性能更好)

// 原子增减操作
int incrementAndGet() // ++i
int getAndIncrement() // i++
int decrementAndGet() // --i
int getAndDecrement() // i--
int addAndGet(int delta) // value += delta
int getAndAdd(int delta) // old = value; value += delta; return old

// 条件更新
boolean compareAndSet(int expect, int update) // CAS操作
int getAndUpdate(IntUnaryOperator updateFunction)
int updateAndGet(IntUnaryOperator updateFunction)

// 统计操作(Java 8+)
int accumulateAndGet(int x, IntBinaryOperator accumulatorFunction)
int getAndAccumulate(int x, IntBinaryOperator accumulatorFunction)

实战示例:简单的限流器

public class RateLimiter {
private final AtomicInteger counter = new AtomicInteger(0);
private final int maxPermits;
private final long resetInterval;
private volatile long lastResetTime;

public RateLimiter(int maxPermits, long resetInterval) {
this.maxPermits = maxPermits;
this.resetInterval = resetInterval;
this.lastResetTime = System.currentTimeMillis();
}

public boolean tryAcquire() {
// 检查是否需要重置计数器
long now = System.currentTimeMillis();
if (now - lastResetTime > resetInterval) {
// 使用 CAS 重置计数器
if (compareAndResetTime(now)) {
counter.set(0);
}
}

// 尝试获取许可
int current = counter.get();
if (current >= maxPermits) {
return false;
}

// CAS 增加计数器
return counter.compareAndSet(current, current + 1);
}

private synchronized boolean compareAndResetTime(long newTime) {
if (System.currentTimeMillis() - lastResetTime > resetInterval) {
lastResetTime = newTime;
return true;
}
return false;
}
}

实战示例:分布式 ID 生成器

public class AtomicIdGenerator {
private final AtomicInteger counter = new AtomicInteger(0);
private final int maxId;

public AtomicIdGenerator(int maxId) {
this.maxId = maxId;
}

public int nextId() {
while (true) {
int current = counter.get();
int next = (current + 1) % maxId;

if (counter.compareAndSet(current, next)) {
return next;
}
// CAS 失败,自旋重试
}
}

// 批量获取 ID,减少 CAS 次数
public int[] nextBatch(int batchSize) {
int[] ids = new int[batchSize];
int start = counter.getAndAdd(batchSize);

for (int i = 0; i < batchSize; i++) {
ids[i] = (start + i) % maxId;
}
return ids;
}
}

AtomicBoolean 实战

使用场景:状态标志控制

public class ServiceManager {
private final AtomicBoolean isRunning = new AtomicBoolean(false);
private final AtomicBoolean isShuttingDown = new AtomicBoolean(false);

public void start() {
if (isRunning.compareAndSet(false, true)) {
System.out.println("Service started successfully");
// 启动服务的逻辑
} else {
System.out.println("Service is already running");
}
}

public void shutdown() {
if (isShuttingDown.compareAndSet(false, true)) {
System.out.println("Initiating graceful shutdown...");
// 开始关闭流程
try {
// 关闭资源的逻辑
Thread.sleep(1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
isRunning.set(false);
System.out.println("Service shutdown completed");
}
}
}

public boolean isServiceRunning() {
return isRunning.get() && !isShuttingDown.get();
}
}

引用类型原子类详解

AtomicReference 基础使用

核心方法

// 创建实例
AtomicReference<User> userRef = new AtomicReference<>();
AtomicReference<User> userRef2 = new AtomicReference<>(new User("Alice"));

// 基础操作
V get() // 获取当前引用
void set(V newValue) // 设置引用
void lazySet(V newValue) // 延迟设置
boolean compareAndSet(V expect, V update) // CAS操作

// 函数式更新(Java 8+)
V getAndUpdate(UnaryOperator<V> updateFunction)
V updateAndGet(UnaryOperator<V> updateFunction)

实战示例:配置热更新

public class ConfigurationManager {
private final AtomicReference<Config> configRef = new AtomicReference<>();

// 配置类
public static class Config {
private final String databaseUrl;
private final int maxConnections;
private final long timeout;

public Config(String databaseUrl, int maxConnections, long timeout) {
this.databaseUrl = databaseUrl;
this.maxConnections = maxConnections;
this.timeout = timeout;
}

// getters...
public String getDatabaseUrl() { return databaseUrl; }
public int getMaxConnections() { return maxConnections; }
public long getTimeout() { return timeout; }
}

// 更新配置
public void updateConfig(String newUrl, int newMaxConnections, long newTimeout) {
Config newConfig = new Config(newUrl, newMaxConnections, newTimeout);
configRef.set(newConfig);
System.out.println("Configuration updated");
}

// 原子性更新配置
public boolean atomicUpdateConfig(String oldUrl, String newUrl,
int newMaxConnections, long newTimeout) {
Config oldConfig = configRef.get();
if (oldConfig != null && oldConfig.getDatabaseUrl().equals(oldUrl)) {
Config newConfig = new Config(newUrl, newMaxConnections, newTimeout);
return configRef.compareAndSet(oldConfig, newConfig);
}
return false;
}

// 使用函数式更新(Java 8+)
public void updateTimeout(long newTimeout) {
configRef.updateAndGet(config ->
new Config(config.getDatabaseUrl(), config.getMaxConnections(), newTimeout));
}

// 获取当前配置
public Config getCurrentConfig() {
return configRef.get();
}
}

实战示例:无锁栈实现

public class LockFreeStack<T> {
private final AtomicReference<Node<T>> head = new AtomicReference<>();

private static class Node<T> {
final T data;
final AtomicReference<Node<T>> next;

Node(T data) {
this.data = data;
this.next = new AtomicReference<>(null);
}
}

public void push(T item) {
Node<T> newNode = new Node<>(item);
Node<T> currentHead;

do {
currentHead = head.get();
newNode.next.set(currentHead);
} while (!head.compareAndSet(currentHead, newNode));
}

public T pop() {
Node<T> currentHead;
Node<T> newHead;

do {
currentHead = head.get();
if (currentHead == null) {
return null; // 栈为空
}
newHead = currentHead.next.get();
} while (!head.compareAndSet(currentHead, newHead));

return currentHead.data;
}

public boolean isEmpty() {
return head.get() == null;
}
}

解决 ABA 问题

🔍 ABA 问题详解

ABA 问题是 CAS 操作的一个经典问题。让我们通过一个银行账户的例子来理解:

// 银行账户类
class BankAccount {
private int balance = 1000;

public void withdraw(int amount) {
// 模拟取款过程中的延迟
try {
Thread.sleep(100);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}

if (balance >= amount) {
balance -= amount;
}
}

public void deposit(int amount) {
balance += amount;
}

public int getBalance() {
return balance;
}
}

// ABA 问题演示
public class ABADemo {
public static void main(String[] args) throws InterruptedException {
BankAccount account = new BankAccount();
AtomicReference<BankAccount> accountRef = new AtomicReference<>(account);

// 线程1:准备转账,但中途被其他线程操作
Thread thread1 = new Thread(() -> {
BankAccount currentAccount = accountRef.get();
// 模拟长时间操作
try {
Thread.sleep(200);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}

// 此时 CAS 会成功,但这可能不是我们想要的结果
boolean success = accountRef.compareAndSet(currentAccount, currentAccount);
System.out.println("Thread1 CAS result: " + success);
});

// 线程2:快速操作账户
Thread thread2 = new Thread(() -> {
BankAccount currentAccount = accountRef.get();
currentAccount.withdraw(200); // 余额: 800
currentAccount.deposit(200); // 余额: 1000 (回到原始值)
System.out.println("Thread2 operations completed");
});

thread1.start();
thread2.start();

thread1.join();
thread2.join();
}
}

✅ 解决方案:AtomicStampedReference

AtomicStampedReference 通过引入版本号(stamp)来彻底解决 ABA 问题:

public class StampedReferenceExample {
public static void main(String[] args) {
// 初始值:100,版本号:1
AtomicStampedReference<Integer> stampedRef =
new AtomicStampedReference<>(100, 1);

// 获取当前值和版本号
int[] stampHolder = new int[1];
Integer currentValue = stampedRef.get(stampHolder);
int currentStamp = stampHolder[0];

System.out.println("Initial: value=" + currentValue + ", stamp=" + currentStamp);

// 尝试更新值和版本号
boolean success1 = stampedRef.compareAndSet(
currentValue, 200, currentStamp, currentStamp + 1
);
System.out.println("First update: " + success1);

// 获取新的值和版本号
currentValue = stampedRef.get(stampHolder);
currentStamp = stampHolder[0];
System.out.println("After first: value=" + currentValue + ", stamp=" + currentStamp);

// 尝试回到原始值,但版本号不同
boolean success2 = stampedRef.compareAndSet(
currentValue, 100, currentStamp, currentStamp + 1
);
System.out.println("Second update: " + success2);

// 获取最终状态
currentValue = stampedRef.get(stampHolder);
currentStamp = stampHolder[0];
System.out.println("Final: value=" + currentValue + ", stamp=" + currentStamp);

// 此时用旧的版本号尝试更新,会失败
boolean shouldFail = stampedRef.compareAndSet(
100, 300, 1, 2 // 使用过期的版本号
);
System.out.println("Should fail: " + shouldFail);
}
}

实战示例:银行转账系统

public class SafeBankTransfer {
private static class Account {
private final AtomicStampedReference<Integer> balance;

public Account(int initialBalance) {
this.balance = new AtomicStampedReference<>(initialBalance, 1);
}

public boolean transfer(Account to, int amount) {
while (true) {
// 获取源账户当前状态
int[] fromStamp = new int[1];
int fromBalance = this.balance.get(fromStamp);

// 获取目标账户当前状态
int[] toStamp = new int[1];
int toBalance = to.balance.get(toStamp);

// 检查源账户余额是否充足
if (fromBalance < amount) {
return false;
}

// 尝试原子性转账
boolean fromUpdated = this.balance.compareAndSet(
fromBalance, fromBalance - amount,
fromStamp[0], fromStamp[0] + 1
);

if (fromUpdated) {
// 源账户扣款成功,现在更新目标账户
boolean toUpdated = to.balance.compareAndSet(
toBalance, toBalance + amount,
toStamp[0], toStamp[0] + 1
);

if (toUpdated) {
// 转账完全成功
return true;
} else {
// 目标账户更新失败,回滚源账户
this.balance.compareAndSet(
fromBalance - amount, fromBalance,
fromStamp[0] + 1, fromStamp[0] + 2
);
}
}

// 如果到这里,说明有竞争,重试
Thread.yield();
}
}

public int getBalance() {
return balance.getReference();
}
}

public static void main(String[] args) throws InterruptedException {
Account account1 = new Account(1000);
Account account2 = new Account(500);

// 模拟多线程转账
Thread[] threads = new Thread[10];
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
for (int j = 0; j < 100; j++) {
if (Math.random() > 0.5) {
account1.transfer(account2, 10);
} else {
account2.transfer(account1, 10);
}
}
});
threads[i].start();
}

// 等待所有转账完成
for (Thread thread : threads) {
thread.join();
}

System.out.println("Account1 balance: " + account1.getBalance());
System.out.println("Account2 balance: " + account2.getBalance());
System.out.println("Total balance: " + (account1.getBalance() + account2.getBalance()));
}
}

AtomicMarkableReference

有时我们不需要精确的版本号,只需要一个布尔标记:

public class MarkableReferenceExample {
public static void main(String[] args) {
// 使用标记来表示元素是否被删除
AtomicMarkableReference<String> markableRef =
new AtomicMarkableReference<>("Hello", false);

// 获取值和标记
boolean[] markHolder = new boolean[1];
String value = markableRef.get(markHolder);
System.out.println("Value: " + value + ", Marked: " + markHolder[0]);

// 标记为删除
boolean marked = markableRef.attemptMark(value, true);
System.out.println("Mark successful: " + marked);

// 再次获取,看看标记
value = markableRef.get(markHolder);
System.out.println("Value: " + value + ", Marked: " + markHolder[0]);

// 尝试更新,但标记需要匹配
boolean updated = markableRef.compareAndSet(
value, "World", true, false // 需要标记为true才能更新
);
System.out.println("Update successful: " + updated);
}
}

高性能计数器:LongAdder

🚀 为什么需要 LongAdder?

在高并发场景下,多个线程同时更新一个 AtomicLong 会导致大量的 CAS 竞争:

// AtomicLong 在高并发下的性能问题
public class AtomicLongPerformanceTest {
public static void main(String[] args) throws InterruptedException {
AtomicLong atomicLong = new AtomicLong(0);
int threadCount = 10;
int operationsPerThread = 1000000;

long startTime = System.nanoTime();

Thread[] threads = new Thread[threadCount];
for (int i = 0; i < threadCount; i++) {
threads[i] = new Thread(() -> {
for (int j = 0; j < operationsPerThread; j++) {
atomicLong.incrementAndGet();
}
});
threads[i].start();
}

for (Thread thread : threads) {
thread.join();
}

long endTime = System.nanoTime();
System.out.println("AtomicLong result: " + atomicLong.get());
System.out.println("Time taken: " + (endTime - startTime) / 1_000_000 + " ms");
}
}

🔧 LongAdder 工作原理

LongAdder 采用分段累加的策略来减少竞争:

📊 性能对比测试

public class CounterComparison {
private static final int THREAD_COUNT = 20;
private static final int OPERATIONS_PER_THREAD = 1_000_000;

public static void testAtomicLong() throws InterruptedException {
AtomicLong counter = new AtomicLong(0);

long startTime = System.nanoTime();

Thread[] threads = new Thread[THREAD_COUNT];
for (int i = 0; i < THREAD_COUNT; i++) {
threads[i] = new Thread(() -> {
for (int j = 0; j < OPERATIONS_PER_THREAD; j++) {
counter.incrementAndGet();
}
});
threads[i].start();
}

for (Thread thread : threads) {
thread.join();
}

long endTime = System.nanoTime();
System.out.println("AtomicLong:");
System.out.println(" Result: " + counter.get());
System.out.println(" Time: " + (endTime - startTime) / 1_000_000 + " ms");
}

public static void testLongAdder() throws InterruptedException {
LongAdder counter = new LongAdder();

long startTime = System.nanoTime();

Thread[] threads = new Thread[THREAD_COUNT];
for (int i = 0; i < THREAD_COUNT; i++) {
threads[i] = new Thread(() -> {
for (int j = 0; j < OPERATIONS_PER_THREAD; j++) {
counter.increment();
}
});
threads[i].start();
}

for (Thread thread : threads) {
thread.join();
}

long endTime = System.nanoTime();
System.out.println("LongAdder:");
System.out.println(" Result: " + counter.sum());
System.out.println(" Time: " + (endTime - startTime) / 1_000_000 + " ms");
}

public static void main(String[] args) throws InterruptedException {
testAtomicLong();
testLongAdder();
}
}

🎯 LongAdder 使用场景

适用场景:

  • 写多读少:如 QPS 统计、访问计数
  • 高并发统计:需要高性能的累加操作
  • 可以容忍最终一致性:读取时需要遍历所有 cells

不适用场景:

  • 需要精确实时值sum() 操作成本较高
  • 读多写少AtomicLong 更合适
  • 需要精确控制:如限流器的精确控制

🛠️ LongAdder 实战示例

高性能 QPS 统计

public class QpsMonitor {
private final LongAdder requestCounter = new LongAdder();
private final LongAdder errorCounter = new LongAdder();
private final LongAdder responseTimeSum = new LongAdder();

private volatile long lastResetTime = System.currentTimeMillis();
private volatile long lastResetSecond = System.currentTimeMillis() / 1000;

public void recordRequest(long responseTimeMs, boolean isError) {
requestCounter.increment();

if (isError) {
errorCounter.increment();
}

responseTimeSum.add(responseTimeMs);
}

public QpsMetrics getCurrentSecondMetrics() {
long currentSecond = System.currentTimeMillis() / 1000;

// 如果进入新的一秒,重置计数器
if (currentSecond > lastResetSecond) {
synchronized (this) {
if (currentSecond > lastResetSecond) {
resetCounters();
lastResetSecond = currentSecond;
}
}
}

long requests = requestCounter.sum();
long errors = errorCounter.sum();
long responseTimeTotal = responseTimeSum.sum();

double qps = requests; // 每秒请求数
double errorRate = requests > 0 ? (double) errors / requests : 0;
double avgResponseTime = requests > 0 ? (double) responseTimeTotal / requests : 0;

return new QpsMetrics(qps, errorRate, avgResponseTime);
}

private void resetCounters() {
// LongAdder 没有直接的 reset 方法,需要重新创建
LongAdder newRequestCounter = new LongAdder();
LongAdder newErrorCounter = new LongAdder();
LongAdder newResponseTimeSum = new LongAdder();

requestCounter = newRequestCounter;
errorCounter = newErrorCounter;
responseTimeSum = newResponseTimeSum;
}

public static class QpsMetrics {
private final double qps;
private final double errorRate;
private final double avgResponseTime;

public QpsMetrics(double qps, double errorRate, double avgResponseTime) {
this.qps = qps;
this.errorRate = errorRate;
this.avgResponseTime = avgResponseTime;
}

@Override
public String toString() {
return String.format("QPS: %.2f, Error Rate: %.2f%%, Avg Response Time: %.2fms",
qps, errorRate * 100, avgResponseTime);
}
}
}

分片计数器

public class ShardedCounter {
private final LongAdder[] counters;
private final int shardCount;

public ShardedCounter(int shardCount) {
this.shardCount = shardCount;
this.counters = new LongAdder[shardCount];
for (int i = 0; i < shardCount; i++) {
counters[i] = new LongAdder();
}
}

public void increment(Object key) {
int shard = Math.abs(key.hashCode()) % shardCount;
counters[shard].increment();
}

public void increment() {
// 随机选择一个分片
int shard = ThreadLocalRandom.current().nextInt(shardCount);
counters[shard].increment();
}

public long sum() {
long total = 0;
for (LongAdder counter : counters) {
total += counter.sum();
}
return total;
}

public Map<Integer, Long> getShardStats() {
Map<Integer, Long> stats = new HashMap<>();
for (int i = 0; i < shardCount; i++) {
stats.put(i, counters[i].sum());
}
return stats;
}
}

LongAccumulator:自定义累加器

LongAccumulator 提供了更灵活的累加操作:

public class AccumulatorExample {
public static void main(String[] args) throws InterruptedException {
// 最大值累加器
LongAccumulator maxAccumulator = new LongAccumulator(Math::max, Long.MIN_VALUE);

// 最小值累加器
LongAccumulator minAccumulator = new LongAccumulator(Math::min, Long.MAX_VALUE);

// 自定义累加器
LongAccumulator customAccumulator = new LongAccumulator((x, y) -> {
// 计算 x^2 + y
return x * x + y;
}, 0);

Thread[] threads = new Thread[10];
for (int i = 0; i < threads.length; i++) {
final int value = i;
threads[i] = new Thread(() -> {
for (int j = 0; j < 1000; j++) {
maxAccumulator.accumulate(value + j);
minAccumulator.accumulate(value + j);
customAccumulator.accumulate(value + j);
}
});
threads[i].start();
}

for (Thread thread : threads) {
thread.join();
}

System.out.println("Max: " + maxAccumulator.get());
System.out.println("Min: " + minAccumulator.get());
System.out.println("Custom: " + customAccumulator.get());
}
}

原子字段更新器

🔧 为什么要使用字段更新器?

原子字段更新器允许我们在不创建额外 Atomic* 对象的情况下,对普通类的字段进行原子操作。这在内存敏感或需要避免对象创建开销的场景中很有用。

📋 字段更新器类型

  • AtomicIntegerFieldUpdater<T> - 更新 int 字段
  • AtomicLongFieldUpdater<T> - 更新 long 字段
  • AtomicReferenceFieldUpdater<T,V> - 更新引用类型字段

🎯 使用要求

  1. 字段必须是 volatile:确保可见性
  2. 字段不能是 final:需要能够修改
  3. 字段访问权限:对于非 public 字段,更新器和目标类必须在同一个包中

💡 实战示例:任务状态管理

public class TaskManager {

// 任务类
public static class Task {
// 使用 volatile 确保可见性
private volatile int state;
private volatile long progress;
private volatile String errorMessage;

// 状态常量
public static final int STATE_PENDING = 0;
public static final int STATE_RUNNING = 1;
public static final int STATE_COMPLETED = 2;
public static final int STATE_FAILED = 3;

// 创建字段更新器
private static final AtomicIntegerFieldUpdater<Task> STATE_UPDATER =
AtomicIntegerFieldUpdater.newUpdater(Task.class, "state");

private static final AtomicLongFieldUpdater<Task> PROGRESS_UPDATER =
AtomicLongFieldUpdater.newUpdater(Task.class, "progress");

private static final AtomicReferenceFieldUpdater<Task, String> ERROR_UPDATER =
AtomicReferenceFieldUpdater.newUpdater(Task.class, "errorMessage", String.class);

// 构造函数
public Task() {
this.state = STATE_PENDING;
this.progress = 0;
this.errorMessage = null;
}

// 尝试开始任务
public boolean tryStart() {
return STATE_UPDATER.compareAndSet(this, STATE_PENDING, STATE_RUNNING);
}

// 更新进度
public void updateProgress(long newProgress) {
PROGRESS_UPDATER.set(this, newProgress);
}

// 完成任务
public boolean complete() {
return STATE_UPDATER.compareAndSet(this, STATE_RUNNING, STATE_COMPLETED);
}

// 标记失败
public boolean fail(String error) {
boolean updated = STATE_UPDATER.compareAndSet(this, STATE_RUNNING, STATE_FAILED);
if (updated) {
ERROR_UPDATER.set(this, error);
}
return updated;
}

// 获取状态信息
public int getState() { return state; }
public long getProgress() { return progress; }
public String getErrorMessage() { return errorMessage; }

@Override
public String toString() {
return String.format("Task[state=%d, progress=%d, error=%s]",
state, progress, errorMessage);
}
}

// 任务管理器
public static class TaskProcessor {
private final List<Task> tasks = new ArrayList<>();
private final ExecutorService executor = Executors.newFixedThreadPool(5);

public void addTask(Task task) {
tasks.add(task);
}

public void processAll() {
for (Task task : tasks) {
executor.submit(() -> {
if (task.tryStart()) {
try {
// 模拟任务执行
for (int i = 0; i <= 100; i += 10) {
task.updateProgress(i);
Thread.sleep(50);
}
task.complete();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
task.fail("Task interrupted");
} catch (Exception e) {
task.fail("Task failed: " + e.getMessage());
}
}
});
}
}

public void shutdown() {
executor.shutdown();
}
}

public static void main(String[] args) throws InterruptedException {
TaskProcessor processor = new TaskProcessor();

// 创建一些任务
for (int i = 0; i < 5; i++) {
processor.addTask(new Task());
}

// 处理任务
processor.processAll();

// 等待处理完成
Thread.sleep(3000);

// 打印任务状态
for (int i = 0; i < 5; i++) {
Task task = processor.getTasks().get(i);
System.out.println("Task " + i + ": " + task);
}

processor.shutdown();
}
}

💡 实战示例:高性能缓存实现

public class HighPerformanceCache<K, V> {

private static class CacheEntry<K, V> {
volatile K key;
volatile V value;
volatile long lastAccessTime;
volatile boolean isValid;

public CacheEntry(K key, V value) {
this.key = key;
this.value = value;
this.lastAccessTime = System.currentTimeMillis();
this.isValid = true;
}

// 字段更新器
private static final AtomicReferenceFieldUpdater<CacheEntry, Object> VALUE_UPDATER =
AtomicReferenceFieldUpdater.newUpdater(CacheEntry.class, "value", Object.class);

private static final AtomicLongFieldUpdater<CacheEntry> ACCESS_TIME_UPDATER =
AtomicLongFieldUpdater.newUpdater(CacheEntry.class, "lastAccessTime");

private static final AtomicIntegerFieldUpdater<CacheEntry> VALID_UPDATER =
AtomicIntegerFieldUpdater.newUpdater(CacheEntry.class, "isValid");

@SuppressWarnings("unchecked")
boolean updateValue(V newValue) {
VALUE_UPDATER.compareAndSet(this, this.value, newValue);
ACCESS_TIME_UPDATER.set(this, System.currentTimeMillis());
return true;
}

void updateAccessTime() {
ACCESS_TIME_UPDATER.set(this, System.currentTimeMillis());
}

boolean invalidate() {
return VALID_UPDATER.compareAndSet(this, 1, 0);
}

boolean isValid() {
return isValid;
}

@SuppressWarnings("unchecked")
V getValue() {
updateAccessTime();
return (V) value;
}
}

private final ConcurrentHashMap<K, CacheEntry<K, V>> cache = new ConcurrentHashMap<>();
private final AtomicLong hitCount = new AtomicLong(0);
private final AtomicLong missCount = new AtomicLong(0);
private final long maxAge; // 最大存活时间(毫秒)

public HighPerformanceCache(long maxAge) {
this.maxAge = maxAge;
}

public V get(K key) {
CacheEntry<K, V> entry = cache.get(key);

if (entry == null || !entry.isValid()) {
missCount.incrementAndGet();
return null;
}

// 检查是否过期
long currentTime = System.currentTimeMillis();
if (currentTime - entry.lastAccessTime > maxAge) {
if (entry.invalidate()) {
cache.remove(key);
missCount.incrementAndGet();
return null;
}
}

hitCount.incrementAndGet();
return entry.getValue();
}

public void put(K key, V value) {
CacheEntry<K, V> newEntry = new CacheEntry<>(key, value);
CacheEntry<K, V> existingEntry = cache.putIfAbsent(key, newEntry);

if (existingEntry != null) {
// 更新现有条目
existingEntry.updateValue(value);
}
}

public boolean remove(K key) {
CacheEntry<K, V> entry = cache.remove(key);
return entry != null && entry.isValid();
}

public int size() {
return cache.size();
}

public double getHitRate() {
long hits = hitCount.get();
long misses = missCount.get();
long total = hits + misses;
return total == 0 ? 0.0 : (double) hits / total;
}

public void cleanup() {
long currentTime = System.currentTimeMillis();

cache.entrySet().removeIf(entry -> {
CacheEntry<K, V> cacheEntry = entry.getValue();
return !cacheEntry.isValid() ||
(currentTime - cacheEntry.lastAccessTime > maxAge);
});
}
}

实战应用案例

🎯 案例1:分布式 ID 生成器

public class DistributedIdGenerator {
private final AtomicInteger counter;
private final String nodeId;
private final long timestamp;

public DistributedIdGenerator(String nodeId, int startValue) {
this.nodeId = nodeId;
this.timestamp = System.currentTimeMillis();
this.counter = new AtomicInteger(startValue);
}

public String generateId() {
// 格式: nodeId-timestamp-counter
int currentCounter = counter.incrementAndGet();
return String.format("%s-%d-%d", nodeId, timestamp, currentCounter);
}

// 生成批次 ID,减少竞争
public List<String> generateBatch(int batchSize) {
List<String> ids = new ArrayList<>(batchSize);
int startCounter = counter.getAndAdd(batchSize);

for (int i = 0; i < batchSize; i++) {
ids.add(String.format("%s-%d-%d", nodeId, timestamp, startCounter + i));
}

return ids;
}

public static void main(String[] args) throws InterruptedException {
DistributedIdGenerator idGenerator = new DistributedIdGenerator("node1", 1000);

// 多线程测试
Thread[] threads = new Thread[10];
Set<String> generatedIds = ConcurrentHashMap.newKeySet();

for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
for (int j = 0; j < 100; j++) {
String id = idGenerator.generateId();
generatedIds.add(id);
}
});
threads[i].start();
}

for (Thread thread : threads) {
thread.join();
}

System.out.println("Generated " + generatedIds.size() + " unique IDs");

// 批量生成示例
List<String> batch = idGenerator.generateBatch(10);
System.out.println("Batch IDs: " + batch);
}
}

🎯 案例2:高性能限流器

public class HighPerformanceRateLimiter {
private final LongAdder requestCount = new LongAdder();
private final AtomicLong lastResetTime = new AtomicLong(System.currentTimeMillis());
private final int maxRequests;
private final long windowSizeMs;

public HighPerformanceRateLimiter(int maxRequests, long windowSizeMs) {
this.maxRequests = maxRequests;
this.windowSizeMs = windowSizeMs;
}

public boolean tryAcquire() {
// 检查是否需要重置窗口
long currentTime = System.currentTimeMillis();
long lastReset = lastResetTime.get();

if (currentTime - lastReset >= windowSizeMs) {
// 尝试更新重置时间
if (lastResetTime.compareAndSet(lastReset, currentTime)) {
// 成功重置窗口
requestCount.reset();
}
}

// 检查当前请求数
long current = requestCount.sum();
if (current >= maxRequests) {
return false;
}

// 增加计数
requestCount.increment();
return true;
}

public long getCurrentCount() {
return requestCount.sum();
}

public double getRemainingPercentage() {
long current = requestCount.sum();
return (double) (maxRequests - current) / maxRequests * 100;
}

// 测试
public static void main(String[] args) throws InterruptedException {
HighPerformanceRateLimiter limiter = new HighPerformanceRateLimiter(100, 1000); // 100 requests per second

Thread[] threads = new Thread[20];
AtomicInteger passedRequests = new AtomicInteger(0);
AtomicInteger blockedRequests = new AtomicInteger(0);

for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
for (int j = 0; j < 10; j++) {
if (limiter.tryAcquire()) {
passedRequests.incrementAndGet();
} else {
blockedRequests.incrementAndGet();
}

try {
Thread.sleep(10);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
}
}
});
threads[i].start();
}

for (Thread thread : threads) {
thread.join();
}

System.out.println("Passed requests: " + passedRequests.get());
System.out.println("Blocked requests: " + blockedRequests.get());
System.out.println("Current count: " + limiter.getCurrentCount());
System.out.println("Remaining: " + limiter.getRemainingPercentage() + "%");
}
}

🎯 案例3:无锁任务队列

public class LockFreeTaskQueue<E> {
private static class Node<E> {
final E item;
final AtomicStampedReference<Node<E>> next;

Node(E item) {
this.item = item;
this.next = new AtomicStampedReference<>(null, 1);
}

Node() {
this.item = null;
this.next = new AtomicStampedReference<>(null, 1);
}
}

private final AtomicStampedReference<Node<E>> head;
private final AtomicStampedReference<Node<E>> tail;

public LockFreeTaskQueue() {
Node<E> dummy = new Node<>();
head = new AtomicStampedReference<>(dummy, 1);
tail = new AtomicStampedReference<>(dummy, 1);
}

public boolean offer(E item) {
if (item == null) {
throw new NullPointerException();
}

Node<E> newNode = new Node<>(item);

while (true) {
int[] tailStamp = new int[1];
Node<E> currentTail = tail.get(tailStamp);
int[] nextStamp = new int[1];
Node<E> tailNext = currentTail.next.get(nextStamp);

if (tailNext == null) {
// 尾节点的 next 为空,可以插入新节点
int[] newNextStamp = new int[]{nextStamp[0] + 1};
if (currentTail.next.compareAndSet(tailNext, newNode,
nextStamp[0], newNextStamp[0])) {
// 成功插入,更新尾节点
int[] newTailStamp = new int[]{tailStamp[0] + 1};
tail.compareAndSet(currentTail, newNode, tailStamp[0], newTailStamp[0]);
return true;
}
} else {
// 尾节点的 next 不为空,说明有其他线程已经在插入
// 帮助推进尾节点
int[] newTailStamp = new int[]{tailStamp[0] + 1};
tail.compareAndSet(currentTail, tailNext, tailStamp[0], newTailStamp[0]);
}
}
}

public E poll() {
while (true) {
int[] headStamp = new int[1];
Node<E> currentHead = head.get(headStamp);
int[] tailStamp = new int[1];
Node<E> currentTail = tail.get(tailStamp);
int[] nextStamp = new int[1];
Node<E> headNext = currentHead.next.get(nextStamp);

if (currentHead == currentTail) {
if (headNext == null) {
// 队列为空
return null;
} else {
// 帮助推进尾节点
int[] newTailStamp = new int[]{tailStamp[0] + 1};
tail.compareAndSet(currentTail, headNext, tailStamp[0], newTailStamp[0]);
}
} else {
E item = headNext.item;
int[] newHeadStamp = new int[]{headStamp[0] + 1};
if (head.compareAndSet(currentHead, headNext, headStamp[0], newHeadStamp[0])) {
return item;
}
}
}
}

public boolean isEmpty() {
int[] headStamp = new int[1];
int[] tailStamp = new int[1];
Node<E> currentHead = head.get(headStamp);
Node<E> currentTail = tail.get(tailStamp);

return currentHead == currentTail && currentHead.next.getReference() == null;
}

public static void main(String[] args) throws InterruptedException {
LockFreeTaskQueue<Integer> queue = new LockFreeTaskQueue<>();

// 生产者线程
Thread producer = new Thread(() -> {
for (int i = 0; i < 1000; i++) {
queue.offer(i);
try {
Thread.sleep(1);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
}
}
});

// 消费者线程
Thread consumer = new Thread(() -> {
int processed = 0;
while (processed < 1000) {
Integer item = queue.poll();
if (item != null) {
processed++;
if (processed % 100 == 0) {
System.out.println("Processed: " + processed);
}
} else {
Thread.yield();
}
}
System.out.println("Consumer finished processing all items");
});

producer.start();
consumer.start();

producer.join();
consumer.join();

System.out.println("Queue is empty: " + queue.isEmpty());
}
}

面试必备考点

🎯 高频面试题详解

1. 什么是 CAS?它有什么优缺点?

答案要点

  • 定义:CAS (Compare-And-Swap) 是一种原子操作,包含三个操作数:内存位置 V、预期值 A、新值 B
  • 操作:当且仅当 V 的值等于 A 时,才将 V 的值更新为 B
  • 优点:无锁、避免线程阻塞、高并发性能好
  • 缺点:ABA 问题、自旋消耗 CPU、只能保证单变量原子性

代码示例

// CAS 的简单实现模拟
public boolean compareAndSwap(int expected, int update) {
if (this.value == expected) {
this.value = update;
return true;
}
return false;
}

2. 什么是 ABA 问题?如何解决?

答案要点

  • 问题:值从 A -> B -> A,CAS 操作会成功,但中间可能发生过变化
  • 解决方案:使用版本号,如 AtomicStampedReference
  • 替代方案AtomicMarkableReference 使用布尔标记

代码演示

// ABA 问题示例
AtomicInteger atomicInt = new AtomicInteger(100);

// 线程1 准备 CAS(100, 200)
int current = atomicInt.get();

// 线程2: 100 -> 200 -> 100
atomicInt.compareAndSet(100, 200);
atomicInt.compareAndSet(200, 100);

// 线程1 执行 CAS,会成功(但可能不是期望的结果)
boolean success = atomicInt.compareAndSet(current, 200); // 会成功

// 解决方案:使用 AtomicStampedReference
AtomicStampedReference<Integer> stampedRef =
new AtomicStampedReference<>(100, 1);
int stamp = stampedRef.getStamp();
stampedRef.compareAndSet(100, 200, stamp, stamp + 1);

3. AtomicLong 和 LongAdder 的区别?如何选择?

答案要点

特性AtomicLongLongAdder
原理单一变量 + CAS分段累加 + 合并
适用场景读多写少写多读少
读取性能优秀需要遍历 cells
写入性能高竞争时下降高并发优秀
内存开销较大(多个 Cell)
精度精确实时最终一致

选择策略

  • 读多写少:使用 AtomicLong
  • 写多读少:使用 LongAdder
  • 需要精确控制:使用 AtomicLong
  • 高并发统计:使用 LongAdder

4. volatile 和原子类的区别?

答案要点

特性volatile原子类
可见性✅ 保证✅ 保证
原子性✅ 单次读写✅ 复合操作
禁止重排
复合操作❌ 不保证✅ 保证
CAS 支持

代码示例

// volatile 只能保证单次读写的原子性
class VolatileCounter {
private volatile int count = 0;

// 这个方法不是线程安全的!
public void increment() {
count++; // 实际上是 read-modify-write 三个操作
}
}

// 原子类保证复合操作的原子性
class AtomicCounter {
private final AtomicInteger count = new AtomicInteger(0);

// 这个方法是线程安全的
public void increment() {
count.incrementAndGet(); // 原子操作
}
}

5. 原子字段更新器的使用场景和限制?

答案要点

  • 场景:避免创建额外对象、内存敏感、框架开发
  • 限制:字段必须是 volatile、不能是 final、访问权限限制
  • 优势:减少内存开销、避免包装对象

代码示例

class Task {
private volatile int state;

private static final AtomicIntegerFieldUpdater<Task> STATE_UPDATER =
AtomicIntegerFieldUpdater.newUpdater(Task.class, "state");

public boolean tryComplete() {
return STATE_UPDATER.compareAndSet(this, 0, 1);
}
}

🎯 进阶面试题

6. 如何实现一个线程安全的单例模式?

// 使用原子类实现延迟初始化的单例
public class AtomicSingleton {
private static final AtomicReference<AtomicSingleton> INSTANCE =
new AtomicReference<>();

private AtomicSingleton() {}

public static AtomicSingleton getInstance() {
AtomicSingleton instance = INSTANCE.get();
if (instance == null) {
synchronized (AtomicSingleton.class) {
instance = INSTANCE.get();
if (instance == null) {
instance = new AtomicSingleton();
INSTANCE.set(instance);
}
}
}
return instance;
}
}

7. 如何实现一个高性能的计数器?

public class HighPerformanceCounter {
// 根据并发量选择合适的计数器
private final LongAdder adder = new LongAdder();
private final AtomicLong atomicLong = new AtomicLong(0);
private final boolean useAdder;

public HighPerformanceCounter(boolean highConcurrency) {
this.useAdder = highConcurrency;
}

public void increment() {
if (useAdder) {
adder.increment();
} else {
atomicLong.incrementAndGet();
}
}

public long get() {
return useAdder ? adder.sum() : atomicLong.get();
}
}

💡 面试技巧

  1. 由浅入深:先讲基本概念,再深入原理
  2. 对比分析:与其他方案对比,突出优势
  3. 实践案例:结合具体业务场景
  4. 源码理解:适当提及底层实现
  5. 性能考虑:讨论不同场景下的性能表现

总结与展望

🎯 核心要点回顾

  1. 原子类是轻量级并发工具:在简单操作场景下比锁更高效
  2. CAS 是核心机制:无锁算法的基础,但要了解其局限性
  3. 合理选择原子类类型:根据场景选择合适的实现
  4. 注意 ABA 问题:在关键场景下使用版本号解决
  5. 性能优化考虑:高并发场景下考虑使用 LongAdder

🚀 进阶学习方向

  1. 底层原理:深入学习 Unsafe 类和硬件原子指令
  2. 无锁数据结构:学习无锁队列、栈等高级数据结构
  3. 内存模型:深入理解 Java 内存模型和 happens-before 原则
  4. 性能调优:学习如何根据业务特点优化并发性能
  5. 源码分析:阅读 JUC 源码,理解设计思想

📚 推荐学习资源

  • 《Java 并发编程实战》- 经典并发编程书籍
  • 《Java 并发编程的艺术》- 深入理解并发机制
  • JDK 源码分析 - java.util.concurrent.atomic 包
  • OpenJDK 文档 - 了解最新特性

通过本章的学习,你已经掌握了 Java 原子类和 CAS 机制的核心知识。在实际开发中,要根据具体的业务场景选择合适的并发工具,既要保证线程安全,又要考虑性能表现。记住,没有银弹,只有在特定场景下的最优选择。