Program.cs 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. using Microsoft.AspNetCore.Hosting;
  2. using Microsoft.Extensions.Configuration;
  3. using Microsoft.Extensions.Hosting;
  4. using Microsoft.Extensions.Logging;
  5. using System;
  6. using System.Collections.Generic;
  7. using System.Linq;
  8. using System.Threading.Tasks;
  9. using Microsoft.ML;
  10. using Microsoft.ML.Data;
  11. namespace Ropin.IOT.MLService
  12. {
  13. public class Program
  14. {
  15. // 定义文件路径
  16. private static string _dataPath = "./accident_data.csv";
  17. private static string _modelPath = "./AccidentPredictionModel.zip";
  18. public static void Main(string[] args)
  19. {
  20. Console.OutputEncoding = System.Text.Encoding.UTF8;
  21. Console.WriteLine("工厂事故预测系统启动...");
  22. // 创建ML.NET上下文
  23. MLContext mlContext = new MLContext(seed: 0);
  24. // 加载数据
  25. Console.WriteLine("正在加载数据...");
  26. IDataView dataView = mlContext.Data.LoadFromTextFile<AccidentData>(
  27. path: _dataPath,
  28. hasHeader: true,
  29. separatorChar: ',',
  30. allowQuoting: true
  31. );
  32. // 分割训练集和测试集
  33. var dataSplit = mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2);
  34. var trainingData = dataSplit.TrainSet;
  35. var testData = dataSplit.TestSet;
  36. // 数据处理和特征工程
  37. Console.WriteLine("正在处理数据和提取特征...");
  38. var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey(
  39. outputColumnName: "Label", // 这将创建一个新列作为Key类型标签
  40. inputColumnName: "AccidentLevel" // 原始字符串标签
  41. )
  42. .Append(mlContext.Transforms.Categorical.OneHotEncoding(
  43. new[] {
  44. new InputOutputColumnPair("CountryEncoded", "Country"),
  45. new InputOutputColumnPair("IndustrySectorEncoded", "IndustrySector"),
  46. new InputOutputColumnPair("CriticalRiskEncoded", "CriticalRisk"),
  47. new InputOutputColumnPair("GenreEncoded", "Genre"),
  48. new InputOutputColumnPair("EmployeeTypeEncoded", "EmployeeOrThirdParty")
  49. }))
  50. // 文本特征提取(从描述中提取特征)
  51. .Append(mlContext.Transforms.Text.FeaturizeText("DescriptionFeaturized", "Description"))
  52. // 合并所有特征到一个向量
  53. .Append(mlContext.Transforms.Concatenate("Features",
  54. "CountryEncoded", "IndustrySectorEncoded", "CriticalRiskEncoded",
  55. "GenreEncoded", "EmployeeTypeEncoded", "DescriptionFeaturized"));
  56. // 选择训练算法 - 使用多类分类器
  57. Console.WriteLine("构建和训练模型...");
  58. var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy();
  59. // 构建完整训练管道
  60. var trainingPipeline = dataProcessPipeline.Append(trainer)
  61. // 将预测结果映射回原始标签值
  62. .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
  63. // 训练模型
  64. var trainedModel = trainingPipeline.Fit(trainingData);
  65. Console.WriteLine("模型训练完成!");
  66. // 保存模型
  67. mlContext.Model.Save(trainedModel, trainingData.Schema, _modelPath);
  68. Console.WriteLine($"模型已保存到: {_modelPath}");
  69. // 评估模型
  70. Console.WriteLine("评估模型性能...");
  71. var predictions = trainedModel.Transform(testData);
  72. var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
  73. // 输出评估指标
  74. Console.WriteLine($"宏平均精确度: {metrics.MacroAccuracy:F2}");
  75. Console.WriteLine($"微平均精确度: {metrics.MicroAccuracy:F2}");
  76. Console.WriteLine($"对数损失: {metrics.LogLoss:F2}");
  77. Console.WriteLine($"混淆矩阵: \n{metrics.ConfusionMatrix.GetFormattedConfusionTable()}");
  78. // 使用模型进行预测示例
  79. PredictSample(mlContext, trainedModel);
  80. CreateHostBuilder(args).Build().Run();
  81. }
  82. public static IHostBuilder CreateHostBuilder(string[] args) =>
  83. Host.CreateDefaultBuilder(args)
  84. .ConfigureWebHostDefaults(webBuilder =>
  85. {
  86. webBuilder.UseStartup<Startup>();
  87. });
  88. private static void PredictSample(MLContext mlContext, ITransformer model)
  89. {
  90. // 创建预测引擎
  91. var predictionEngine = mlContext.Model.CreatePredictionEngine<AccidentData, AccidentPrediction>(model);
  92. // 创建测试样本
  93. var sampleAccident = new AccidentData
  94. {
  95. Country = "Country_01",
  96. IndustrySector = "Mining",
  97. Genre = "Male",
  98. EmployeeOrThirdParty = "Third Party",
  99. CriticalRisk = "Pressed",
  100. Description = "Worker operating drilling equipment without proper safety procedures, potential for hand injury."
  101. };
  102. // 进行预测
  103. var prediction = predictionEngine.Predict(sampleAccident);
  104. Console.WriteLine("\n预测示例:");
  105. Console.WriteLine($"行业: {sampleAccident.IndustrySector}");
  106. Console.WriteLine($"风险类型: {sampleAccident.CriticalRisk}");
  107. Console.WriteLine($"事故描述: {sampleAccident.Description}");
  108. Console.WriteLine($"预测事故等级: {prediction.PredictedAccidentLevel}");
  109. // 输出各类别的概率
  110. Console.WriteLine("各等级的概率分布:");
  111. var labels = new[] { "I", "II", "III", "IV" };
  112. for (int i = 0; i < prediction.Score.Length; i++)
  113. {
  114. if (i < prediction.Score.Length)
  115. Console.WriteLine($"等级 {i}: {prediction.Score[i]:P2}");
  116. }
  117. Console.ReadKey();
  118. }
  119. }
  120. public class AccidentData
  121. {
  122. [LoadColumn(0)]
  123. public float SeqNo { get; set; }
  124. [LoadColumn(1)]
  125. public string Date { get; set; }
  126. [LoadColumn(2)]
  127. public string Country { get; set; }
  128. [LoadColumn(3)]
  129. public string Local { get; set; }
  130. [LoadColumn(4)]
  131. public string IndustrySector { get; set; }
  132. [LoadColumn(5)]
  133. public string AccidentLevel { get; set; } // 原始字符串标签
  134. [LoadColumn(6)]
  135. public string PotentialAccidentLevel { get; set; }
  136. [LoadColumn(7)]
  137. public string Genre { get; set; }
  138. [LoadColumn(8)]
  139. public string EmployeeOrThirdParty { get; set; }
  140. [LoadColumn(9)]
  141. public string CriticalRisk { get; set; }
  142. [LoadColumn(10)]
  143. public string Description { get; set; }
  144. }
  145. public class AccidentPrediction
  146. {
  147. [ColumnName("PredictedLabel")]
  148. public string PredictedAccidentLevel { get; set; }
  149. public float[] Score { get; set; }
  150. }
  151. }