90 lines
No EOL
3 KiB
Python
90 lines
No EOL
3 KiB
Python
"""
|
|
Network visualization utilities for the NetViz application.
|
|
"""
|
|
import networkx as nx
|
|
import matplotlib.pyplot as plt
|
|
import io
|
|
import base64
|
|
from typing import Dict, Any
|
|
|
|
from app.models.network import Network
|
|
|
|
|
|
def generate_network_diagram(network: Network) -> str:
|
|
"""
|
|
Generate a network diagram for visualization.
|
|
|
|
Args:
|
|
network: The Network object to visualize
|
|
|
|
Returns:
|
|
Base64 encoded PNG image of the network diagram
|
|
"""
|
|
# Create a directed graph
|
|
G = nx.DiGraph()
|
|
|
|
# Add nodes for subnets
|
|
for subnet in network.subnets:
|
|
G.add_node(f"subnet-{subnet.id}",
|
|
label=f"{subnet.name}\n{subnet.cidr}",
|
|
type="subnet")
|
|
|
|
# Add nodes for devices
|
|
for device in network.devices:
|
|
G.add_node(f"device-{device.id}",
|
|
label=f"{device.name}\n{device.ip_address or ''}",
|
|
type="device")
|
|
|
|
# Connect devices to their subnets
|
|
if device.subnet_id:
|
|
G.add_edge(f"device-{device.id}", f"subnet-{device.subnet_id}")
|
|
|
|
# Add firewall rules as edges
|
|
for rule in network.firewall_rules:
|
|
# For simplicity, we're assuming source and destination are device IPs
|
|
# In a real implementation, you'd need to resolve these to actual devices
|
|
source_devices = [d for d in network.devices if d.ip_address == rule.source]
|
|
dest_devices = [d for d in network.devices if d.ip_address == rule.destination]
|
|
|
|
for src in source_devices:
|
|
for dst in dest_devices:
|
|
G.add_edge(f"device-{src.id}", f"device-{dst.id}",
|
|
label=f"{rule.protocol}/{rule.port_range}\n{rule.action}",
|
|
color="green" if rule.action == "allow" else "red")
|
|
|
|
# Set node colors based on type
|
|
node_colors = []
|
|
for node in G.nodes():
|
|
node_type = G.nodes[node].get("type")
|
|
if node_type == "subnet":
|
|
node_colors.append("skyblue")
|
|
elif node_type == "device":
|
|
node_colors.append("lightgreen")
|
|
else:
|
|
node_colors.append("lightgray")
|
|
|
|
# Create the plot
|
|
plt.figure(figsize=(12, 8))
|
|
pos = nx.spring_layout(G)
|
|
|
|
# Draw nodes
|
|
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
|
|
|
|
# Draw edges
|
|
edge_colors = [G.edges[edge].get("color", "black") for edge in G.edges()]
|
|
nx.draw_networkx_edges(G, pos, edge_color=edge_colors, arrowstyle='->', arrowsize=15)
|
|
|
|
# Draw labels
|
|
node_labels = {node: G.nodes[node].get("label", node) for node in G.nodes()}
|
|
nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10)
|
|
|
|
# Save the plot to a bytes buffer
|
|
buf = io.BytesIO()
|
|
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
# Encode the image as base64 for embedding in HTML
|
|
buf.seek(0)
|
|
image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
|
|
|
return f"data:image/png;base64,{image_base64}" |