浏览代码

Add method to return both IP and port for SRV DNS lookup requests

Closes gh-9030
Kathryn Newbould 4 年之前
父节点
当前提交
2af322c06d

+ 38 - 4
remoting/src/main/java/org/springframework/security/remoting/dns/JndiDnsResolver.java

@@ -61,16 +61,28 @@ public class JndiDnsResolver implements DnsResolver {
 
 
 	@Override
 	@Override
 	public String resolveServiceEntry(String serviceType, String domain) {
 	public String resolveServiceEntry(String serviceType, String domain) {
-		return resolveServiceEntry(serviceType, domain, this.ctxFactory.getCtx());
+		return resolveServiceEntry(serviceType, domain, this.ctxFactory.getCtx()).getHostName();
 	}
 	}
 
 
 	@Override
 	@Override
 	public String resolveServiceIpAddress(String serviceType, String domain) {
 	public String resolveServiceIpAddress(String serviceType, String domain) {
 		DirContext ctx = this.ctxFactory.getCtx();
 		DirContext ctx = this.ctxFactory.getCtx();
-		String hostname = resolveServiceEntry(serviceType, domain, ctx);
+		String hostname = resolveServiceEntry(serviceType, domain, ctx).getHostName();
 		return resolveIpAddress(hostname, ctx);
 		return resolveIpAddress(hostname, ctx);
 	}
 	}
 
 
+	/**
+	 * @author Kathryn Newbould
+	 * @since 5.4.1
+	 * @return String of ip address and port, format [ip_address]:[port] of service if found
+	 * @throws DnsLookupException if not found
+	 */
+	public String resolveServiceIpAddressAndPort(String serviceType, String domain) {
+		DirContext ctx = this.ctxFactory.getCtx();
+		ConnectionInfo hostInfo = resolveServiceEntry(serviceType, domain, ctx);
+		return resolveIpAddress(hostInfo.getHostName(), ctx) + ":" + hostInfo.getPort();
+	}
+
 	// This method is needed, so that we can use only one DirContext for
 	// This method is needed, so that we can use only one DirContext for
 	// resolveServiceIpAddress().
 	// resolveServiceIpAddress().
 	private String resolveIpAddress(String hostname, DirContext ctx) {
 	private String resolveIpAddress(String hostname, DirContext ctx) {
@@ -88,8 +100,9 @@ public class JndiDnsResolver implements DnsResolver {
 
 
 	// This method is needed, so that we can use only one DirContext for
 	// This method is needed, so that we can use only one DirContext for
 	// resolveServiceIpAddress().
 	// resolveServiceIpAddress().
-	private String resolveServiceEntry(String serviceType, String domain, DirContext ctx) {
+	private ConnectionInfo resolveServiceEntry(String serviceType, String domain, DirContext ctx) {
 		String result = null;
 		String result = null;
+		String port = null;
 		try {
 		try {
 			String query = new StringBuilder("_").append(serviceType).append("._tcp.").append(domain).toString();
 			String query = new StringBuilder("_").append(serviceType).append("._tcp.").append(domain).toString();
 			Attribute dnsRecord = lookup(query, ctx, "SRV");
 			Attribute dnsRecord = lookup(query, ctx, "SRV");
@@ -107,15 +120,18 @@ public class JndiDnsResolver implements DnsResolver {
 				int priority = Integer.parseInt(record[0]);
 				int priority = Integer.parseInt(record[0]);
 				int weight = Integer.parseInt(record[1]);
 				int weight = Integer.parseInt(record[1]);
 				// we have a new highest Priority, so forget also the highest weight
 				// we have a new highest Priority, so forget also the highest weight
+				int SERVICE_RECORD_PORT_INDEX = 2;
 				if (priority < highestPriority || highestPriority == -1) {
 				if (priority < highestPriority || highestPriority == -1) {
 					highestPriority = priority;
 					highestPriority = priority;
 					highestWeight = weight;
 					highestWeight = weight;
 					result = record[3].trim();
 					result = record[3].trim();
+					port = record[SERVICE_RECORD_PORT_INDEX].trim();
 				}
 				}
 				// same priority, but higher weight
 				// same priority, but higher weight
 				if (priority == highestPriority && weight > highestWeight) {
 				if (priority == highestPriority && weight > highestWeight) {
 					highestWeight = weight;
 					highestWeight = weight;
 					result = record[3].trim();
 					result = record[3].trim();
+					port = record[SERVICE_RECORD_PORT_INDEX].trim();
 				}
 				}
 			}
 			}
 		}
 		}
@@ -126,7 +142,7 @@ public class JndiDnsResolver implements DnsResolver {
 		if (result.endsWith(".")) {
 		if (result.endsWith(".")) {
 			result = result.substring(0, result.length() - 1);
 			result = result.substring(0, result.length() - 1);
 		}
 		}
-		return result;
+		return new ConnectionInfo(result, port);
 	}
 	}
 
 
 	private Attribute lookup(String query, DirContext ictx, String recordType) {
 	private Attribute lookup(String query, DirContext ictx, String recordType) {
@@ -159,4 +175,22 @@ public class JndiDnsResolver implements DnsResolver {
 
 
 	}
 	}
 
 
+	private static class ConnectionInfo {
+		private final String hostName;
+		private final String port;
+
+		public ConnectionInfo(String hostName, String port) {
+			this.hostName = hostName;
+			this.port = port;
+		}
+
+		public String getHostName() {
+			return hostName;
+		}
+
+		public String getPort() {
+			return port;
+		}
+	}
+
 }
 }

+ 10 - 0
remoting/src/test/java/org/springframework/security/remoting/dns/JndiDnsResolverTests.java

@@ -95,6 +95,16 @@ public class JndiDnsResolverTests {
 		assertThat(ipAddress).isEqualTo("63.246.7.80");
 		assertThat(ipAddress).isEqualTo("63.246.7.80");
 	}
 	}
 
 
+	@Test
+	public void resolveServiceIpAddressWithPort() throws Exception {
+		BasicAttributes srvRecords = createSrvRecords();
+		BasicAttributes aRecords = new BasicAttributes("A", "63.246.7.80");
+		given(this.context.getAttributes("_ldap._tcp.springsource.com", new String[] { "SRV" })).willReturn(srvRecords);
+		given(this.context.getAttributes("kdc.springsource.com", new String[] { "A" })).willReturn(aRecords);
+		String ipAddress = this.dnsResolver.resolveServiceIpAddressAndPort("ldap", "springsource.com");
+		assertThat(ipAddress).isEqualTo("63.246.7.80:389");
+	}
+
 	@Test
 	@Test
 	public void testUnknowError() throws Exception {
 	public void testUnknowError() throws Exception {
 		given(this.context.getAttributes(any(String.class), any(String[].class)))
 		given(this.context.getAttributes(any(String.class), any(String[].class)))