You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

363 lines
14 KiB

  1. using IllusionInjector.Updating.Backup;
  2. using IllusionInjector.Utilities;
  3. using Ionic.Zip;
  4. using SimpleJSON;
  5. using System;
  6. using System.Collections;
  7. using System.Collections.Generic;
  8. using System.Diagnostics;
  9. using System.IO;
  10. using System.Linq;
  11. using System.Security.Cryptography;
  12. using System.Text;
  13. using System.Text.RegularExpressions;
  14. using System.Threading;
  15. using System.Threading.Tasks;
  16. using UnityEngine;
  17. using UnityEngine.Networking;
  18. using Logger = IllusionInjector.Logging.Logger;
  19. namespace IllusionInjector.Updating.ModsaberML
  20. {
  21. class Updater : MonoBehaviour
  22. {
  23. public static Updater instance;
  24. public void Awake()
  25. {
  26. try
  27. {
  28. if (instance != null)
  29. Destroy(this);
  30. else
  31. {
  32. instance = this;
  33. CheckForUpdates();
  34. }
  35. }
  36. catch (Exception e)
  37. {
  38. Logger.log.Error(e);
  39. }
  40. }
  41. public void CheckForUpdates()
  42. {
  43. StartCoroutine(CheckForUpdatesCoroutine());
  44. }
  45. private struct UpdateStruct
  46. {
  47. public PluginManager.BSPluginMeta plugin;
  48. public ApiEndpoint.Mod externInfo;
  49. }
  50. IEnumerator CheckForUpdatesCoroutine()
  51. {
  52. Logger.log.Info("Checking for mod updates...");
  53. var toUpdate = new List<UpdateStruct>();
  54. var GameVersion = new Version(Application.version);
  55. foreach (var plugin in PluginManager.BSMetas)
  56. {
  57. var info = plugin.ModsaberInfo;
  58. if (info == null) continue;
  59. using (var request = UnityWebRequest.Get(ApiEndpoint.ApiBase + string.Format(ApiEndpoint.GetApprovedEndpoint, info.InternalName)))
  60. {
  61. yield return request.SendWebRequest();
  62. if (request.isNetworkError)
  63. {
  64. Logger.log.Error("Network error while trying to update mods");
  65. Logger.log.Error(request.error);
  66. continue;
  67. }
  68. if (request.isHttpError)
  69. {
  70. if (request.responseCode == 404)
  71. {
  72. Logger.log.Error($"Mod {plugin.Plugin.Name} not found under name {info.InternalName}");
  73. continue;
  74. }
  75. Logger.log.Error($"Server returned an error code while trying to update mod {plugin.Plugin.Name}");
  76. Logger.log.Error(request.error);
  77. continue;
  78. }
  79. var json = request.downloadHandler.text;
  80. JSONObject obj = null;
  81. try
  82. {
  83. obj = JSON.Parse(json).AsObject;
  84. }
  85. catch (InvalidCastException)
  86. {
  87. Logger.log.Error($"Parse error while trying to update mods");
  88. Logger.log.Error($"Response doesn't seem to be a JSON object");
  89. continue;
  90. }
  91. catch (Exception e)
  92. {
  93. Logger.log.Error($"Parse error while trying to update mods");
  94. Logger.log.Error(e);
  95. continue;
  96. }
  97. ApiEndpoint.Mod modRegistry;
  98. try
  99. {
  100. modRegistry = ApiEndpoint.Mod.DecodeJSON(obj);
  101. }
  102. catch (Exception e)
  103. {
  104. Logger.log.Error($"Parse error while trying to update mods");
  105. Logger.log.Error(e);
  106. continue;
  107. }
  108. Logger.log.Debug($"Found Modsaber.ML registration for {plugin.Plugin.Name} ({info.InternalName})");
  109. Logger.log.Debug($"Installed version: {info.CurrentVersion}; Latest version: {modRegistry.Version}");
  110. if (modRegistry.Version > info.CurrentVersion)
  111. {
  112. Logger.log.Debug($"{plugin.Plugin.Name} needs an update!");
  113. if (modRegistry.GameVersion == GameVersion)
  114. {
  115. Logger.log.Debug($"Queueing update...");
  116. toUpdate.Add(new UpdateStruct
  117. {
  118. plugin = plugin,
  119. externInfo = modRegistry
  120. });
  121. }
  122. else
  123. {
  124. Logger.log.Warn($"Update avaliable for {plugin.Plugin.Name}, but for a different Beat Saber version!");
  125. }
  126. }
  127. }
  128. }
  129. Logger.log.Info($"{toUpdate.Count} mods need updating");
  130. if (toUpdate.Count == 0) yield break;
  131. string tempDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName() + Path.GetRandomFileName());
  132. Directory.CreateDirectory(tempDirectory);
  133. foreach (var item in toUpdate)
  134. {
  135. StartCoroutine(UpdateModCoroutine(item, tempDirectory));
  136. }
  137. }
  138. class StreamDownloadHandler : DownloadHandlerScript
  139. {
  140. public MemoryStream Stream { get; set; }
  141. public StreamDownloadHandler(MemoryStream stream) : base()
  142. {
  143. Stream = stream;
  144. }
  145. protected override void ReceiveContentLength(int contentLength)
  146. {
  147. Stream.Capacity = contentLength;
  148. Logger.log.Debug($"Got content length: {contentLength}");
  149. }
  150. protected override void CompleteContent()
  151. {
  152. Logger.log.Debug("Download complete");
  153. }
  154. protected override bool ReceiveData(byte[] data, int dataLength)
  155. {
  156. if (data == null || data.Length < 1)
  157. {
  158. Logger.log.Debug("CustomWebRequest :: ReceiveData - received a null/empty buffer");
  159. return false;
  160. }
  161. Stream.Write(data, 0, dataLength);
  162. return true;
  163. }
  164. protected override byte[] GetData() { return null; }
  165. protected override float GetProgress()
  166. {
  167. return 0f;
  168. }
  169. public override string ToString()
  170. {
  171. return $"{base.ToString()} ({Stream?.ToString()})";
  172. }
  173. }
  174. private void ExtractPluginAsync(MemoryStream stream, UpdateStruct item, ApiEndpoint.Mod.PlatformFile fileInfo, string tempDirectory)
  175. {
  176. Logger.log.Debug($"Extracting ZIP file for {item.plugin.Plugin.Name}");
  177. var data = stream.GetBuffer();
  178. SHA1 sha = new SHA1CryptoServiceProvider();
  179. var hash = sha.ComputeHash(data);
  180. if (!LoneFunctions.UnsafeCompare(hash, fileInfo.Hash))
  181. throw new Exception("The hash for the file doesn't match what is defined");
  182. var newFiles = new List<FileInfo>();
  183. var backup = new BackupUnit(tempDirectory, $"backup-{item.plugin.ModsaberInfo.InternalName}");
  184. try
  185. {
  186. bool shouldDeleteOldFile = true;
  187. using (var zipFile = ZipFile.Read(stream))
  188. {
  189. Logger.log.Debug("Streams opened");
  190. foreach (var entry in zipFile)
  191. {
  192. if (entry.IsDirectory)
  193. {
  194. Logger.log.Debug($"Creating directory {entry.FileName}");
  195. Directory.CreateDirectory(Path.Combine(Environment.CurrentDirectory, entry.FileName));
  196. }
  197. else
  198. {
  199. using (var ostream = new MemoryStream((int)entry.UncompressedSize))
  200. {
  201. entry.Extract(ostream);
  202. ostream.Seek(0, SeekOrigin.Begin);
  203. sha = new SHA1CryptoServiceProvider();
  204. var fileHash = sha.ComputeHash(ostream);
  205. if (!LoneFunctions.UnsafeCompare(fileHash, fileInfo.FileHashes[entry.FileName]))
  206. throw new Exception("The hash for the file doesn't match what is defined");
  207. ostream.Seek(0, SeekOrigin.Begin);
  208. FileInfo targetFile = new FileInfo(Path.Combine(Environment.CurrentDirectory, entry.FileName));
  209. Directory.CreateDirectory(targetFile.DirectoryName);
  210. if (targetFile.FullName == item.plugin.Filename)
  211. shouldDeleteOldFile = false; // overwriting old file, no need to delete
  212. if (targetFile.Exists)
  213. backup.Add(targetFile);
  214. else
  215. newFiles.Add(targetFile);
  216. Logger.log.Debug($"Extracting file {targetFile.FullName}");
  217. var fstream = targetFile.Create();
  218. ostream.CopyTo(fstream);
  219. }
  220. }
  221. }
  222. }
  223. if (item.plugin.Plugin is SelfPlugin)
  224. { // currently updating self
  225. Process.Start(new ProcessStartInfo
  226. {
  227. FileName = item.plugin.Filename,
  228. Arguments = $"--waitfor={Process.GetCurrentProcess().Id} --nowait",
  229. UseShellExecute = false
  230. });
  231. }
  232. else if (shouldDeleteOldFile)
  233. File.Delete(item.plugin.Filename);
  234. }
  235. catch (Exception)
  236. { // something failed; restore
  237. foreach (var file in newFiles)
  238. file.Delete();
  239. backup.Restore();
  240. backup.Delete();
  241. throw;
  242. }
  243. backup.Delete();
  244. Logger.log.Debug("Downloader exited");
  245. }
  246. IEnumerator UpdateModCoroutine(UpdateStruct item, string tempDirectory)
  247. {
  248. Logger.log.Debug($"Steam avaliable: {SteamCheck.IsAvailable}");
  249. ApiEndpoint.Mod.PlatformFile platformFile;
  250. if (SteamCheck.IsAvailable || item.externInfo.OculusFile == null)
  251. platformFile = item.externInfo.SteamFile;
  252. else
  253. platformFile = item.externInfo.OculusFile;
  254. string url = platformFile.DownloadPath;
  255. Logger.log.Debug($"URL = {url}");
  256. const int MaxTries = 3;
  257. int maxTries = MaxTries;
  258. while (maxTries > 0)
  259. {
  260. if (maxTries-- != MaxTries)
  261. Logger.log.Info($"Re-trying download...");
  262. using (var stream = new MemoryStream())
  263. using (var request = UnityWebRequest.Get(url))
  264. using (var taskTokenSource = new CancellationTokenSource())
  265. {
  266. var dlh = new StreamDownloadHandler(stream);
  267. request.downloadHandler = dlh;
  268. Logger.log.Debug("Sending request");
  269. //Logger.log.Debug(request?.downloadHandler?.ToString() ?? "DLH==NULL");
  270. yield return request.SendWebRequest();
  271. Logger.log.Debug("Download finished");
  272. if (request.isNetworkError)
  273. {
  274. Logger.log.Error("Network error while trying to update mod");
  275. Logger.log.Error(request.error);
  276. taskTokenSource.Cancel();
  277. continue;
  278. }
  279. if (request.isHttpError)
  280. {
  281. Logger.log.Error($"Server returned an error code while trying to update mod");
  282. Logger.log.Error(request.error);
  283. taskTokenSource.Cancel();
  284. continue;
  285. }
  286. stream.Seek(0, SeekOrigin.Begin); // reset to beginning
  287. var downloadTask = Task.Run(() =>
  288. { // use slightly more multithreaded approach than coroutines
  289. ExtractPluginAsync(stream, item, platformFile, tempDirectory);
  290. }, taskTokenSource.Token);
  291. while (!(downloadTask.IsCompleted || downloadTask.IsCanceled || downloadTask.IsFaulted))
  292. yield return null; // pause coroutine until task is done
  293. if (downloadTask.IsFaulted)
  294. {
  295. Logger.log.Error($"Error downloading mod {item.plugin.Plugin.Name}");
  296. Logger.log.Error(downloadTask.Exception);
  297. continue;
  298. }
  299. break;
  300. }
  301. }
  302. if (maxTries == 0)
  303. Logger.log.Warn($"Plugin download failed {MaxTries} times, not re-trying");
  304. else
  305. Logger.log.Debug("Download complete");
  306. }
  307. }
  308. }