From 83d65228e980a80a4c64298440608c4b93b138b5 Mon Sep 17 00:00:00 2001 From: Moxie Marlinspike Date: Sun, 22 Feb 2015 22:49:55 -0800 Subject: [PATCH] Fix for race conditioned caused by OkHttpClient NPE. We catch OkHttpClient exceptions to deal with bugs in their code, but in some cases that was leaving our state information in a bad situation. // FREEBIE --- .../textsecure/api/TextSecureMessagePipe.java | 2 +- .../websocket/OkHttpClientWrapper.java | 142 +++++++++++++++++ .../websocket/WebSocketConnection.java | 144 ++---------------- .../websocket/WebSocketEventListener.java | 9 ++ 4 files changed, 162 insertions(+), 135 deletions(-) create mode 100644 libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/OkHttpClientWrapper.java create mode 100644 libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/WebSocketEventListener.java diff --git a/libtextsecure/src/main/java/org/whispersystems/textsecure/api/TextSecureMessagePipe.java b/libtextsecure/src/main/java/org/whispersystems/textsecure/api/TextSecureMessagePipe.java index 64a7f7408c..6036f0dee5 100644 --- a/libtextsecure/src/main/java/org/whispersystems/textsecure/api/TextSecureMessagePipe.java +++ b/libtextsecure/src/main/java/org/whispersystems/textsecure/api/TextSecureMessagePipe.java @@ -51,7 +51,7 @@ public class TextSecureMessagePipe { } } - public void shutdown() throws IOException { + public void shutdown() { websocket.disconnect(); } diff --git a/libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/OkHttpClientWrapper.java b/libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/OkHttpClientWrapper.java new file mode 100644 index 0000000000..419ad407af --- /dev/null +++ b/libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/OkHttpClientWrapper.java @@ -0,0 +1,142 @@ +package org.whispersystems.textsecure.internal.websocket; + +import android.util.Log; + +import com.squareup.okhttp.OkHttpClient; +import com.squareup.okhttp.Request; +import com.squareup.okhttp.Response; +import com.squareup.okhttp.internal.ws.WebSocket; +import com.squareup.okhttp.internal.ws.WebSocketListener; + +import org.whispersystems.textsecure.api.push.TrustStore; +import org.whispersystems.textsecure.api.util.CredentialsProvider; +import org.whispersystems.textsecure.internal.util.BlacklistingTrustManager; +import org.whispersystems.textsecure.internal.util.Util; + +import java.io.IOException; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; + +import okio.Buffer; +import okio.BufferedSource; + +public class OkHttpClientWrapper implements WebSocketListener { + + private static final String TAG = OkHttpClientWrapper.class.getSimpleName(); + + private final String uri; + private final TrustStore trustStore; + private final CredentialsProvider credentialsProvider; + private final WebSocketEventListener listener; + + private WebSocket webSocket; + private boolean closed; + private boolean connected; + + public OkHttpClientWrapper(String uri, TrustStore trustStore, + CredentialsProvider credentialsProvider, + WebSocketEventListener listener) + { + Log.w(TAG, "Connecting to: " + uri); + + this.uri = uri; + this.trustStore = trustStore; + this.credentialsProvider = credentialsProvider; + this.listener = listener; + } + + public void connect() { + new Thread() { + @Override + public void run() { + int attempt = 0; + + while ((webSocket = newSocket()) != null) { + try { + Response response = webSocket.connect(OkHttpClientWrapper.this); + + if (response.code() == 101) { + synchronized (OkHttpClientWrapper.this) { + if (closed) webSocket.close(1000, "OK"); + else connected = true; + } + + listener.onConnected(); + return; + } + + Log.w(TAG, "WebSocket Response: " + response.code()); + } catch (IOException e) { + Log.w(TAG, e); + } + + Util.sleep(Math.min(++attempt * 200, TimeUnit.SECONDS.toMillis(15))); + } + } + }.start(); + } + + public synchronized void disconnect() { + Log.w(TAG, "Calling disconnect()..."); + try { + closed = true; + if (webSocket != null && connected) { + webSocket.close(1000, "OK"); + } + } catch (IOException e) { + Log.w(TAG, e); + } + } + + public void sendMessage(byte[] message) throws IOException { + webSocket.sendMessage(WebSocket.PayloadType.BINARY, new Buffer().write(message)); + } + + @Override + public void onMessage(BufferedSource payload, WebSocket.PayloadType type) throws IOException { + Log.w(TAG, "onMessage: " + type); + if (type.equals(WebSocket.PayloadType.BINARY)) { + listener.onMessage(payload.readByteArray()); + } + + payload.close(); + } + + @Override + public void onClose(int code, String reason) { + Log.w(TAG, String.format("onClose(%d, %s)", code, reason)); + listener.onClose(); + } + + @Override + public void onFailure(IOException e) { + Log.w(TAG, e); + listener.onClose(); + } + + private synchronized WebSocket newSocket() { + if (closed) return null; + + String filledUri = String.format(uri, credentialsProvider.getUser(), credentialsProvider.getPassword()); + SSLSocketFactory socketFactory = createTlsSocketFactory(trustStore); + + return WebSocket.newWebSocket(new OkHttpClient().setSslSocketFactory(socketFactory), + new Request.Builder().url(filledUri).build()); + } + + private SSLSocketFactory createTlsSocketFactory(TrustStore trustStore) { + try { + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, BlacklistingTrustManager.createFor(trustStore), null); + + return context.getSocketFactory(); + } catch (NoSuchAlgorithmException | KeyManagementException e) { + throw new AssertionError(e); + } + } + +} diff --git a/libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/WebSocketConnection.java b/libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/WebSocketConnection.java index 425a8a9dc7..25b18695e4 100644 --- a/libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/WebSocketConnection.java +++ b/libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/WebSocketConnection.java @@ -3,36 +3,22 @@ package org.whispersystems.textsecure.internal.websocket; import android.util.Log; import com.google.protobuf.InvalidProtocolBufferException; -import com.squareup.okhttp.OkHttpClient; -import com.squareup.okhttp.Request; -import com.squareup.okhttp.Response; -import com.squareup.okhttp.internal.ws.WebSocket; -import com.squareup.okhttp.internal.ws.WebSocketListener; import org.whispersystems.textsecure.api.push.TrustStore; import org.whispersystems.textsecure.api.util.CredentialsProvider; -import org.whispersystems.textsecure.internal.util.BlacklistingTrustManager; import org.whispersystems.textsecure.internal.util.Util; import java.io.IOException; -import java.security.KeyManagementException; -import java.security.NoSuchAlgorithmException; import java.util.LinkedList; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLSocketFactory; - -import okio.Buffer; -import okio.BufferedSource; - import static org.whispersystems.textsecure.internal.websocket.WebSocketProtos.WebSocketMessage; import static org.whispersystems.textsecure.internal.websocket.WebSocketProtos.WebSocketRequestMessage; import static org.whispersystems.textsecure.internal.websocket.WebSocketProtos.WebSocketResponseMessage; -public class WebSocketConnection { +public class WebSocketConnection implements WebSocketEventListener { private static final String TAG = WebSocketConnection.class.getSimpleName(); @@ -42,8 +28,8 @@ public class WebSocketConnection { private final TrustStore trustStore; private final CredentialsProvider credentialsProvider; - private Client client; - private KeepAliveSender keepAliveSender; + private OkHttpClientWrapper client; + private KeepAliveSender keepAliveSender; public WebSocketConnection(String httpUri, TrustStore trustStore, CredentialsProvider credentialsProvider) { this.trustStore = trustStore; @@ -56,12 +42,12 @@ public class WebSocketConnection { Log.w(TAG, "WSC connect()..."); if (client == null) { - client = new Client(wsUri, trustStore, credentialsProvider); + client = new OkHttpClientWrapper(wsUri, trustStore, credentialsProvider, this); client.connect(); } } - public synchronized void disconnect() throws IOException { + public synchronized void disconnect() { Log.w(TAG, "WSC disconnect()..."); if (client != null) { @@ -119,7 +105,7 @@ public class WebSocketConnection { } } - private synchronized void onMessage(byte[] payload) { + public synchronized void onMessage(byte[] payload) { Log.w(TAG, "WSC onMessage()"); try { WebSocketMessage message = WebSocketMessage.parseFrom(payload); @@ -136,10 +122,11 @@ public class WebSocketConnection { } } - private synchronized void onClose() { + public synchronized void onClose() { Log.w(TAG, "onClose()..."); if (client != null) { + client.disconnect(); client = null; connect(); } @@ -152,8 +139,8 @@ public class WebSocketConnection { notifyAll(); } - private synchronized void onConnected() { - if (client != null) { + public synchronized void onConnected() { + if (client != null && keepAliveSender == null) { keepAliveSender = new KeepAliveSender(); keepAliveSender.start(); } @@ -163,117 +150,6 @@ public class WebSocketConnection { return System.currentTimeMillis() - startTime; } - private class Client implements WebSocketListener { - - private final String uri; - private final TrustStore trustStore; - private final CredentialsProvider credentialsProvider; - - private WebSocket webSocket; - private boolean closed; - - public Client(String uri, TrustStore trustStore, CredentialsProvider credentialsProvider) { - Log.w(TAG, "Connecting to: " + uri); - - this.uri = uri; - this.trustStore = trustStore; - this.credentialsProvider = credentialsProvider; - } - - public void connect() { - new Thread() { - @Override - public void run() { - int attempt = 0; - - while (newSocket()) { - try { - Response response; - - try { - response = webSocket.connect(Client.this); - } catch (IllegalStateException e) { - throw new IOException(e); - } - - if (response.code() == 101) { - onConnected(); - return; - } - - Log.w(TAG, "WebSocket Response: " + response.code()); - } catch (IOException e) { - Log.w(TAG, e); - } - - Util.sleep(Math.min(++attempt * 200, TimeUnit.SECONDS.toMillis(15))); - } - } - }.start(); - } - - public synchronized void disconnect() { - Log.w(TAG, "Calling disconnect()..."); - try { - closed = true; - if (webSocket != null) { - webSocket.close(1000, "OK"); - } - } catch (IOException e) { - Log.w(TAG, e); - } - } - - public void sendMessage(byte[] message) throws IOException { - webSocket.sendMessage(WebSocket.PayloadType.BINARY, new Buffer().write(message)); - } - - @Override - public void onMessage(BufferedSource payload, WebSocket.PayloadType type) throws IOException { - Log.w(TAG, "onMessage: " + type); - if (type.equals(WebSocket.PayloadType.BINARY)) { - WebSocketConnection.this.onMessage(payload.readByteArray()); - } - - payload.close(); - } - - @Override - public void onClose(int code, String reason) { - Log.w(TAG, String.format("onClose(%d, %s)", code, reason)); - WebSocketConnection.this.onClose(); - } - - @Override - public void onFailure(IOException e) { - Log.w(TAG, e); - WebSocketConnection.this.onClose(); - } - - private synchronized boolean newSocket() { - if (closed) return false; - - String filledUri = String.format(uri, credentialsProvider.getUser(), credentialsProvider.getPassword()); - SSLSocketFactory socketFactory = createTlsSocketFactory(trustStore); - - this.webSocket = WebSocket.newWebSocket(new OkHttpClient().setSslSocketFactory(socketFactory), - new Request.Builder().url(filledUri).build()); - - return true; - } - - private SSLSocketFactory createTlsSocketFactory(TrustStore trustStore) { - try { - SSLContext context = SSLContext.getInstance("TLS"); - context.init(null, BlacklistingTrustManager.createFor(trustStore), null); - - return context.getSocketFactory(); - } catch (NoSuchAlgorithmException | KeyManagementException e) { - throw new AssertionError(e); - } - } - } - private class KeepAliveSender extends Thread { private AtomicBoolean stop = new AtomicBoolean(false); diff --git a/libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/WebSocketEventListener.java b/libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/WebSocketEventListener.java new file mode 100644 index 0000000000..0e87b9417e --- /dev/null +++ b/libtextsecure/src/main/java/org/whispersystems/textsecure/internal/websocket/WebSocketEventListener.java @@ -0,0 +1,9 @@ +package org.whispersystems.textsecure.internal.websocket; + +public interface WebSocketEventListener { + + public void onMessage(byte[] payload); + public void onClose(); + public void onConnected(); + +}