Skip to content

Commit 1727360

Browse files
Updating tests and implementation to allow 3 parts to SRV string
1 parent f405c06 commit 1727360

3 files changed

Lines changed: 82 additions & 7 deletions

File tree

lib/mongo/srv/result.rb

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,28 @@ def normalize_hostname(host)
110110
# A hostname's domain name consists of each of the '.' delineated
111111
# parts after the first. For example, the hostname 'foo.bar.baz'
112112
# has the domain name 'bar.baz'.
113+
#
114+
# If the hostname has less than three parts, its domain name is the hostname itself.
113115
#
114116
# @param [ String ] record_host The host of the SRV record.
115117
#
116118
# @raise [ Mongo::Error::MismatchedDomain ] If the record's domain name doesn't match that of
117119
# the hostname.
118120
def validate_same_origin!(record_host)
119-
domain_name ||= query_hostname.split('.')[1..-1]
120-
host_parts = record_host.split('.')
121+
srv_is_less_than_three_parts = query_hostname.split('.').length < 3
122+
srv_host_domain = if srv_is_less_than_three_parts
123+
query_hostname.split('.')
124+
else
125+
query_hostname.split('.')[1..-1]
126+
end
127+
record_host_parts = record_host.split('.')
128+
129+
if (srv_is_less_than_three_parts && record_host_parts.length <= srv_host_domain.length)
130+
raise Error::MismatchedDomain.new(MISMATCHED_DOMAINNAME % [record_host, srv_host_domain])
131+
end
121132

122-
unless (host_parts.size > domain_name.size) && (domain_name == host_parts[-domain_name.length..-1])
123-
raise Error::MismatchedDomain.new(MISMATCHED_DOMAINNAME % [record_host, domain_name])
133+
unless (srv_host_domain == record_host_parts[-srv_host_domain.size..-1])
134+
raise Error::MismatchedDomain.new(MISMATCHED_DOMAINNAME % [record_host, srv_host_domain])
124135
end
125136
end
126137
end

lib/mongo/uri/srv_protocol.rb

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,6 @@ def validate_srv_hostname(hostname)
184184
if parts.any?(&:empty?)
185185
raise_invalid_error!("Hostname cannot have consecutive dots: #{hostname}")
186186
end
187-
if parts.length < 1
188-
raise_invalid_error!("Hostname must have a minimum of 1 component (tld): #{hostname}")
189-
end
190187
end
191188

192189
# Obtains the TXT options of a host.

spec/mongo/srv/result_spec.rb

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,73 @@
2323
expect(result.address_strs).to eq(['foo.bar.com:42'])
2424
end
2525
end
26+
27+
exampleSrvName = ['i-love-rb', 'i-love-rb.mongodb', 'i-love-ruby.mongodb.io'];
28+
exampleHostName = [
29+
'rb-00.i-love-rb',
30+
'rb-00.i-love-rb.mongodb',
31+
'i-love-ruby-00.mongodb.io'
32+
];
33+
exampleHostNameThatDoNotMatchParent = [
34+
'rb-00.i-love-rb-a-little',
35+
'rb-00.i-love-rb-a-little.mongodb',
36+
'i-love-ruby-00.evil-mongodb.io'
37+
];
38+
39+
(0..2).each do |i|
40+
context "when srvName has #{i+1} part#{i != 0 ? 's' : ''}" do
41+
let(:srv_name) { exampleSrvName[i] }
42+
let(:host_name) { exampleHostName[i] }
43+
let(:mismatched_host_name) { exampleHostNameThatDoNotMatchParent[i] }
44+
45+
context 'when address does not match parent domain' do
46+
it 'raises MismatchedDomain error' do
47+
record = double('record').tap do |record|
48+
allow(record).to receive(:target).and_return(mismatched_host_name)
49+
allow(record).to receive(:port).and_return(42)
50+
allow(record).to receive(:ttl).and_return(1)
51+
end
52+
53+
expect {
54+
result = described_class.new(srv_name)
55+
result.add_record(record)
56+
}.to raise_error(Mongo::Error::MismatchedDomain)
57+
end
58+
end
59+
60+
context 'when address matches parent domain' do
61+
it 'adds the record' do
62+
record = double('record').tap do |record|
63+
allow(record).to receive(:target).and_return(host_name)
64+
allow(record).to receive(:port).and_return(42)
65+
allow(record).to receive(:ttl).and_return(1)
66+
end
67+
68+
result = described_class.new(srv_name)
69+
result.add_record(record)
70+
71+
expect(result.address_strs).to eq([host_name + ':42'])
72+
end
73+
end
74+
75+
if i < 2
76+
context 'when the address is less than 3 parts' do
77+
it 'does not accept address if it does not contain an extra domain level' do
78+
record = double('record').tap do |record|
79+
allow(record).to receive(:target).and_return(srv_name)
80+
allow(record).to receive(:port).and_return(42)
81+
allow(record).to receive(:ttl).and_return(1)
82+
end
83+
84+
expect {
85+
result = described_class.new(srv_name)
86+
result.add_record(record)
87+
}.to raise_error(Mongo::Error::MismatchedDomain)
88+
end
89+
end
90+
end
91+
end
92+
end
2693
end
2794

2895
describe '#normalize_hostname' do

0 commit comments

Comments
 (0)