netviz/app/utils/visualization.py
2025-03-25 23:41:13 +01:00

90 lines
No EOL
3 KiB
Python

"""
Network visualization utilities for the NetViz application.
"""
import networkx as nx
import matplotlib.pyplot as plt
import io
import base64
from typing import Dict, Any
from app.models.network import Network
def generate_network_diagram(network: Network) -> str:
"""
Generate a network diagram for visualization.
Args:
network: The Network object to visualize
Returns:
Base64 encoded PNG image of the network diagram
"""
# Create a directed graph
G = nx.DiGraph()
# Add nodes for subnets
for subnet in network.subnets:
G.add_node(f"subnet-{subnet.id}",
label=f"{subnet.name}\n{subnet.cidr}",
type="subnet")
# Add nodes for devices
for device in network.devices:
G.add_node(f"device-{device.id}",
label=f"{device.name}\n{device.ip_address or ''}",
type="device")
# Connect devices to their subnets
if device.subnet_id:
G.add_edge(f"device-{device.id}", f"subnet-{device.subnet_id}")
# Add firewall rules as edges
for rule in network.firewall_rules:
# For simplicity, we're assuming source and destination are device IPs
# In a real implementation, you'd need to resolve these to actual devices
source_devices = [d for d in network.devices if d.ip_address == rule.source]
dest_devices = [d for d in network.devices if d.ip_address == rule.destination]
for src in source_devices:
for dst in dest_devices:
G.add_edge(f"device-{src.id}", f"device-{dst.id}",
label=f"{rule.protocol}/{rule.port_range}\n{rule.action}",
color="green" if rule.action == "allow" else "red")
# Set node colors based on type
node_colors = []
for node in G.nodes():
node_type = G.nodes[node].get("type")
if node_type == "subnet":
node_colors.append("skyblue")
elif node_type == "device":
node_colors.append("lightgreen")
else:
node_colors.append("lightgray")
# Create the plot
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G)
# Draw nodes
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
# Draw edges
edge_colors = [G.edges[edge].get("color", "black") for edge in G.edges()]
nx.draw_networkx_edges(G, pos, edge_color=edge_colors, arrowstyle='->', arrowsize=15)
# Draw labels
node_labels = {node: G.nodes[node].get("label", node) for node in G.nodes()}
nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10)
# Save the plot to a bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
plt.close()
# Encode the image as base64 for embedding in HTML
buf.seek(0)
image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
return f"data:image/png;base64,{image_base64}"