Program.cs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. using Microsoft.AspNetCore.Hosting;
  2. using Microsoft.Extensions.Configuration;
  3. using Microsoft.Extensions.Hosting;
  4. using Microsoft.Extensions.Logging;
  5. using System;
  6. using System.Collections.Generic;
  7. using System.Linq;
  8. using System.Threading.Tasks;
  9. using Microsoft.ML;
  10. using Microsoft.ML.Data;
  11. using InfluxData.Net.Common.Enums;
  12. using InfluxData.Net.InfluxDb;
  13. using System.CodeDom.Compiler;
  14. using AdysTech.InfluxDB.Client.Net;
  15. using System.Drawing;
  16. using InfluxData.Net.InfluxDb.Models;
  17. namespace Ropin.IOT.MLService
  18. {
  19. public class Program
  20. {
  21. //public static void Main(string[] args)
  22. //{
  23. // CreateHostBuilder(args).Build().Run();
  24. //}
  25. // 定义文件路径
  26. private static string _dataPath = "./accident_data.csv";
  27. private static string _modelPath = "./AccidentPredictionModel.zip";
  28. public static void Main(string[] args)
  29. {
  30. //try
  31. //{
  32. // GetData(args);
  33. //}
  34. //catch (Exception ex) { Console.WriteLine(ex.InnerException.Message); }
  35. Console.OutputEncoding = System.Text.Encoding.UTF8;
  36. Console.WriteLine("工厂事故预测系统启动...");
  37. // 创建ML.NET上下文
  38. MLContext mlContext = new MLContext(seed: 0);
  39. // 加载数据
  40. Console.WriteLine("正在加载数据...");
  41. IDataView dataView = mlContext.Data.LoadFromTextFile<AccidentData>(
  42. path: _dataPath,
  43. hasHeader: true,
  44. separatorChar: ',',
  45. allowQuoting: true
  46. );
  47. // 分割训练集和测试集
  48. var dataSplit = mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2);
  49. var trainingData = dataSplit.TrainSet;
  50. var testData = dataSplit.TestSet;
  51. // 数据处理和特征工程
  52. Console.WriteLine("正在处理数据和提取特征...");
  53. var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey(
  54. outputColumnName: "Label", // 这将创建一个新列作为Key类型标签
  55. inputColumnName: "AccidentLevel" // 原始字符串标签
  56. )
  57. .Append(mlContext.Transforms.Categorical.OneHotEncoding(
  58. new[] {
  59. new InputOutputColumnPair("CountryEncoded", "Country"),
  60. new InputOutputColumnPair("IndustrySectorEncoded", "IndustrySector"),
  61. new InputOutputColumnPair("CriticalRiskEncoded", "CriticalRisk"),
  62. new InputOutputColumnPair("GenreEncoded", "Genre"),
  63. new InputOutputColumnPair("EmployeeTypeEncoded", "EmployeeOrThirdParty")
  64. }))
  65. // 文本特征提取(从描述中提取特征)
  66. .Append(mlContext.Transforms.Text.FeaturizeText("DescriptionFeaturized", "Description"))
  67. // 合并所有特征到一个向量
  68. .Append(mlContext.Transforms.Concatenate("Features",
  69. "CountryEncoded", "IndustrySectorEncoded", "CriticalRiskEncoded",
  70. "GenreEncoded", "EmployeeTypeEncoded", "DescriptionFeaturized"));
  71. // 选择训练算法 - 使用多类分类器
  72. Console.WriteLine("构建和训练模型...");
  73. var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy();
  74. // 构建完整训练管道
  75. var trainingPipeline = dataProcessPipeline.Append(trainer)
  76. // 将预测结果映射回原始标签值
  77. .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
  78. // 训练模型
  79. var trainedModel = trainingPipeline.Fit(trainingData);
  80. Console.WriteLine("模型训练完成!");
  81. // 保存模型
  82. mlContext.Model.Save(trainedModel, trainingData.Schema, _modelPath);
  83. Console.WriteLine($"模型已保存到: {_modelPath}");
  84. // 评估模型
  85. Console.WriteLine("评估模型性能...");
  86. var predictions = trainedModel.Transform(testData);
  87. var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
  88. // 输出评估指标
  89. Console.WriteLine($"宏平均精确度: {metrics.MacroAccuracy:F2}");
  90. Console.WriteLine($"微平均精确度: {metrics.MicroAccuracy:F2}");
  91. Console.WriteLine($"对数损失: {metrics.LogLoss:F2}");
  92. Console.WriteLine($"混淆矩阵: \n{metrics.ConfusionMatrix.GetFormattedConfusionTable()}");
  93. // 使用模型进行预测示例
  94. PredictSample(mlContext, trainedModel);
  95. CreateHostBuilder(args).Build().Run();
  96. }
  97. public static IHostBuilder CreateHostBuilder(string[] args) =>
  98. Host.CreateDefaultBuilder(args)
  99. .ConfigureWebHostDefaults(webBuilder =>
  100. {
  101. webBuilder.UseStartup<Startup>();
  102. });
  103. #region 从Influxdb读取数据创建模型预测
  104. private static readonly IInfluxDBClient client = new InfluxDBClient("http://60.204.212.71:8085/", "admin", "123456");
  105. public static async Task GetData(string[] args)
  106. {
  107. //// 查询最近 10 条 CPU 使用率数据
  108. // var query1 = "SELECT * FROM fanyidev ORDER BY time DESC LIMIT 10";
  109. // var results = await client.QueryMultiSeriesAsync("fanyidb", query1);
  110. // if (results != null && results.Count > 0)
  111. // {
  112. // }
  113. // if (results != null && results.Count > 0)
  114. // {
  115. // // 处理查询结果
  116. // foreach (var series in results)
  117. // {
  118. // Console.WriteLine($"查询结果集: {series.Name}");
  119. // foreach (var point in series.Entries)
  120. // {
  121. // // 根据字段名获取值
  122. // var time = point.GetTimeAsDateTime();
  123. // var host = point.GetTagAsString("host");
  124. // var value = point.GetFieldAsDouble("value");
  125. // Console.WriteLine($"时间: {time}, 主机: {host}, CPU 使用率: {value}");
  126. // }
  127. //}
  128. // }
  129. //传入查询命令,支持多条
  130. var queries = new[]
  131. {
  132. " SELECT * FROM fanyidev WHERE time> now() - 1h "
  133. };
  134. var dbName = "fanyidb";
  135. InfluxDbClient influxDbClient = new InfluxDbClient("http://60.204.212.71:8085/", "admin", "123456", InfluxDbVersion.Latest);
  136. //从指定库中查询数据
  137. var response = await influxDbClient.Client.QueryAsync(queries, dbName);
  138. if (response.Any())
  139. {
  140. var series = response.ToList();
  141. foreach (var value in series[0].Values)
  142. {
  143. //var TimeStamp = DateTime.Parse((string)value[0]);
  144. var Pressure = Convert.ToDouble(value[2]);
  145. }
  146. var dataPoints = series[0].Values.Select(value =>
  147. new DeviceStatusDataPoint
  148. {
  149. //TimeStamp = DateTime.Parse((string)value[0]),
  150. //Temperature = Convert.ToDouble(value[1]),
  151. Pressure = Convert.ToDouble(value[2])//,
  152. //IsFaulty = (bool)value[3]
  153. }).ToList();
  154. dataPoints = new List<DeviceStatusDataPoint>();
  155. var v1 = new DeviceStatusDataPoint
  156. {
  157. TimeStamp = DateTime.Now.AddMilliseconds(-1),
  158. Temperature = 70,
  159. Pressure = 150,
  160. IsFaulty = true
  161. };
  162. var v2 = new DeviceStatusDataPoint
  163. {
  164. TimeStamp = DateTime.Now,
  165. Temperature = 60,
  166. Pressure = 140,
  167. IsFaulty = false
  168. };
  169. dataPoints.Add(v1);
  170. dataPoints.Add(v2);
  171. // 使用 ML.NET 训练模型
  172. var mlContext = new MLContext();
  173. // 将数据加载到 IDataView
  174. var dataView = mlContext.Data.LoadFromEnumerable(dataPoints);
  175. // 分割数据集为训练集和测试集
  176. var trainTestSplit = mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2);
  177. var trainingData = trainTestSplit.TrainSet;
  178. var testData = trainTestSplit.TestSet;
  179. // 定义管道
  180. var pipeline = mlContext.Transforms.Concatenate("Features", nameof(DeviceStatusDataPoint.Temperature), nameof(DeviceStatusDataPoint.Pressure))
  181. .Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(labelColumnName: nameof(DeviceStatusDataPoint.IsFaulty)));
  182. // 训练模型
  183. var model = pipeline.Fit(trainingData);
  184. // 创建预测引擎
  185. var predictionEngine = mlContext.Model.CreatePredictionEngine<DeviceStatusDataPoint, FaultPrediction>(model);
  186. // 预测新样本
  187. var sampleData = new DeviceStatusDataPoint { Temperature = 75.0, Pressure = 150.0 };
  188. var prediction = predictionEngine.Predict(sampleData);
  189. Console.WriteLine($"Predicted IsFaulty: {prediction.PredictedLabel}, Probability: {prediction.Probability}");
  190. // 评估模型
  191. var predictions = model.Transform(testData);
  192. var metrics = mlContext.BinaryClassification.Evaluate(predictions);
  193. Console.WriteLine($"Accuracy: {metrics.Accuracy}");
  194. }
  195. else
  196. {
  197. Console.WriteLine("Failed to retrieve data from InfluxDB.");
  198. }
  199. //var response = await influxDbClient.Client.QueryAsync(query);
  200. ////得到Serie集合对象(返回执行多个查询的结果)
  201. //var series = response.ToList();
  202. ////取出第一条命令的查询结果,是一个集合
  203. //var list = series[0].Values;
  204. ////从集合中取出第一条数据
  205. //var info_model = list.FirstOrDefault();
  206. }
  207. #endregion
  208. private static void PredictSample(MLContext mlContext, ITransformer model)
  209. {
  210. // 创建预测引擎
  211. var predictionEngine = mlContext.Model.CreatePredictionEngine<AccidentData, AccidentPrediction>(model);
  212. // 创建测试样本
  213. var sampleAccident = new AccidentData
  214. {
  215. Country = "Country_01",
  216. IndustrySector = "Mining",
  217. Genre = "Male",
  218. EmployeeOrThirdParty = "Third Party",
  219. CriticalRisk = "Pressed",
  220. Description = "Worker operating drilling equipment without proper safety procedures, potential for hand injury."
  221. };
  222. // 进行预测
  223. var prediction = predictionEngine.Predict(sampleAccident);
  224. Console.WriteLine("\n预测示例:");
  225. Console.WriteLine($"行业: {sampleAccident.IndustrySector}");
  226. Console.WriteLine($"风险类型: {sampleAccident.CriticalRisk}");
  227. Console.WriteLine($"事故描述: {sampleAccident.Description}");
  228. Console.WriteLine($"预测事故等级: {prediction.PredictedAccidentLevel}");
  229. // 输出各类别的概率
  230. Console.WriteLine("各等级的概率分布:");
  231. var labels = new[] { "I", "II", "III", "IV" };
  232. for (int i = 0; i < prediction.Score.Length; i++)
  233. {
  234. if (i < prediction.Score.Length)
  235. Console.WriteLine($"等级 {i}: {prediction.Score[i]:P2}");
  236. }
  237. Console.ReadKey();
  238. }
  239. }
  240. public class AccidentData
  241. {
  242. [LoadColumn(0)]
  243. public float SeqNo { get; set; }
  244. [LoadColumn(1)]
  245. public string Date { get; set; }
  246. [LoadColumn(2)]
  247. public string Country { get; set; }
  248. [LoadColumn(3)]
  249. public string Local { get; set; }
  250. [LoadColumn(4)]
  251. public string IndustrySector { get; set; }
  252. [LoadColumn(5)]
  253. public string AccidentLevel { get; set; } // 原始字符串标签
  254. [LoadColumn(6)]
  255. public string PotentialAccidentLevel { get; set; }
  256. [LoadColumn(7)]
  257. public string Genre { get; set; }
  258. [LoadColumn(8)]
  259. public string EmployeeOrThirdParty { get; set; }
  260. [LoadColumn(9)]
  261. public string CriticalRisk { get; set; }
  262. [LoadColumn(10)]
  263. public string Description { get; set; }
  264. }
  265. public class AccidentPrediction
  266. {
  267. [ColumnName("PredictedLabel")]
  268. public string PredictedAccidentLevel { get; set; }
  269. public float[] Score { get; set; }
  270. }
  271. public class DeviceStatusDataPoint
  272. {
  273. [LoadColumn(0)]
  274. public DateTime TimeStamp { get; set; }
  275. [LoadColumn(1)]
  276. public double Temperature { get; set; }
  277. [LoadColumn(2)]
  278. public double Pressure { get; set; }
  279. [LoadColumn(3)]
  280. public bool IsFaulty { get; set; }
  281. }
  282. public class FaultPrediction
  283. {
  284. [ColumnName("PredictedLabel")]
  285. public bool PredictedLabel { get; set; }
  286. [ColumnName("Probability")]
  287. public float Probability { get; set; }
  288. }
  289. }