Browse code

Catch sshpass authentication errors and don't retry multiple times to prevent account lockout (#50776)

* Catch SSH authentication errors and don't retry multiple times to prevent account lock out

Signed-off-by: Sam Doran <sdoran@redhat.com>

* Subclass AnsibleAuthenticationFailure from AnsibleConnectionFailure

Use comparison rather than range() because it's much more efficient.

Signed-off-by: Sam Doran <sdoran@redhat.com>

* Add tests

Signed-off-by: Sam Doran <sdoran@redhat.com>

* Make paramiko_ssh connection plugin behave the same way

Signed-off-by: Sam Doran <sdoran@redhat.com>

* Add changelog

Signed-off-by: Sam Doran <sdoran@redhat.com>
(cherry picked from commit 9d4c0dc1116f0bbd01bac6580a61dc28e314eec4)

Sam Doran authored on 2019/01/24 01:32:25
Showing 5 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,2 @@
0
+bugfixes:
1
+  - ssh connection - do not retry with invalid credentials to prevent account lockout (https://github.com/ansible/ansible/issues/48422)
... ...
@@ -209,6 +209,11 @@ class AnsibleConnectionFailure(AnsibleRuntimeError):
209 209
     pass
210 210
 
211 211
 
212
+class AnsibleAuthenticationFailure(AnsibleConnectionFailure):
213
+    '''invalid username/password/key'''
214
+    pass
215
+
216
+
212 217
 class AnsibleFilterError(AnsibleRuntimeError):
213 218
     ''' a templating failure '''
214 219
     pass
... ...
@@ -140,12 +140,17 @@ from termios import tcflush, TCIFLUSH
140 140
 from binascii import hexlify
141 141
 
142 142
 from ansible import constants as C
143
-from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
143
+from ansible.errors import (
144
+    AnsibleAuthenticationFailure,
145
+    AnsibleConnectionFailure,
146
+    AnsibleError,
147
+    AnsibleFileNotFound,
148
+)
144 149
 from ansible.module_utils.six import iteritems
145 150
 from ansible.module_utils.six.moves import input
146 151
 from ansible.plugins.connection import ConnectionBase
147 152
 from ansible.utils.path import makedirs_safe
148
-from ansible.module_utils._text import to_bytes, to_native
153
+from ansible.module_utils._text import to_bytes, to_native, to_text
149 154
 
150 155
 try:
151 156
     from __main__ import display
... ...
@@ -353,6 +358,9 @@ class Connection(ConnectionBase):
353 353
             )
354 354
         except paramiko.ssh_exception.BadHostKeyException as e:
355 355
             raise AnsibleConnectionFailure('host key mismatch for %s' % e.hostname)
356
+        except paramiko.ssh_exception.AuthenticationException as e:
357
+            msg = 'Invalid/incorrect username/password. {0}'.format(to_text(e))
358
+            raise AnsibleAuthenticationFailure(msg)
356 359
         except Exception as e:
357 360
             msg = str(e)
358 361
             if "PID check failed" in msg:
... ...
@@ -223,7 +223,12 @@ import time
223 223
 
224 224
 from functools import wraps
225 225
 from ansible import constants as C
226
-from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
226
+from ansible.errors import (
227
+    AnsibleAuthenticationFailure,
228
+    AnsibleConnectionFailure,
229
+    AnsibleError,
230
+    AnsibleFileNotFound,
231
+)
227 232
 from ansible.errors import AnsibleOptionsError
228 233
 from ansible.compat import selectors
229 234
 from ansible.module_utils.six import PY3, text_type, binary_type
... ...
@@ -240,6 +245,11 @@ except ImportError:
240 240
     display = Display()
241 241
 
242 242
 
243
+b_NOT_SSH_ERRORS = (b'Traceback (most recent call last):',  # Python-2.6 when there's an exception
244
+                                                            # while invoking a script via -m
245
+                    b'PHP Parse error:',  # Php always returns error 255
246
+                    )
247
+
243 248
 SSHPASS_AVAILABLE = None
244 249
 
245 250
 
... ...
@@ -248,6 +258,55 @@ class AnsibleControlPersistBrokenPipeError(AnsibleError):
248 248
     pass
249 249
 
250 250
 
251
+def _handle_error(remaining_retries, command, return_tuple, no_log, host, display=display):
252
+
253
+    # sshpass errors
254
+    if command == b'sshpass':
255
+        # Error 5 is invalid/incorrect password. Raise an exception to prevent retries from locking the account.
256
+        if return_tuple[0] == 5:
257
+            msg = 'Invalid/incorrect username/password. Skipping remaining {0} retries to prevent account lockout:'.format(remaining_retries)
258
+            if remaining_retries <= 0:
259
+                msg = 'Invalid/incorrect password:'
260
+            if no_log:
261
+                msg = '{0} <error censored due to no log>'.format(msg)
262
+            else:
263
+                msg = '{0} {1}'.format(msg, to_native(return_tuple[2].rstrip()))
264
+            raise AnsibleAuthenticationFailure(msg)
265
+
266
+        # sshpass returns codes are 1-6. We handle 5 previously, so this catches other scenarios.
267
+        # No exception is raised, so the connection is retried.
268
+        elif return_tuple[0] in [1, 2, 3, 4, 6]:
269
+            msg = 'sshpass error:'
270
+            if no_log:
271
+                msg = '{0} <error censored due to no log>'.format(msg)
272
+            else:
273
+                msg = '{0} {1}'.format(msg, to_native(return_tuple[2].rstrip()))
274
+
275
+    if return_tuple[0] == 255:
276
+        SSH_ERROR = True
277
+        for signature in b_NOT_SSH_ERRORS:
278
+            if signature in return_tuple[1]:
279
+                SSH_ERROR = False
280
+                break
281
+
282
+        if SSH_ERROR:
283
+            msg = "Failed to connect to the host via ssh:"
284
+            if no_log:
285
+                msg = '{0} <error censored due to no log>'.format(msg)
286
+            else:
287
+                msg = '{0} {1}'.format(msg, to_native(return_tuple[2]).rstrip())
288
+            raise AnsibleConnectionFailure(msg)
289
+
290
+    # For other errors, no execption is raised so the connection is retried and we only log the messages
291
+    if 1 <= return_tuple[0] <= 254:
292
+        msg = "Failed to connect to the host via ssh:"
293
+        if no_log:
294
+            msg = '{0} <error censored due to no log>'.format(msg)
295
+        else:
296
+            msg = '{0} {1}'.format(msg, to_native(return_tuple[2]).rstrip())
297
+        display.vvv(msg, host=host)
298
+
299
+
251 300
 def _ssh_retry(func):
252 301
     """
253 302
     Decorator to retry ssh/scp/sftp in the case of a connection failure
... ...
@@ -256,7 +315,8 @@ def _ssh_retry(func):
256 256
     * an exception is caught
257 257
     * ssh returns 255
258 258
     Will not retry if
259
-    * remaining_tries is <2
259
+    * sshpass returns 5 (invalid password, to prevent account lockouts)
260
+    * remaining_tries is < 2
260 261
     * retries limit reached
261 262
     """
262 263
     @wraps(func)
... ...
@@ -274,7 +334,7 @@ def _ssh_retry(func):
274 274
                 try:
275 275
                     return_tuple = func(self, *args, **kwargs)
276 276
                     if self._play_context.no_log:
277
-                        display.vvv('rc=%s, stdout & stderr censored due to no log' % return_tuple[0], host=self.host)
277
+                        display.vvv('rc=%s, stdout and stderr censored due to no log' % return_tuple[0], host=self.host)
278 278
                     else:
279 279
                         display.vvv(return_tuple, host=self.host)
280 280
                     # 0 = success
... ...
@@ -289,18 +349,19 @@ def _ssh_retry(func):
289 289
                         cmd[1] = b'-d' + to_bytes(self.sshpass_pipe[0], nonstring='simplerepr', errors='surrogate_or_strict')
290 290
                     display.vvv(u"RETRYING BECAUSE OF CONTROLPERSIST BROKEN PIPE")
291 291
                     return_tuple = func(self, *args, **kwargs)
292
+                remaining_retries = remaining_tries - attempt - 1
292 293
 
293
-                if return_tuple[0] != 255:
294
-                    break
295
-                else:
296
-                    msg = "Failed to connect to the host via ssh: "
297
-                    if self._play_context.no_log:
298
-                        msg += '<error censored due to no log>'
299
-                    else:
300
-                        msg += to_native(return_tuple[2])
301
-                    raise AnsibleConnectionFailure(msg)
294
+                _handle_error(remaining_retries, cmd[0], return_tuple, self._play_context.no_log, self.host)
295
+
296
+                break
297
+
298
+            # 5 = Invalid/incorrect password from sshpass
299
+            except AnsibleAuthenticationFailure as e:
300
+                # Raising this exception, which is subclassed from AnsibleConnectionFailure, prevents further retries
301
+                raise
302 302
 
303 303
             except (AnsibleConnectionFailure, Exception) as e:
304
+
304 305
                 if attempt == remaining_tries - 1:
305 306
                     raise
306 307
                 else:
... ...
@@ -309,9 +370,9 @@ def _ssh_retry(func):
309 309
                         pause = 30
310 310
 
311 311
                     if isinstance(e, AnsibleConnectionFailure):
312
-                        msg = "ssh_retry: attempt: %d, ssh return code is 255. cmd (%s), pausing for %d seconds" % (attempt, cmd_summary, pause)
312
+                        msg = "ssh_retry: attempt: %d, ssh return code is 255. cmd (%s), pausing for %d seconds" % (attempt + 1, cmd_summary, pause)
313 313
                     else:
314
-                        msg = "ssh_retry: attempt: %d, caught exception(%s) from cmd (%s), pausing for %d seconds" % (attempt, e, cmd_summary, pause)
314
+                        msg = "ssh_retry: attempt: %d, caught exception(%s) from cmd (%s), pausing for %d seconds" % (attempt + 1, e, cmd_summary, pause)
315 315
 
316 316
                     display.vv(msg, host=self.host)
317 317
 
... ...
@@ -25,6 +25,7 @@ import pytest
25 25
 
26 26
 
27 27
 from ansible import constants as C
28
+from ansible.errors import AnsibleAuthenticationFailure
28 29
 from ansible.compat.selectors import SelectorKey, EVENT_READ
29 30
 from ansible.compat.tests import unittest
30 31
 from ansible.compat.tests.mock import patch, MagicMock, PropertyMock
... ...
@@ -501,6 +502,33 @@ class TestSSHConnectionRun(object):
501 501
 
502 502
 @pytest.mark.usefixtures('mock_run_env')
503 503
 class TestSSHConnectionRetries(object):
504
+    def test_incorrect_password(self, monkeypatch):
505
+        monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False)
506
+        monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 5)
507
+        monkeypatch.setattr('time.sleep', lambda x: None)
508
+
509
+        self.mock_popen_res.stdout.read.side_effect = [b'']
510
+        self.mock_popen_res.stderr.read.side_effect = [b'Permission denied, please try again.\r\n']
511
+        type(self.mock_popen_res).returncode = PropertyMock(side_effect=[5] * 4)
512
+
513
+        self.mock_selector.select.side_effect = [
514
+            [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
515
+            [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
516
+            [],
517
+        ]
518
+
519
+        self.mock_selector.get_map.side_effect = lambda: True
520
+
521
+        self.conn._build_command = MagicMock()
522
+        self.conn._build_command.return_value = [b'sshpass', b'-d41', b'ssh', b'-C']
523
+        self.conn.get_option = MagicMock()
524
+        self.conn.get_option.return_value = True
525
+
526
+        exception_info = pytest.raises(AnsibleAuthenticationFailure, self.conn.exec_command, 'sshpass', 'some data')
527
+        assert exception_info.value.message == ('Invalid/incorrect username/password. Skipping remaining 5 retries to prevent account lockout: '
528
+                                                'Permission denied, please try again.')
529
+        assert self.mock_popen.call_count == 1
530
+
504 531
     def test_retry_then_success(self, monkeypatch):
505 532
         monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False)
506 533
         monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 3)