diff --git a/agents-common/src/main/java/org/apache/ranger/plugin/client/BaseClient.java b/agents-common/src/main/java/org/apache/ranger/plugin/client/BaseClient.java index dbbebbf367..bab9643853 100644 --- a/agents-common/src/main/java/org/apache/ranger/plugin/client/BaseClient.java +++ b/agents-common/src/main/java/org/apache/ranger/plugin/client/BaseClient.java @@ -31,6 +31,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.regex.PatternSyntaxException; public abstract class BaseClient { private static final Logger LOG = LoggerFactory.getLogger(BaseClient.class); @@ -184,6 +185,168 @@ private void init() { } } + protected void validateSqlIdentifier(String identifier, String identifierType) throws HadoopException { + if (StringUtils.isBlank(identifier)) { + return; + } + if (identifier.contains("..") || identifier.contains("//") || identifier.contains("\\")) { + String msgDesc = "Invalid " + identifierType + ": [" + identifier + "]. Path traversal patterns are not allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null); + LOG.error(msgDesc); + throw hdpException; + } + if (!identifier.matches("^[a-zA-Z0-9*?\\[\\]\\-\\$%\\{\\}\\=\\/\\._]+$")) { + String msgDesc = "Invalid " + identifierType + ": [" + identifier + "]. Only alphanumeric characters along with ( ., _, -, *, ?, [], {}, %, $, = / ) are allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null); + LOG.error(msgDesc); + throw hdpException; + } + } + + protected String convertToSqlPattern(String pattern) throws HadoopException { + if (pattern == null || pattern.isEmpty()) { + return "%"; + } + // Convert custom wildcards to SQL LIKE pattern: + // '*' -> '%' (multi-character wildcard) + // '?' -> '_' (single-character wildcard) + String sqlPattern = pattern.replace("*", "%").replace("?", "_"); + return sqlPattern; + } + + protected boolean matchesSqlPattern(String value, String pattern) throws HadoopException { + if (pattern == null || pattern.equals("%")) { + return true; + } + + String regex = convertSqlPatternToRegex(pattern); + try { + return value.matches(regex); + } catch (PatternSyntaxException pe) { + String msgDesc = "Invalid value: [" + value + "]. Only alphanumeric characters along with ( ., _, -, *, ?, [], {}, %, $, = / ) are allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null); + LOG.error(msgDesc); + throw hdpException; + } + } + + protected void validateUrlResourceName(String resourceName, String resourceType) throws HadoopException { + if (resourceName == null) { + return; + } + if (resourceName.contains("..") || resourceName.contains("//") || resourceName.contains("\\")) { + String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Path traversal patterns are not allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null); + LOG.error(msgDesc); + throw hdpException; + } + if (!resourceName.matches("^[a-zA-Z0-9_.*\\-]+$")) { + String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Only alphanumeric characters with ( ., _, *, -) are allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null); + LOG.error(msgDesc); + throw hdpException; + } + } + + public void validateWildcardPattern(String pattern, String patternType) throws HadoopException { + if (pattern == null || pattern.isEmpty()) { + return; + } + if (pattern.contains("..") || pattern.contains("//") || pattern.contains("\\")) { + String msgDesc = "Invalid " + patternType + ": [" + pattern + "]. Path traversal patterns are not allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null); + LOG.error(msgDesc); + throw hdpException; + } + if (!pattern.matches("^[a-zA-Z0-9_.*?\\[\\]\\-\\$%\\{\\}\\=\\/]+$")) { + String msgDesc = "Invalid " + patternType + ": [" + pattern + "]. Only alphanumeric characters along with ( ., _, -, *, ?, [], {}, %, $, = / ) are allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + DEFAULT_ERROR_MESSAGE, null, null); + LOG.error(msgDesc); + throw hdpException; + } + } + + protected String convertSqlPatternToRegex(String pattern) { + StringBuilder regexBuilder = new StringBuilder("^"); + + for (int i = 0; i < pattern.length(); i++) { + char c = pattern.charAt(i); + switch (c) { + case '%': + // SQL LIKE wildcard: zero or more characters + regexBuilder.append(".*"); + break; + case '_': + // SQL LIKE wildcard: exactly one character + regexBuilder.append('.'); + break; + case '.': + case '^': + case '$': + case '+': + case '?': + case '{': + case '}': + case '[': + case ']': + case '(': + case ')': + case '|': + case '\\': + // Escape regex metacharacters so they are treated literally + regexBuilder.append('\\').append(c); + break; + default: + regexBuilder.append(c); + break; + } + } + + return regexBuilder.toString(); + } + + public String convertWildcardToRegex(String wildcard) { + if (wildcard == null || wildcard.isEmpty()) { + return ".*"; + } + StringBuilder regex = new StringBuilder("^"); + for (int i = 0; i < wildcard.length(); i++) { + char c = wildcard.charAt(i); + switch (c) { + case '*': + regex.append(".*"); + break; + case '?': + regex.append("."); + break; + case '.': + case '\\': + case '^': + case '$': + case '|': + regex.append('\\').append(c); + break; + case '{': + case '}': + case '[': + case ']': + regex.append('\\').append(c); + break; + default: + regex.append(c); + } + } + regex.append('$'); + return regex.toString(); + } + private HadoopException createException(Exception exp) { return createException("Unable to login to Hadoop environment [" + serviceName + "]", exp); } diff --git a/agents-common/src/test/java/org/apache/ranger/plugin/client/TestBaseClient.java b/agents-common/src/test/java/org/apache/ranger/plugin/client/TestBaseClient.java index 6d27ef920b..758f5ae53c 100644 --- a/agents-common/src/test/java/org/apache/ranger/plugin/client/TestBaseClient.java +++ b/agents-common/src/test/java/org/apache/ranger/plugin/client/TestBaseClient.java @@ -326,4 +326,114 @@ class TestClient extends BaseClient { assertEquals(IllegalArgumentException.class, ex.getClass()); } } + + @Test + public void test15_convertWildcardToRegex() { + class TestClient extends BaseClient { + TestClient() { + super("test", new HashMap<>()); + } + + @Override + protected void login() { + } + + public String convert(String s) { + return convertWildcardToRegex(s); + } + } + + TestClient client = new TestClient(); + assertEquals(".*", client.convert(null)); + assertEquals(".*", client.convert("")); + assertEquals("^atlas.*$", client.convert("atlas*")); + assertEquals("^atlas\\..*$", client.convert("atlas.*")); + assertEquals("^.*atlas.*$", client.convert("*atlas*")); + assertEquals("^at.as$", client.convert("at?as")); + assertEquals("^atlas\\.$", client.convert("atlas.")); + assertEquals("^atlas\\$$", client.convert("atlas$")); + assertEquals("^atlas\\^$", client.convert("atlas^")); + assertEquals("^atlas\\[\\]$", client.convert("atlas[]")); + } + + @Test + public void test16_convertToSqlPattern() throws Exception { + class TestClient extends BaseClient { + TestClient() { + super("test", new HashMap<>()); + } + + @Override + protected void login() { + } + + public String convert(String s) throws Exception { + return convertToSqlPattern(s); + } + } + + TestClient client = new TestClient(); + assertEquals("%", client.convert(null)); + assertEquals("%", client.convert("")); + assertEquals("atlas%", client.convert("atlas*")); + assertEquals("at_as", client.convert("at?as")); + } + + @Test + public void test17_matchesSqlPattern() throws Exception { + class TestClient extends BaseClient { + TestClient() { + super("test", new HashMap<>()); + } + + @Override + protected void login() { + } + + public boolean match(String v, String p) throws Exception { + return matchesSqlPattern(v, p); + } + } + + TestClient client = new TestClient(); + assertEquals(true, client.match("atlas", null)); + assertEquals(true, client.match("atlas", "%")); + assertEquals(true, client.match("atlas", "atlas%")); + assertEquals(true, client.match("atlas_test", "atlas%")); + assertEquals(true, client.match("atlas", "at_as")); + assertEquals(false, client.match("atlas", "at_a")); + } + + @Test + public void test18_validateWildcardPattern() { + class TestClient extends BaseClient { + TestClient() { + super("test", new HashMap<>()); + } + + @Override + protected void login() { + } + + public void validate(String s) throws Exception { + validateWildcardPattern(s, "test"); + } + } + + TestClient client = new TestClient(); + try { + client.validate("atlas*"); + client.validate("atlas.*"); + client.validate("atlas?"); + } catch (Exception e) { + org.junit.jupiter.api.Assertions.fail("Should not throw exception for valid patterns"); + } + + try { + client.validate("atlas../test"); + org.junit.jupiter.api.Assertions.fail("Should throw exception for path traversal"); + } catch (Exception e) { + // Expected + } + } } diff --git a/hbase-agent/src/main/java/org/apache/ranger/services/hbase/client/HBaseClient.java b/hbase-agent/src/main/java/org/apache/ranger/services/hbase/client/HBaseClient.java index 1b1100189a..ee2630aad5 100644 --- a/hbase-agent/src/main/java/org/apache/ranger/services/hbase/client/HBaseClient.java +++ b/hbase-agent/src/main/java/org/apache/ranger/services/hbase/client/HBaseClient.java @@ -191,6 +191,12 @@ public List getTableList(final String tableNameMatching, final List>() { @Override public List run() { + String wildcard = tableNameMatching; + if (wildcard != null) { + wildcard = wildcard.replace(".*", "*"); + } + validateWildcardPattern(wildcard, "table pattern"); + String safeTablePattern = convertWildcardToRegex(wildcard); List tableList = new ArrayList<>(); Admin admin = null; @@ -205,8 +211,7 @@ public List run() { LOG.info("getTableList: no exception: HbaseAvailability true"); admin = conn.getAdmin(); - - List htds = admin.listTableDescriptors(Pattern.compile(tableNameMatching)); + List htds = admin.listTableDescriptors(Pattern.compile(safeTablePattern)); if (htds != null) { for (TableDescriptor htd : htds) { @@ -240,6 +245,8 @@ public List run() { LOG.error(msgDesc + mnre); throw hdpException; + } catch (HadoopException he) { + throw he; } catch (IOException io) { String msgDesc = "getTableList: Unable to get HBase table List for [repository:" + getConfigHolder().getDatasourceName() + ",table-match:" + tableNameMatching + "]."; HadoopException hdpException = new HadoopException(msgDesc, io); @@ -291,14 +298,18 @@ public List getColumnFamilyList(final String columnFamilyMatching, final @Override public List run() { + String wildcard = columnFamilyMatching; + if (wildcard != null) { + wildcard = wildcard.replace(".*", "*"); + } + validateWildcardPattern(wildcard, "column family pattern"); + String safeColumnPattern = convertWildcardToRegex(wildcard); List colfList = new ArrayList<>(); Admin admin = null; try { LOG.info("getColumnFamilyList: setting config values from client"); - setClientConfigValues(conf); - LOG.info("getColumnFamilyList: checking HbaseAvailability with the new config"); try (Connection conn = ConnectionFactory.createConnection(conf)) { @@ -314,8 +325,7 @@ public List run() { if (htd != null) { for (ColumnFamilyDescriptor hcd : htd.getColumnFamilies()) { String colf = hcd.getNameAsString(); - - if (colf.matches(columnFamilyMatching)) { + if (colf.matches(safeColumnPattern)) { if (existingColumnFamilies != null && existingColumnFamilies.contains(colf)) { continue; } else { @@ -345,6 +355,8 @@ public List run() { LOG.error(msgDesc + mnre); throw hdpException; + } catch (HadoopException he) { + throw he; } catch (IOException io) { String msgDesc = "getColumnFamilyList: Unable to get HBase ColumnFamilyList for [repository:" + getConfigHolder().getDatasourceName() + ",table:" + tblName + ", table-match:" + columnFamilyMatching + "] "; HadoopException hdpException = new HadoopException(msgDesc, io); diff --git a/hbase-agent/src/test/java/org/apache/ranger/services/hbase/client/TestHBaseClient.java b/hbase-agent/src/test/java/org/apache/ranger/services/hbase/client/TestHBaseClient.java index 4edf5f993b..7c2f858bb0 100644 --- a/hbase-agent/src/test/java/org/apache/ranger/services/hbase/client/TestHBaseClient.java +++ b/hbase-agent/src/test/java/org/apache/ranger/services/hbase/client/TestHBaseClient.java @@ -314,11 +314,11 @@ public void test09_getColumnFamilyList_filtersAndExceptions() throws Exception { List tables = new ArrayList<>(Collections.singletonList("t1")); List existing = new ArrayList<>(Collections.singletonList("cf1")); - List ret = client.getColumnFamilyList("cf.*", tables, existing); + List ret = client.getColumnFamilyList("cf*", tables, existing); assertEquals(Collections.singletonList("cf2"), ret); Mockito.when(admin.getDescriptor(tn)).thenThrow(new IOException("io")); - assertThrows(HadoopException.class, () -> client.getColumnFamilyList("cf.*", tables, null)); + assertThrows(HadoopException.class, () -> client.getColumnFamilyList("cf*", tables, null)); } } @@ -588,6 +588,291 @@ private ColumnFamilyDescriptor mockCfd(String name) { return cfd; } + @Test + public void test19_validatePattern_validWildcards() throws Exception { + Map props = new HashMap<>(); + props.put("username", "user"); + + try (MockedStatic confStatic = Mockito.mockStatic(HBaseConfiguration.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class); + MockedStatic connFactoryStatic = Mockito.mockStatic(ConnectionFactory.class)) { + Configuration conf = Mockito.mock(Configuration.class); + confStatic.when(HBaseConfiguration::create).thenReturn(conf); + + Subject subject = Mockito.mock(Subject.class); + subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + Connection connection = Mockito.mock(Connection.class); + Admin admin = Mockito.mock(Admin.class); + connFactoryStatic.when(() -> ConnectionFactory.createConnection(Mockito.any(Configuration.class))) + .thenReturn(connection); + Mockito.when(connection.getAdmin()).thenReturn(admin); + Mockito.when(admin.listTableDescriptors(Mockito.any(Pattern.class))).thenReturn(new ArrayList<>()); + + HBaseClient client = new TestableHBaseClient("svc", props, subject); + + List validPatterns = Arrays.asList("user*", "test?", "table_name", "prefix-*", "test{user}"); + for (String pattern : validPatterns) { + List result = client.getTableList(pattern, null); + assertNotNull(result, "Valid pattern should not throw exception: " + pattern); + } + } + } + + @Test + public void test19_getTableList_wildcardReplacement() throws Exception { + Map props = new HashMap<>(); + props.put("username", "user"); + + try (MockedStatic confStatic = Mockito.mockStatic(HBaseConfiguration.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class); + MockedStatic connFactoryStatic = Mockito.mockStatic(ConnectionFactory.class)) { + Configuration conf = Mockito.mock(Configuration.class); + confStatic.when(HBaseConfiguration::create).thenReturn(conf); + + Subject subject = Mockito.mock(Subject.class); + subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + Connection connection = Mockito.mock(Connection.class); + Admin admin = Mockito.mock(Admin.class); + connFactoryStatic.when(() -> ConnectionFactory.createConnection(Mockito.any(Configuration.class))) + .thenReturn(connection); + Mockito.when(connection.getAdmin()).thenReturn(admin); + + TableDescriptor td1 = Mockito.mock(TableDescriptor.class); + TableName tn1 = TableName.valueOf("atlas_test"); + Mockito.when(td1.getTableName()).thenReturn(tn1); + + // We expect the pattern to be "^atlas.*$" because "atlas.*" should be converted to "atlas*" then to "^atlas.*$" + Mockito.when(admin.listTableDescriptors(Mockito.any(Pattern.class))).thenAnswer(inv -> { + Pattern p = inv.getArgument(0); + if (p.pattern().equals("^atlas.*$")) { + return Collections.singletonList(td1); + } + return Collections.emptyList(); + }); + + HBaseClient client = new TestableHBaseClient("svc", props, subject); + + List ret = client.getTableList("atlas.*", null); + assertEquals(Collections.singletonList("atlas_test"), ret); + } + } + + @Test + public void test20_getColumnFamilyList_wildcardReplacement() throws Exception { + Map props = new HashMap<>(); + props.put("username", "user"); + + try (MockedStatic confStatic = Mockito.mockStatic(HBaseConfiguration.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class); + MockedStatic connFactoryStatic = Mockito.mockStatic(ConnectionFactory.class)) { + Configuration conf = Mockito.mock(Configuration.class); + confStatic.when(HBaseConfiguration::create).thenReturn(conf); + + Subject subject = Mockito.mock(Subject.class); + subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + Connection connection = Mockito.mock(Connection.class); + Admin admin = Mockito.mock(Admin.class); + connFactoryStatic.when(() -> ConnectionFactory.createConnection(Mockito.any(Configuration.class))) + .thenReturn(connection); + Mockito.when(connection.getAdmin()).thenReturn(admin); + + TableDescriptor td = Mockito.mock(TableDescriptor.class); + TableName tn = TableName.valueOf("t1"); + Mockito.when(admin.getDescriptor(tn)).thenReturn(td); + ColumnFamilyDescriptor cfd1 = mockCfd("cf_test"); + Mockito.when(td.getColumnFamilies()).thenReturn(new ColumnFamilyDescriptor[] {cfd1}); + + HBaseClient client = new TestableHBaseClient("svc", props, subject); + + List tables = new ArrayList<>(Collections.singletonList("t1")); + // "cf.*" should be converted to "cf*" then to "^cf.*$" which matches "cf_test" + List ret = client.getColumnFamilyList("cf.*", tables, null); + assertEquals(Collections.singletonList("cf_test"), ret); + } + } + + @Test + public void test21_validatePattern_rejectsReDoSPatterns() throws Exception { + Map props = new HashMap<>(); + props.put("username", "user"); + + try (MockedStatic confStatic = Mockito.mockStatic(HBaseConfiguration.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Configuration conf = Mockito.mock(Configuration.class); + confStatic.when(HBaseConfiguration::create).thenReturn(conf); + + Subject subject = Mockito.mock(Subject.class); + subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + HBaseClient client = new TestableHBaseClient("svc", props, subject); + + List redosPatterns = Arrays.asList("(a+)+", "(a|a)*", "(a+)+$", "a{100,200}", "(x+x+)+y"); + + for (String pattern : redosPatterns) { + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getTableList(pattern, null), + "ReDoS pattern should be rejected: " + pattern); + String msg = ex.getMessage(); + assertTrue(msg != null && msg.contains("Invalid") && msg.contains("Only alphanumeric"), "Error should indicate invalid pattern for: " + pattern + ", but got: " + msg); + } + } + } + + @Test + public void test22_validatePattern_rejectsComplexRegex() throws Exception { + Map props = new HashMap<>(); + props.put("username", "user"); + + try (MockedStatic confStatic = Mockito.mockStatic(HBaseConfiguration.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Configuration conf = Mockito.mock(Configuration.class); + confStatic.when(HBaseConfiguration::create).thenReturn(conf); + + Subject subject = Mockito.mock(Subject.class); + subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + HBaseClient client = new TestableHBaseClient("svc", props, subject); + + List maliciousPatterns = Arrays.asList("test(abc)", "a+b", "x|y", "$(command)"); + + for (String pattern : maliciousPatterns) { + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getTableList(pattern, null), + "Complex regex should be rejected: " + pattern); + assertTrue(ex.getMessage().contains("Invalid") && ex.getMessage().contains("Only alphanumeric"), "Error should indicate invalid pattern for: " + pattern); + } + } + } + + @Test + public void test23_validatePattern_rejectsInjectionAttempts() throws Exception { + Map props = new HashMap<>(); + props.put("username", "user"); + + try (MockedStatic confStatic = Mockito.mockStatic(HBaseConfiguration.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Configuration conf = Mockito.mock(Configuration.class); + confStatic.when(HBaseConfiguration::create).thenReturn(conf); + + Subject subject = Mockito.mock(Subject.class); + subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + HBaseClient client = new TestableHBaseClient("svc", props, subject); + + List injectionAttempts = Arrays.asList("'; DROP TABLE users; --", "../../../etc/passwd", "test", "table\nname", "test\0null"); + + for (String pattern : injectionAttempts) { + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getTableList(pattern, null), + "Injection attempt should be rejected: " + pattern); + assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid pattern for: " + pattern); + } + } + } + + @Test + public void test24_columnFamilyMatching_rejectsReDoS() throws Exception { + Map props = new HashMap<>(); + props.put("username", "user"); + + try (MockedStatic confStatic = Mockito.mockStatic(HBaseConfiguration.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Configuration conf = Mockito.mock(Configuration.class); + confStatic.when(HBaseConfiguration::create).thenReturn(conf); + + Subject subject = Mockito.mock(Subject.class); + subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + HBaseClient client = new TestableHBaseClient("svc", props, subject); + + List tables = Arrays.asList("table1"); + List redosPatterns = Arrays.asList("(a+)+", "(x|x)*"); + + for (String pattern : redosPatterns) { + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getColumnFamilyList(pattern, tables, null), + "ReDoS pattern should be rejected in column family: " + pattern); + assertTrue(ex.getMessage().contains("Invalid") && ex.getMessage().contains("Only alphanumeric"), + "Error should indicate invalid pattern for: " + pattern); + } + } + } + + @Test + public void test25_convertWildcardToRegex_correctConversion() throws Exception { + Map props = new HashMap<>(); + props.put("username", "user"); + + try (MockedStatic confStatic = Mockito.mockStatic(HBaseConfiguration.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class); + MockedStatic connFactoryStatic = Mockito.mockStatic(ConnectionFactory.class)) { + Configuration conf = Mockito.mock(Configuration.class); + confStatic.when(HBaseConfiguration::create).thenReturn(conf); + + Subject subject = Mockito.mock(Subject.class); + subjectStatic.when(() -> Subject.doAs(Mockito.any(), Mockito.any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + Connection connection = Mockito.mock(Connection.class); + Admin admin = Mockito.mock(Admin.class); + connFactoryStatic.when(() -> ConnectionFactory.createConnection(Mockito.any(Configuration.class))) + .thenReturn(connection); + Mockito.when(connection.getAdmin()).thenReturn(admin); + + TableDescriptor td1 = Mockito.mock(TableDescriptor.class); + TableName tn1 = TableName.valueOf("test_table"); + Mockito.when(td1.getTableName()).thenReturn(tn1); + Mockito.when(admin.listTableDescriptors(Mockito.any(Pattern.class))).thenAnswer(inv -> { + Pattern pattern = inv.getArgument(0); + List result = new ArrayList<>(); + if (pattern.matcher("test_table").matches()) { + result.add(td1); + } + return result; + }); + + HBaseClient client = new TestableHBaseClient("svc", props, subject); + + List result = client.getTableList("test*", null); + assertEquals(1, result.size()); + assertTrue(result.contains("test_table")); + } + } + private static class TestableHBaseClient extends HBaseClient { private final Subject testSubject; diff --git a/hive-agent/src/main/java/org/apache/ranger/services/hive/client/HiveClient.java b/hive-agent/src/main/java/org/apache/ranger/services/hive/client/HiveClient.java index aa40b977c6..f739fef921 100644 --- a/hive-agent/src/main/java/org/apache/ranger/services/hive/client/HiveClient.java +++ b/hive-agent/src/main/java/org/apache/ranger/services/hive/client/HiveClient.java @@ -261,6 +261,8 @@ private List getDBListFromHM(String databaseMatching, List dbLis List ret = new ArrayList<>(); + validateSqlIdentifier(databaseMatching, "database pattern"); + try { if (hiveClient != null) { List hiveDBList; @@ -303,20 +305,15 @@ private List getDBList(String databaseMatching, List dbList) thr List ret = new ArrayList<>(); if (con != null) { - Statement stat = null; - ResultSet rs = null; - String sql = "show databases"; - - if (databaseMatching != null && !databaseMatching.isEmpty()) { - sql = sql + " like \"" + databaseMatching + "\""; - } + ResultSet rs = null; try { - stat = con.createStatement(); - rs = stat.executeQuery(sql); + validateSqlIdentifier(databaseMatching, "database pattern"); + String schemaPattern = convertToSqlPattern(databaseMatching); + rs = con.getMetaData().getSchemas(null, schemaPattern); while (rs.next()) { - String dbName = rs.getString(1); + String dbName = rs.getString("TABLE_SCHEM"); if (dbList != null && dbList.contains(dbName)) { continue; @@ -325,7 +322,7 @@ private List getDBList(String databaseMatching, List dbList) thr ret.add(dbName); } } catch (SQLTimeoutException sqlt) { - String msgDesc = "Time Out, Unable to execute SQL [" + sql + "]."; + String msgDesc = "Time Out, Unable to retrieve database list."; HadoopException hdpException = new HadoopException(msgDesc, sqlt); hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); @@ -334,7 +331,7 @@ private List getDBList(String databaseMatching, List dbList) thr throw hdpException; } catch (SQLException sqle) { - String msgDesc = "Unable to execute SQL [" + sql + "]."; + String msgDesc = "Unable to retrieve database list."; HadoopException hdpException = new HadoopException(msgDesc, sqle); hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null); @@ -342,9 +339,10 @@ private List getDBList(String databaseMatching, List dbList) thr LOG.debug("<== HiveClient.getDBList() Error : ", sqle); throw hdpException; + } catch (HadoopException he) { + throw he; } finally { close(rs); - close(stat); } } @@ -358,8 +356,11 @@ private List getTblListFromHM(String tableNameMatching, List dbL List ret = new ArrayList<>(); + validateSqlIdentifier(tableNameMatching, "table pattern"); + if (hiveClient != null && dbList != null && !dbList.isEmpty()) { for (String dbName : dbList) { + validateSqlIdentifier(dbName, "database name"); try { List hiveTblList = hiveClient.getTables(dbName, tableNameMatching); @@ -394,55 +395,31 @@ private List getTblList(String tableNameMatching, List dbList, L List ret = new ArrayList<>(); if (con != null) { - Statement stat = null; - ResultSet rs = null; - String sql = null; + ResultSet rs = null; try { + validateSqlIdentifier(tableNameMatching, "table pattern"); if (dbList != null && !dbList.isEmpty()) { for (String db : dbList) { - sql = "use " + db; - - try { - stat = con.createStatement(); + validateSqlIdentifier(db, "database name"); + String tablePattern = convertToSqlPattern(tableNameMatching); + rs = con.getMetaData().getTables(null, db, tablePattern, new String[] {"TABLE", "VIEW"}); - stat.execute(sql); - } finally { - close(stat); + while (rs.next()) { + String tblName = rs.getString("TABLE_NAME"); - stat = null; - } - - sql = "show tables "; - - if (tableNameMatching != null && !tableNameMatching.isEmpty()) { - sql = sql + " like \"" + tableNameMatching + "\""; - } - - try { - stat = con.createStatement(); - rs = stat.executeQuery(sql); - - while (rs.next()) { - String tblName = rs.getString(1); - - if (tblList != null && tblList.contains(tblName)) { - continue; - } - - ret.add(tblName); + if (tblList != null && tblList.contains(tblName)) { + continue; } - } finally { - close(rs); - close(stat); - rs = null; - stat = null; + ret.add(tblName); } + close(rs); + rs = null; } } } catch (SQLTimeoutException sqlt) { - String msgDesc = "Time Out, Unable to execute SQL [" + sql + "]."; + String msgDesc = "Time Out, Unable to retrieve table list."; HadoopException hdpException = new HadoopException(msgDesc, sqlt); hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); @@ -451,7 +428,7 @@ private List getTblList(String tableNameMatching, List dbList, L throw hdpException; } catch (SQLException sqle) { - String msgDesc = "Unable to execute SQL [" + sql + "]."; + String msgDesc = "Unable to retrieve table list."; HadoopException hdpException = new HadoopException(msgDesc, sqle); hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null); @@ -459,6 +436,10 @@ private List getTblList(String tableNameMatching, List dbList, L LOG.debug("<== HiveClient.getTblList() Error : ", sqle); throw hdpException; + } catch (HadoopException he) { + throw he; + } finally { + close(rs); } } @@ -473,13 +454,17 @@ private List getClmListFromHM(String columnNameMatching, List db List ret = new ArrayList<>(); String columnNameMatchingRegEx = null; + validateSqlIdentifier(columnNameMatching, "column pattern"); + if (columnNameMatching != null && !columnNameMatching.isEmpty()) { columnNameMatchingRegEx = columnNameMatching; } if (hiveClient != null && dbList != null && !dbList.isEmpty() && tblList != null && !tblList.isEmpty()) { for (String db : dbList) { + validateSqlIdentifier(db, "database name"); for (String tbl : tblList) { + validateSqlIdentifier(tbl, "table name"); try { List hiveSch = hiveClient.getFields(db, tbl); @@ -529,30 +514,20 @@ private List getClmList(String columnNameMatching, List dbList, columnNameMatchingRegEx = columnNameMatching; } - Statement stat = null; - ResultSet rs = null; - String sql = null; - - if (dbList != null && !dbList.isEmpty() && tblList != null && !tblList.isEmpty()) { - for (String db : dbList) { - for (String tbl : tblList) { - try { - sql = "use " + db; - - try { - stat = con.createStatement(); + ResultSet rs = null; - stat.execute(sql); - } finally { - close(stat); - } - - sql = "describe " + tbl; - stat = con.createStatement(); - rs = stat.executeQuery(sql); + try { + validateSqlIdentifier(columnNameMatching, "column pattern"); + if (dbList != null && !dbList.isEmpty() && tblList != null && !tblList.isEmpty()) { + for (String db : dbList) { + validateSqlIdentifier(db, "database name"); + for (String tbl : tblList) { + validateSqlIdentifier(tbl, "table name"); + String columnPattern = convertToSqlPattern(columnNameMatching); + rs = con.getMetaData().getColumns(null, db, tbl, columnPattern); while (rs.next()) { - String columnName = rs.getString(1); + String columnName = rs.getString("COLUMN_NAME"); if (colList != null && colList.contains(columnName)) { continue; @@ -564,30 +539,33 @@ private List getClmList(String columnNameMatching, List dbList, ret.add(columnName); } } - } catch (SQLTimeoutException sqlt) { - String msgDesc = "Time Out, Unable to execute SQL [" + sql + "]."; - HadoopException hdpException = new HadoopException(msgDesc, sqlt); + close(rs); + rs = null; + } + } + } + } catch (SQLTimeoutException sqlt) { + String msgDesc = "Time Out, Unable to retrieve column list."; + HadoopException hdpException = new HadoopException(msgDesc, sqlt); - hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); + hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); - LOG.debug("<== HiveClient.getClmList() Error : ", sqlt); + LOG.debug("<== HiveClient.getClmList() Error : ", sqlt); - throw hdpException; - } catch (SQLException sqle) { - String msgDesc = "Unable to execute SQL [" + sql + "]."; - HadoopException hdpException = new HadoopException(msgDesc, sqle); + throw hdpException; + } catch (SQLException sqle) { + String msgDesc = "Unable to retrieve column list."; + HadoopException hdpException = new HadoopException(msgDesc, sqle); - hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null); + hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null); - LOG.debug("<== HiveClient.getClmList() Error : ", sqle); + LOG.debug("<== HiveClient.getClmList() Error : ", sqle); - throw hdpException; - } finally { - close(rs); - close(stat); - } - } - } + throw hdpException; + } catch (HadoopException he) { + throw he; + } finally { + close(rs); } } diff --git a/hive-agent/src/test/java/org/apache/ranger/services/hive/client/TestHiveClient.java b/hive-agent/src/test/java/org/apache/ranger/services/hive/client/TestHiveClient.java index fc74b9599a..69bc3e33d8 100644 --- a/hive-agent/src/test/java/org/apache/ranger/services/hive/client/TestHiveClient.java +++ b/hive-agent/src/test/java/org/apache/ranger/services/hive/client/TestHiveClient.java @@ -41,12 +41,12 @@ import java.security.Permission; import java.security.PrivilegedExceptionAction; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.Driver; import java.sql.DriverPropertyInfo; import java.sql.ResultSet; import java.sql.SQLFeatureNotSupportedException; import java.sql.SQLTimeoutException; -import java.sql.Statement; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -220,12 +220,12 @@ public void test11_getDatabaseList_jdbcPath() throws Exception { Field fCon = HiveClient.class.getDeclaredField("con"); fCon.setAccessible(true); Connection con = Mockito.mock(Connection.class); - Statement stat = Mockito.mock(Statement.class); + DatabaseMetaData metadata = Mockito.mock(DatabaseMetaData.class); ResultSet rs = Mockito.mock(ResultSet.class); - when(con.createStatement()).thenReturn(stat); - when(stat.executeQuery(Mockito.anyString())).thenReturn(rs); + when(con.getMetaData()).thenReturn(metadata); + when(metadata.getSchemas(Mockito.isNull(), Mockito.anyString())).thenReturn(rs); when(rs.next()).thenReturn(true, false); - when(rs.getString(1)).thenReturn("db1"); + when(rs.getString("TABLE_SCHEM")).thenReturn("db1"); fCon.set(client, con); List out = client.getDatabaseList("db*", null); assertEquals(Collections.singletonList("db1"), out); @@ -251,17 +251,15 @@ public void test13_getTableList_jdbcPath_excludesAndPattern() throws Exception { Field fCon = HiveClient.class.getDeclaredField("con"); fCon.setAccessible(true); Connection con = Mockito.mock(Connection.class); - Statement stat = Mockito.mock(Statement.class); + DatabaseMetaData metadata = Mockito.mock(DatabaseMetaData.class); ResultSet rs = Mockito.mock(ResultSet.class); - when(con.createStatement()).thenReturn(stat); - when(stat.execute(Mockito.eq("use db1"))).thenReturn(true); - when(stat.executeQuery(Mockito.anyString())).thenReturn(rs); + when(con.getMetaData()).thenReturn(metadata); + when(metadata.getTables(Mockito.isNull(), Mockito.eq("db1"), Mockito.anyString(), Mockito.any())).thenReturn(rs); when(rs.next()).thenReturn(true, true, false); - when(rs.getString(1)).thenReturn("t1", "t2"); + when(rs.getString("TABLE_NAME")).thenReturn("t1", "t2"); fCon.set(client, con); List out = client.getTableList("t*", Collections.singletonList("db1"), Collections.singletonList("t2")); assertEquals(Collections.singletonList("t1"), out); - Mockito.verify(stat, Mockito.atLeastOnce()).execute(Mockito.eq("use db1")); } @Test @@ -273,10 +271,9 @@ public void test14_getTableList_jdbcPath_timeoutThrowsHadoopException() throws E Field fCon = HiveClient.class.getDeclaredField("con"); fCon.setAccessible(true); Connection con = Mockito.mock(Connection.class); - Statement stat = Mockito.mock(Statement.class); - when(con.createStatement()).thenReturn(stat); - when(stat.execute(Mockito.anyString())).thenReturn(true); - when(stat.executeQuery(Mockito.anyString())).thenThrow(new SQLTimeoutException("timeout")); + DatabaseMetaData metadata = Mockito.mock(DatabaseMetaData.class); + when(con.getMetaData()).thenReturn(metadata); + when(metadata.getTables(Mockito.isNull(), Mockito.anyString(), Mockito.anyString(), Mockito.any())).thenThrow(new SQLTimeoutException("timeout")); fCon.set(client, con); assertThrows(HadoopException.class, () -> client.getTableList("t*", Collections.singletonList("db1"), null)); } @@ -290,13 +287,12 @@ public void test17_getClmList_jdbcPath_excludesAndPattern() throws Exception { Field fCon = HiveClient.class.getDeclaredField("con"); fCon.setAccessible(true); Connection con = Mockito.mock(Connection.class); - Statement stat = Mockito.mock(Statement.class); + DatabaseMetaData metadata = Mockito.mock(DatabaseMetaData.class); ResultSet rs = Mockito.mock(ResultSet.class); - when(con.createStatement()).thenReturn(stat); - when(stat.execute(Mockito.eq("use db1"))).thenReturn(true); - when(stat.executeQuery(Mockito.eq("describe t1"))).thenReturn(rs); + when(con.getMetaData()).thenReturn(metadata); + when(metadata.getColumns(Mockito.isNull(), Mockito.eq("db1"), Mockito.eq("t1"), Mockito.anyString())).thenReturn(rs); when(rs.next()).thenReturn(true, true, false); - when(rs.getString(1)).thenReturn("c1", "c2"); + when(rs.getString("COLUMN_NAME")).thenReturn("c1", "c2"); fCon.set(client, con); List out = client.getColumnList("c*", Collections.singletonList("db1"), Collections.singletonList("t1"), Collections.singletonList("c2")); assertEquals(Collections.singletonList("c1"), out); @@ -311,10 +307,9 @@ public void test18_getClmList_jdbcPath_timeoutThrowsHadoopException() throws Exc Field fCon = HiveClient.class.getDeclaredField("con"); fCon.setAccessible(true); Connection con = Mockito.mock(Connection.class); - Statement stat = Mockito.mock(Statement.class); - when(con.createStatement()).thenReturn(stat); - when(stat.execute(Mockito.anyString())).thenReturn(true); - when(stat.executeQuery(Mockito.anyString())).thenThrow(new SQLTimeoutException("timeout")); + DatabaseMetaData metadata = Mockito.mock(DatabaseMetaData.class); + when(con.getMetaData()).thenReturn(metadata); + when(metadata.getColumns(Mockito.isNull(), Mockito.anyString(), Mockito.anyString(), Mockito.anyString())).thenThrow(new SQLTimeoutException("timeout")); fCon.set(client, con); assertThrows(HadoopException.class, () -> client.getColumnList("c*", Collections.singletonList("db1"), Collections.singletonList("t1"), null)); } @@ -476,6 +471,146 @@ public void test25_initHive_nonKerberosPath_invokesJdbcInitConnectionAndWraps() } } + @Test + public void test26_validateSqlIdentifier_validInput() throws Exception { + NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>()); + Method m = HiveClient.class.getSuperclass().getDeclaredMethod("validateSqlIdentifier", String.class, String.class); + m.setAccessible(true); + + m.invoke(client, "test_db123", "database"); + m.invoke(client, "table_name", "table"); + m.invoke(client, "col*", "column pattern"); + m.invoke(client, "db%", "database pattern"); + m.invoke(client, "a_b_c_123", "identifier"); + } + + @Test + public void test27_validateSqlIdentifier_sqlInjectionAttempts() throws Exception { + NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>()); + Method m = HiveClient.class.getSuperclass().getDeclaredMethod("validateSqlIdentifier", String.class, String.class); + m.setAccessible(true); + + assertThrows(HadoopException.class, () -> { + try { + m.invoke(client, "test\" OR 1=1 --", "database"); + } catch (Exception e) { + throw e.getCause(); + } + }, "Should reject double quote injection"); + + assertThrows(HadoopException.class, () -> { + try { + m.invoke(client, "testdb; DROP TABLE users; --", "database"); + } catch (Exception e) { + throw e.getCause(); + } + }, "Should reject semicolon command injection"); + + assertThrows(HadoopException.class, () -> { + try { + m.invoke(client, "test'; DROP DATABASE production; --", "table"); + } catch (Exception e) { + throw e.getCause(); + } + }, "Should reject single quote command injection"); + + assertThrows(HadoopException.class, () -> { + try { + m.invoke(client, "test\n--malicious", "identifier"); + } catch (Exception e) { + throw e.getCause(); + } + }, "Should reject newline injection"); + + assertThrows(HadoopException.class, () -> { + try { + m.invoke(client, "test`malicious`", "identifier"); + } catch (Exception e) { + throw e.getCause(); + } + }, "Should reject backtick injection"); + + assertThrows(HadoopException.class, () -> { + try { + m.invoke(client, "test$(whoami)", "identifier"); + } catch (Exception e) { + throw e.getCause(); + } + }, "Should reject shell command injection"); + + assertThrows(HadoopException.class, () -> { + try { + m.invoke(client, "../../../etc/passwd", "identifier"); + } catch (Exception e) { + throw e.getCause(); + } + }, "Should reject path traversal"); + } + + @Test + public void test28_convertToSqlPattern_convertsWildcards() throws Exception { + NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>()); + Method m = HiveClient.class.getSuperclass().getDeclaredMethod("convertToSqlPattern", String.class); + m.setAccessible(true); + + String result1 = (String) m.invoke(client, "test*"); + assertEquals("test%", result1, "Should convert * to %"); + + String result2 = (String) m.invoke(client, "*"); + assertEquals("%", result2, "Should convert single * to %"); + + String result3 = (String) m.invoke(client, "test*pattern*"); + assertEquals("test%pattern%", result3, "Should convert multiple *"); + + String result4 = (String) m.invoke(client, (Object) null); + assertEquals("%", result4, "Should handle null as %"); + + String result5 = (String) m.invoke(client, ""); + assertEquals("%", result5, "Should handle empty string as %"); + } + + @Test + public void test29_getDatabaseList_rejectsInjection() throws Exception { + NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>()); + Field fFlag = HiveClient.class.getDeclaredField("enableHiveMetastoreLookup"); + fFlag.setAccessible(true); + fFlag.set(client, false); + Field fCon = HiveClient.class.getDeclaredField("con"); + fCon.setAccessible(true); + Connection con = Mockito.mock(Connection.class); + fCon.set(client, con); + + assertThrows(HadoopException.class, () -> client.getDatabaseList("testdb\"; DROP DATABASE production; --", null), "Should reject SQL injection in database pattern"); + } + + @Test + public void test30_getTableList_rejectsInjectionInDbName() throws Exception { + NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>()); + Field fFlag = HiveClient.class.getDeclaredField("enableHiveMetastoreLookup"); + fFlag.setAccessible(true); + fFlag.set(client, false); + Field fCon = HiveClient.class.getDeclaredField("con"); + fCon.setAccessible(true); + Connection con = Mockito.mock(Connection.class); + fCon.set(client, con); + + assertThrows(HadoopException.class, () -> client.getTableList("valid", Collections.singletonList("testdb; DROP TABLE users; --"), null), "Should reject SQL injection in database name"); + } + + @Test + public void test31_getColumnList_rejectsInjectionInTableName() throws Exception { + NoopHiveClient client = new NoopHiveClient("svc", new HashMap<>()); + Field fFlag = HiveClient.class.getDeclaredField("enableHiveMetastoreLookup"); + fFlag.setAccessible(true); + fFlag.set(client, false); + Field fCon = HiveClient.class.getDeclaredField("con"); + fCon.setAccessible(true); + Connection con = Mockito.mock(Connection.class); + fCon.set(client, con); + + assertThrows(HadoopException.class, () -> client.getColumnList("valid", Collections.singletonList("db1"), Collections.singletonList("test'; DROP TABLE users; --"), null), "Should reject SQL injection in table name"); + } + public static class NoopHiveClient extends HiveClient { public boolean initCalled; diff --git a/knox-agent/src/main/java/org/apache/ranger/services/knox/client/KnoxClient.java b/knox-agent/src/main/java/org/apache/ranger/services/knox/client/KnoxClient.java index 7e06de932a..ec6d79e370 100644 --- a/knox-agent/src/main/java/org/apache/ranger/services/knox/client/KnoxClient.java +++ b/knox-agent/src/main/java/org/apache/ranger/services/knox/client/KnoxClient.java @@ -340,6 +340,7 @@ public List getServiceList(List knoxTopologyList, String service client.addFilter(new HTTPBasicAuthFilter(userName, decryptedPwd)); for (String topologyName : knoxTopologyList) { + validateResourceName(topologyName, "topology name"); WebResource webResource = client.resource(knoxUrl + "/" + topologyName); response = webResource.accept(EXPECTED_MIME_TYPE).get(ClientResponse.class); @@ -420,4 +421,32 @@ public List getServiceList(List knoxTopologyList, String service } return serviceList; } + + private void validateResourceName(String resourceName, String resourceType) { + if (resourceName == null) { + return; + } + + if (resourceName.contains("..") || resourceName.contains("//") || resourceName.contains("\\")) { + String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Path traversal patterns are not allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + ERROR_MSG, null, null); + + LOG.error(msgDesc); + + throw hdpException; + } + + if (!resourceName.matches("^[a-zA-Z0-9_.*\\-]+$")) { + String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Only alphanumeric characters, dots, underscores, hyphens, and wildcards are allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + ERROR_MSG, null, null); + + LOG.error(msgDesc); + + throw hdpException; + } + } } diff --git a/knox-agent/src/test/java/org/apache/ranger/services/knox/client/TestKnoxClient.java b/knox-agent/src/test/java/org/apache/ranger/services/knox/client/TestKnoxClient.java index e70f7b6914..8d3c7aa369 100644 --- a/knox-agent/src/test/java/org/apache/ranger/services/knox/client/TestKnoxClient.java +++ b/knox-agent/src/test/java/org/apache/ranger/services/knox/client/TestKnoxClient.java @@ -37,6 +37,7 @@ import java.io.ByteArrayOutputStream; import java.io.PrintStream; +import java.lang.reflect.Method; import java.security.Permission; import java.util.ArrayList; import java.util.Arrays; @@ -435,4 +436,85 @@ public void test20_main_validArgs_happyPath_printsServices() { System.setOut(origOut); } } + + @Test + public void test12_validateResourceName_rejectsPathTraversal() throws Exception { + KnoxClient client = new KnoxClient("https://localhost:8443/gateway/admin/api/v1/topologies", "admin", "pwd"); + + List pathTraversalInputs = Arrays.asList("../etc/passwd", "../../sensitive", "topology/../admin", "topology//malicious", "test\\windows\\path", "..\\..\\config"); + + for (String input : pathTraversalInputs) { + HadoopException ex = Assertions.assertThrows(HadoopException.class, + () -> invokeValidateResourceName(client, input, "topology name"), + "Path traversal should be rejected: " + input); + Assertions.assertTrue(ex.getMessage().contains("Path traversal"), + "Error should indicate path traversal for: " + input); + } + } + + @Test + public void test13_validateResourceName_rejectsSpecialCharacters() throws Exception { + KnoxClient client = new KnoxClient("https://localhost:8443/gateway/admin/api/v1/topologies", "admin", "pwd"); + + List invalidInputs = Arrays.asList("'; DROP TABLE users; --", "topology", "test@topology", "topology#name", "test!topology", "topology&name", "topology(with)parens", "topology{with}braces", "topology[with]brackets", "topology$name", "topology%encoded", "topology name", "topology\ttab", "topology\nnewline", "topology;rm -rf /", "topology|cat /etc/passwd", "topology`whoami`", "topology$(whoami)"); + + for (String input : invalidInputs) { + HadoopException ex = Assertions.assertThrows(HadoopException.class, + () -> invokeValidateResourceName(client, input, "topology name"), "Special characters should be rejected: " + input); + Assertions.assertTrue(ex.getMessage().contains("Invalid"), "Error should indicate invalid input for: " + input); + } + } + + @Test + public void test14_validateResourceName_acceptsValidNames() throws Exception { + KnoxClient client = new KnoxClient("https://localhost:8443/gateway/admin/api/v1/topologies", "admin", "pwd"); + + List validInputs = Arrays.asList("topology", "topology_name", "topology123", "TOPOLOGY", "Topology_Name_123", "_topology", "topology_", "topology.name", "topology-name", "topology*", "top*"); + + for (String input : validInputs) { + try { + invokeValidateResourceName(client, input, "topology name"); + } catch (Exception e) { + throw new AssertionError("Valid topology name should not throw exception: " + input, e); + } + } + } + + @Test + public void test15_validateResourceName_rejectsNullByteInjection() throws Exception { + KnoxClient client = new KnoxClient("https://localhost:8443/gateway/admin/api/v1/topologies", "admin", "pwd"); + + HadoopException ex = Assertions.assertThrows(HadoopException.class, + () -> invokeValidateResourceName(client, "topology\0null", "topology name")); + Assertions.assertTrue(ex.getMessage().contains("Invalid")); + } + + @Test + public void test16_validateResourceName_rejectsUrlEncoded() throws Exception { + KnoxClient client = new KnoxClient("https://localhost:8443/gateway/admin/api/v1/topologies", "admin", "pwd"); + + List encodedInputs = Arrays.asList("%2e%2e%2f", "topology%00", "test%20space"); + + for (String input : encodedInputs) { + HadoopException ex = Assertions.assertThrows(HadoopException.class, + () -> invokeValidateResourceName(client, input, "topology name"), + "URL encoded attack should be rejected: " + input); + Assertions.assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid input for: " + input); + } + } + + private void invokeValidateResourceName(KnoxClient client, String resourceName, String resourceType) throws Exception { + Method method = KnoxClient.class.getDeclaredMethod("validateResourceName", String.class, String.class); + method.setAccessible(true); + try { + method.invoke(client, resourceName, resourceType); + } catch (java.lang.reflect.InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof HadoopException) { + throw (HadoopException) cause; + } + throw e; + } + } } diff --git a/plugin-elasticsearch/pom.xml b/plugin-elasticsearch/pom.xml index 32f6203fc1..0461cd6859 100644 --- a/plugin-elasticsearch/pom.xml +++ b/plugin-elasticsearch/pom.xml @@ -79,5 +79,23 @@ + + org.junit.jupiter + junit-jupiter + ${junit.jupiter.version} + test + + + org.mockito + mockito-inline + ${mockito.version} + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + diff --git a/plugin-elasticsearch/src/main/java/org/apache/ranger/services/elasticsearch/client/ElasticsearchClient.java b/plugin-elasticsearch/src/main/java/org/apache/ranger/services/elasticsearch/client/ElasticsearchClient.java index d7802e8f7b..2a5e7aaf93 100644 --- a/plugin-elasticsearch/src/main/java/org/apache/ranger/services/elasticsearch/client/ElasticsearchClient.java +++ b/plugin-elasticsearch/src/main/java/org/apache/ranger/services/elasticsearch/client/ElasticsearchClient.java @@ -133,6 +133,7 @@ public List getIndexList(final String indexMatching, final List String indexApi; if (StringUtils.isNotEmpty(indexMatching)) { + validateUrlResourceName(indexMatching, "index pattern"); indexApi = '/' + indexMatching; if (!indexApi.endsWith("*")) { diff --git a/plugin-elasticsearch/src/test/java/org/apache/ranger/services/elasticsearch/client/TestElasticsearchClient.java b/plugin-elasticsearch/src/test/java/org/apache/ranger/services/elasticsearch/client/TestElasticsearchClient.java new file mode 100644 index 0000000000..46120eeb2c --- /dev/null +++ b/plugin-elasticsearch/src/test/java/org/apache/ranger/services/elasticsearch/client/TestElasticsearchClient.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.ranger.services.elasticsearch.client; + +import org.apache.ranger.plugin.client.HadoopException; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestMethodOrder; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@ExtendWith(MockitoExtension.class) +@TestMethodOrder(MethodOrderer.MethodName.class) +public class TestElasticsearchClient { + @Test + public void test01_validateUrlResourceName_rejectsPathTraversal() throws Exception { + Map configs = new HashMap<>(); + configs.put("elasticsearch.url", "http://localhost:9200"); + configs.put("username", "test"); + ElasticsearchClient client = new ElasticsearchClient("svc", configs); + + List pathTraversalInputs = Arrays.asList( + "../etc/passwd", + "../../sensitive", + "test/../admin", + "index//malicious", + "test\\windows\\path", + "..\\..\\config"); + + for (String input : pathTraversalInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> invokeValidateUrlResourceName(client, input, "index pattern"), + "Path traversal should be rejected: " + input); + assertTrue(ex.getMessage().contains("Path traversal"), + "Error should indicate path traversal for: " + input); + } + } + + @Test + public void test02_validateUrlResourceName_rejectsSpecialCharacters() throws Exception { + Map configs = new HashMap<>(); + configs.put("elasticsearch.url", "http://localhost:9200"); + configs.put("username", "test"); + ElasticsearchClient client = new ElasticsearchClient("svc", configs); + + List invalidInputs = Arrays.asList( + "'; DROP TABLE users; --", + "index", + "test@index", + "index#name", + "test!index", + "index&name", + "index(with)parens", + "index{with}braces", + "index[with]brackets", + "index$name", + "index%encoded", + "index name", + "index\ttab", + "index\nnewline", + "index;rm -rf /", + "index|cat /etc/passwd", + "index`whoami`", + "index$(whoami)"); + + for (String input : invalidInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> invokeValidateUrlResourceName(client, input, "index pattern"), + "Special characters should be rejected: " + input); + assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid input for: " + input); + } + } + + @Test + public void test03_validateUrlResourceName_acceptsValidNames() throws Exception { + Map configs = new HashMap<>(); + configs.put("elasticsearch.url", "http://localhost:9200"); + configs.put("username", "test"); + ElasticsearchClient client = new ElasticsearchClient("svc", configs); + + List validInputs = Arrays.asList( + "index", + "index_name", + "index123", + "INDEX", + "Index_Name_123", + "_index", + "index_", + "index.name", + "index-name", + "index*", + "idx*"); + + for (String input : validInputs) { + try { + invokeValidateUrlResourceName(client, input, "index pattern"); + } catch (Exception e) { + throw new AssertionError("Valid index name should not throw exception: " + input, e); + } + } + } + + @Test + public void test04_validateUrlResourceName_rejectsNullByteInjection() throws Exception { + Map configs = new HashMap<>(); + configs.put("elasticsearch.url", "http://localhost:9200"); + configs.put("username", "test"); + ElasticsearchClient client = new ElasticsearchClient("svc", configs); + + HadoopException ex = assertThrows(HadoopException.class, + () -> invokeValidateUrlResourceName(client, "index\0null", "index pattern")); + assertTrue(ex.getMessage().contains("Invalid")); + } + + @Test + public void test05_validateUrlResourceName_rejectsUrlEncoded() throws Exception { + Map configs = new HashMap<>(); + configs.put("elasticsearch.url", "http://localhost:9200"); + configs.put("username", "test"); + ElasticsearchClient client = new ElasticsearchClient("svc", configs); + + List encodedInputs = Arrays.asList( + "%2e%2e%2f", + "index%00", + "test%20space"); + + for (String input : encodedInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> invokeValidateUrlResourceName(client, input, "index pattern"), + "URL encoded attack should be rejected: " + input); + assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid input for: " + input); + } + } + + private void invokeValidateUrlResourceName(ElasticsearchClient client, String resourceName, String resourceType) throws Exception { + Method method = ElasticsearchClient.class.getSuperclass().getDeclaredMethod("validateUrlResourceName", String.class, String.class); + method.setAccessible(true); + try { + method.invoke(client, resourceName, resourceType); + } catch (java.lang.reflect.InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof HadoopException) { + throw (HadoopException) cause; + } + throw e; + } + } +} diff --git a/plugin-presto/pom.xml b/plugin-presto/pom.xml index 3a01427150..15431c74c9 100644 --- a/plugin-presto/pom.xml +++ b/plugin-presto/pom.xml @@ -107,6 +107,18 @@ ${junit.jupiter.version} test + + org.mockito + mockito-inline + ${mockito.version} + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + diff --git a/plugin-presto/src/main/java/org/apache/ranger/services/presto/client/PrestoClient.java b/plugin-presto/src/main/java/org/apache/ranger/services/presto/client/PrestoClient.java index 2492dd718a..c034824907 100644 --- a/plugin-presto/src/main/java/org/apache/ranger/services/presto/client/PrestoClient.java +++ b/plugin-presto/src/main/java/org/apache/ranger/services/presto/client/PrestoClient.java @@ -19,7 +19,6 @@ package org.apache.ranger.services.presto.client; import org.apache.commons.io.FilenameUtils; -import org.apache.commons.lang3.StringUtils; import org.apache.ranger.plugin.client.BaseClient; import org.apache.ranger.plugin.client.HadoopConfigHolder; import org.apache.ranger.plugin.client.HadoopException; @@ -118,6 +117,8 @@ public List getSchemaList(String needle, List catalogs, List getCatalogs(String needle, List catalogs) throws Ha List ret = new ArrayList<>(); if (con != null) { - Statement stat = null; - ResultSet rs = null; - String sql = "SHOW CATALOGS"; + ResultSet rs = null; try { - if (needle != null && !needle.isEmpty() && !needle.equals("*")) { - // Cannot use a prepared statement for this as presto does not support that - sql += " LIKE '" + escapeSql(needle) + "%'"; - } - - stat = con.createStatement(); - rs = stat.executeQuery(sql); + validateSqlIdentifier(needle, "catalog pattern"); + String catalogPattern = convertToSqlPattern(needle); + rs = con.getMetaData().getCatalogs(); while (rs.next()) { - String catalogName = rs.getString(1); + String catalogName = rs.getString("TABLE_CAT"); if (catalogs != null && catalogs.contains(catalogName)) { continue; } - ret.add(catalogName); + if (catalogPattern == null || catalogPattern.equals("%") || matchesSqlPattern(catalogName, catalogPattern)) { + ret.add(catalogName); + } } } catch (SQLTimeoutException sqlt) { - String msgDesc = "Time Out, Unable to execute SQL [" + sql + "]."; + String msgDesc = "Time Out, Unable to retrieve catalog list."; HadoopException hdpException = new HadoopException(msgDesc, sqlt); hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); + throw hdpException; } catch (SQLException se) { - String msg = "Unable to execute SQL [" + sql + "]. "; + String msg = "Unable to retrieve catalog list. "; HadoopException he = new HadoopException(msg, se); he.generateResponseDataMap(false, getMessage(se), msg + ERR_MSG, null, null); + throw he; + } catch (HadoopException he) { throw he; } finally { close(rs); - close(stat); } } @@ -347,43 +346,31 @@ private List getSchemas(String needle, List catalogs, List ret = new ArrayList<>(); if (con != null) { - Statement stat = null; - ResultSet rs = null; - String sql = null; + ResultSet rs = null; try { + validateSqlIdentifier(needle, "schema pattern"); + String schemaPattern = convertToSqlPattern(needle); if (catalogs != null && !catalogs.isEmpty()) { for (String catalog : catalogs) { - sql = "SHOW SCHEMAS FROM \"" + escapeSql(catalog) + "\""; - - try { - if (needle != null && !needle.isEmpty() && !needle.equals("*")) { - sql += " LIKE '" + escapeSql(needle) + "%'"; - } - - stat = con.createStatement(); - rs = stat.executeQuery(sql); + validateSqlIdentifier(catalog, "catalog name"); + rs = con.getMetaData().getSchemas(catalog, schemaPattern); - while (rs.next()) { - String schema = rs.getString(1); + while (rs.next()) { + String schema = rs.getString("TABLE_SCHEM"); - if (schemas != null && schemas.contains(schema)) { - continue; - } - - ret.add(schema); + if (schemas != null && schemas.contains(schema)) { + continue; } - } finally { - close(rs); - close(stat); - rs = null; - stat = null; + ret.add(schema); } + close(rs); + rs = null; } } } catch (SQLTimeoutException sqlt) { - String msgDesc = "Time Out, Unable to execute SQL [" + sql + "]."; + String msgDesc = "Time Out, Unable to retrieve schema list."; HadoopException hdpException = new HadoopException(msgDesc, sqlt); hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); @@ -392,7 +379,7 @@ private List getSchemas(String needle, List catalogs, List getSchemas(String needle, List catalogs, List getTables(String needle, List catalogs, List ret = new ArrayList<>(); if (con != null) { - Statement stat = null; - ResultSet rs = null; - String sql = null; + ResultSet rs = null; - if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty()) { - try { + try { + validateSqlIdentifier(needle, "table pattern"); + String tablePattern = convertToSqlPattern(needle); + if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty()) { for (String catalog : catalogs) { + validateSqlIdentifier(catalog, "catalog name"); for (String schema : schemas) { - sql = "SHOW tables FROM \"" + escapeSql(catalog) + "\".\"" + escapeSql(schema) + "\""; - - try { - if (needle != null && !needle.isEmpty() && !needle.equals("*")) { - sql += " LIKE '" + escapeSql(needle) + "%'"; - } + validateSqlIdentifier(schema, "schema name"); + rs = con.getMetaData().getTables(catalog, schema, tablePattern, new String[] {"TABLE", "VIEW"}); - stat = con.createStatement(); - rs = stat.executeQuery(sql); - - while (rs.next()) { - String table = rs.getString(1); - - if (tables != null && tables.contains(table)) { - continue; - } + while (rs.next()) { + String table = rs.getString("TABLE_NAME"); - ret.add(table); + if (tables != null && tables.contains(table)) { + continue; } - } finally { - close(rs); - close(stat); - rs = null; - stat = null; + ret.add(table); } + close(rs); + rs = null; } } - } catch (SQLTimeoutException sqlt) { - String msgDesc = "Time Out, Unable to execute SQL [" + sql + "]."; - HadoopException hdpException = new HadoopException(msgDesc, sqlt); + } + } catch (SQLTimeoutException sqlt) { + String msgDesc = "Time Out, Unable to retrieve table list."; + HadoopException hdpException = new HadoopException(msgDesc, sqlt); - hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); + hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); - LOG.debug("<== PrestoClient.getTables() Error : ", sqlt); + LOG.debug("<== PrestoClient.getTables() Error : ", sqlt); - throw hdpException; - } catch (SQLException sqle) { - String msgDesc = "Unable to execute SQL [" + sql + "]."; - HadoopException hdpException = new HadoopException(msgDesc, sqle); + throw hdpException; + } catch (SQLException sqle) { + String msgDesc = "Unable to retrieve table list."; + HadoopException hdpException = new HadoopException(msgDesc, sqle); - hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null); + hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null); - LOG.debug("<== PrestoClient.getTables() Error : ", sqle); + LOG.debug("<== PrestoClient.getTables() Error : ", sqle); - throw hdpException; - } + throw hdpException; + } catch (HadoopException he) { + throw he; + } finally { + close(rs); } } @@ -477,66 +461,64 @@ private List getColumns(String needle, List catalogs, List()); + } + + private PrestoClient createMockedClient(Map props) throws Exception { + Map config = new HashMap<>(); + config.put("username", "test"); + config.put("password", "test"); + config.put("jdbc.driverClassName", "com.facebook.presto.jdbc.PrestoDriver"); + config.put("jdbc.url", "jdbc:presto://localhost:8080"); + config.putAll(props); + + NoConnectionPrestoClient client = new NoConnectionPrestoClient("svc", config); + return client; + } + + @Test + public void test01_getCatalogList_normalOperation() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + DatabaseMetaData metadata = mock(DatabaseMetaData.class); + ResultSet rs = mock(ResultSet.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + when(mockCon.getMetaData()).thenReturn(metadata); + when(metadata.getCatalogs()).thenReturn(rs); + when(rs.next()).thenReturn(true, true, false); + when(rs.getString("TABLE_CAT")).thenReturn("catalog1", "catalog2"); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + PrestoClient client = createMockedClient(); + List out = client.getCatalogList("*", null); + assertEquals(Arrays.asList("catalog1", "catalog2"), out); + } + } + + @Test + public void test02_getCatalogList_withExcludeList() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + DatabaseMetaData metadata = mock(DatabaseMetaData.class); + ResultSet rs = mock(ResultSet.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + when(mockCon.getMetaData()).thenReturn(metadata); + when(metadata.getCatalogs()).thenReturn(rs); + when(rs.next()).thenReturn(true, true, false); + when(rs.getString("TABLE_CAT")).thenReturn("catalog1", "catalog2"); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + PrestoClient client = createMockedClient(); + List out = client.getCatalogList("*", Collections.singletonList("catalog1")); + assertEquals(Collections.singletonList("catalog2"), out); + } + } + + @Test + public void test03_getCatalogList_timeoutThrowsHadoopException() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + DatabaseMetaData metadata = mock(DatabaseMetaData.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + when(mockCon.getMetaData()).thenReturn(metadata); + when(metadata.getCatalogs()).thenThrow(SQLTimeoutException.class); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + PrestoClient client = createMockedClient(); + assertThrows(HadoopException.class, () -> client.getCatalogList("*", null)); + } + } + + @Test + public void test04_validateSqlIdentifier_rejectsSqlInjection() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + PrestoClient client = createMockedClient(); + + List maliciousInputs = Arrays.asList( + "'; DROP TABLE users; --", + "test\" OR 1=1 --", + "catalog'; DELETE FROM users; --", + "../../../etc/passwd", + "test"); + + for (String input : maliciousInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getCatalogList(input, null), + "SQL injection attempt should be rejected: " + input); + assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid input for: " + input); + } + } + } + + @Test + public void test05_validateSqlIdentifier_rejectsSpecialCharacters() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + PrestoClient client = createMockedClient(); + + List invalidInputs = Arrays.asList( + "test@catalog", + "catalog#name", + "table!name", + "name(with)parens"); + + for (String input : invalidInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getCatalogList(input, null), + "Special characters should be rejected: " + input); + assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid input for: " + input); + } + } + } + + @Test + public void test06_schemaValidation_rejectsSqlInjection() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + PrestoClient client = createMockedClient(); + Field fCon = PrestoClient.class.getDeclaredField("con"); + fCon.setAccessible(true); + fCon.set(client, mockCon); + + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getSchemaList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + } + + @Test + public void test07_tableValidation_rejectsSqlInjection() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + PrestoClient client = createMockedClient(); + + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getTableList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), Collections.singletonList("schema1"), null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + } + + @Test + public void test08_columnValidation_rejectsSqlInjection() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + PrestoClient client = createMockedClient(); + + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getColumnList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), Collections.singletonList("schema1"), Collections.singletonList("table1"), null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + } + + @Test + public void test09_catalogName_validateInList() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + DatabaseMetaData metadata = mock(DatabaseMetaData.class); + ResultSet rs = mock(ResultSet.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + when(mockCon.getMetaData()).thenReturn(metadata); + when(metadata.getSchemas(anyString(), anyString())).thenReturn(rs); + when(rs.next()).thenReturn(false); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + PrestoClient client = createMockedClient(); + Field fCon = PrestoClient.class.getDeclaredField("con"); + fCon.setAccessible(true); + fCon.set(client, mockCon); + + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getSchemaList("schema1", Arrays.asList("catalog1", "'; DROP TABLE --"), null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + } + + @Test + public void test10_schemaName_validateInList() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + DatabaseMetaData metadata = mock(DatabaseMetaData.class); + ResultSet rs = mock(ResultSet.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + when(mockCon.getMetaData()).thenReturn(metadata); + when(metadata.getTables(anyString(), anyString(), anyString(), any())).thenReturn(rs); + when(rs.next()).thenReturn(false); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + PrestoClient client = createMockedClient(); + + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getTableList("table1", Collections.singletonList("catalog1"), Arrays.asList("schema1", "'; DROP --"), null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + } + + private static class NoConnectionPrestoClient extends PrestoClient { + public NoConnectionPrestoClient(String serviceName, Map connectionProperties) { + super(serviceName, connectionProperties); + } + + @Override + protected Subject getLoginSubject() { + Subject subject = new Subject(); + return subject; + } + + @Override + protected void login() { + } + } +} diff --git a/plugin-schema-registry/src/main/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgent.java b/plugin-schema-registry/src/main/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgent.java index 074d0058bd..ac5d3a0b20 100644 --- a/plugin-schema-registry/src/main/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgent.java +++ b/plugin-schema-registry/src/main/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgent.java @@ -118,15 +118,63 @@ public String toString() { } List expandSchemaMetadataNameRegex(List schemaGroupList, String lookupSchemaMetadataName) { + validatePattern(lookupSchemaMetadataName, "schema metadata pattern"); + String safePattern = convertWildcardToRegex(lookupSchemaMetadataName); List res = new ArrayList<>(); Collection schemas = client.getSchemaNames(schemaGroupList); schemas.forEach(sName -> { - if (sName.matches(lookupSchemaMetadataName)) { + if (sName.matches(safePattern)) { res.add(sName); } }); return res; } + + private void validatePattern(String pattern, String patternType) { + if (pattern == null || pattern.isEmpty()) { + return; + } + if (!pattern.matches("^[a-zA-Z0-9*?\\[\\]\\-\\$%\\{\\}\\=\\/\\._]+$")) { + String msgDesc = "Invalid " + patternType + ": [" + pattern + "]. Only alphanumeric characters along with ( ., _, -, *, ?, [], {}, %, $, = / ) are allowed."; + LOG.error(msgDesc); + throw new IllegalArgumentException(msgDesc); + } + } + + protected String convertWildcardToRegex(String wildcard) { + if (wildcard == null || wildcard.isEmpty()) { + return ".*"; + } + StringBuilder regex = new StringBuilder("^"); + for (int i = 0; i < wildcard.length(); i++) { + char c = wildcard.charAt(i); + switch (c) { + case '*': + regex.append(".*"); + break; + case '?': + regex.append("."); + break; + case '.': + case '\\': + case '^': + case '$': + case '|': + regex.append('\\').append(c); + break; + case '{': + case '}': + case '[': + case ']': + regex.append('\\').append(c); + break; + default: + regex.append(c); + } + } + regex.append('$'); + return regex.toString(); + } } diff --git a/plugin-schema-registry/src/test/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgentTest.java b/plugin-schema-registry/src/test/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgentTest.java index e578ec82ac..779ce3bf30 100644 --- a/plugin-schema-registry/src/test/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgentTest.java +++ b/plugin-schema-registry/src/test/java/org/apache/ranger/services/schema/registry/client/AutocompletionAgentTest.java @@ -19,6 +19,7 @@ import org.apache.ranger.services.schema.registry.client.connection.ISchemaRegistryClient; import org.apache.ranger.services.schema.registry.client.util.DefaultSchemaRegistryClientForTesting; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.ArrayList; @@ -27,8 +28,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; public class AutocompletionAgentTest { @Test @@ -37,11 +36,11 @@ public void connectionTest() { AutocompletionAgent autocompletionAgent = new AutocompletionAgent("schema-registry", client); HashMap res = autocompletionAgent.connectionTest(); - assertEquals(true, res.get("connectivityStatus")); - assertEquals("ConnectionTest Successful", res.get("message")); - assertEquals("ConnectionTest Successful", res.get("description")); - assertNull(res.get("objectId")); - assertNull(res.get("fieldName")); + Assertions.assertEquals(true, res.get("connectivityStatus")); + Assertions.assertEquals("ConnectionTest Successful", res.get("message")); + Assertions.assertEquals("ConnectionTest Successful", res.get("description")); + Assertions.assertNull(res.get("objectId")); + Assertions.assertNull(res.get("fieldName")); client = new DefaultSchemaRegistryClientForTesting() { public void checkConnection() throws Exception { @@ -52,11 +51,11 @@ public void checkConnection() throws Exception { res = autocompletionAgent.connectionTest(); String errMessage = "You can still save the repository and start creating policies, but you would not be able to use autocomplete for resource names. Check server logs for more info."; - assertEquals(false, res.get("connectivityStatus")); + Assertions.assertEquals(false, res.get("connectivityStatus")); assertThat(res.get("message"), is(errMessage)); assertThat(res.get("description"), is(errMessage)); - assertNull(res.get("objectId")); - assertNull(res.get("fieldName")); + Assertions.assertNull(res.get("objectId")); + Assertions.assertNull(res.get("fieldName")); } @Test @@ -75,7 +74,7 @@ public List getSchemaGroups() { // doesn't contain any groups that starts with 'tesSome' List initialGroups = new ArrayList<>(); List res = autocompletionAgent.getSchemaGroupList("tesSome", initialGroups); - assertEquals(0, res.size()); + Assertions.assertEquals(0, res.size()); // Empty initialGroups and the list of groups returned by ISchemaRegistryClient // contains a group that starts with 'tes' @@ -83,7 +82,7 @@ public List getSchemaGroups() { res = autocompletionAgent.getSchemaGroupList("tes", initialGroups); List expected = new ArrayList<>(); expected.add("testGroup"); - assertEquals(1, res.size()); + Assertions.assertEquals(1, res.size()); assertThat(res, is(expected)); // initialGroups contains one element, list of the groups returned by ISchemaRegistryClient @@ -93,7 +92,7 @@ public List getSchemaGroups() { res = autocompletionAgent.getSchemaGroupList("tes", initialGroups); expected = new ArrayList<>(); expected.add("testGroup"); - assertEquals(1, res.size()); + Assertions.assertEquals(1, res.size()); assertThat(res, is(expected)); // initialGroups contains one element, list of the groups returned by ISchemaRegistryClient @@ -104,7 +103,7 @@ public List getSchemaGroups() { expected = new ArrayList<>(); expected.add("testGroup2"); expected.add("testGroup"); - assertEquals(2, res.size()); + Assertions.assertEquals(2, res.size()); assertThat(res, is(expected)); } @@ -129,11 +128,11 @@ public List getSchemaNames(List schemaGroup) { List res = autocompletionAgent.getSchemaMetadataList("tes", groupList, new ArrayList<>()); List expected = new ArrayList<>(); expected.add("testSchema"); - assertEquals(1, res.size()); + Assertions.assertEquals(1, res.size()); assertThat(res, is(expected)); res = autocompletionAgent.getSchemaMetadataList("tesSome", groupList, new ArrayList<>()); - assertEquals(0, res.size()); + Assertions.assertEquals(0, res.size()); } @Test @@ -168,10 +167,114 @@ public List getSchemaBranches(String schemaMetadataName) { List res = autocompletionAgent.getBranchList("tes", groups, schemaList, new ArrayList<>()); List expected = new ArrayList<>(); expected.add("testBranch"); - assertEquals(1, res.size()); + Assertions.assertEquals(1, res.size()); assertThat(res, is(expected)); res = autocompletionAgent.getSchemaMetadataList("tesSome", schemaList, new ArrayList<>()); - assertEquals(0, res.size()); + Assertions.assertEquals(0, res.size()); + } + + @Test + void testValidatePattern_validAlphanumeric() { + ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting() { + public List getSchemaNames(List schemaGroup) { + List schemas = new ArrayList<>(); + schemas.add("mySchema123"); + return schemas; + } + }; + AutocompletionAgent agent = new AutocompletionAgent("test", client); + List groups = new ArrayList<>(); + groups.add("testGroup"); + List result = agent.expandSchemaMetadataNameRegex(groups, "mySchema123"); + Assertions.assertEquals(1, result.size()); + Assertions.assertEquals("mySchema123", result.get(0)); + } + + @Test + void testValidatePattern_validWildcards() { + ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting() { + public List getSchemaNames(List schemaGroup) { + List schemas = new ArrayList<>(); + schemas.add("mySchema123"); + schemas.add("testSchema"); + return schemas; + } + }; + AutocompletionAgent agent = new AutocompletionAgent("test", client); + List groups = new ArrayList<>(); + groups.add("testGroup"); + List result = agent.expandSchemaMetadataNameRegex(groups, "my*"); + Assertions.assertEquals(1, result.size()); + Assertions.assertEquals("mySchema123", result.get(0)); + } + + @Test + void testValidatePattern_rejectsReDoSPattern() { + ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting(); + AutocompletionAgent agent = new AutocompletionAgent("test", client); + List groups = new ArrayList<>(); + groups.add("testGroup"); + Assertions.assertThrows(IllegalArgumentException.class, () -> { + agent.expandSchemaMetadataNameRegex(groups, "(a+)+"); + }); + } + + @Test + void testValidatePattern_rejectsComplexRegex() { + ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting(); + AutocompletionAgent agent = new AutocompletionAgent("test", client); + List groups = new ArrayList<>(); + groups.add("testGroup"); + Assertions.assertThrows(IllegalArgumentException.class, () -> { + agent.expandSchemaMetadataNameRegex(groups, "test{1,5}"); + }); + } + + @Test + void testValidatePattern_rejectsInjectionAttempt() { + ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting(); + AutocompletionAgent agent = new AutocompletionAgent("test", client); + List groups = new ArrayList<>(); + groups.add("testGroup"); + Assertions.assertThrows(IllegalArgumentException.class, () -> { + agent.expandSchemaMetadataNameRegex(groups, "test'; DROP TABLE users--"); + }); + } + + @Test + void testConvertWildcardToRegex_asterisk() { + ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting() { + public List getSchemaNames(List schemaGroup) { + List schemas = new ArrayList<>(); + schemas.add("testSchema"); + schemas.add("prodSchema"); + return schemas; + } + }; + AutocompletionAgent agent = new AutocompletionAgent("test", client); + List groups = new ArrayList<>(); + groups.add("testGroup"); + List result = agent.expandSchemaMetadataNameRegex(groups, "test*"); + Assertions.assertEquals(1, result.size()); + Assertions.assertEquals("testSchema", result.get(0)); + } + + @Test + void testConvertWildcardToRegex_questionMark() { + ISchemaRegistryClient client = new DefaultSchemaRegistryClientForTesting() { + public List getSchemaNames(List schemaGroup) { + List schemas = new ArrayList<>(); + schemas.add("schema1"); + schemas.add("schema12"); + return schemas; + } + }; + AutocompletionAgent agent = new AutocompletionAgent("test", client); + List groups = new ArrayList<>(); + groups.add("testGroup"); + List result = agent.expandSchemaMetadataNameRegex(groups, "schema?"); + Assertions.assertEquals(1, result.size()); + Assertions.assertEquals("schema1", result.get(0)); } } diff --git a/plugin-solr/pom.xml b/plugin-solr/pom.xml index f409216436..38b8c08223 100644 --- a/plugin-solr/pom.xml +++ b/plugin-solr/pom.xml @@ -119,5 +119,23 @@ + + org.junit.jupiter + junit-jupiter + ${junit.jupiter.version} + test + + + org.mockito + mockito-inline + ${mockito.version} + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + diff --git a/plugin-solr/src/main/java/org/apache/ranger/services/solr/client/ServiceSolrClient.java b/plugin-solr/src/main/java/org/apache/ranger/services/solr/client/ServiceSolrClient.java index 3966c11beb..dd7f9be001 100644 --- a/plugin-solr/src/main/java/org/apache/ranger/services/solr/client/ServiceSolrClient.java +++ b/plugin-solr/src/main/java/org/apache/ranger/services/solr/client/ServiceSolrClient.java @@ -418,10 +418,10 @@ private List getCoresList(List ignoreCollectionList) throws Exce } private List getFieldList(String collection, List ignoreFieldList) throws Exception { - // TODO: Best is to get the collections based on the collection value which could contain wild cards String queryStr = ""; if (collection != null && !collection.isEmpty()) { + validateResourceName(collection, "collection name"); queryStr += "/" + collection; } @@ -619,6 +619,34 @@ private void login(Map configs) { } } + private void validateResourceName(String resourceName, String resourceType) { + if (resourceName == null) { + return; + } + + if (resourceName.contains("..") || resourceName.contains("//") || resourceName.contains("\\")) { + String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Path traversal patterns are not allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + RangerSolrConstants.errMessage, null, null); + + LOG.error(msgDesc); + + throw hdpException; + } + + if (!resourceName.matches("^[a-zA-Z0-9_.*\\-]+$")) { + String msgDesc = "Invalid " + resourceType + ": [" + resourceName + "]. Only alphanumeric characters, dots, underscores, hyphens, and wildcards are allowed."; + HadoopException hdpException = new HadoopException(msgDesc); + + hdpException.generateResponseDataMap(false, msgDesc, msgDesc + RangerSolrConstants.errMessage, null, null); + + LOG.error(msgDesc); + + throw hdpException; + } + } + private HadoopException createException(String msgDesc, Exception exp) { HadoopException hdpException = new HadoopException(msgDesc, exp); final String fullDescription = exp != null ? BaseClient.getMessage(exp) : msgDesc; diff --git a/plugin-solr/src/test/java/org/apache/ranger/services/solr/client/TestServiceSolrClient.java b/plugin-solr/src/test/java/org/apache/ranger/services/solr/client/TestServiceSolrClient.java new file mode 100644 index 0000000000..0e022582d9 --- /dev/null +++ b/plugin-solr/src/test/java/org/apache/ranger/services/solr/client/TestServiceSolrClient.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.ranger.services.solr.client; + +import org.apache.ranger.plugin.client.HadoopException; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestMethodOrder; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@ExtendWith(MockitoExtension.class) +@TestMethodOrder(MethodOrderer.MethodName.class) +public class TestServiceSolrClient { + @Test + public void test01_validateResourceName_rejectsPathTraversal() throws Exception { + Map configs = new HashMap<>(); + configs.put("username", "test"); + configs.put("password", "test"); + NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false); + + List pathTraversalInputs = Arrays.asList("../etc/passwd", "../../sensitive", "test/../admin", "collection//malicious", "test\\windows\\path", "..\\..\\config"); + + for (String input : pathTraversalInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> invokeGetFieldList(client, input, null), + "Path traversal should be rejected: " + input); + assertTrue(ex.getMessage().contains("Path traversal"), + "Error should indicate path traversal for: " + input); + } + } + + @Test + public void test02_validateResourceName_rejectsSpecialCharacters() throws Exception { + Map configs = new HashMap<>(); + configs.put("username", "test"); + configs.put("password", "test"); + NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false); + + List invalidInputs = Arrays.asList("'; DROP TABLE users; --", "collection", "test@collection", "collection#name", "test!collection", "collection&name", "collection(with)parens", "collection{with}braces", "collection[with]brackets", "collection$name", "collection%encoded", "collection name", "collection\ttab", "collection\nnewline"); + + for (String input : invalidInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> invokeGetFieldList(client, input, null), + "Special characters should be rejected: " + input); + assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid input for: " + input); + } + } + + @Test + public void test03_validateResourceName_acceptsValidNames() throws Exception { + Map configs = new HashMap<>(); + configs.put("username", "test"); + configs.put("password", "test"); + NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false); + + List validInputs = Arrays.asList("collection", "collection_name", "collection123", "COLLECTION", "Collection_Name_123", "_collection", "collection_", "collection.name", "collection-name", "collection*", "coll*"); + + Method validateMethod = ServiceSolrClient.class.getDeclaredMethod("validateResourceName", String.class, String.class); + validateMethod.setAccessible(true); + + for (String input : validInputs) { + try { + validateMethod.invoke(client, input, "collection name"); + } catch (Exception e) { + throw new AssertionError("Valid collection name should not throw exception: " + input, e); + } + } + } + + @Test + public void test04_validateResourceName_rejectsNullByteInjection() throws Exception { + Map configs = new HashMap<>(); + configs.put("username", "test"); + configs.put("password", "test"); + NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false); + + HadoopException ex = assertThrows(HadoopException.class, + () -> invokeGetFieldList(client, "collection\0null", null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + + @Test + public void test05_validateResourceName_rejectsCommandInjection() throws Exception { + Map configs = new HashMap<>(); + configs.put("username", "test"); + configs.put("password", "test"); + NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false); + + List commandInjectionInputs = Arrays.asList("collection;rm -rf /", "collection|cat /etc/passwd", "collection`whoami`", "collection$(whoami)", "collection&&ls"); + + for (String input : commandInjectionInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> invokeGetFieldList(client, input, null), + "Command injection should be rejected: " + input); + assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid input for: " + input); + } + } + + @Test + public void test06_validateResourceName_rejectsUrlEncoded() throws Exception { + Map configs = new HashMap<>(); + configs.put("username", "test"); + configs.put("password", "test"); + NoopServiceSolrClient client = new NoopServiceSolrClient("svc", configs, "http://localhost:8983/solr", false); + + List encodedInputs = Arrays.asList("%2e%2e%2f", "collection%00", "test%20space"); + + for (String input : encodedInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> invokeGetFieldList(client, input, null), + "URL encoded attack should be rejected: " + input); + assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid input for: " + input); + } + } + + private List invokeGetFieldList(ServiceSolrClient client, String collection, List ignoreList) throws Exception { + Method method = ServiceSolrClient.class.getDeclaredMethod("getFieldList", String.class, List.class); + method.setAccessible(true); + try { + return (List) method.invoke(client, collection, ignoreList); + } catch (java.lang.reflect.InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof HadoopException) { + throw (HadoopException) cause; + } + throw e; + } + } + + private static class NoopServiceSolrClient extends ServiceSolrClient { + public NoopServiceSolrClient(String serviceName, Map configs, String url, boolean isSolrCloud) { + super(serviceName, configs, url, isSolrCloud); + } + } +} diff --git a/plugin-trino/pom.xml b/plugin-trino/pom.xml index 3db390b253..cfb6d628c2 100644 --- a/plugin-trino/pom.xml +++ b/plugin-trino/pom.xml @@ -167,6 +167,18 @@ ${junit.jupiter.version} test + + org.mockito + mockito-inline + ${mockito.version} + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + diff --git a/plugin-trino/src/main/java/org/apache/ranger/services/trino/client/TrinoClient.java b/plugin-trino/src/main/java/org/apache/ranger/services/trino/client/TrinoClient.java index 7325347b8c..ef988695b4 100644 --- a/plugin-trino/src/main/java/org/apache/ranger/services/trino/client/TrinoClient.java +++ b/plugin-trino/src/main/java/org/apache/ranger/services/trino/client/TrinoClient.java @@ -14,7 +14,6 @@ package org.apache.ranger.services.trino.client; import org.apache.commons.io.FilenameUtils; -import org.apache.commons.lang3.StringUtils; import org.apache.ranger.plugin.client.BaseClient; import org.apache.ranger.plugin.client.HadoopConfigHolder; import org.apache.ranger.plugin.client.HadoopException; @@ -115,6 +114,8 @@ public List getSchemaList(String needle, List catalogs, List getCatalogs(String needle, List catalogs) List ret = new ArrayList<>(); if (con != null) { - Statement stat = null; - ResultSet rs = null; - String sql = "SHOW CATALOGS"; + ResultSet rs = null; try { - if (needle != null && !needle.isEmpty() && !needle.equals("*")) { - // Cannot use a prepared statement for this as trino does not support that - sql += " LIKE '" + escapeSql(needle) + "%'"; - } - - stat = con.createStatement(); - rs = stat.executeQuery(sql); + validateSqlIdentifier(needle, "catalog pattern"); + String catalogPattern = convertToSqlPattern(needle); + rs = con.getMetaData().getCatalogs(); while (rs.next()) { - String catalogName = rs.getString(1); + String catalogName = rs.getString("TABLE_CAT"); if (catalogs != null && catalogs.contains(catalogName)) { continue; } - ret.add(catalogName); + if (catalogPattern == null || catalogPattern.equals("%") || matchesSqlPattern(catalogName, catalogPattern)) { + ret.add(catalogName); + } } } catch (SQLTimeoutException sqlt) { - String msgDesc = "Time Out, Unable to execute SQL [" + sql + "]."; + String msgDesc = "Time Out, Unable to retrieve catalog list."; HadoopException hdpException = new HadoopException(msgDesc, sqlt); hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); throw hdpException; } catch (SQLException se) { - String msg = "Unable to execute SQL [" + sql + "]. "; + String msg = "Unable to retrieve catalog list. "; HadoopException he = new HadoopException(msg, se); he.generateResponseDataMap(false, getMessage(se), msg + ERR_MSG, null, null); + throw he; + } catch (HadoopException he) { throw he; } finally { close(rs); - close(stat); } } return ret; @@ -350,43 +348,31 @@ private List getSchemas(String needle, List catalogs, List ret = new ArrayList<>(); if (con != null) { - Statement stat = null; - ResultSet rs = null; - String sql = null; + ResultSet rs = null; try { + validateSqlIdentifier(needle, "schema pattern"); + String schemaPattern = convertToSqlPattern(needle); if (catalogs != null && !catalogs.isEmpty()) { for (String catalog : catalogs) { - sql = "SHOW SCHEMAS FROM \"" + escapeSql(catalog) + "\""; - - try { - if (needle != null && !needle.isEmpty() && !needle.equals("*")) { - sql += " LIKE '" + escapeSql(needle) + "%'"; - } + validateSqlIdentifier(catalog, "catalog name"); + rs = con.getMetaData().getSchemas(catalog, schemaPattern); - stat = con.createStatement(); - rs = stat.executeQuery(sql); - - while (rs.next()) { - String schema = rs.getString(1); + while (rs.next()) { + String schema = rs.getString("TABLE_SCHEM"); - if (schemas != null && schemas.contains(schema)) { - continue; - } - - ret.add(schema); + if (schemas != null && schemas.contains(schema)) { + continue; } - } finally { - close(rs); - close(stat); - rs = null; - stat = null; + ret.add(schema); } + close(rs); + rs = null; } } } catch (SQLTimeoutException sqlt) { - String msgDesc = "Time Out, Unable to execute SQL [" + sql + "]."; + String msgDesc = "Time Out, Unable to retrieve schema list."; HadoopException hdpException = new HadoopException(msgDesc, sqlt); hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); @@ -397,7 +383,7 @@ private List getSchemas(String needle, List catalogs, List getSchemas(String needle, List catalogs, List getTables(String needle, List catalogs, List ret = new ArrayList<>(); if (con != null) { - Statement stat = null; - ResultSet rs = null; - String sql = null; + ResultSet rs = null; - if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty()) { - try { + try { + validateSqlIdentifier(needle, "table pattern"); + String tablePattern = convertToSqlPattern(needle); + if (catalogs != null && !catalogs.isEmpty() && schemas != null && !schemas.isEmpty()) { for (String catalog : catalogs) { + validateSqlIdentifier(catalog, "catalog name"); for (String schema : schemas) { - sql = "SHOW tables FROM \"" + escapeSql(catalog) + "\".\"" + escapeSql(schema) + "\""; + validateSqlIdentifier(schema, "schema name"); + rs = con.getMetaData().getTables(catalog, schema, tablePattern, new String[] {"TABLE", "VIEW"}); - try { - if (needle != null && !needle.isEmpty() && !needle.equals("*")) { - sql += " LIKE '" + escapeSql(needle) + "%'"; - } - - stat = con.createStatement(); - rs = stat.executeQuery(sql); - - while (rs.next()) { - String table = rs.getString(1); - - if (tables != null && tables.contains(table)) { - continue; - } + while (rs.next()) { + String table = rs.getString("TABLE_NAME"); - ret.add(table); + if (tables != null && tables.contains(table)) { + continue; } - } finally { - close(rs); - close(stat); - rs = null; - stat = null; + ret.add(table); } + close(rs); + rs = null; } } - } catch (SQLTimeoutException sqlt) { - String msgDesc = "Time Out, Unable to execute SQL [" + sql + "]."; - HadoopException hdpException = new HadoopException(msgDesc, sqlt); - - hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); + } + } catch (SQLTimeoutException sqlt) { + String msgDesc = "Time Out, Unable to retrieve table list."; + HadoopException hdpException = new HadoopException(msgDesc, sqlt); - if (LOG.isDebugEnabled()) { - LOG.debug("<== TrinoClient.getTables() Error : ", sqlt); - } + hdpException.generateResponseDataMap(false, getMessage(sqlt), msgDesc + ERR_MSG, null, null); - throw hdpException; - } catch (SQLException sqle) { - String msgDesc = "Unable to execute SQL [" + sql + "]."; - HadoopException hdpException = new HadoopException(msgDesc, sqle); + if (LOG.isDebugEnabled()) { + LOG.debug("<== TrinoClient.getTables() Error : ", sqlt); + } - hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null); + throw hdpException; + } catch (SQLException sqle) { + String msgDesc = "Unable to retrieve table list."; + HadoopException hdpException = new HadoopException(msgDesc, sqle); - if (LOG.isDebugEnabled()) { - LOG.debug("<== TrinoClient.getTables() Error : ", sqle); - } + hdpException.generateResponseDataMap(false, getMessage(sqle), msgDesc + ERR_MSG, null, null); - throw hdpException; + if (LOG.isDebugEnabled()) { + LOG.debug("<== TrinoClient.getTables() Error : ", sqle); } + + throw hdpException; + } catch (HadoopException he) { + throw he; + } finally { + close(rs); } } @@ -490,72 +473,68 @@ private List getColumns(String needle, List catalogs, List()); + } + + private TrinoClient createMockedClient(Map props) throws Exception { + Map config = new HashMap<>(); + config.put("username", "test"); + config.put("password", "test"); + config.put("jdbc.driverClassName", "io.trino.jdbc.TrinoDriver"); + config.put("jdbc.url", "jdbc:trino://localhost:8080"); + config.putAll(props); + + NoConnectionTrinoClient client = new NoConnectionTrinoClient("svc", config); + return client; + } + + @Test + public void test01_getCatalogList_normalOperation() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + DatabaseMetaData metadata = mock(DatabaseMetaData.class); + ResultSet rs = mock(ResultSet.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + when(mockCon.getMetaData()).thenReturn(metadata); + when(metadata.getCatalogs()).thenReturn(rs); + when(rs.next()).thenReturn(true, true, false); + when(rs.getString("TABLE_CAT")).thenReturn("catalog1", "catalog2"); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + TrinoClient client = createMockedClient(); + List out = client.getCatalogList("*", null); + assertEquals(Arrays.asList("catalog1", "catalog2"), out); + } + } + + @Test + public void test02_getCatalogList_withExcludeList() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + DatabaseMetaData metadata = mock(DatabaseMetaData.class); + ResultSet rs = mock(ResultSet.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + when(mockCon.getMetaData()).thenReturn(metadata); + when(metadata.getCatalogs()).thenReturn(rs); + when(rs.next()).thenReturn(true, true, false); + when(rs.getString("TABLE_CAT")).thenReturn("catalog1", "catalog2"); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + TrinoClient client = createMockedClient(); + List out = client.getCatalogList("*", Collections.singletonList("catalog1")); + assertEquals(Collections.singletonList("catalog2"), out); + } + } + + @Test + public void test03_getCatalogList_timeoutThrowsHadoopException() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + DatabaseMetaData metadata = mock(DatabaseMetaData.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + when(mockCon.getMetaData()).thenReturn(metadata); + when(metadata.getCatalogs()).thenThrow(SQLTimeoutException.class); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + TrinoClient client = createMockedClient(); + assertThrows(HadoopException.class, () -> client.getCatalogList("*", null)); + } + } + + @Test + public void test04_validateSqlIdentifier_rejectsSqlInjection() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + TrinoClient client = createMockedClient(); + + List maliciousInputs = Arrays.asList("'; DROP TABLE users; --", "test\" OR 1=1 --", "catalog'; DELETE FROM users; --", "../../../etc/passwd", "test"); + + for (String input : maliciousInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getCatalogList(input, null), + "SQL injection attempt should be rejected: " + input); + assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid input for: " + input); + } + } + } + + @Test + public void test05_validateSqlIdentifier_rejectsSpecialCharacters() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + TrinoClient client = createMockedClient(); + + List invalidInputs = Arrays.asList("test@catalog", "catalog#name", "table!name", "name(with)parens"); + + for (String input : invalidInputs) { + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getCatalogList(input, null), + "Special characters should be rejected: " + input); + assertTrue(ex.getMessage().contains("Invalid"), + "Error should indicate invalid input for: " + input); + } + } + } + + @Test + public void test06_schemaValidation_rejectsSqlInjection() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + TrinoClient client = createMockedClient(); + Field fCon = TrinoClient.class.getDeclaredField("con"); + fCon.setAccessible(true); + fCon.set(client, mockCon); + + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getSchemaList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + } + + @Test + public void test07_tableValidation_rejectsSqlInjection() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + TrinoClient client = createMockedClient(); + + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getTableList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), Collections.singletonList("schema1"), null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + } + + @Test + public void test08_columnValidation_rejectsSqlInjection() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + TrinoClient client = createMockedClient(); + + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getColumnList("'; DROP TABLE users; --", Collections.singletonList("catalog1"), Collections.singletonList("schema1"), Collections.singletonList("table1"), null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + } + + @Test + public void test09_catalogName_validateInList() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + DatabaseMetaData metadata = mock(DatabaseMetaData.class); + ResultSet rs = mock(ResultSet.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + when(mockCon.getMetaData()).thenReturn(metadata); + when(metadata.getSchemas(anyString(), anyString())).thenReturn(rs); + when(rs.next()).thenReturn(false); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + TrinoClient client = createMockedClient(); + Field fCon = TrinoClient.class.getDeclaredField("con"); + fCon.setAccessible(true); + fCon.set(client, mockCon); + + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getSchemaList("schema1", Arrays.asList("catalog1", "'; DROP TABLE --"), null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + } + + @Test + public void test10_schemaName_validateInList() throws Exception { + try (MockedStatic dmStatic = Mockito.mockStatic(DriverManager.class); + MockedStatic subjectStatic = Mockito.mockStatic(Subject.class)) { + Connection mockCon = mock(Connection.class); + DatabaseMetaData metadata = mock(DatabaseMetaData.class); + ResultSet rs = mock(ResultSet.class); + dmStatic.when(() -> DriverManager.getConnection(anyString(), any())).thenReturn(mockCon); + when(mockCon.getMetaData()).thenReturn(metadata); + when(metadata.getTables(anyString(), anyString(), anyString(), any())).thenReturn(rs); + when(rs.next()).thenReturn(false); + subjectStatic.when(() -> Subject.doAs(any(), any(PrivilegedAction.class))) + .thenAnswer(inv -> { + PrivilegedAction action = inv.getArgument(1); + return action.run(); + }); + + TrinoClient client = createMockedClient(); + Field fCon = TrinoClient.class.getDeclaredField("con"); + fCon.setAccessible(true); + fCon.set(client, mockCon); + + HadoopException ex = assertThrows(HadoopException.class, + () -> client.getTableList("table1", Collections.singletonList("catalog1"), Arrays.asList("schema1", "'; DROP --"), null)); + assertTrue(ex.getMessage().contains("Invalid")); + } + } + + private static class NoConnectionTrinoClient extends TrinoClient { + public NoConnectionTrinoClient(String serviceName, Map connectionProperties) { + super(serviceName, connectionProperties); + } + + @Override + protected Subject getLoginSubject() { + Subject subject = new Subject(); + return subject; + } + + @Override + protected void login() { + } + } +}