This commit is contained in:
pika 2025-03-25 23:41:13 +01:00
commit 66f9ce3614
33 changed files with 2271 additions and 0 deletions

49
app/utils/commands.py Normal file
View file

@ -0,0 +1,49 @@
"""
CLI commands for the NetViz application.
"""
import click
from flask.cli import with_appcontext
from werkzeug.security import generate_password_hash
from app.extensions import db
from app.models.user import User
def register_commands(app):
"""
Register CLI commands with the Flask application.
Args:
app: The Flask application
"""
@app.cli.command("create-admin")
@click.argument("username")
@click.argument("email")
@click.password_option()
@with_appcontext
def create_admin(username, email, password):
"""Create an admin user."""
user = User.query.filter_by(username=username).first()
if user:
click.echo(f"User {username} already exists.")
return
user = User(
username=username,
email=email,
is_admin=True
)
user.set_password(password)
db.session.add(user)
db.session.commit()
click.echo(f"Admin user {username} created successfully.")
@app.cli.command("init-db")
@with_appcontext
def init_db():
"""Initialize the database."""
db.create_all()
click.echo("Database initialized.")

56
app/utils/email.py Normal file
View file

@ -0,0 +1,56 @@
"""
Email utilities for the NetViz application.
"""
from flask import current_app, render_template
from flask_mail import Message
from threading import Thread
from app.extensions import mail
from app.models.user import User
def send_async_email(app, msg):
"""Send email asynchronously."""
with app.app_context():
mail.send(msg)
def send_email(subject, sender, recipients, text_body, html_body):
"""
Send an email.
Args:
subject: Email subject
sender: Sender email address
recipients: List of recipient email addresses
text_body: Plain text email body
html_body: HTML email body
"""
msg = Message(subject, sender=sender, recipients=recipients)
msg.body = text_body
msg.html = html_body
# Send email asynchronously to not block the request
Thread(
target=send_async_email,
args=(current_app._get_current_object(), msg)
).start()
def send_password_reset_email(user: User):
"""
Send a password reset email to a user.
Args:
user: The user requesting password reset
"""
token = user.generate_reset_token()
reset_url = f"{current_app.config['SERVER_NAME']}/auth/reset-password/{token}"
send_email(
subject="[NetViz] Reset Your Password",
sender=current_app.config['MAIL_DEFAULT_SENDER'],
recipients=[user.email],
text_body=render_template("email/reset_password.txt", user=user, reset_url=reset_url),
html_body=render_template("email/reset_password.html", user=user, reset_url=reset_url)
)

View file

@ -0,0 +1,32 @@
"""
Error handlers for the NetViz application.
"""
import traceback
from flask import render_template, current_app
def register_error_handlers(app):
"""
Register error handlers for the application.
Args:
app: The Flask application
"""
@app.errorhandler(403)
def forbidden_error(error):
return render_template('errors/403.html'), 403
@app.errorhandler(404)
def not_found_error(error):
return render_template('errors/404.html'), 404
@app.errorhandler(500)
def internal_error(error):
# Log the error
current_app.logger.error(f"Server Error: {error}")
current_app.logger.error(traceback.format_exc())
return render_template('errors/500.html'), 500
@app.errorhandler(429)
def ratelimit_error(error):
return render_template('errors/429.html'), 429

71
app/utils/security.py Normal file
View file

@ -0,0 +1,71 @@
"""
Security utilities for the NetViz application.
"""
from typing import Dict, Any
import secrets
import string
def get_secure_headers() -> Dict[str, Any]:
"""
Get secure headers configuration for Flask-Talisman.
Returns:
Dict with security header configuration
"""
return {
'content_security_policy': {
'default-src': "'self'",
'img-src': "'self' data:",
'style-src': "'self' 'unsafe-inline'", # Needed for Tailwind
'script-src': "'self' 'unsafe-inline'", # Needed for HTMX
'font-src': "'self'"
},
'force_https': False, # Set to True in production
'strict_transport_security': True,
'strict_transport_security_max_age': 31536000,
'strict_transport_security_include_subdomains': True,
'referrer_policy': 'strict-origin-when-cross-origin',
'frame_options': 'DENY',
'session_cookie_secure': False, # Set to True in production
'session_cookie_http_only': True
}
def generate_password() -> str:
"""
Generate a secure random password.
Returns:
A secure random password string
"""
alphabet = string.ascii_letters + string.digits + string.punctuation
password = ''.join(secrets.choice(alphabet) for _ in range(16))
return password
def sanitize_input(input_string: str) -> str:
"""
Sanitize user input to prevent XSS attacks.
Args:
input_string: The input string to sanitize
Returns:
Sanitized string
"""
# Replace problematic characters with HTML entities
replacements = {
'<': '&lt;',
'>': '&gt;',
'"': '&quot;',
"'": '&#x27;',
'/': '&#x2F;',
'\\': '&#x5C;',
'\n': '<br>',
}
for char, replacement in replacements.items():
input_string = input_string.replace(char, replacement)
return input_string

View file

@ -0,0 +1,90 @@
"""
Network visualization utilities for the NetViz application.
"""
import networkx as nx
import matplotlib.pyplot as plt
import io
import base64
from typing import Dict, Any
from app.models.network import Network
def generate_network_diagram(network: Network) -> str:
"""
Generate a network diagram for visualization.
Args:
network: The Network object to visualize
Returns:
Base64 encoded PNG image of the network diagram
"""
# Create a directed graph
G = nx.DiGraph()
# Add nodes for subnets
for subnet in network.subnets:
G.add_node(f"subnet-{subnet.id}",
label=f"{subnet.name}\n{subnet.cidr}",
type="subnet")
# Add nodes for devices
for device in network.devices:
G.add_node(f"device-{device.id}",
label=f"{device.name}\n{device.ip_address or ''}",
type="device")
# Connect devices to their subnets
if device.subnet_id:
G.add_edge(f"device-{device.id}", f"subnet-{device.subnet_id}")
# Add firewall rules as edges
for rule in network.firewall_rules:
# For simplicity, we're assuming source and destination are device IPs
# In a real implementation, you'd need to resolve these to actual devices
source_devices = [d for d in network.devices if d.ip_address == rule.source]
dest_devices = [d for d in network.devices if d.ip_address == rule.destination]
for src in source_devices:
for dst in dest_devices:
G.add_edge(f"device-{src.id}", f"device-{dst.id}",
label=f"{rule.protocol}/{rule.port_range}\n{rule.action}",
color="green" if rule.action == "allow" else "red")
# Set node colors based on type
node_colors = []
for node in G.nodes():
node_type = G.nodes[node].get("type")
if node_type == "subnet":
node_colors.append("skyblue")
elif node_type == "device":
node_colors.append("lightgreen")
else:
node_colors.append("lightgray")
# Create the plot
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G)
# Draw nodes
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
# Draw edges
edge_colors = [G.edges[edge].get("color", "black") for edge in G.edges()]
nx.draw_networkx_edges(G, pos, edge_color=edge_colors, arrowstyle='->', arrowsize=15)
# Draw labels
node_labels = {node: G.nodes[node].get("label", node) for node in G.nodes()}
nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10)
# Save the plot to a bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
plt.close()
# Encode the image as base64 for embedding in HTML
buf.seek(0)
image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
return f"data:image/png;base64,{image_base64}"