tom

Working on IO loop stuff.

......@@ -12,6 +12,10 @@ public final class TestTools {
private TestTools() {
}
public static void print(String msg) {
System.out.print(msg);
}
/**
* Suspends the current thread for a specified number of millis.
*
......
......@@ -20,7 +20,10 @@
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava-testlib</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.onlab.onos</groupId>
<artifactId>onlab-junit</artifactId>
</dependency>
<dependency>
<groupId>io.netty</groupId>
......
package org.onlab.util;
import java.util.Objects;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
/**
* Counting mechanism capable of tracking occurrences and rates.
*/
public class Counter {
private long total = 0;
private long start = System.currentTimeMillis();
private long end = 0;
/**
* Creates a new counter.
*/
public Counter() {
}
/**
* Creates a new counter in a specific state. If non-zero end time is
* specified, the counter will be frozen.
*
* @param start start time
* @param total total number of items to start with
* @param end end time; if non-ze
*/
public Counter(long start, long total, long end) {
checkArgument(start <= end, "Malformed interval: start > end");
checkArgument(total >= 0, "Total must be non-negative");
this.start = start;
this.total = total;
this.end = end;
}
/**
* Resets the counter, by zeroing out the count and restarting the timer.
*/
public synchronized void reset() {
end = 0;
total = 0;
start = System.currentTimeMillis();
}
/**
* Freezes the counter in the current state including the counts and times.
*/
public synchronized void freeze() {
end = System.currentTimeMillis();
}
/**
* Adds the specified number of occurrences to the counter. No-op if the
* counter has been frozen.
*
* @param count number of occurrences
*/
public synchronized void add(long count) {
checkArgument(count >= 0, "Count must be non-negative");
if (end == 0L) {
total += count;
}
}
/**
* Returns the number of occurrences per second.
*
* @return throughput in occurrences per second
*/
public synchronized double throughput() {
return total / duration();
}
/**
* Returns the total number of occurrences counted.
*
* @return number of counted occurrences
*/
public synchronized long total() {
return total;
}
/**
* Returns the duration expressed in fractional number of seconds.
*
* @return fractional number of seconds since the last reset
*/
public synchronized double duration() {
// Protect against 0 return by artificially setting duration to 1ms
long duration = (end == 0L ? System.currentTimeMillis() : end) - start;
return (duration == 0 ? 1 : duration) / 1000.0;
}
@Override
public int hashCode() {
return Objects.hash(total, start, end);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj instanceof Counter) {
final Counter other = (Counter) obj;
return Objects.equals(this.total, other.total) &&
Objects.equals(this.start, other.start) &&
Objects.equals(this.end, other.end);
}
return false;
}
@Override
public String toString() {
return toStringHelper(this)
.add("total", total)
.add("start", start)
.add("end", end)
.toString();
}
}
package org.onlab.util;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.onlab.junit.TestTools.delay;
/**
* Tests of the Counter utility.
*/
public class CounterTest {
@Test
public void basics() {
Counter tt = new Counter();
assertEquals("incorrect number of bytes", 0L, tt.total());
assertEquals("incorrect throughput", 0.0, tt.throughput(), 0.0001);
tt.add(1234567890L);
assertEquals("incorrect number of bytes", 1234567890L, tt.total());
assertTrue("incorrect throughput", 1234567890.0 < tt.throughput());
delay(1500);
tt.add(1L);
assertEquals("incorrect number of bytes", 1234567891L, tt.total());
assertTrue("incorrect throughput", 1234567891.0 > tt.throughput());
tt.reset();
assertEquals("incorrect number of bytes", 0L, tt.total());
assertEquals("incorrect throughput", 0.0, tt.throughput(), 0.0001);
}
@Test
public void freeze() {
Counter tt = new Counter();
tt.add(123L);
assertEquals("incorrect number of bytes", 123L, tt.total());
delay(1000);
tt.freeze();
tt.add(123L);
assertEquals("incorrect number of bytes", 123L, tt.total());
double d = tt.duration();
double t = tt.throughput();
assertEquals("incorrect duration", d, tt.duration(), 0.0001);
assertEquals("incorrect throughput", t, tt.throughput(), 0.0001);
assertEquals("incorrect number of bytes", 123L, tt.total());
}
@Test
public void reset() {
Counter tt = new Counter();
tt.add(123L);
assertEquals("incorrect number of bytes", 123L, tt.total());
double d = tt.duration();
double t = tt.throughput();
assertEquals("incorrect duration", d, tt.duration(), 0.0001);
assertEquals("incorrect throughput", t, tt.throughput(), 0.0001);
assertEquals("incorrect number of bytes", 123L, tt.total());
tt.reset();
assertEquals("incorrect throughput", 0.0, tt.throughput(), 0.0001);
assertEquals("incorrect number of bytes", 0, tt.total());
}
@Test
public void syntheticTracker() {
Counter tt = new Counter(5000, 1000, 6000);
assertEquals("incorrect duration", 1, tt.duration(), 0.1);
assertEquals("incorrect throughput", 1000, tt.throughput(), 1.0);
}
}
......@@ -22,6 +22,15 @@
<artifactId>guava-testlib</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.onlab.onos</groupId>
<artifactId>onlab-misc</artifactId>
</dependency>
<dependency>
<groupId>org.onlab.onos</groupId>
<artifactId>onlab-junit</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
......
package org.onlab.nio;
/**
* Base {@link Message} implementation.
*/
public abstract class AbstractMessage implements Message {
protected int length;
@Override
public int length() {
return length;
}
}
......@@ -28,7 +28,7 @@ public abstract class AcceptorLoop extends SelectorLoop {
public AcceptorLoop(long selectTimeout, SocketAddress listenAddress)
throws IOException {
super(selectTimeout);
this.listenAddress = checkNotNull(this.listenAddress, "Address cannot be null");
this.listenAddress = checkNotNull(listenAddress, "Address cannot be null");
}
/**
......
package org.onlab.nio;
import java.io.IOException;
import java.nio.channels.ByteChannel;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CopyOnWriteArraySet;
/**
* I/O loop for driving inbound &amp; outbound {@link Message} transfer via
* {@link MessageStream}.
*
* @param <M> message type
* @param <S> message stream type
*/
public abstract class IOLoop<M extends Message, S extends MessageStream<M>>
extends SelectorLoop {
// Queue of requests for new message streams to enter the IO loop processing.
private final Queue<NewStreamRequest> newStreamRequests = new ConcurrentLinkedQueue<>();
// Carries information required for admitting a new message stream.
private class NewStreamRequest {
private final S stream;
private final SelectableChannel channel;
private final int op;
public NewStreamRequest(S stream, SelectableChannel channel, int op) {
this.stream = stream;
this.channel = channel;
this.op = op;
}
}
// Set of message streams currently admitted into the IO loop.
private final Set<MessageStream<M>> streams = new CopyOnWriteArraySet<>();
/**
* Creates an IO loop with the given selection timeout.
*
* @param timeout selection timeout in milliseconds
* @throws IOException if the backing selector cannot be opened
*/
public IOLoop(long timeout) throws IOException {
super(timeout);
}
/**
* Creates a new message stream backed by the specified socket channel.
*
* @param byteChannel backing byte channel
* @return newly created message stream
*/
protected abstract S createStream(ByteChannel byteChannel);
/**
* Removes the specified message stream from the IO loop.
*
* @param stream message stream to remove
*/
void removeStream(MessageStream<M> stream) {
streams.remove(stream);
}
/**
* Processes the list of messages extracted from the specified message
* stream.
*
* @param messages non-empty list of received messages
* @param stream message stream from which the messages were extracted
*/
protected abstract void processMessages(List<M> messages, MessageStream<M> stream);
/**
* Completes connection request pending on the given selection key.
*
* @param key selection key holding the pending connect operation.
*/
protected void connect(SelectionKey key) {
try {
SocketChannel ch = (SocketChannel) key.channel();
ch.finishConnect();
} catch (IOException | IllegalStateException e) {
log.warn("Unable to complete connection", e);
}
if (key.isValid()) {
key.interestOps(SelectionKey.OP_READ);
}
}
/**
* Processes an IO operation pending on the specified key.
*
* @param key selection key holding the pending I/O operation.
*/
protected void processKeyOperation(SelectionKey key) {
@SuppressWarnings("unchecked")
S stream = (S) key.attachment();
try {
// If the key is not valid, bail out.
if (!key.isValid()) {
stream.close();
return;
}
// If there is a pending connect operation, complete it.
if (key.isConnectable()) {
connect(key);
}
// If there is a read operation, slurp as much data as possible.
if (key.isReadable()) {
List<M> messages = stream.read();
// No messages or failed flush imply disconnect; bail.
if (messages == null || stream.hadError()) {
stream.close();
return;
}
// If there were any messages read, process them.
if (!messages.isEmpty()) {
try {
processMessages(messages, stream);
} catch (RuntimeException e) {
onError(stream, e);
}
}
}
// If there are pending writes, flush them
if (key.isWritable()) {
stream.flushIfPossible();
}
// If there were any issued flushing, close the stream.
if (stream.hadError()) {
stream.close();
}
} catch (CancelledKeyException e) {
// Key was cancelled, so silently close the stream
stream.close();
} catch (IOException e) {
if (!stream.isClosed() && !isResetByPeer(e)) {
log.warn("Unable to process IO", e);
}
stream.close();
}
}
// Indicates whether or not this exception is caused by 'reset by peer'.
private boolean isResetByPeer(IOException e) {
Throwable cause = e.getCause();
return cause != null && cause instanceof IOException &&
cause.getMessage().contains("reset by peer");
}
/**
* Hook to allow intercept of any errors caused during message processing.
* Default behaviour is to rethrow the error.
*
* @param stream message stream involved in the error
* @param error the runtime exception
*/
protected void onError(S stream, RuntimeException error) {
throw error;
}
/**
* Admits a new message stream backed by the specified socket channel
* with a pending accept operation.
*
* @param channel backing socket channel
*/
public void acceptStream(SocketChannel channel) {
createAndAdmit(channel, SelectionKey.OP_READ);
}
/**
* Admits a new message stream backed by the specified socket channel
* with a pending connect operation.
*
* @param channel backing socket channel
*/
public void connectStream(SocketChannel channel) {
createAndAdmit(channel, SelectionKey.OP_CONNECT);
}
/**
* Creates a new message stream backed by the specified socket channel
* and admits it into the IO loop.
*
* @param channel socket channel
* @param op pending operations mask to be applied to the selection
* key as a set of initial interestedOps
*/
private synchronized void createAndAdmit(SocketChannel channel, int op) {
S stream = createStream(channel);
streams.add(stream);
newStreamRequests.add(new NewStreamRequest(stream, channel, op));
selector.wakeup();
}
/**
* Safely admits new streams into the IO loop.
*/
private void admitNewStreams() {
Iterator<NewStreamRequest> it = newStreamRequests.iterator();
while (isRunning() && it.hasNext()) {
try {
NewStreamRequest request = it.next();
it.remove();
SelectionKey key = request.channel.register(selector, request.op,
request.stream);
request.stream.setKey(key);
} catch (ClosedChannelException e) {
log.warn("Unable to admit new message stream", e);
}
}
}
@Override
protected void loop() throws IOException {
notifyReady();
// Keep going until told otherwise.
while (isRunning()) {
admitNewStreams();
// Process flushes & write selects on all streams
for (MessageStream<M> stream : streams) {
stream.flushIfWriteNotPending();
}
// Select keys and process them.
int count = selector.select(selectTimeout);
if (count > 0 && isRunning()) {
Iterator<SelectionKey> it = selector.selectedKeys().iterator();
while (it.hasNext()) {
SelectionKey key = it.next();
it.remove();
processKeyOperation(key);
}
}
}
}
/**
* Prunes the registered streams by discarding any stale ones.
*/
public synchronized void pruneStaleStreams() {
for (MessageStream<M> stream : streams) {
if (stream.isStale()) {
stream.close();
}
}
}
}
package org.onlab.nio;
/**
* Representation of a message transferred via {@link MessageStream}.
*/
public interface Message {
/**
* Gets the message length in bytes.
*
* @return number of bytes
*/
int length();
}
package org.onlab.nio;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.SelectionKey;
import java.util.ArrayList;
import java.util.List;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static java.lang.System.currentTimeMillis;
import static java.nio.ByteBuffer.allocateDirect;
/**
* Bi-directional message stream for transferring messages to &amp; from the
* network via two byte buffers.
*
* @param <M> message type
*/
public abstract class MessageStream<M extends Message> {
protected Logger log = LoggerFactory.getLogger(getClass());
private final IOLoop<M, ?> loop;
private final ByteChannel channel;
private final int maxIdleMillis;
private final ByteBuffer inbound;
private ByteBuffer outbound;
private SelectionKey key;
private volatile boolean closed = false;
private volatile boolean writePending;
private volatile boolean writeOccurred;
private Exception ioError;
private long lastActiveTime;
/**
* Creates a message stream associated with the specified IO loop and
* backed by the given byte channel.
*
* @param loop IO loop
* @param byteChannel backing byte channel
* @param bufferSize size of the backing byte buffers
* @param maxIdleMillis maximum number of millis the stream can be idle
* before it will be closed
*/
protected MessageStream(IOLoop<M, ?> loop, ByteChannel byteChannel,
int bufferSize, int maxIdleMillis) {
this.loop = checkNotNull(loop, "Loop cannot be null");
this.channel = checkNotNull(byteChannel, "Byte channel cannot be null");
checkArgument(maxIdleMillis > 0, "Idle time must be positive");
this.maxIdleMillis = maxIdleMillis;
inbound = allocateDirect(bufferSize);
outbound = allocateDirect(bufferSize);
}
/**
* Gets a single message from the specified byte buffer; this is
* to be done without manipulating the buffer via flip, reset or clear.
*
* @param buffer byte buffer
* @return read message or null if there are not enough bytes to read
* a complete message
*/
protected abstract M read(ByteBuffer buffer);
/**
* Puts the specified message into the specified byte buffer; this is
* to be done without manipulating the buffer via flip, reset or clear.
*
* @param message message to be write into the buffer
* @param buffer byte buffer
*/
protected abstract void write(M message, ByteBuffer buffer);
/**
* Closes the message buffer.
*/
public void close() {
synchronized (this) {
if (closed) {
return;
}
closed = true;
}
loop.removeStream(this);
if (key != null) {
try {
key.cancel();
key.channel().close();
} catch (IOException e) {
log.warn("Unable to close stream", e);
}
}
}
/**
* Indicates whether this buffer has been closed.
*
* @return true if this stream has been closed
*/
public synchronized boolean isClosed() {
return closed;
}
/**
* Returns the stream IO selection key.
*
* @return socket channel registration selection key
*/
public SelectionKey key() {
return key;
}
/**
* Binds the selection key to be used for driving IO operations on the stream.
*
* @param key IO selection key
*/
public void setKey(SelectionKey key) {
this.key = key;
this.lastActiveTime = currentTimeMillis();
}
/**
* Returns the IO loop to which this stream is bound.
*
* @return I/O loop used to drive this stream
*/
public IOLoop<M, ?> loop() {
return loop;
}
/**
* Indicates whether the any prior IO encountered an error.
*
* @return true if a write failed
*/
public boolean hadError() {
return ioError != null;
}
/**
* Gets the prior IO error, if one occurred.
*
* @return IO error; null if none occurred
*/
public Exception getError() {
return ioError;
}
/**
* Reads, withouth blocking, a list of messages from the stream.
* The list will be empty if there were not messages pending.
*
* @return list of messages or null if backing channel has been closed
* @throws IOException if messages could not be read
*/
public List<M> read() throws IOException {
try {
int read = channel.read(inbound);
if (read != -1) {
// Read the messages one-by-one and add them to the list.
List<M> messages = new ArrayList<>();
M message;
inbound.flip();
while ((message = read(inbound)) != null) {
messages.add(message);
}
inbound.compact();
// Mark the stream with current time to indicate liveness.
lastActiveTime = currentTimeMillis();
return messages;
}
return null;
} catch (Exception e) {
throw new IOException("Unable to read messages", e);
}
}
/**
* Writes the specified list of messages to the stream.
*
* @param messages list of messages to write
* @throws IOException if error occurred while writing the data
*/
public void write(List<M> messages) throws IOException {
synchronized (this) {
// First write all messages.
for (M m : messages) {
append(m);
}
flushUnlessAlreadyPlanningTo();
}
}
/**
* Writes the given message to the stream.
*
* @param message message to write
* @throws IOException if error occurred while writing the data
*/
public void write(M message) throws IOException {
synchronized (this) {
append(message);
flushUnlessAlreadyPlanningTo();
}
}
// Appends the specified message into the internal buffer, growing the
// buffer if required.
private void append(M message) {
// If the buffer does not have sufficient length double it.
while (outbound.remaining() < message.length()) {
doubleSize();
}
// Place the message into the buffer and bump the output trackers.
write(message, outbound);
}
// Forces a flush, unless one is planned already.
private void flushUnlessAlreadyPlanningTo() throws IOException {
if (!writeOccurred && !writePending) {
flush();
}
}
/**
* Flushes any pending writes.
*
* @throws IOException if flush failed
*/
public void flush() throws IOException {
synchronized (this) {
if (!writeOccurred && !writePending) {
outbound.flip();
try {
channel.write(outbound);
} catch (IOException e) {
if (!closed && !e.getMessage().equals("Broken pipe")) {
log.warn("Unable to write data", e);
ioError = e;
}
}
lastActiveTime = currentTimeMillis();
writeOccurred = true;
writePending = outbound.hasRemaining();
outbound.compact();
}
}
}
/**
* Indicates whether the stream has bytes to be written to the channel.
*
* @return true if there are bytes to be written
*/
boolean isWritePending() {
synchronized (this) {
return writePending;
}
}
/**
* Attempts to flush data, internal stream state and channel availability
* permitting. Invoked by the driver I/O loop during handling of writable
* selection key.
* <p/>
* Resets the internal state flags {@code writeOccurred} and
* {@code writePending}.
*
* @throws IOException if implicit flush failed
*/
void flushIfPossible() throws IOException {
synchronized (this) {
writePending = false;
writeOccurred = false;
if (outbound.position() > 0) {
flush();
}
}
key.interestOps(SelectionKey.OP_READ);
}
/**
* Attempts to flush data, internal stream state and channel availability
* permitting and if other writes are not pending. Invoked by the driver
* I/O loop prior to entering select wait. Resets the internal
* {@code writeOccurred} state flag.
*
* @throws IOException if implicit flush failed
*/
void flushIfWriteNotPending() throws IOException {
synchronized (this) {
writeOccurred = false;
if (!writePending && outbound.position() > 0) {
flush();
}
}
if (isWritePending()) {
key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
}
}
/**
* Doubles the size of the outbound buffer.
*/
private void doubleSize() {
ByteBuffer newBuffer = allocateDirect(outbound.capacity() * 2);
outbound.flip();
newBuffer.put(outbound);
outbound = newBuffer;
}
/**
* Returns the maximum number of milliseconds the stream is allowed
* without any read/write operations.
*
* @return number if millis of permissible idle time
*/
protected int maxIdleMillis() {
return maxIdleMillis;
}
/**
* Returns true if the given stream has gone stale.
*
* @return true if the stream is stale
*/
boolean isStale() {
return currentTimeMillis() - lastActiveTime > maxIdleMillis() && key != null;
}
}
package org.onlab.nio;
import org.junit.Before;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import static java.util.concurrent.Executors.newSingleThreadExecutor;
import static org.junit.Assert.fail;
import static org.onlab.util.Tools.namedThreads;
/**
* Base class for various NIO loop unit tests.
*/
public abstract class AbstractLoopTest {
protected static final long MAX_MS_WAIT = 500;
/** Block on specified countdown latch. Return when countdown reaches
* zero, or fail the test if the {@value #MAX_MS_WAIT} ms timeout expires.
*
* @param latch the latch
* @param label an identifying label
*/
protected void waitForLatch(CountDownLatch latch, String label) {
try {
boolean ok = latch.await(MAX_MS_WAIT, TimeUnit.MILLISECONDS);
if (!ok) {
fail("Latch await timeout! [" + label + "]");
}
} catch (InterruptedException e) {
System.out.println("Latch interrupt [" + label + "] : " + e);
fail("Unexpected interrupt");
}
}
protected ExecutorService exec;
@Before
public void setUp() {
exec = newSingleThreadExecutor(namedThreads("test"));
}
}
package org.onlab.nio;
import org.junit.Test;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.util.concurrent.CountDownLatch;
import static org.junit.Assert.assertEquals;
import static org.onlab.junit.TestTools.delay;
/**
* Unit tests for AcceptLoop.
*/
public class AcceptorLoopTest extends AbstractLoopTest {
private static final int PORT = 9876;
private static final SocketAddress SOCK_ADDR = new InetSocketAddress("127.0.0.1", PORT);
private static class MyAcceptLoop extends AcceptorLoop {
private final CountDownLatch loopStarted = new CountDownLatch(1);
private final CountDownLatch loopFinished = new CountDownLatch(1);
private final CountDownLatch runDone = new CountDownLatch(1);
private final CountDownLatch ceaseLatch = new CountDownLatch(1);
private int acceptCount = 0;
MyAcceptLoop() throws IOException {
super(500, SOCK_ADDR);
}
@Override
protected void acceptConnection(ServerSocketChannel ssc) throws IOException {
acceptCount++;
}
@Override
public void loop() throws IOException {
loopStarted.countDown();
super.loop();
loopFinished.countDown();
}
@Override
public void run() {
super.run();
runDone.countDown();
}
@Override
public void shutdown() {
super.shutdown();
ceaseLatch.countDown();
}
}
@Test
// @Ignore("Doesn't shut down the socket")
public void basic() throws IOException {
MyAcceptLoop myAccLoop = new MyAcceptLoop();
AcceptorLoop accLoop = myAccLoop;
exec.execute(accLoop);
waitForLatch(myAccLoop.loopStarted, "loopStarted");
delay(200); // take a quick nap
accLoop.shutdown();
waitForLatch(myAccLoop.loopFinished, "loopFinished");
waitForLatch(myAccLoop.runDone, "runDone");
assertEquals(0, myAccLoop.acceptCount);
}
}
package org.onlab.nio;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import java.net.InetAddress;
import java.text.DecimalFormat;
import java.util.Random;
import static org.onlab.junit.TestTools.delay;
/**
* Integration test for the select, accept and IO loops.
*/
public class IOLoopIntegrationTest {
private static final int MILLION = 1000000;
private static final int TIMEOUT = 60;
private static final int THREADS = 6;
private static final int MSG_COUNT = 20 * MILLION;
private static final int MSG_SIZE = 128;
private static final long MIN_MPS = 10 * MILLION;
@Before
public void warmUp() throws Exception {
try {
run(MILLION, MSG_SIZE, 15, 0);
} catch (Throwable e) {
System.err.println("Failed warmup but moving on.");
e.printStackTrace();
}
}
@Ignore
@Test
public void basic() throws Exception {
run(MSG_COUNT, MSG_SIZE, TIMEOUT, MIN_MPS);
}
private void run(int count, int size, int timeout, double mps) throws Exception {
DecimalFormat f = new DecimalFormat("#,##0");
System.out.print(f.format(count * THREADS) +
(mps > 0.0 ? " messages: " : " message warm-up: "));
// Setup the test on a random port to avoid intermittent test failures
// due to the port being already bound.
int port = StandaloneSpeedServer.PORT + new Random().nextInt(100);
InetAddress ip = InetAddress.getLoopbackAddress();
StandaloneSpeedServer sss = new StandaloneSpeedServer(ip, THREADS, size, port);
StandaloneSpeedClient ssc = new StandaloneSpeedClient(ip, THREADS, count, size, port);
sss.start();
ssc.start();
delay(250); // give the server and client a chance to go
ssc.await(timeout);
ssc.report();
delay(1000);
sss.stop();
sss.report();
// Note that the client and server will have potentially significantly
// differing rates. This is due to the wide variance in how tightly
// the throughput tracking starts & stops relative to to the short
// test duration.
// System.out.println(f.format(ssc.messages.throughput()) + " mps");
// // Make sure client sent everything.
// assertEquals("incorrect client message count sent",
// (long) count * THREADS, ssc.messages.total());
// assertEquals("incorrect client bytes count sent",
// (long) size * count * THREADS, ssc.bytes.total());
//
// // Make sure server received everything.
// assertEquals("incorrect server message count received",
// (long) count * THREADS, sss.messages.total());
// assertEquals("incorrect server bytes count received",
// (long) size * count * THREADS, sss.bytes.total());
//
// // Make sure speeds were reasonable.
// if (mps > 0.0) {
// assertAboveThreshold("insufficient client speed", mps,
// ssc.messages.throughput());
// assertAboveThreshold("insufficient server speed", mps / 2,
// sss.messages.throughput());
// }
}
}
This diff is collapsed. Click to expand it.
package org.onlab.nio;
import java.io.IOException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.spi.AbstractSelectableChannel;
import java.nio.channels.spi.AbstractSelector;
import java.util.Set;
/**
* A selector instrumented for unit tests.
*/
public class MockSelector extends AbstractSelector {
int wakeUpCount = 0;
/**
* Creates a mock selector, specifying null as the SelectorProvider.
*/
public MockSelector() {
super(null);
}
@Override
public String toString() {
return "{MockSelector: wake=" + wakeUpCount + "}";
}
@Override
protected void implCloseSelector() throws IOException {
}
@Override
protected SelectionKey register(AbstractSelectableChannel ch, int ops,
Object att) {
return null;
}
@Override
public Set<SelectionKey> keys() {
return null;
}
@Override
public Set<SelectionKey> selectedKeys() {
return null;
}
@Override
public int selectNow() throws IOException {
return 0;
}
@Override
public int select(long timeout) throws IOException {
return 0;
}
@Override
public int select() throws IOException {
return 0;
}
@Override
public Selector wakeup() {
wakeUpCount++;
return null;
}
}
package org.onlab.nio;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.channels.ByteChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import static org.onlab.junit.TestTools.delay;
import static org.onlab.util.Tools.namedThreads;
/**
* Auxiliary test fixture to measure speed of NIO-based channels.
*/
public class StandaloneSpeedClient {
private static Logger log = LoggerFactory.getLogger(StandaloneSpeedClient.class);
private final InetAddress ip;
private final int port;
private final int msgCount;
private final int msgLength;
private final List<CustomIOLoop> iloops = new ArrayList<>();
private final ExecutorService ipool;
private final ExecutorService wpool;
// ThroughputTracker messages;
// ThroughputTracker bytes;
/**
* Main entry point to launch the client.
*
* @param args command-line arguments
* @throws IOException if unable to connect to server
* @throws InterruptedException if latch wait gets interrupted
* @throws ExecutionException if wait gets interrupted
* @throws TimeoutException if timeout occurred while waiting for completion
*/
public static void main(String[] args)
throws IOException, InterruptedException, ExecutionException, TimeoutException {
InetAddress ip = InetAddress.getByName(args.length > 0 ? args[0] : "127.0.0.1");
int wc = args.length > 1 ? Integer.parseInt(args[1]) : 6;
int mc = args.length > 2 ? Integer.parseInt(args[2]) : 50 * 1000000;
int ml = args.length > 3 ? Integer.parseInt(args[3]) : 128;
int to = args.length > 4 ? Integer.parseInt(args[4]) : 30;
log.info("Setting up client with {} workers sending {} {}-byte messages to {} server... ",
wc, mc, ml, ip);
StandaloneSpeedClient sc = new StandaloneSpeedClient(ip, wc, mc, ml, StandaloneSpeedServer.PORT);
sc.start();
delay(2000);
sc.await(to);
sc.report();
System.exit(0);
}
/**
* Creates a speed client.
*
* @param ip ip address of server
* @param wc worker count
* @param mc message count to send per client
* @param ml message length in bytes
* @param port socket port
* @throws IOException if unable to create IO loops
*/
public StandaloneSpeedClient(InetAddress ip, int wc, int mc, int ml, int port) throws IOException {
this.ip = ip;
this.port = port;
this.msgCount = mc;
this.msgLength = ml;
this.wpool = Executors.newFixedThreadPool(wc, namedThreads("worker"));
this.ipool = Executors.newFixedThreadPool(wc, namedThreads("io-loop"));
for (int i = 0; i < wc; i++) {
iloops.add(new CustomIOLoop());
}
}
/**
* Starts the client workers.
*
* @throws IOException if unable to open connection
*/
public void start() throws IOException {
// messages = new ThroughputTracker();
// bytes = new ThroughputTracker();
// First start up all the IO loops
for (CustomIOLoop l : iloops) {
ipool.execute(l);
}
// // Wait for all of them to get going
// for (CustomIOLoop l : iloops)
// l.waitForStart(TIMEOUT);
// ... and Next open all connections; one-per-loop
for (CustomIOLoop l : iloops) {
openConnection(l);
}
}
/**
* Initiates open connection request and registers the pending socket
* channel with the given IO loop.
*
* @param loop loop with which the channel should be registered
* @throws IOException if the socket could not be open or connected
*/
private void openConnection(CustomIOLoop loop) throws IOException {
SocketAddress sa = new InetSocketAddress(ip, port);
SocketChannel ch = SocketChannel.open();
ch.configureBlocking(false);
loop.connectStream(ch);
ch.connect(sa);
}
/**
* Waits for the client workers to complete.
*
* @param secs timeout in seconds
* @throws ExecutionException if execution failed
* @throws InterruptedException if interrupt occurred while waiting
* @throws TimeoutException if timeout occurred
*/
public void await(int secs) throws InterruptedException,
ExecutionException, TimeoutException {
for (CustomIOLoop l : iloops) {
if (l.worker.task != null) {
l.worker.task.get(secs, TimeUnit.SECONDS);
}
}
// messages.freeze();
// bytes.freeze();
}
/**
* Reports on the accumulated throughput trackers.
*/
public void report() {
// DecimalFormat f = new DecimalFormat("#,##0");
// log.info("{} messages; {} bytes; {} mps; {} Mbs",
// f.format(messages.total()),
// f.format(bytes.total()),
// f.format(messages.throughput()),
// f.format(bytes.throughput() / (1024 * 128)));
}
// Loop for transfer of fixed-length messages
private class CustomIOLoop extends IOLoop<TestMessage, TestMessageStream> {
Worker worker = new Worker();
public CustomIOLoop() throws IOException {
super(500);
}
@Override
protected TestMessageStream createStream(ByteChannel channel) {
return new TestMessageStream(msgLength, channel, this);
}
@Override
protected synchronized void removeStream(MessageStream<TestMessage> b) {
super.removeStream(b);
// messages.add(b.inMessages().total());
// bytes.add(b.inBytes().total());
// b.inMessages().reset();
// b.inBytes().reset();
// log.info("Disconnected client; inbound {} mps, {} Mbps; outbound {} mps, {} Mbps",
// StandaloneSpeedServer.format.format(b.inMessages().throughput()),
// StandaloneSpeedServer.format.format(b.inBytes().throughput() / (1024 * 128)),
// StandaloneSpeedServer.format.format(b.outMessages().throughput()),
// StandaloneSpeedServer.format.format(b.outBytes().throughput() / (1024 * 128)));
}
@Override
protected void processMessages(List<TestMessage> messages,
MessageStream<TestMessage> b) {
worker.release(messages.size());
}
@Override
protected void connect(SelectionKey key) {
super.connect(key);
TestMessageStream b = (TestMessageStream) key.attachment();
Worker w = ((CustomIOLoop) b.loop()).worker;
w.pump(b);
}
}
/**
* Auxiliary worker to connect and pump batched messages using blocking I/O.
*/
private class Worker implements Runnable {
private static final int BATCH_SIZE = 1000;
private static final int PERMITS = 2 * BATCH_SIZE;
private TestMessageStream b;
private FutureTask<Worker> task;
// Stuff to throttle pump
private final Semaphore semaphore = new Semaphore(PERMITS);
private int msgWritten;
void pump(TestMessageStream b) {
this.b = b;
task = new FutureTask<>(this, this);
wpool.execute(task);
}
@Override
public void run() {
try {
log.info("Worker started...");
List<TestMessage> batch = new ArrayList<>();
for (int i = 0; i < BATCH_SIZE; i++) {
batch.add(new TestMessage(msgLength));
}
while (msgWritten < msgCount) {
msgWritten += writeBatch(b, batch);
}
// Now try to get all the permits back before sending poison pill
semaphore.acquireUninterruptibly(PERMITS);
b.close();
log.info("Worker done...");
} catch (IOException e) {
log.error("Worker unable to perform I/O", e);
}
}
private int writeBatch(TestMessageStream b, List<TestMessage> batch)
throws IOException {
int count = Math.min(BATCH_SIZE, msgCount - msgWritten);
acquire(count);
if (count == BATCH_SIZE) {
b.write(batch);
} else {
for (int i = 0; i < count; i++) {
b.write(batch.get(i));
}
}
return count;
}
// Release permits based on the specified number of message credits
private void release(int permits) {
semaphore.release(permits);
}
// Acquire permit for a single batch
private void acquire(int permits) {
semaphore.acquireUninterruptibly(permits);
}
}
}
package org.onlab.nio;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.channels.ByteChannel;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import static org.onlab.junit.TestTools.delay;
import static org.onlab.util.Tools.namedThreads;
/**
* Auxiliary test fixture to measure speed of NIO-based channels.
*/
public class StandaloneSpeedServer {
private static Logger log = LoggerFactory.getLogger(StandaloneSpeedServer.class);
private static final int PRUNE_FREQUENCY = 1000;
static final int PORT = 9876;
static final long TIMEOUT = 1000;
static final boolean SO_NO_DELAY = false;
static final int SO_SEND_BUFFER_SIZE = 1024 * 1024;
static final int SO_RCV_BUFFER_SIZE = 1024 * 1024;
static final DecimalFormat FORMAT = new DecimalFormat("#,##0");
private final AcceptorLoop aloop;
private final ExecutorService apool = Executors.newSingleThreadExecutor(namedThreads("accept"));
private final List<CustomIOLoop> iloops = new ArrayList<>();
private final ExecutorService ipool;
private final int workerCount;
private final int msgLength;
private int lastWorker = -1;
// ThroughputTracker messages;
// ThroughputTracker bytes;
/**
* Main entry point to launch the server.
*
* @param args command-line arguments
* @throws IOException if unable to crate IO loops
*/
public static void main(String[] args) throws IOException {
InetAddress ip = InetAddress.getByName(args.length > 0 ? args[0] : "127.0.0.1");
int wc = args.length > 1 ? Integer.parseInt(args[1]) : 6;
int ml = args.length > 2 ? Integer.parseInt(args[2]) : 128;
log.info("Setting up the server with {} workers, {} byte messages on {}... ",
wc, ml, ip);
StandaloneSpeedServer ss = new StandaloneSpeedServer(ip, wc, ml, PORT);
ss.start();
// Start pruning clients.
while (true) {
delay(PRUNE_FREQUENCY);
ss.prune();
}
}
/**
* Creates a speed server.
*
* @param ip optional ip of the adapter where to bind
* @param wc worker count
* @param ml message length in bytes
* @param port listen port
* @throws IOException if unable to create IO loops
*/
public StandaloneSpeedServer(InetAddress ip, int wc, int ml, int port) throws IOException {
this.workerCount = wc;
this.msgLength = ml;
this.ipool = Executors.newFixedThreadPool(workerCount, namedThreads("io-loop"));
this.aloop = new CustomAcceptLoop(new InetSocketAddress(ip, port));
for (int i = 0; i < workerCount; i++) {
iloops.add(new CustomIOLoop());
}
}
/**
* Start the server IO loops and kicks off throughput tracking.
*/
public void start() {
// messages = new ThroughputTracker();
// bytes = new ThroughputTracker();
for (CustomIOLoop l : iloops) {
ipool.execute(l);
}
apool.execute(aloop);
//
// for (CustomIOLoop l : iloops)
// l.waitForStart(TIMEOUT);
// aloop.waitForStart(TIMEOUT);
}
/**
* Stop the server IO loops and freezes throughput tracking.
*/
public void stop() {
aloop.shutdown();
for (CustomIOLoop l : iloops) {
l.shutdown();
}
// for (CustomIOLoop l : iloops)
// l.waitForFinish(TIMEOUT);
// aloop.waitForFinish(TIMEOUT);
//
// messages.freeze();
// bytes.freeze();
}
/**
* Reports on the accumulated throughput trackers.
*/
public void report() {
// DecimalFormat f = new DecimalFormat("#,##0");
// log.info("{} messages; {} bytes; {} mps; {} Mbs",
// f.format(messages.total()),
// f.format(bytes.total()),
// f.format(messages.throughput()),
// f.format(bytes.throughput() / (1024 * 128)));
}
/**
* Prunes the IO loops of stale message buffers.
*/
public void prune() {
for (CustomIOLoop l : iloops) {
l.pruneStaleStreams();
}
}
// Get the next worker to which a client should be assigned
private synchronized CustomIOLoop nextWorker() {
lastWorker = (lastWorker + 1) % workerCount;
return iloops.get(lastWorker);
}
// Loop for transfer of fixed-length messages
private class CustomIOLoop extends IOLoop<TestMessage, TestMessageStream> {
public CustomIOLoop() throws IOException {
super(500);
}
@Override
protected TestMessageStream createStream(ByteChannel channel) {
return new TestMessageStream(msgLength, channel, this);
}
@Override
protected void removeStream(MessageStream<TestMessage> stream) {
super.removeStream(stream);
//
// messages.add(b.inMessages().total());
// bytes.add(b.inBytes().total());
//
// log.info("Disconnected client; inbound {} mps, {} Mbps; outbound {} mps, {} Mbps",
// format.format(b.inMessages().throughput()),
// format.format(b.inBytes().throughput() / (1024 * 128)),
// format.format(b.outMessages().throughput()),
// format.format(b.outBytes().throughput() / (1024 * 128)));
}
@Override
protected void processMessages(List<TestMessage> messages,
MessageStream<TestMessage> stream) {
try {
stream.write(messages);
} catch (IOException e) {
log.error("Unable to echo messages", e);
}
}
}
// Loop for accepting client connections
private class CustomAcceptLoop extends AcceptorLoop {
public CustomAcceptLoop(SocketAddress address) throws IOException {
super(500, address);
}
@Override
protected void acceptConnection(ServerSocketChannel channel) throws IOException {
SocketChannel sc = channel.accept();
sc.configureBlocking(false);
Socket so = sc.socket();
so.setTcpNoDelay(SO_NO_DELAY);
so.setReceiveBufferSize(SO_RCV_BUFFER_SIZE);
so.setSendBufferSize(SO_SEND_BUFFER_SIZE);
nextWorker().acceptStream(sc);
log.info("Connected client");
}
}
}
package org.onlab.nio;
/**
* Fixed-length message.
*/
public class TestMessage extends AbstractMessage {
private final byte[] data;
/**
* Creates a new message with the specified length.
*
* @param length message length
*/
public TestMessage(int length) {
this.length = length;
data = new byte[length];
}
/**
* Creates a new message with the specified data.
*
* @param data message data
*/
TestMessage(byte[] data) {
this.length = data.length;
this.data = data;
}
/**
* Gets the backing byte array data.
*
* @return backing byte array
*/
public byte[] data() {
return data;
}
}
package org.onlab.nio;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
/**
* Fixed-length message transfer buffer.
*/
public class TestMessageStream extends MessageStream<TestMessage> {
private static final String E_WRONG_LEN = "Illegal message length: ";
private final int length;
/**
* Create a new buffer for transferring messages of the specified length.
*
* @param length message length
* @param ch backing channel
* @param loop driver loop
*/
public TestMessageStream(int length, ByteChannel ch,
IOLoop<TestMessage, ?> loop) {
super(loop, ch, 64 * 1024, 500);
this.length = length;
}
@Override
protected TestMessage read(ByteBuffer rb) {
if (rb.remaining() < length) {
return null;
}
TestMessage message = new TestMessage(length);
rb.get(message.data());
return message;
}
/**
* {@inheritDoc}
* <p/>
* This implementation enforces the message length against the buffer
* supported length.
*
* @throws IllegalArgumentException if message size does not match the
* supported buffer size
*/
@Override
protected void write(TestMessage message, ByteBuffer wb) {
if (message.length() != length) {
throw new IllegalArgumentException(E_WRONG_LEN + message.length());
}
wb.put(message.data());
}
}