Fix for race conditioned caused by OkHttpClient NPE.

We catch OkHttpClient exceptions to deal with bugs in their code,
but in some cases that was leaving our state information in a
bad situation.

// FREEBIE
This commit is contained in:
Moxie Marlinspike 2015-02-22 22:49:55 -08:00
parent 8a2caeef3d
commit 83d65228e9
4 changed files with 162 additions and 135 deletions

View File

@ -51,7 +51,7 @@ public class TextSecureMessagePipe {
}
}
public void shutdown() throws IOException {
public void shutdown() {
websocket.disconnect();
}

View File

@ -0,0 +1,142 @@
package org.whispersystems.textsecure.internal.websocket;
import android.util.Log;
import com.squareup.okhttp.OkHttpClient;
import com.squareup.okhttp.Request;
import com.squareup.okhttp.Response;
import com.squareup.okhttp.internal.ws.WebSocket;
import com.squareup.okhttp.internal.ws.WebSocketListener;
import org.whispersystems.textsecure.api.push.TrustStore;
import org.whispersystems.textsecure.api.util.CredentialsProvider;
import org.whispersystems.textsecure.internal.util.BlacklistingTrustManager;
import org.whispersystems.textsecure.internal.util.Util;
import java.io.IOException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import okio.Buffer;
import okio.BufferedSource;
public class OkHttpClientWrapper implements WebSocketListener {
private static final String TAG = OkHttpClientWrapper.class.getSimpleName();
private final String uri;
private final TrustStore trustStore;
private final CredentialsProvider credentialsProvider;
private final WebSocketEventListener listener;
private WebSocket webSocket;
private boolean closed;
private boolean connected;
public OkHttpClientWrapper(String uri, TrustStore trustStore,
CredentialsProvider credentialsProvider,
WebSocketEventListener listener)
{
Log.w(TAG, "Connecting to: " + uri);
this.uri = uri;
this.trustStore = trustStore;
this.credentialsProvider = credentialsProvider;
this.listener = listener;
}
public void connect() {
new Thread() {
@Override
public void run() {
int attempt = 0;
while ((webSocket = newSocket()) != null) {
try {
Response response = webSocket.connect(OkHttpClientWrapper.this);
if (response.code() == 101) {
synchronized (OkHttpClientWrapper.this) {
if (closed) webSocket.close(1000, "OK");
else connected = true;
}
listener.onConnected();
return;
}
Log.w(TAG, "WebSocket Response: " + response.code());
} catch (IOException e) {
Log.w(TAG, e);
}
Util.sleep(Math.min(++attempt * 200, TimeUnit.SECONDS.toMillis(15)));
}
}
}.start();
}
public synchronized void disconnect() {
Log.w(TAG, "Calling disconnect()...");
try {
closed = true;
if (webSocket != null && connected) {
webSocket.close(1000, "OK");
}
} catch (IOException e) {
Log.w(TAG, e);
}
}
public void sendMessage(byte[] message) throws IOException {
webSocket.sendMessage(WebSocket.PayloadType.BINARY, new Buffer().write(message));
}
@Override
public void onMessage(BufferedSource payload, WebSocket.PayloadType type) throws IOException {
Log.w(TAG, "onMessage: " + type);
if (type.equals(WebSocket.PayloadType.BINARY)) {
listener.onMessage(payload.readByteArray());
}
payload.close();
}
@Override
public void onClose(int code, String reason) {
Log.w(TAG, String.format("onClose(%d, %s)", code, reason));
listener.onClose();
}
@Override
public void onFailure(IOException e) {
Log.w(TAG, e);
listener.onClose();
}
private synchronized WebSocket newSocket() {
if (closed) return null;
String filledUri = String.format(uri, credentialsProvider.getUser(), credentialsProvider.getPassword());
SSLSocketFactory socketFactory = createTlsSocketFactory(trustStore);
return WebSocket.newWebSocket(new OkHttpClient().setSslSocketFactory(socketFactory),
new Request.Builder().url(filledUri).build());
}
private SSLSocketFactory createTlsSocketFactory(TrustStore trustStore) {
try {
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, BlacklistingTrustManager.createFor(trustStore), null);
return context.getSocketFactory();
} catch (NoSuchAlgorithmException | KeyManagementException e) {
throw new AssertionError(e);
}
}
}

View File

@ -3,36 +3,22 @@ package org.whispersystems.textsecure.internal.websocket;
import android.util.Log;
import com.google.protobuf.InvalidProtocolBufferException;
import com.squareup.okhttp.OkHttpClient;
import com.squareup.okhttp.Request;
import com.squareup.okhttp.Response;
import com.squareup.okhttp.internal.ws.WebSocket;
import com.squareup.okhttp.internal.ws.WebSocketListener;
import org.whispersystems.textsecure.api.push.TrustStore;
import org.whispersystems.textsecure.api.util.CredentialsProvider;
import org.whispersystems.textsecure.internal.util.BlacklistingTrustManager;
import org.whispersystems.textsecure.internal.util.Util;
import java.io.IOException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.LinkedList;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import okio.Buffer;
import okio.BufferedSource;
import static org.whispersystems.textsecure.internal.websocket.WebSocketProtos.WebSocketMessage;
import static org.whispersystems.textsecure.internal.websocket.WebSocketProtos.WebSocketRequestMessage;
import static org.whispersystems.textsecure.internal.websocket.WebSocketProtos.WebSocketResponseMessage;
public class WebSocketConnection {
public class WebSocketConnection implements WebSocketEventListener {
private static final String TAG = WebSocketConnection.class.getSimpleName();
@ -42,7 +28,7 @@ public class WebSocketConnection {
private final TrustStore trustStore;
private final CredentialsProvider credentialsProvider;
private Client client;
private OkHttpClientWrapper client;
private KeepAliveSender keepAliveSender;
public WebSocketConnection(String httpUri, TrustStore trustStore, CredentialsProvider credentialsProvider) {
@ -56,12 +42,12 @@ public class WebSocketConnection {
Log.w(TAG, "WSC connect()...");
if (client == null) {
client = new Client(wsUri, trustStore, credentialsProvider);
client = new OkHttpClientWrapper(wsUri, trustStore, credentialsProvider, this);
client.connect();
}
}
public synchronized void disconnect() throws IOException {
public synchronized void disconnect() {
Log.w(TAG, "WSC disconnect()...");
if (client != null) {
@ -119,7 +105,7 @@ public class WebSocketConnection {
}
}
private synchronized void onMessage(byte[] payload) {
public synchronized void onMessage(byte[] payload) {
Log.w(TAG, "WSC onMessage()");
try {
WebSocketMessage message = WebSocketMessage.parseFrom(payload);
@ -136,10 +122,11 @@ public class WebSocketConnection {
}
}
private synchronized void onClose() {
public synchronized void onClose() {
Log.w(TAG, "onClose()...");
if (client != null) {
client.disconnect();
client = null;
connect();
}
@ -152,8 +139,8 @@ public class WebSocketConnection {
notifyAll();
}
private synchronized void onConnected() {
if (client != null) {
public synchronized void onConnected() {
if (client != null && keepAliveSender == null) {
keepAliveSender = new KeepAliveSender();
keepAliveSender.start();
}
@ -163,117 +150,6 @@ public class WebSocketConnection {
return System.currentTimeMillis() - startTime;
}
private class Client implements WebSocketListener {
private final String uri;
private final TrustStore trustStore;
private final CredentialsProvider credentialsProvider;
private WebSocket webSocket;
private boolean closed;
public Client(String uri, TrustStore trustStore, CredentialsProvider credentialsProvider) {
Log.w(TAG, "Connecting to: " + uri);
this.uri = uri;
this.trustStore = trustStore;
this.credentialsProvider = credentialsProvider;
}
public void connect() {
new Thread() {
@Override
public void run() {
int attempt = 0;
while (newSocket()) {
try {
Response response;
try {
response = webSocket.connect(Client.this);
} catch (IllegalStateException e) {
throw new IOException(e);
}
if (response.code() == 101) {
onConnected();
return;
}
Log.w(TAG, "WebSocket Response: " + response.code());
} catch (IOException e) {
Log.w(TAG, e);
}
Util.sleep(Math.min(++attempt * 200, TimeUnit.SECONDS.toMillis(15)));
}
}
}.start();
}
public synchronized void disconnect() {
Log.w(TAG, "Calling disconnect()...");
try {
closed = true;
if (webSocket != null) {
webSocket.close(1000, "OK");
}
} catch (IOException e) {
Log.w(TAG, e);
}
}
public void sendMessage(byte[] message) throws IOException {
webSocket.sendMessage(WebSocket.PayloadType.BINARY, new Buffer().write(message));
}
@Override
public void onMessage(BufferedSource payload, WebSocket.PayloadType type) throws IOException {
Log.w(TAG, "onMessage: " + type);
if (type.equals(WebSocket.PayloadType.BINARY)) {
WebSocketConnection.this.onMessage(payload.readByteArray());
}
payload.close();
}
@Override
public void onClose(int code, String reason) {
Log.w(TAG, String.format("onClose(%d, %s)", code, reason));
WebSocketConnection.this.onClose();
}
@Override
public void onFailure(IOException e) {
Log.w(TAG, e);
WebSocketConnection.this.onClose();
}
private synchronized boolean newSocket() {
if (closed) return false;
String filledUri = String.format(uri, credentialsProvider.getUser(), credentialsProvider.getPassword());
SSLSocketFactory socketFactory = createTlsSocketFactory(trustStore);
this.webSocket = WebSocket.newWebSocket(new OkHttpClient().setSslSocketFactory(socketFactory),
new Request.Builder().url(filledUri).build());
return true;
}
private SSLSocketFactory createTlsSocketFactory(TrustStore trustStore) {
try {
SSLContext context = SSLContext.getInstance("TLS");
context.init(null, BlacklistingTrustManager.createFor(trustStore), null);
return context.getSocketFactory();
} catch (NoSuchAlgorithmException | KeyManagementException e) {
throw new AssertionError(e);
}
}
}
private class KeepAliveSender extends Thread {
private AtomicBoolean stop = new AtomicBoolean(false);

View File

@ -0,0 +1,9 @@
package org.whispersystems.textsecure.internal.websocket;
public interface WebSocketEventListener {
public void onMessage(byte[] payload);
public void onClose();
public void onConnected();
}