import sys, subprocess, tempfile, os, cgi, json from http.server import BaseHTTPRequestHandler, HTTPServer class VaultSnapshotHandler(BaseHTTPRequestHandler): def create_and_serve_snapshot(self, snapshot_file): """Create a snapshot of the vault and return the contents as a stream/string (so that it can be served/downloaded via HTTP) Args: snapshot_file (str): The path to the snapshot file to be created Returns: bytes: The contents of the snapshot file """ # Execute the Vault snapshot command cmd = ['vault', 'operator', 'raft', 'snapshot', 'save', snapshot_file] process = subprocess.run(cmd, capture_output=True, text=True) # Check if the Vault command executed successfully # If it wasn't successful, return a 500 error with the error message if process.returncode != 0: self.send_response(500) self.send_header('Content-Type', 'text/plain') self.end_headers() error_message = f"Error generating snapshot: {process.stderr}" self.wfile.write(error_message.encode('utf-8')) # Clean up the temporary file if the command failed os.remove(snapshot_file) return # Read the snapshot file content with open(snapshot_file, 'rb') as f: file_data = f.read() return file_data def do_GET(self): if self.path == '/snapshot': # Create a temporary file to store the snapshot with tempfile.NamedTemporaryFile(delete=False) as tmp_file: snapshot_file = tmp_file.name print(f"Temporary file created for snapshot: {snapshot_file}") try: snapshot_data = self.create_and_serve_snapshot(snapshot_file) # Send response headers to indicate a file attachment self.send_response(200) self.send_header('Content-Type', 'application/octet-stream') self.send_header('Content-Disposition', 'attachment; filename="raft_snapshot.snap"') self.send_header('Content-Length', str(len(snapshot_data))) self.end_headers() # Send the file content as the response body self.wfile.write(snapshot_data) finally: # Clean up the temporary file whether the command succeeded or failed if os.path.exists(snapshot_file): os.remove(snapshot_file) else: # If a path other than `/snapshot` self.send_error(404, "Not Found") def upload_snapshot(self, form): """Upload the snapshot file Args: form (cgi.FieldStorage): The form data containing the snapshot file Returns: str: The path to the temporary file containing the snapshot """ # Expecting the file to be sent with the name 'file' if 'file' not in form: print("Missing 'file' field in the form data.") self.send_error(400, "Missing 'file' field in the form data.") return file_item = form['file'] if not file_item.file: print("No file uploaded in the 'file' field.") self.send_error(400, "No file uploaded in the 'file' field.") return # Save the uploaded file to a temporary file with tempfile.NamedTemporaryFile(delete=False) as tmp_file: tmp_filename = tmp_file.name print(f"Temporary file created for restore: {tmp_filename}") file_item.file.seek(0) tmp_file.write(file_item.file.read()) return tmp_filename def upload_unseal_keys(self, form): """Upload the unseal keys file Args: form (cgi.FieldStorage): The form data containing the unseal keys file Returns: str: The path to the temporary file containing the unseal keys """ # Expecting a text file to be sent with the name 'unseal_keys' if 'unseal_keys' not in form: print("Missing 'unseal_keys' field in the form data.") self.send_error(400, "Missing 'unseal_keys' field in the form data.") return file_item = form['unseal_keys'] if not file_item.file: print("No file uploaded in the 'unseal_keys' field.") self.send_error(400, "No file uploaded in the 'unseal_keys' field.") return # Save the uploaded file to a temporary file with tempfile.NamedTemporaryFile(delete=False) as tmp_file: tmp_filename = tmp_file.name print(f"Temporary unseal_keys file created for restore: {tmp_filename}") file_item.file.seek(0) tmp_file.write(file_item.file.read()) return tmp_filename def restore_backup(self, snapshot_file) -> bool: """Restore the vault from a snapshot Args: snapshot_file (str): The path to the snapshot file to be restored Returns: bool: True if the restore was successful, False otherwise """ print('+--------------------------+') print('| Restoring Vault Snapshot |') print('+--------------------------+') print(f"Restoring snapshot from: {snapshot_file}") # Run the Vault restore command cmd = ['vault', 'operator', 'raft', 'snapshot', 'restore', '--force', snapshot_file] process = subprocess.run(cmd, capture_output=True, text=True) # If the restore fails return a 500 error with the error message if process.returncode != 0: # Setup error response headers self.send_response(500) self.send_header('Content-Type', 'text/plain') self.end_headers() # Create the error message error_message = f"Error restoring snapshot: {process.stderr}" # Print the error to the console (or equivalent) print(error_message.strip()) # Send the error message as part of the response self.wfile.write(error_message.encode('utf-8')) # Indicate to the caller that the restore failed return False print('Snapshot restored successfully') return True def is_vault_unsealed(self) -> bool: """Check if the vault is sealed or not Returns: bool: If the vault is unsealed or not """ # Get the status of the vault # Note, because it returns a non-zero exit code when the vault is sealed, we set check to False # Which is also why we need to check the return code manually process = subprocess.run(['vault', 'status', '-format=json'], capture_output=True, text=True) # Verify the return code is either 0 (unsealed) or 2 (sealed) if process.returncode != 0 and process.returncode != 2: raise RuntimeError('Failed to get the status of the vault') # Print the raw status print(process.stdout.strip()) # Parse the seal stat from the status seal_status = json.loads(process.stdout.strip())['sealed'] print(f'Is Sealed: {seal_status}') return seal_status def unseal_vault(self, unseal_keys): """Unseal a restored vault""" print('+--------------------------+') print('| Unsealing restored Vault |') print('+--------------------------+') # Use each key to unseal the vault for key in unseal_keys: process = subprocess.run(['vault', 'operator', 'unseal', key], capture_output=True, text=True) print(process.stdout.strip()) # If the vault is now unsealed break/escape from the loop if not self.is_vault_unsealed(): print('Vault is unsealed') break def do_POST(self): """Handle POST requests to restore a snapshot""" if self.path == '/restore': content_type = self.headers.get('Content-Type') ctype, pdict = cgi.parse_header(content_type) # Verify the request is formatted properly (`multipart/form-data``) if ctype != 'multipart/form-data': print(f"Invalid Content-Type: {content_type}") self.send_error(400, 'Content-Type must be multipart/form-data') return # cgi.FieldStorage requires the boundary to be bytes pdict['boundary'] = pdict['boundary'].encode('utf-8') pdict['CONTENT-LENGTH'] = int(self.headers.get('Content-Length', 0)) form = cgi.FieldStorage( fp=self.rfile, headers=self.headers, environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': content_type} ) # Upload/Save the files that were included in the request snapshot_file = self.upload_snapshot(form) unseal_keys_file = self.upload_unseal_keys(form) try: # Restore the snapshot using the uploaded file self.restore_backup(snapshot_file) # Load the unseal keys from the uploaded file # And unseal the restored vault with open(unseal_keys_file, 'r') as f: unseal_keys = f.read().splitlines() self.unseal_vault(unseal_keys) # If successful, return a JSON response with a success message self.send_response(200) self.send_header('Content-Type', 'application/json') self.end_headers() self.wfile.write(b'{"status": "success", "message": "Snapshot restored successfully"}') finally: # Remove the temporary files regardless of success or failure if os.path.exists(snapshot_file): os.remove(snapshot_file) if os.path.exists(unseal_keys_file): os.remove(unseal_keys_file) else: # If a path other than `/restore` is requested self.send_error(404, "Endpoint not found.") # Optionally override logging to avoid default console messages def log_message(self, format, *args): pass def run_server(port=8300): server_address = ('', port) httpd = HTTPServer(server_address, VaultSnapshotHandler) print(f"Starting HTTP server on port {port}...") try: httpd.serve_forever() except KeyboardInterrupt: print("\nServer interrupted by user, shutting down.") finally: httpd.server_close() print("Server stopped.") if __name__ == '__main__': print('+--------------------------+') print('| Starting snapshot server |') print('+--------------------------+') run_server(sys.argv[1] if len(sys.argv) > 1 else 8300)