import socket import random import math import csv import torch import torch.nn as nn global i_h1_w, h1_h2_w, h2_o_w, h1_b, h2_b, outp, o_b, inp, h1, h2 class SimpleMLP(nn.Module): def __init__(self): super(SimpleMLP, self).__init__() self.layer1 = nn.Linear(6, 32) self.layer2 = nn.Linear(32, 16) self.output = nn.Linear(16, 1) def sigmoid(x): return 1 / (1 + math.exp(-x)) def relu(x): if x>0 : return x else : return 0.001*x def calculate_hidden1(inputs, hidden1): for j in range(hidden1): a = 0 for i in range(inputs): a = a + inp[i] * i_h1_w[j][i] #i, j a = relu(a + h1_b[j]) h1.append(a) def calculate_hidden2(hidden1, hidden2): for j in range(hidden2): a = 0 for i in range(hidden1): a = a + h1[i] * h1_h2_w[j][i] #i,j a = relu(a + h2_b[j]) h2.append(a) def calculate_output(hidden2): a = 0 for i in range(hidden2): a = a + h2[i] * h2_o_w[i] return sigmoid(a + o_b) #MAIN PROGRAM i_h1_w = [] h1_h2_w = [] h2_o_w = [] h1_b = [] h2_b = [] outp = 0 o_b = 0 model = SimpleMLP() model.load_state_dict(torch.load("road_danger_mlp_weights.pth")) model_dict = model.state_dict() i_h1_w = model_dict['layer1.weight'].tolist() h1_b = model_dict['layer1.bias'].tolist() h1_h2_w = model_dict['layer2.weight'].tolist() h2_b = model_dict['layer2.bias'].tolist() h2_o_w = model_dict['output.weight'].tolist()[0] o_b = float(model_dict['output.bias'].tolist()[0]) print("model and weights loaded.") HOST = '127.0.0.1' PORT = 5006 server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.bind((HOST, PORT)) server.listen(1) print("Python server listening...") conn, addr = server.accept() print("Connected by", addr) while True: data = conn.recv(1024) if not data: break message = data.decode('utf-8') print("Received from Processing:", message) #nn - start inp = [] inputs = message.split(',') for i in range(6): inp.append(float(inputs[i])) h1 = [] calculate_hidden1(6, 32) h2 = [] calculate_hidden2(32, 16) outp = calculate_output(16) #nn - stop # Example response (you can run torch here) response = f"Echo: {outp}" conn.sendall(response.encode('utf-8')) conn.close()