diff --git a/bramble-core/src/main/java/org/briarproject/bramble/connection/ConnectionRegistryImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/connection/ConnectionRegistryImpl.java index 0f580a4cd..3563c4850 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/connection/ConnectionRegistryImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/connection/ConnectionRegistryImpl.java @@ -1,6 +1,5 @@ package org.briarproject.bramble.connection; -import org.briarproject.bramble.api.Multiset; import org.briarproject.bramble.api.Pair; import org.briarproject.bramble.api.connection.ConnectionRegistry; import org.briarproject.bramble.api.contact.ContactId; @@ -22,6 +21,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; import java.util.logging.Logger; @@ -41,22 +41,30 @@ class ConnectionRegistryImpl implements ConnectionRegistry { getLogger(ConnectionRegistryImpl.class.getName()); private final EventBus eventBus; - private final List> preferences; + private final Map> betterTransports; private final Object lock = new Object(); @GuardedBy("lock") - private final Map> contactConnections; - @GuardedBy("lock") - private final Multiset contactCounts; + private final Map> contactConnections; @GuardedBy("lock") private final Set connectedPendingContacts; @Inject ConnectionRegistryImpl(EventBus eventBus, PluginConfig pluginConfig) { this.eventBus = eventBus; - preferences = pluginConfig.getTransportPreferences(); + betterTransports = new HashMap<>(); + for (Pair pair : + pluginConfig.getTransportPreferences()) { + TransportId better = pair.getFirst(); + TransportId worse = pair.getSecond(); + List list = betterTransports.get(worse); + if (list == null) { + list = new ArrayList<>(); + betterTransports.put(worse, list); + } + list.add(better); + } contactConnections = new HashMap<>(); - contactCounts = new Multiset<>(); connectedPendingContacts = new HashSet<>(); } @@ -69,13 +77,13 @@ class ConnectionRegistryImpl implements ConnectionRegistry { } boolean firstConnection = false; synchronized (lock) { - Multiset m = contactConnections.get(t); - if (m == null) { - m = new Multiset<>(); - contactConnections.put(t, m); + List recs = contactConnections.get(c); + if (recs == null) { + recs = new ArrayList<>(); + contactConnections.put(c, recs); } - m.add(c); - if (contactCounts.add(c) == 1) firstConnection = true; + if (recs.isEmpty()) firstConnection = true; + recs.add(new ConnectionRecord(t)); } eventBus.broadcast(new ConnectionOpenedEvent(c, t, incoming)); if (firstConnection) { @@ -93,11 +101,10 @@ class ConnectionRegistryImpl implements ConnectionRegistry { } boolean lastConnection = false; synchronized (lock) { - Multiset m = contactConnections.get(t); - if (m == null || !m.contains(c)) + List recs = contactConnections.get(c); + if (recs == null || !recs.remove(new ConnectionRecord(t))) throw new IllegalArgumentException(); - m.remove(c); - if (contactCounts.remove(c) == 0) lastConnection = true; + if (recs.isEmpty()) lastConnection = true; } eventBus.broadcast(new ConnectionClosedEvent(c, t, incoming)); if (lastConnection) { @@ -109,12 +116,20 @@ class ConnectionRegistryImpl implements ConnectionRegistry { @Override public Collection getConnectedContacts(TransportId t) { synchronized (lock) { - Multiset m = contactConnections.get(t); - if (m == null) return emptyList(); - List ids = new ArrayList<>(m.keySet()); - if (LOG.isLoggable(INFO)) - LOG.info(ids.size() + " contacts connected: " + t); - return ids; + List contactIds = new ArrayList<>(); + for (Entry> e : + contactConnections.entrySet()) { + for (ConnectionRecord rec : e.getValue()) { + if (rec.transportId.equals(t)) { + contactIds.add(e.getKey()); + break; + } + } + } + if (LOG.isLoggable(INFO)) { + LOG.info(contactIds.size() + " contacts connected: " + t); + } + return contactIds; } } @@ -122,34 +137,43 @@ class ConnectionRegistryImpl implements ConnectionRegistry { public Collection getConnectedOrPreferredContacts( TransportId t) { synchronized (lock) { - Multiset m = contactConnections.get(t); - if (m == null) return emptyList(); - Set ids = new HashSet<>(m.keySet()); - for (Pair pair : preferences) { - if (pair.getSecond().equals(t)) { - TransportId better = pair.getFirst(); - Multiset m1 = contactConnections.get(better); - if (m1 != null) ids.addAll(m1.keySet()); + List better = betterTransports.get(t); + if (better == null) better = emptyList(); + List contactIds = new ArrayList<>(); + for (Entry> e : + contactConnections.entrySet()) { + for (ConnectionRecord rec : e.getValue()) { + if (rec.transportId.equals(t) || + better.contains(rec.transportId)) { + contactIds.add(e.getKey()); + break; + } } } - if (LOG.isLoggable(INFO)) - LOG.info(ids.size() + " contacts connected or preferred: " + t); - return ids; + if (LOG.isLoggable(INFO)) { + LOG.info(contactIds.size() + + " contacts connected or preferred: " + t); + } + return contactIds; } } @Override public boolean isConnected(ContactId c, TransportId t) { synchronized (lock) { - Multiset m = contactConnections.get(t); - return m != null && m.contains(c); + List recs = contactConnections.get(c); + if (recs == null) return false; + for (ConnectionRecord rec : recs) { + if (rec.transportId.equals(t)) return true; + } + return false; } } @Override public boolean isConnected(ContactId c) { synchronized (lock) { - return contactCounts.contains(c); + return contactConnections.containsKey(c); } } @@ -171,4 +195,27 @@ class ConnectionRegistryImpl implements ConnectionRegistry { } eventBus.broadcast(new RendezvousConnectionClosedEvent(p, success)); } + + private static class ConnectionRecord { + + private final TransportId transportId; + + private ConnectionRecord(TransportId transportId) { + this.transportId = transportId; + } + + @Override + public boolean equals(Object o) { + if (o instanceof ConnectionRecord) { + ConnectionRecord rec = (ConnectionRecord) o; + return transportId.equals(rec.transportId); + } + return false; + } + + @Override + public int hashCode() { + return transportId.hashCode(); + } + } } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/connection/ConnectionRegistryImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/connection/ConnectionRegistryImplTest.java index 51da97ae2..3b1dcf208 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/connection/ConnectionRegistryImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/connection/ConnectionRegistryImplTest.java @@ -19,7 +19,6 @@ import org.junit.Test; import java.util.Collection; import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; import static org.briarproject.bramble.test.TestUtils.getContactId; import static org.briarproject.bramble.test.TestUtils.getRandomId; @@ -45,7 +44,7 @@ public class ConnectionRegistryImplTest extends BrambleMockTestCase { public void testRegisterAndUnregister() { context.checking(new Expectations() {{ allowing(pluginConfig).getTransportPreferences(); - will(returnValue(emptyMap())); + will(returnValue(emptyList())); }}); ConnectionRegistry c = @@ -133,7 +132,7 @@ public class ConnectionRegistryImplTest extends BrambleMockTestCase { public void testRegisterAndUnregisterPendingContacts() { context.checking(new Expectations() {{ allowing(pluginConfig).getTransportPreferences(); - will(returnValue(emptyMap())); + will(returnValue(emptyList())); }}); ConnectionRegistry c =