Java时间轮算法的实现代码示例
程序员文章站
2024-02-25 11:15:34
考虑这样一个场景,现在有5000个任务,要让这5000个任务每隔5分中触发某个操作,怎么去实现这个需求。大部分人首先想到的是使用定时器,但是5000个任务,你就要用5000...
考虑这样一个场景,现在有5000个任务,要让这5000个任务每隔5分中触发某个操作,怎么去实现这个需求。大部分人首先想到的是使用定时器,但是5000个任务,你就要用5000个定时器,一个定时器就是一个线程,你懂了吧,这种方法肯定是不行的。
针对这个场景,催生了时间轮算法,时间轮到底是什么?我一贯的风格,自行谷歌去。大发慈悲,发个时间轮介绍你们看看,看文字和图就好了,代码不要看了,那个文章里的代码运行不起来,时间轮介绍。
看好了介绍,我们就开始动手吧。
开发环境:idea + jdk1.8 + maven
新建一个maven工程
创建如下的目录结构
不要忘了pom.xml中添加netty库
<dependencies> <dependency> <groupid>io.netty</groupid> <artifactid>netty-all</artifactid> <version>4.1.5.final</version> </dependency> </dependencies>
代码如下
timeout.java
package com.tanghuachun.timer; public interface timeout { timer timer(); timertask task(); boolean isexpired(); boolean iscancelled(); boolean cancel(); }
timer.java
package com.tanghuachun.timer; import java.util.set; import java.util.concurrent.timeunit; public interface timer { timeout newtimeout(timertask task, long delay, timeunit unit, string argv); set<timeout> stop(); }
timertask.java
package com.tanghuachun.timer; public interface timertask { void run(timeout timeout, string argv) throws exception; }
timerwheel.java
/* * copyright 2012 the netty project * * the netty project licenses this file to you under the apache license, * version 2.0 (the "license"); you may not use this file except in compliance * with the license. you may obtain a copy of the license at: * * http://www.apache.org/licenses/license-2.0 * * unless required by applicable law or agreed to in writing, software * distributed under the license is distributed on an "as is" basis, without * warranties or conditions of any kind, either express or implied. see the * license for the specific language governing permissions and limitations * under the license. */ package com.tanghuachun.timer; import io.netty.util.*; import io.netty.util.internal.platformdependent; import io.netty.util.internal.stringutil; import io.netty.util.internal.logging.internallogger; import io.netty.util.internal.logging.internalloggerfactory; import java.util.collections; import java.util.hashset; import java.util.queue; import java.util.set; import java.util.concurrent.countdownlatch; import java.util.concurrent.executors; import java.util.concurrent.threadfactory; import java.util.concurrent.timeunit; import java.util.concurrent.atomic.atomicintegerfieldupdater; public class timerwheel implements timer { static final internallogger logger = internalloggerfactory.getinstance(timerwheel.class); private static final resourceleakdetector<timerwheel> leakdetector = resourceleakdetectorfactory.instance() .newresourceleakdetector(timerwheel.class, 1, runtime.getruntime().availableprocessors() * 4l); private static final atomicintegerfieldupdater<timerwheel> worker_state_updater; static { atomicintegerfieldupdater<timerwheel> workerstateupdater = platformdependent.newatomicintegerfieldupdater(timerwheel.class, "workerstate"); if (workerstateupdater == null) { workerstateupdater = atomicintegerfieldupdater.newupdater(timerwheel.class, "workerstate"); } worker_state_updater = workerstateupdater; } private final resourceleak leak; private final worker worker = new worker(); private final thread workerthread; public static final int worker_state_init = 0; public static final int worker_state_started = 1; public static final int worker_state_shutdown = 2; @suppresswarnings({ "unused", "fieldmaybefinal", "redundantfieldinitialization" }) private volatile int workerstate = worker_state_init; // 0 - init, 1 - started, 2 - shut down private final long tickduration; private final hashedwheelbucket[] wheel; private final int mask; private final countdownlatch starttimeinitialized = new countdownlatch(1); private final queue<hashedwheeltimeout> timeouts = platformdependent.newmpscqueue(); private final queue<hashedwheeltimeout> cancelledtimeouts = platformdependent.newmpscqueue(); private volatile long starttime; /** * creates a new timer with the default thread factory * ({@link executors#defaultthreadfactory()}), default tick duration, and * default number of ticks per wheel. */ public timerwheel() { this(executors.defaultthreadfactory()); } /** * creates a new timer with the default thread factory * ({@link executors#defaultthreadfactory()}) and default number of ticks * per wheel. * * @param tickduration the duration between tick * @param unit the time unit of the {@code tickduration} * @throws nullpointerexception if {@code unit} is {@code null} * @throws illegalargumentexception if {@code tickduration} is <= 0 */ public timerwheel(long tickduration, timeunit unit) { this(executors.defaultthreadfactory(), tickduration, unit); } /** * creates a new timer with the default thread factory * ({@link executors#defaultthreadfactory()}). * * @param tickduration the duration between tick * @param unit the time unit of the {@code tickduration} * @param ticksperwheel the size of the wheel * @throws nullpointerexception if {@code unit} is {@code null} * @throws illegalargumentexception if either of {@code tickduration} and {@code ticksperwheel} is <= 0 */ public timerwheel(long tickduration, timeunit unit, int ticksperwheel) { this(executors.defaultthreadfactory(), tickduration, unit, ticksperwheel); } /** * creates a new timer with the default tick duration and default number of * ticks per wheel. * * @param threadfactory a {@link threadfactory} that creates a * background {@link thread} which is dedicated to * {@link timertask} execution. * @throws nullpointerexception if {@code threadfactory} is {@code null} */ public timerwheel(threadfactory threadfactory) { this(threadfactory, 100, timeunit.milliseconds); } /** * creates a new timer with the default number of ticks per wheel. * * @param threadfactory a {@link threadfactory} that creates a * background {@link thread} which is dedicated to * {@link timertask} execution. * @param tickduration the duration between tick * @param unit the time unit of the {@code tickduration} * @throws nullpointerexception if either of {@code threadfactory} and {@code unit} is {@code null} * @throws illegalargumentexception if {@code tickduration} is <= 0 */ public timerwheel( threadfactory threadfactory, long tickduration, timeunit unit) { this(threadfactory, tickduration, unit, 512); } /** * creates a new timer. * * @param threadfactory a {@link threadfactory} that creates a * background {@link thread} which is dedicated to * {@link timertask} execution. * @param tickduration the duration between tick * @param unit the time unit of the {@code tickduration} * @param ticksperwheel the size of the wheel * @throws nullpointerexception if either of {@code threadfactory} and {@code unit} is {@code null} * @throws illegalargumentexception if either of {@code tickduration} and {@code ticksperwheel} is <= 0 */ public timerwheel( threadfactory threadfactory, long tickduration, timeunit unit, int ticksperwheel) { this(threadfactory, tickduration, unit, ticksperwheel, true); } /** * creates a new timer. * * @param threadfactory a {@link threadfactory} that creates a * background {@link thread} which is dedicated to * {@link timertask} execution. * @param tickduration the duration between tick * @param unit the time unit of the {@code tickduration} * @param ticksperwheel the size of the wheel * @param leakdetection {@code true} if leak detection should be enabled always, if false it will only be enabled * if the worker thread is not a daemon thread. * @throws nullpointerexception if either of {@code threadfactory} and {@code unit} is {@code null} * @throws illegalargumentexception if either of {@code tickduration} and {@code ticksperwheel} is <= 0 */ public timerwheel( threadfactory threadfactory, long tickduration, timeunit unit, int ticksperwheel, boolean leakdetection) { if (threadfactory == null) { throw new nullpointerexception("threadfactory"); } if (unit == null) { throw new nullpointerexception("unit"); } if (tickduration <= 0) { throw new illegalargumentexception("tickduration must be greater than 0: " + tickduration); } if (ticksperwheel <= 0) { throw new illegalargumentexception("ticksperwheel must be greater than 0: " + ticksperwheel); } // normalize ticksperwheel to power of two and initialize the wheel. wheel = createwheel(ticksperwheel); mask = wheel.length - 1; // convert tickduration to nanos. this.tickduration = unit.tonanos(tickduration); // prevent overflow. if (this.tickduration >= long.max_value / wheel.length) { throw new illegalargumentexception(string.format( "tickduration: %d (expected: 0 < tickduration in nanos < %d", tickduration, long.max_value / wheel.length)); } workerthread = threadfactory.newthread(worker); leak = leakdetection || !workerthread.isdaemon() ? leakdetector.open(this) : null; } private static hashedwheelbucket[] createwheel(int ticksperwheel) { if (ticksperwheel <= 0) { throw new illegalargumentexception( "ticksperwheel must be greater than 0: " + ticksperwheel); } if (ticksperwheel > 1073741824) { throw new illegalargumentexception( "ticksperwheel may not be greater than 2^30: " + ticksperwheel); } ticksperwheel = normalizeticksperwheel(ticksperwheel); hashedwheelbucket[] wheel = new hashedwheelbucket[ticksperwheel]; for (int i = 0; i < wheel.length; i ++) { wheel[i] = new hashedwheelbucket(); } return wheel; } private static int normalizeticksperwheel(int ticksperwheel) { int normalizedticksperwheel = 1; while (normalizedticksperwheel < ticksperwheel) { normalizedticksperwheel <<= 1; } return normalizedticksperwheel; } /** * starts the background thread explicitly. the background thread will * start automatically on demand even if you did not call this method. * * @throws illegalstateexception if this timer has been * {@linkplain #stop() stopped} already */ public void start() { switch (worker_state_updater.get(this)) { case worker_state_init: if (worker_state_updater.compareandset(this, worker_state_init, worker_state_started)) { workerthread.start(); } break; case worker_state_started: break; case worker_state_shutdown: throw new illegalstateexception("cannot be started once stopped"); default: throw new error("invalid workerstate"); } // wait until the starttime is initialized by the worker. while (starttime == 0) { try { starttimeinitialized.await(); } catch (interruptedexception ignore) { // ignore - it will be ready very soon. } } } @override public set<timeout> stop() { if (thread.currentthread() == workerthread) { throw new illegalstateexception( timerwheel.class.getsimplename() + ".stop() cannot be called from " + timertask.class.getsimplename()); } if (!worker_state_updater.compareandset(this, worker_state_started, worker_state_shutdown)) { // workerstate can be 0 or 2 at this moment - let it always be 2. worker_state_updater.set(this, worker_state_shutdown); if (leak != null) { leak.close(); } return collections.emptyset(); } boolean interrupted = false; while (workerthread.isalive()) { workerthread.interrupt(); try { workerthread.join(100); } catch (interruptedexception ignored) { interrupted = true; } } if (interrupted) { thread.currentthread().interrupt(); } if (leak != null) { leak.close(); } return worker.unprocessedtimeouts(); } @override public timeout newtimeout(timertask task, long delay, timeunit unit, string argv) { if (task == null) { throw new nullpointerexception("task"); } if (unit == null) { throw new nullpointerexception("unit"); } start(); // add the timeout to the timeout queue which will be processed on the next tick. // during processing all the queued hashedwheeltimeouts will be added to the correct hashedwheelbucket. long deadline = system.nanotime() + unit.tonanos(delay) - starttime; hashedwheeltimeout timeout = new hashedwheeltimeout(this, task, deadline, argv); timeouts.add(timeout); return timeout; } private final class worker implements runnable { private final set<timeout> unprocessedtimeouts = new hashset<timeout>(); private long tick; @override public void run() { // initialize the starttime. starttime = system.nanotime(); if (starttime == 0) { // we use 0 as an indicator for the uninitialized value here, so make sure it's not 0 when initialized. starttime = 1; } // notify the other threads waiting for the initialization at start(). starttimeinitialized.countdown(); do { final long deadline = waitfornexttick(); if (deadline > 0) { int idx = (int) (tick & mask); processcancelledtasks(); hashedwheelbucket bucket = wheel[idx]; transfertimeoutstobuckets(); bucket.expiretimeouts(deadline); tick++; } } while (worker_state_updater.get(timerwheel.this) == worker_state_started); // fill the unprocessedtimeouts so we can return them from stop() method. for (hashedwheelbucket bucket: wheel) { bucket.cleartimeouts(unprocessedtimeouts); } for (;;) { hashedwheeltimeout timeout = timeouts.poll(); if (timeout == null) { break; } if (!timeout.iscancelled()) { unprocessedtimeouts.add(timeout); } } processcancelledtasks(); } private void transfertimeoutstobuckets() { // transfer only max. 100000 timeouts per tick to prevent a thread to stale the workerthread when it just // adds new timeouts in a loop. for (int i = 0; i < 100000; i++) { hashedwheeltimeout timeout = timeouts.poll(); if (timeout == null) { // all processed break; } if (timeout.state() == hashedwheeltimeout.st_cancelled) { // was cancelled in the meantime. continue; } long calculated = timeout.deadline / tickduration; timeout.remainingrounds = (calculated - tick) / wheel.length; final long ticks = math.max(calculated, tick); // ensure we don't schedule for past. int stopindex = (int) (ticks & mask); hashedwheelbucket bucket = wheel[stopindex]; bucket.addtimeout(timeout); } } private void processcancelledtasks() { for (;;) { hashedwheeltimeout timeout = cancelledtimeouts.poll(); if (timeout == null) { // all processed break; } try { timeout.remove(); } catch (throwable t) { if (logger.iswarnenabled()) { logger.warn("an exception was thrown while process a cancellation task", t); } } } } /** * calculate goal nanotime from starttime and current tick number, * then wait until that goal has been reached. * @return long.min_value if received a shutdown request, * current time otherwise (with long.min_value changed by +1) */ private long waitfornexttick() { long deadline = tickduration * (tick + 1); for (;;) { final long currenttime = system.nanotime() - starttime; long sleeptimems = (deadline - currenttime + 999999) / 1000000; if (sleeptimems <= 0) { if (currenttime == long.min_value) { return -long.max_value; } else { return currenttime; } } // check if we run on windows, as if thats the case we will need // to round the sleeptime as workaround for a bug that only affect // the jvm if it runs on windows. // // see https://github.com/netty/netty/issues/356 if (platformdependent.iswindows()) { sleeptimems = sleeptimems / 10 * 10; } try { thread.sleep(sleeptimems); } catch (interruptedexception ignored) { if (worker_state_updater.get(timerwheel.this) == worker_state_shutdown) { return long.min_value; } } } } public set<timeout> unprocessedtimeouts() { return collections.unmodifiableset(unprocessedtimeouts); } } private static final class hashedwheeltimeout implements timeout { private static final int st_init = 0; private static final int st_cancelled = 1; private static final int st_expired = 2; private static final atomicintegerfieldupdater<hashedwheeltimeout> state_updater; static { atomicintegerfieldupdater<hashedwheeltimeout> updater = platformdependent.newatomicintegerfieldupdater(hashedwheeltimeout.class, "state"); if (updater == null) { updater = atomicintegerfieldupdater.newupdater(hashedwheeltimeout.class, "state"); } state_updater = updater; } private final timerwheel timer; private final timertask task; private final long deadline; @suppresswarnings({"unused", "fieldmaybefinal", "redundantfieldinitialization" }) private volatile int state = st_init; // remainingrounds will be calculated and set by worker.transfertimeoutstobuckets() before the // hashedwheeltimeout will be added to the correct hashedwheelbucket. long remainingrounds; string argv; // this will be used to chain timeouts in hashedwheeltimerbucket via a double-linked-list. // as only the workerthread will act on it there is no need for synchronization / volatile. hashedwheeltimeout next; hashedwheeltimeout prev; // the bucket to which the timeout was added hashedwheelbucket bucket; hashedwheeltimeout(timerwheel timer, timertask task, long deadline, string argv) { this.timer = timer; this.task = task; this.deadline = deadline; this.argv = argv; } @override public timer timer() { return timer; } @override public timertask task() { return task; } @override public boolean cancel() { // only update the state it will be removed from hashedwheelbucket on next tick. if (!compareandsetstate(st_init, st_cancelled)) { return false; } // if a task should be canceled we put this to another queue which will be processed on each tick. // so this means that we will have a gc latency of max. 1 tick duration which is good enough. this way // we can make again use of our mpsclinkedqueue and so minimize the locking / overhead as much as possible. timer.cancelledtimeouts.add(this); return true; } void remove() { hashedwheelbucket bucket = this.bucket; if (bucket != null) { bucket.remove(this); } } public boolean compareandsetstate(int expected, int state) { return state_updater.compareandset(this, expected, state); } public int state() { return state; } @override public boolean iscancelled() { return state() == st_cancelled; } @override public boolean isexpired() { return state() == st_expired; } public void expire() { if (!compareandsetstate(st_init, st_expired)) { return; } try { task.run(this, argv); } catch (throwable t) { if (logger.iswarnenabled()) { logger.warn("an exception was thrown by " + timertask.class.getsimplename() + '.', t); } } } @override public string tostring() { final long currenttime = system.nanotime(); long remaining = deadline - currenttime + timer.starttime; stringbuilder buf = new stringbuilder(192) .append(stringutil.simpleclassname(this)) .append('(') .append("deadline: "); if (remaining > 0) { buf.append(remaining) .append(" ns later"); } else if (remaining < 0) { buf.append(-remaining) .append(" ns ago"); } else { buf.append("now"); } if (iscancelled()) { buf.append(", cancelled"); } return buf.append(", task: ") .append(task()) .append(')') .tostring(); } } /** * bucket that stores hashedwheeltimeouts. these are stored in a linked-list like datastructure to allow easy * removal of hashedwheeltimeouts in the middle. also the hashedwheeltimeout act as nodes themself and so no * extra object creation is needed. */ private static final class hashedwheelbucket { // used for the linked-list datastructure private hashedwheeltimeout head; private hashedwheeltimeout tail; /** * add {@link hashedwheeltimeout} to this bucket. */ public void addtimeout(hashedwheeltimeout timeout) { assert timeout.bucket == null; timeout.bucket = this; if (head == null) { head = tail = timeout; } else { tail.next = timeout; timeout.prev = tail; tail = timeout; } } /** * expire all {@link hashedwheeltimeout}s for the given {@code deadline}. */ public void expiretimeouts(long deadline) { hashedwheeltimeout timeout = head; // process all timeouts while (timeout != null) { boolean remove = false; if (timeout.remainingrounds <= 0) { if (timeout.deadline <= deadline) { timeout.expire(); } else { // the timeout was placed into a wrong slot. this should never happen. throw new illegalstateexception(string.format( "timeout.deadline (%d) > deadline (%d)", timeout.deadline, deadline)); } remove = true; } else if (timeout.iscancelled()) { remove = true; } else { timeout.remainingrounds --; } // store reference to next as we may null out timeout.next in the remove block. hashedwheeltimeout next = timeout.next; if (remove) { remove(timeout); } timeout = next; } } public void remove(hashedwheeltimeout timeout) { hashedwheeltimeout next = timeout.next; // remove timeout that was either processed or cancelled by updating the linked-list if (timeout.prev != null) { timeout.prev.next = next; } if (timeout.next != null) { timeout.next.prev = timeout.prev; } if (timeout == head) { // if timeout is also the tail we need to adjust the entry too if (timeout == tail) { tail = null; head = null; } else { head = next; } } else if (timeout == tail) { // if the timeout is the tail modify the tail to be the prev node. tail = timeout.prev; } // null out prev, next and bucket to allow for gc. timeout.prev = null; timeout.next = null; timeout.bucket = null; } /** * clear this bucket and return all not expired / cancelled {@link timeout}s. */ public void cleartimeouts(set<timeout> set) { for (;;) { hashedwheeltimeout timeout = polltimeout(); if (timeout == null) { return; } if (timeout.isexpired() || timeout.iscancelled()) { continue; } set.add(timeout); } } private hashedwheeltimeout polltimeout() { hashedwheeltimeout head = this.head; if (head == null) { return null; } hashedwheeltimeout next = head.next; if (next == null) { tail = this.head = null; } else { this.head = next; next.prev = null; } // null out prev and next to allow for gc. head.next = null; head.prev = null; head.bucket = null; return head; } } }
编写测试类main.java
package com.tanghuachun.timer; import java.util.concurrent.timeunit; /** * created by darren on 2016/11/17. */ public class main implements timertask{ final static timer timer = new timerwheel(); public static void main(string[] args) { timertask timertask = new main(); for (int i = 0; i < 10; i++) { timer.newtimeout(timertask, 5, timeunit.seconds, "" + i ); } } @override public void run(timeout timeout, string argv) throws exception { system.out.println("timeout, argv = " + argv ); } }
然后就可以看到运行结果啦。
(以maven的方式导入)。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。