custom-hashicorp-vault/snapshot-server/server.py
2025-05-14 10:15:55 -05:00

288 lines
11 KiB
Python

import sys, subprocess, tempfile, os, cgi, json
from http.server import BaseHTTPRequestHandler, HTTPServer
class VaultSnapshotHandler(BaseHTTPRequestHandler):
def create_and_serve_snapshot(self, snapshot_file):
"""Create a snapshot of the vault and return the contents as a stream/string (so that it can be served/downloaded via HTTP)
Args:
snapshot_file (str): The path to the snapshot file to be created
Returns:
bytes: The contents of the snapshot file
"""
# Execute the Vault snapshot command
cmd = ['vault', 'operator', 'raft', 'snapshot', 'save', snapshot_file]
process = subprocess.run(cmd, capture_output=True, text=True)
# Check if the Vault command executed successfully
# If it wasn't successful, return a 500 error with the error message
if process.returncode != 0:
self.send_response(500)
self.send_header('Content-Type', 'text/plain')
self.end_headers()
error_message = f"Error generating snapshot: {process.stderr}"
self.wfile.write(error_message.encode('utf-8'))
# Clean up the temporary file if the command failed
os.remove(snapshot_file)
return
print(f'Snapshot output: %s', process.stdout.strip())
# Read the snapshot file content
with open(snapshot_file, 'rb') as f:
file_data = f.read()
return file_data
def do_GET(self):
if self.path == '/snapshot':
# Create a temporary file to store the snapshot
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
snapshot_file = tmp_file.name
print(f"Temporary file created for snapshot: {snapshot_file}")
try:
snapshot_data = self.create_and_serve_snapshot(snapshot_file)
# Send response headers to indicate a file attachment
self.send_response(200)
self.send_header('Content-Type', 'application/octet-stream')
self.send_header('Content-Disposition', 'attachment; filename="raft_snapshot.snap"')
self.send_header('Content-Length', str(len(snapshot_data)))
self.end_headers()
# Send the file content as the response body
self.wfile.write(snapshot_data)
finally:
# Clean up the temporary file whether the command succeeded or failed
if os.path.exists(snapshot_file):
os.remove(snapshot_file)
else:
# If a path other than `/snapshot`
self.send_error(404, "Not Found")
def upload_snapshot(self, form):
"""Upload the snapshot file
Args:
form (cgi.FieldStorage): The form data containing the snapshot file
Returns:
str: The path to the temporary file containing the snapshot
"""
# Expecting the file to be sent with the name 'file'
if 'file' not in form:
print("Missing 'file' field in the form data.")
self.send_error(400, "Missing 'file' field in the form data.")
return
file_item = form['file']
if not file_item.file:
print("No file uploaded in the 'file' field.")
self.send_error(400, "No file uploaded in the 'file' field.")
return
# Save the uploaded file to a temporary file
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_filename = tmp_file.name
print(f"Temporary file created for restore: {tmp_filename}")
file_item.file.seek(0)
tmp_file.write(file_item.file.read())
return tmp_filename
def upload_unseal_keys(self, form):
"""Upload the unseal keys file
Args:
form (cgi.FieldStorage): The form data containing the unseal keys file
Returns:
str: The path to the temporary file containing the unseal keys
"""
# Expecting a text file to be sent with the name 'unseal_keys'
if 'unseal_keys' not in form:
print("Missing 'unseal_keys' field in the form data.")
self.send_error(400, "Missing 'unseal_keys' field in the form data.")
return
file_item = form['unseal_keys']
if not file_item.file:
print("No file uploaded in the 'unseal_keys' field.")
self.send_error(400, "No file uploaded in the 'unseal_keys' field.")
return
# Save the uploaded file to a temporary file
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_filename = tmp_file.name
print(f"Temporary unseal_keys file created for restore: {tmp_filename}")
file_item.file.seek(0)
tmp_file.write(file_item.file.read())
return tmp_filename
def restore_backup(self, snapshot_file) -> bool:
"""Restore the vault from a snapshot
Args:
snapshot_file (str): The path to the snapshot file to be restored
Returns:
bool: True if the restore was successful, False otherwise
"""
print('+--------------------------+')
print('| Restoring Vault Snapshot |')
print('+--------------------------+')
print(f"Restoring snapshot from: {snapshot_file}")
# Run the Vault restore command
cmd = ['vault', 'operator', 'raft', 'snapshot', 'restore', '--force', snapshot_file]
process = subprocess.run(cmd, capture_output=True, text=True)
# If the restore fails return a 500 error with the error message
if process.returncode != 0:
# Setup error response headers
self.send_response(500)
self.send_header('Content-Type', 'text/plain')
self.end_headers()
# Create the error message
error_message = f"Error restoring snapshot: {process.stderr}"
# Print the error to the console (or equivalent)
print(error_message.strip())
# Send the error message as part of the response
self.wfile.write(error_message.encode('utf-8'))
# Indicate to the caller that the restore failed
return False
print('Snapshot restored successfully')
return True
def is_vault_unsealed(self) -> bool:
"""Check if the vault is sealed or not
Returns:
bool: If the vault is unsealed or not
"""
# Get the status of the vault
# Note, because it returns a non-zero exit code when the vault is sealed, we set check to False
# Which is also why we need to check the return code manually
process = subprocess.run(['vault', 'status', '-format=json'], capture_output=True, text=True)
# Verify the return code is either 0 (unsealed) or 2 (sealed)
if process.returncode != 0 and process.returncode != 2:
raise RuntimeError('Failed to get the status of the vault')
# Print the raw status
print(process.stdout.strip())
# Parse the seal stat from the status
seal_status = json.loads(process.stdout.strip())['sealed']
print(f'Is Sealed: {seal_status}')
return seal_status
def unseal_vault(self, unseal_keys):
"""Unseal a restored vault"""
print('+--------------------------+')
print('| Unsealing restored Vault |')
print('+--------------------------+')
# Use each key to unseal the vault
for key in unseal_keys:
process = subprocess.run(['vault', 'operator', 'unseal', key], capture_output=True, text=True)
print(process.stdout.strip())
# If the vault is now unsealed break/escape from the loop
if not self.is_vault_unsealed():
print('Vault is unsealed')
break
def do_POST(self):
"""Handle POST requests to restore a snapshot"""
if self.path == '/restore':
content_type = self.headers.get('Content-Type')
ctype, pdict = cgi.parse_header(content_type)
# Verify the request is formatted properly (`multipart/form-data``)
if ctype != 'multipart/form-data':
print(f"Invalid Content-Type: {content_type}")
self.send_error(400, 'Content-Type must be multipart/form-data')
return
# cgi.FieldStorage requires the boundary to be bytes
pdict['boundary'] = pdict['boundary'].encode('utf-8')
pdict['CONTENT-LENGTH'] = int(self.headers.get('Content-Length', 0))
form = cgi.FieldStorage(
fp=self.rfile,
headers=self.headers,
environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': content_type}
)
# Upload/Save the files that were included in the request
snapshot_file = self.upload_snapshot(form)
unseal_keys_file = self.upload_unseal_keys(form)
try:
# Restore the snapshot using the uploaded file
self.restore_backup(snapshot_file)
# Load the unseal keys from the uploaded file
# And unseal the restored vault
with open(unseal_keys_file, 'r') as f:
unseal_keys = f.read().splitlines()
self.unseal_vault(unseal_keys)
# If successful, return a JSON response with a success message
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(b'{"status": "success", "message": "Snapshot restored successfully"}')
finally:
# Remove the temporary files regardless of success or failure
if os.path.exists(snapshot_file):
os.remove(snapshot_file)
if os.path.exists(unseal_keys_file):
os.remove(unseal_keys_file)
else:
# If a path other than `/restore` is requested
self.send_error(404, "Endpoint not found.")
# Optionally override logging to avoid default console messages
def log_message(self, format, *args):
pass
def run_server(port=8300):
server_address = ('', port)
httpd = HTTPServer(server_address, VaultSnapshotHandler)
print(f"Starting HTTP server on port {port}...")
try:
httpd.serve_forever()
except KeyboardInterrupt:
print("\nServer interrupted by user, shutting down.")
finally:
httpd.server_close()
print("Server stopped.")
if __name__ == '__main__':
print('+--------------------------+')
print('| Starting snapshot server |')
print('+--------------------------+')
run_server(sys.argv[1] if len(sys.argv) > 1 else 8300)