287 lines
11 KiB
Python
287 lines
11 KiB
Python
import sys, subprocess, tempfile, os, cgi, json
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
|
|
|
|
class VaultSnapshotHandler(BaseHTTPRequestHandler):
|
|
def create_and_serve_snapshot(self, snapshot_file):
|
|
"""Create a snapshot of the vault and return the contents as a stream/string (so that it can be served/downloaded via HTTP)
|
|
|
|
Args:
|
|
snapshot_file (str): The path to the snapshot file to be created
|
|
|
|
Returns:
|
|
bytes: The contents of the snapshot file
|
|
"""
|
|
# Execute the Vault snapshot command
|
|
cmd = ['vault', 'operator', 'raft', 'snapshot', 'save', snapshot_file]
|
|
process = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
# Check if the Vault command executed successfully
|
|
# If it wasn't successful, return a 500 error with the error message
|
|
if process.returncode != 0:
|
|
self.send_response(500)
|
|
self.send_header('Content-Type', 'text/plain')
|
|
self.end_headers()
|
|
error_message = f"Error generating snapshot: {process.stderr}"
|
|
self.wfile.write(error_message.encode('utf-8'))
|
|
|
|
# Clean up the temporary file if the command failed
|
|
os.remove(snapshot_file)
|
|
|
|
return
|
|
|
|
# Read the snapshot file content
|
|
with open(snapshot_file, 'rb') as f:
|
|
file_data = f.read()
|
|
|
|
return file_data
|
|
|
|
def do_GET(self):
|
|
if self.path == '/snapshot':
|
|
# Create a temporary file to store the snapshot
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
|
snapshot_file = tmp_file.name
|
|
print(f"Temporary file created for snapshot: {snapshot_file}")
|
|
|
|
try:
|
|
snapshot_data = self.create_and_serve_snapshot(snapshot_file)
|
|
|
|
# Send response headers to indicate a file attachment
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/octet-stream')
|
|
self.send_header('Content-Disposition', 'attachment; filename="raft_snapshot.snap"')
|
|
self.send_header('Content-Length', str(len(snapshot_data)))
|
|
self.end_headers()
|
|
|
|
# Send the file content as the response body
|
|
self.wfile.write(snapshot_data)
|
|
finally:
|
|
# Clean up the temporary file whether the command succeeded or failed
|
|
if os.path.exists(snapshot_file):
|
|
os.remove(snapshot_file)
|
|
else:
|
|
# If a path other than `/snapshot`
|
|
self.send_error(404, "Not Found")
|
|
|
|
def upload_snapshot(self, form):
|
|
"""Upload the snapshot file
|
|
|
|
Args:
|
|
form (cgi.FieldStorage): The form data containing the snapshot file
|
|
|
|
Returns:
|
|
str: The path to the temporary file containing the snapshot
|
|
"""
|
|
|
|
# Expecting the file to be sent with the name 'file'
|
|
if 'file' not in form:
|
|
print("Missing 'file' field in the form data.")
|
|
self.send_error(400, "Missing 'file' field in the form data.")
|
|
return
|
|
|
|
file_item = form['file']
|
|
if not file_item.file:
|
|
print("No file uploaded in the 'file' field.")
|
|
self.send_error(400, "No file uploaded in the 'file' field.")
|
|
return
|
|
|
|
# Save the uploaded file to a temporary file
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
|
tmp_filename = tmp_file.name
|
|
print(f"Temporary file created for restore: {tmp_filename}")
|
|
file_item.file.seek(0)
|
|
tmp_file.write(file_item.file.read())
|
|
return tmp_filename
|
|
|
|
def upload_unseal_keys(self, form):
|
|
"""Upload the unseal keys file
|
|
|
|
Args:
|
|
form (cgi.FieldStorage): The form data containing the unseal keys file
|
|
|
|
Returns:
|
|
str: The path to the temporary file containing the unseal keys
|
|
"""
|
|
|
|
# Expecting a text file to be sent with the name 'unseal_keys'
|
|
if 'unseal_keys' not in form:
|
|
print("Missing 'unseal_keys' field in the form data.")
|
|
self.send_error(400, "Missing 'unseal_keys' field in the form data.")
|
|
return
|
|
|
|
file_item = form['unseal_keys']
|
|
if not file_item.file:
|
|
print("No file uploaded in the 'unseal_keys' field.")
|
|
self.send_error(400, "No file uploaded in the 'unseal_keys' field.")
|
|
return
|
|
|
|
# Save the uploaded file to a temporary file
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
|
tmp_filename = tmp_file.name
|
|
print(f"Temporary unseal_keys file created for restore: {tmp_filename}")
|
|
file_item.file.seek(0)
|
|
tmp_file.write(file_item.file.read())
|
|
return tmp_filename
|
|
|
|
def restore_backup(self, snapshot_file) -> bool:
|
|
"""Restore the vault from a snapshot
|
|
|
|
Args:
|
|
snapshot_file (str): The path to the snapshot file to be restored
|
|
|
|
Returns:
|
|
bool: True if the restore was successful, False otherwise
|
|
"""
|
|
|
|
print('+--------------------------+')
|
|
print('| Restoring Vault Snapshot |')
|
|
print('+--------------------------+')
|
|
|
|
print(f"Restoring snapshot from: {snapshot_file}")
|
|
|
|
# Run the Vault restore command
|
|
cmd = ['vault', 'operator', 'raft', 'snapshot', 'restore', '--force', snapshot_file]
|
|
process = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
# If the restore fails return a 500 error with the error message
|
|
if process.returncode != 0:
|
|
# Setup error response headers
|
|
self.send_response(500)
|
|
self.send_header('Content-Type', 'text/plain')
|
|
self.end_headers()
|
|
|
|
# Create the error message
|
|
error_message = f"Error restoring snapshot: {process.stderr}"
|
|
|
|
# Print the error to the console (or equivalent)
|
|
print(error_message.strip())
|
|
|
|
# Send the error message as part of the response
|
|
self.wfile.write(error_message.encode('utf-8'))
|
|
|
|
# Indicate to the caller that the restore failed
|
|
return False
|
|
|
|
print('Snapshot restored successfully')
|
|
|
|
return True
|
|
|
|
def is_vault_unsealed(self) -> bool:
|
|
"""Check if the vault is sealed or not
|
|
|
|
Returns:
|
|
bool: If the vault is unsealed or not
|
|
"""
|
|
|
|
# Get the status of the vault
|
|
# Note, because it returns a non-zero exit code when the vault is sealed, we set check to False
|
|
# Which is also why we need to check the return code manually
|
|
process = subprocess.run(['vault', 'status', '-format=json'], capture_output=True, text=True)
|
|
|
|
# Verify the return code is either 0 (unsealed) or 2 (sealed)
|
|
if process.returncode != 0 and process.returncode != 2:
|
|
raise RuntimeError('Failed to get the status of the vault')
|
|
|
|
# Print the raw status
|
|
print(process.stdout.strip())
|
|
|
|
# Parse the seal stat from the status
|
|
seal_status = json.loads(process.stdout.strip())['sealed']
|
|
|
|
print(f'Is Sealed: {seal_status}')
|
|
|
|
return seal_status
|
|
|
|
def unseal_vault(self, unseal_keys):
|
|
"""Unseal a restored vault"""
|
|
|
|
print('+--------------------------+')
|
|
print('| Unsealing restored Vault |')
|
|
print('+--------------------------+')
|
|
|
|
# Use each key to unseal the vault
|
|
for key in unseal_keys:
|
|
process = subprocess.run(['vault', 'operator', 'unseal', key], capture_output=True, text=True)
|
|
|
|
print(process.stdout.strip())
|
|
|
|
# If the vault is now unsealed break/escape from the loop
|
|
if not self.is_vault_unsealed():
|
|
print('Vault is unsealed')
|
|
break
|
|
|
|
def do_POST(self):
|
|
"""Handle POST requests to restore a snapshot"""
|
|
|
|
if self.path == '/restore':
|
|
content_type = self.headers.get('Content-Type')
|
|
ctype, pdict = cgi.parse_header(content_type)
|
|
|
|
# Verify the request is formatted properly (`multipart/form-data``)
|
|
if ctype != 'multipart/form-data':
|
|
print(f"Invalid Content-Type: {content_type}")
|
|
self.send_error(400, 'Content-Type must be multipart/form-data')
|
|
return
|
|
|
|
# cgi.FieldStorage requires the boundary to be bytes
|
|
pdict['boundary'] = pdict['boundary'].encode('utf-8')
|
|
pdict['CONTENT-LENGTH'] = int(self.headers.get('Content-Length', 0))
|
|
|
|
form = cgi.FieldStorage(
|
|
fp=self.rfile,
|
|
headers=self.headers,
|
|
environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': content_type}
|
|
)
|
|
|
|
# Upload/Save the files that were included in the request
|
|
snapshot_file = self.upload_snapshot(form)
|
|
unseal_keys_file = self.upload_unseal_keys(form)
|
|
|
|
try:
|
|
# Restore the snapshot using the uploaded file
|
|
self.restore_backup(snapshot_file)
|
|
|
|
# Load the unseal keys from the uploaded file
|
|
# And unseal the restored vault
|
|
with open(unseal_keys_file, 'r') as f:
|
|
unseal_keys = f.read().splitlines()
|
|
self.unseal_vault(unseal_keys)
|
|
|
|
# If successful, return a JSON response with a success message
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
self.wfile.write(b'{"status": "success", "message": "Snapshot restored successfully"}')
|
|
finally:
|
|
# Remove the temporary files regardless of success or failure
|
|
if os.path.exists(snapshot_file):
|
|
os.remove(snapshot_file)
|
|
|
|
if os.path.exists(unseal_keys_file):
|
|
os.remove(unseal_keys_file)
|
|
else:
|
|
# If a path other than `/restore` is requested
|
|
self.send_error(404, "Endpoint not found.")
|
|
|
|
# Optionally override logging to avoid default console messages
|
|
def log_message(self, format, *args):
|
|
pass
|
|
|
|
def run_server(port=8300):
|
|
server_address = ('', port)
|
|
httpd = HTTPServer(server_address, VaultSnapshotHandler)
|
|
print(f"Starting HTTP server on port {port}...")
|
|
try:
|
|
httpd.serve_forever()
|
|
except KeyboardInterrupt:
|
|
print("\nServer interrupted by user, shutting down.")
|
|
finally:
|
|
httpd.server_close()
|
|
print("Server stopped.")
|
|
|
|
if __name__ == '__main__':
|
|
print('+--------------------------+')
|
|
print('| Starting snapshot server |')
|
|
print('+--------------------------+')
|
|
|
|
run_server(sys.argv[1] if len(sys.argv) > 1 else 8300)
|