Spaces:
Sleeping
Sleeping
| import time | |
| import subprocess | |
| import pyshark | |
| from selenium import webdriver | |
| from selenium.webdriver.chrome.service import Service | |
| from webdriver_manager.chrome import ChromeDriverManager | |
| from selenium.webdriver.chrome.options import Options | |
| import numpy as np | |
| import joblib | |
| import pandas as pd | |
| import scapy.all as scapy | |
| import requests | |
| import gradio as gr | |
| # Load the pre-trained model and feature names | |
| model = joblib.load('extratrees.pkl') | |
| all_features = joblib.load('featurenames.pkl') | |
| # Modify the capture duration to a longer period | |
| def capture_packets(url, capture_duration=30, capture_file="capture.pcap"): | |
| try: | |
| # Start tshark to capture packets | |
| tshark_process = subprocess.Popen( | |
| ["tshark", "-i", "any", "-f", "tcp port 80 or tcp port 443 or port 53", "-w", capture_file], | |
| stdout=subprocess.PIPE, stderr=subprocess.PIPE | |
| ) | |
| # Wait for tshark to start | |
| time.sleep(2) | |
| # Set up Chrome options | |
| chrome_options = Options() | |
| chrome_options.add_argument("--headless") # Run Chrome in headless mode | |
| chrome_options.add_argument("--no-sandbox") | |
| chrome_options.add_argument("--disable-dev-shm-usage") | |
| # Use Selenium to visit the URL | |
| service = Service(ChromeDriverManager().install()) # Ensure the driver is installed | |
| driver = webdriver.Chrome(service=service, options=chrome_options) | |
| driver.get(url) | |
| # Capture packets for the specified duration | |
| time.sleep(capture_duration) | |
| # Close the browser | |
| driver.quit() | |
| # Stop tshark | |
| tshark_process.terminate() | |
| tshark_process.wait() | |
| # Read captured packets using pyshark for detailed packet information | |
| packets = [] | |
| cap = pyshark.FileCapture(capture_file) | |
| for packet in cap: | |
| packets.append(str(packet)) | |
| cap.close() | |
| return packets | |
| except Exception as e: | |
| print(f"Error in capturing packets: {e}") | |
| return None | |
| # Function to extract features from captured packets | |
| def extract_features(capture_file): | |
| try: | |
| cap = pyshark.FileCapture(capture_file) | |
| # Initialize features | |
| features = {feature: 0 for feature in all_features} | |
| total_packets = 0 | |
| total_bytes = 0 | |
| start_time = None | |
| end_time = None | |
| packet_lengths = [] | |
| protocol_counts = {'TCP': 0, 'UDP': 0, 'ICMP': 0} | |
| tcp_flags = {'SYN': 0, 'ACK': 0, 'FIN': 0, 'RST': 0} | |
| for packet in cap: | |
| total_packets += 1 | |
| total_bytes += int(packet.length) | |
| packet_lengths.append(int(packet.length)) | |
| timestamp = float(packet.sniff_time.timestamp()) | |
| if start_time is None: | |
| start_time = timestamp | |
| end_time = timestamp | |
| # Counting protocols and flags | |
| if hasattr(packet, 'tcp'): | |
| protocol_counts['TCP'] += 1 | |
| if 'SYN' in packet.tcp.flags: | |
| tcp_flags['SYN'] += 1 | |
| if 'ACK' in packet.tcp.flags: | |
| tcp_flags['ACK'] += 1 | |
| if 'FIN' in packet.tcp.flags: | |
| tcp_flags['FIN'] += 1 | |
| if 'RST' in packet.tcp.flags: | |
| tcp_flags['RST'] += 1 | |
| elif hasattr(packet, 'udp'): | |
| protocol_counts['UDP'] += 1 | |
| elif hasattr(packet, 'icmp'): | |
| protocol_counts['ICMP'] += 1 | |
| duration = end_time - start_time if start_time and end_time else 0 | |
| # Populate extracted features | |
| features.update({ | |
| "Flow Duration": duration, | |
| "Total Packets": total_packets, | |
| "Total Bytes": total_bytes, | |
| "Fwd Packet Length Mean": np.mean(packet_lengths) if packet_lengths else 0, | |
| "Bwd Packet Length Mean": 0, # Assuming no distinction here | |
| "Flow Bytes/s": total_bytes / duration if duration else 0, | |
| "Flow Packets/s": total_packets / duration if duration else 0, | |
| "Average Packet Size": np.mean(packet_lengths) if packet_lengths else 0, | |
| "Min Packet Size": min(packet_lengths) if packet_lengths else 0, | |
| "Max Packet Size": max(packet_lengths) if packet_lengths else 0, | |
| "Packet Length Variance": np.var(packet_lengths) if len(packet_lengths) > 1 else 0, | |
| "TCP Packets": protocol_counts['TCP'], | |
| "UDP Packets": protocol_counts['UDP'], | |
| "ICMP Packets": protocol_counts['ICMP'], | |
| "TCP SYN Flags": tcp_flags['SYN'], | |
| "TCP ACK Flags": tcp_flags['ACK'], | |
| "TCP FIN Flags": tcp_flags['FIN'], | |
| "TCP RST Flags": tcp_flags['RST'] | |
| }) | |
| return features | |
| except Exception as e: | |
| print(f"Error in extracting features: {e}") | |
| return None | |
| # Function to compare features with CIC-IDS-2017 dataset | |
| def compare_with_dataset(packet_features): | |
| # Convert the extracted features into a format that the model can use | |
| packet_features_series = pd.Series(packet_features) | |
| packet_features_series = packet_features_series.reindex(all_features, fill_value=0) | |
| # Predict using the loaded model | |
| prediction = model.predict([packet_features_series])[0] | |
| return "benign" if prediction == 0 else "malicious" | |
| # Analyze the URL and predict if it's malicious | |
| def analyze_url(url): | |
| try: | |
| # Capture packets using Scapy (updating to capture more specific traffic) | |
| response = requests.get(url) | |
| packets = scapy.sniff(count=100) # Capture packets with Scapy | |
| capture_file = 'capture.pcap' | |
| scapy.wrpcap(capture_file, packets) | |
| # Extract features from the captured packets | |
| packet_features = extract_features(capture_file) | |
| if packet_features is not None: | |
| prediction = compare_with_dataset(packet_features) | |
| # Use Pyshark to capture HTTP/HTTPS/DNS packet details | |
| http_dns_packets = capture_packets(url) | |
| captured_packets = [str(packet) for packet in packets] | |
| return prediction, {"scapy_packets": captured_packets, "http_dns_packets": http_dns_packets} | |
| else: | |
| return "Error in feature extraction", [] | |
| except Exception as e: | |
| return str(e), [] | |
| # Define the Gradio interface | |
| iface = gr.Interface( | |
| fn=analyze_url, | |
| inputs=gr.Textbox(label="Enter URL"), | |
| outputs=[gr.Textbox(label="Prediction"), gr.JSON(label="Captured Packets")], | |
| title="URL Malicious Activity Detection", | |
| description="Enter a URL to predict if it's malicious or benign by analyzing the network traffic." | |
| ) | |
| # Launch the interface | |
| iface.launch(debug=True) |