123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- using Microsoft.AspNetCore.Hosting;
- using Microsoft.Extensions.Configuration;
- using Microsoft.Extensions.Hosting;
- using Microsoft.Extensions.Logging;
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Threading.Tasks;
- using Microsoft.ML;
- using Microsoft.ML.Data;
- namespace Ropin.IOT.MLService
- {
- public class Program
- {
- // 定义文件路径
- private static string _dataPath = "./accident_data.csv";
- private static string _modelPath = "./AccidentPredictionModel.zip";
- public static void Main(string[] args)
- {
-
- Console.OutputEncoding = System.Text.Encoding.UTF8;
- Console.WriteLine("工厂事故预测系统启动...");
- // 创建ML.NET上下文
- MLContext mlContext = new MLContext(seed: 0);
- // 加载数据
- Console.WriteLine("正在加载数据...");
- IDataView dataView = mlContext.Data.LoadFromTextFile<AccidentData>(
- path: _dataPath,
- hasHeader: true,
- separatorChar: ',',
- allowQuoting: true
- );
- // 分割训练集和测试集
- var dataSplit = mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2);
- var trainingData = dataSplit.TrainSet;
- var testData = dataSplit.TestSet;
- // 数据处理和特征工程
- Console.WriteLine("正在处理数据和提取特征...");
- var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey(
- outputColumnName: "Label", // 这将创建一个新列作为Key类型标签
- inputColumnName: "AccidentLevel" // 原始字符串标签
- )
- .Append(mlContext.Transforms.Categorical.OneHotEncoding(
- new[] {
- new InputOutputColumnPair("CountryEncoded", "Country"),
- new InputOutputColumnPair("IndustrySectorEncoded", "IndustrySector"),
- new InputOutputColumnPair("CriticalRiskEncoded", "CriticalRisk"),
- new InputOutputColumnPair("GenreEncoded", "Genre"),
- new InputOutputColumnPair("EmployeeTypeEncoded", "EmployeeOrThirdParty")
- }))
- // 文本特征提取(从描述中提取特征)
- .Append(mlContext.Transforms.Text.FeaturizeText("DescriptionFeaturized", "Description"))
- // 合并所有特征到一个向量
- .Append(mlContext.Transforms.Concatenate("Features",
- "CountryEncoded", "IndustrySectorEncoded", "CriticalRiskEncoded",
- "GenreEncoded", "EmployeeTypeEncoded", "DescriptionFeaturized"));
- // 选择训练算法 - 使用多类分类器
- Console.WriteLine("构建和训练模型...");
- var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy();
- // 构建完整训练管道
- var trainingPipeline = dataProcessPipeline.Append(trainer)
- // 将预测结果映射回原始标签值
- .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
- // 训练模型
- var trainedModel = trainingPipeline.Fit(trainingData);
- Console.WriteLine("模型训练完成!");
- // 保存模型
- mlContext.Model.Save(trainedModel, trainingData.Schema, _modelPath);
- Console.WriteLine($"模型已保存到: {_modelPath}");
- // 评估模型
- Console.WriteLine("评估模型性能...");
- var predictions = trainedModel.Transform(testData);
- var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
- // 输出评估指标
- Console.WriteLine($"宏平均精确度: {metrics.MacroAccuracy:F2}");
- Console.WriteLine($"微平均精确度: {metrics.MicroAccuracy:F2}");
- Console.WriteLine($"对数损失: {metrics.LogLoss:F2}");
- Console.WriteLine($"混淆矩阵: \n{metrics.ConfusionMatrix.GetFormattedConfusionTable()}");
- // 使用模型进行预测示例
- PredictSample(mlContext, trainedModel);
- CreateHostBuilder(args).Build().Run();
- }
- public static IHostBuilder CreateHostBuilder(string[] args) =>
- Host.CreateDefaultBuilder(args)
- .ConfigureWebHostDefaults(webBuilder =>
- {
- webBuilder.UseStartup<Startup>();
- });
- private static void PredictSample(MLContext mlContext, ITransformer model)
- {
- // 创建预测引擎
- var predictionEngine = mlContext.Model.CreatePredictionEngine<AccidentData, AccidentPrediction>(model);
- // 创建测试样本
- var sampleAccident = new AccidentData
- {
- Country = "Country_01",
- IndustrySector = "Mining",
- Genre = "Male",
- EmployeeOrThirdParty = "Third Party",
- CriticalRisk = "Pressed",
- Description = "Worker operating drilling equipment without proper safety procedures, potential for hand injury."
- };
- // 进行预测
- var prediction = predictionEngine.Predict(sampleAccident);
- Console.WriteLine("\n预测示例:");
- Console.WriteLine($"行业: {sampleAccident.IndustrySector}");
- Console.WriteLine($"风险类型: {sampleAccident.CriticalRisk}");
- Console.WriteLine($"事故描述: {sampleAccident.Description}");
- Console.WriteLine($"预测事故等级: {prediction.PredictedAccidentLevel}");
- // 输出各类别的概率
- Console.WriteLine("各等级的概率分布:");
- var labels = new[] { "I", "II", "III", "IV" };
- for (int i = 0; i < prediction.Score.Length; i++)
- {
- if (i < prediction.Score.Length)
- Console.WriteLine($"等级 {i}: {prediction.Score[i]:P2}");
- }
- Console.ReadKey();
- }
- }
- public class AccidentData
- {
- [LoadColumn(0)]
- public float SeqNo { get; set; }
- [LoadColumn(1)]
- public string Date { get; set; }
- [LoadColumn(2)]
- public string Country { get; set; }
- [LoadColumn(3)]
- public string Local { get; set; }
- [LoadColumn(4)]
- public string IndustrySector { get; set; }
- [LoadColumn(5)]
- public string AccidentLevel { get; set; } // 原始字符串标签
- [LoadColumn(6)]
- public string PotentialAccidentLevel { get; set; }
- [LoadColumn(7)]
- public string Genre { get; set; }
- [LoadColumn(8)]
- public string EmployeeOrThirdParty { get; set; }
- [LoadColumn(9)]
- public string CriticalRisk { get; set; }
- [LoadColumn(10)]
- public string Description { get; set; }
- }
- public class AccidentPrediction
- {
- [ColumnName("PredictedLabel")]
- public string PredictedAccidentLevel { get; set; }
- public float[] Score { get; set; }
- }
- }
|