using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.ML; using Microsoft.ML.Data; namespace Ropin.IOT.MLService { public class Program { // 定义文件路径 private static string _dataPath = "./accident_data.csv"; private static string _modelPath = "./AccidentPredictionModel.zip"; public static void Main(string[] args) { Console.OutputEncoding = System.Text.Encoding.UTF8; Console.WriteLine("工厂事故预测系统启动..."); // 创建ML.NET上下文 MLContext mlContext = new MLContext(seed: 0); // 加载数据 Console.WriteLine("正在加载数据..."); IDataView dataView = mlContext.Data.LoadFromTextFile( path: _dataPath, hasHeader: true, separatorChar: ',', allowQuoting: true ); // 分割训练集和测试集 var dataSplit = mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2); var trainingData = dataSplit.TrainSet; var testData = dataSplit.TestSet; // 数据处理和特征工程 Console.WriteLine("正在处理数据和提取特征..."); var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey( outputColumnName: "Label", // 这将创建一个新列作为Key类型标签 inputColumnName: "AccidentLevel" // 原始字符串标签 ) .Append(mlContext.Transforms.Categorical.OneHotEncoding( new[] { new InputOutputColumnPair("CountryEncoded", "Country"), new InputOutputColumnPair("IndustrySectorEncoded", "IndustrySector"), new InputOutputColumnPair("CriticalRiskEncoded", "CriticalRisk"), new InputOutputColumnPair("GenreEncoded", "Genre"), new InputOutputColumnPair("EmployeeTypeEncoded", "EmployeeOrThirdParty") })) // 文本特征提取(从描述中提取特征) .Append(mlContext.Transforms.Text.FeaturizeText("DescriptionFeaturized", "Description")) // 合并所有特征到一个向量 .Append(mlContext.Transforms.Concatenate("Features", "CountryEncoded", "IndustrySectorEncoded", "CriticalRiskEncoded", "GenreEncoded", "EmployeeTypeEncoded", "DescriptionFeaturized")); // 选择训练算法 - 使用多类分类器 Console.WriteLine("构建和训练模型..."); var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(); // 构建完整训练管道 var trainingPipeline = dataProcessPipeline.Append(trainer) // 将预测结果映射回原始标签值 .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); // 训练模型 var trainedModel = trainingPipeline.Fit(trainingData); Console.WriteLine("模型训练完成!"); // 保存模型 mlContext.Model.Save(trainedModel, trainingData.Schema, _modelPath); Console.WriteLine($"模型已保存到: {_modelPath}"); // 评估模型 Console.WriteLine("评估模型性能..."); var predictions = trainedModel.Transform(testData); var metrics = mlContext.MulticlassClassification.Evaluate(predictions); // 输出评估指标 Console.WriteLine($"宏平均精确度: {metrics.MacroAccuracy:F2}"); Console.WriteLine($"微平均精确度: {metrics.MicroAccuracy:F2}"); Console.WriteLine($"对数损失: {metrics.LogLoss:F2}"); Console.WriteLine($"混淆矩阵: \n{metrics.ConfusionMatrix.GetFormattedConfusionTable()}"); // 使用模型进行预测示例 PredictSample(mlContext, trainedModel); CreateHostBuilder(args).Build().Run(); } public static IHostBuilder CreateHostBuilder(string[] args) => Host.CreateDefaultBuilder(args) .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); }); private static void PredictSample(MLContext mlContext, ITransformer model) { // 创建预测引擎 var predictionEngine = mlContext.Model.CreatePredictionEngine(model); // 创建测试样本 var sampleAccident = new AccidentData { Country = "Country_01", IndustrySector = "Mining", Genre = "Male", EmployeeOrThirdParty = "Third Party", CriticalRisk = "Pressed", Description = "Worker operating drilling equipment without proper safety procedures, potential for hand injury." }; // 进行预测 var prediction = predictionEngine.Predict(sampleAccident); Console.WriteLine("\n预测示例:"); Console.WriteLine($"行业: {sampleAccident.IndustrySector}"); Console.WriteLine($"风险类型: {sampleAccident.CriticalRisk}"); Console.WriteLine($"事故描述: {sampleAccident.Description}"); Console.WriteLine($"预测事故等级: {prediction.PredictedAccidentLevel}"); // 输出各类别的概率 Console.WriteLine("各等级的概率分布:"); var labels = new[] { "I", "II", "III", "IV" }; for (int i = 0; i < prediction.Score.Length; i++) { if (i < prediction.Score.Length) Console.WriteLine($"等级 {i}: {prediction.Score[i]:P2}"); } Console.ReadKey(); } } public class AccidentData { [LoadColumn(0)] public float SeqNo { get; set; } [LoadColumn(1)] public string Date { get; set; } [LoadColumn(2)] public string Country { get; set; } [LoadColumn(3)] public string Local { get; set; } [LoadColumn(4)] public string IndustrySector { get; set; } [LoadColumn(5)] public string AccidentLevel { get; set; } // 原始字符串标签 [LoadColumn(6)] public string PotentialAccidentLevel { get; set; } [LoadColumn(7)] public string Genre { get; set; } [LoadColumn(8)] public string EmployeeOrThirdParty { get; set; } [LoadColumn(9)] public string CriticalRisk { get; set; } [LoadColumn(10)] public string Description { get; set; } } public class AccidentPrediction { [ColumnName("PredictedLabel")] public string PredictedAccidentLevel { get; set; } public float[] Score { get; set; } } }