You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

330 lines
12 KiB

  1. using IllusionInjector.Utilities;
  2. using Ionic.Zip;
  3. using SimpleJSON;
  4. using System;
  5. using System.Collections;
  6. using System.Collections.Generic;
  7. using System.IO;
  8. using System.Linq;
  9. using System.Security.Cryptography;
  10. using System.Text;
  11. using System.Text.RegularExpressions;
  12. using System.Threading;
  13. using System.Threading.Tasks;
  14. using UnityEngine;
  15. using UnityEngine.Networking;
  16. using Logger = IllusionInjector.Logging.Logger;
  17. namespace IllusionInjector.Updating.ModsaberML
  18. {
  19. class Updater : MonoBehaviour
  20. {
  21. public static Updater instance;
  22. public void Awake()
  23. {
  24. try
  25. {
  26. if (instance != null)
  27. Destroy(this);
  28. else
  29. {
  30. instance = this;
  31. CheckForUpdates();
  32. }
  33. }
  34. catch (Exception e)
  35. {
  36. Logger.log.Error(e);
  37. }
  38. }
  39. public void CheckForUpdates()
  40. {
  41. StartCoroutine(CheckForUpdatesCoroutine());
  42. }
  43. private struct UpdateStruct
  44. {
  45. public PluginManager.BSPluginMeta plugin;
  46. public ApiEndpoint.Mod externInfo;
  47. }
  48. IEnumerator CheckForUpdatesCoroutine()
  49. {
  50. Logger.log.Info("Checking for mod updates...");
  51. var toUpdate = new List<UpdateStruct>();
  52. var modList = new List<ApiEndpoint.Mod>();
  53. using (var request = UnityWebRequest.Get(ApiEndpoint.ApiBase+ApiEndpoint.GetApprovedEndpoint))
  54. {
  55. yield return request.SendWebRequest();
  56. if (request.isNetworkError)
  57. {
  58. Logger.log.Error("Network error while trying to update mods");
  59. Logger.log.Error(request.error);
  60. yield break;
  61. }
  62. if (request.isHttpError)
  63. {
  64. Logger.log.Error($"Server returned an error code while trying to update mods");
  65. Logger.log.Error(request.error);
  66. }
  67. var json = request.downloadHandler.text;
  68. JSONObject obj = null;
  69. try
  70. {
  71. obj = JSON.Parse(json).AsObject;
  72. }
  73. catch (InvalidCastException)
  74. {
  75. Logger.log.Error($"Parse error while trying to update mods");
  76. Logger.log.Error($"Response doesn't seem to be a JSON object");
  77. yield break;
  78. }
  79. catch (Exception e)
  80. {
  81. Logger.log.Error($"Parse error while trying to update mods");
  82. Logger.log.Error(e);
  83. yield break;
  84. }
  85. foreach (var modObj in obj["mods"].AsArray.Children)
  86. {
  87. try
  88. {
  89. modList.Add(ApiEndpoint.Mod.DecodeJSON(modObj.AsObject));
  90. }
  91. catch (Exception e)
  92. {
  93. Logger.log.Error($"Parse error while trying to update mods");
  94. Logger.log.Error($"Response doesn't seem to be correctly formatted");
  95. Logger.log.Error(e);
  96. break;
  97. }
  98. }
  99. }
  100. var GameVersion = new Version(Application.version);
  101. foreach (var plugin in PluginManager.BSMetas)
  102. {
  103. var info = plugin.ModsaberInfo;
  104. var modRegistry = modList.FirstOrDefault(o => o.Name == info.InternalName);
  105. if (modRegistry != null)
  106. { // a.k.a we found it
  107. Logger.log.Debug($"Found Modsaber.ML registration for {plugin.Plugin.Name} ({info.InternalName})");
  108. Logger.log.Debug($"Installed version: {info.CurrentVersion}; Latest version: {modRegistry.Version}");
  109. if (modRegistry.Version > info.CurrentVersion)
  110. {
  111. Logger.log.Debug($"{plugin.Plugin.Name} needs an update!");
  112. if (modRegistry.GameVersion == GameVersion)
  113. {
  114. Logger.log.Debug($"Queueing update...");
  115. toUpdate.Add(new UpdateStruct
  116. {
  117. plugin = plugin,
  118. externInfo = modRegistry
  119. });
  120. }
  121. else
  122. {
  123. Logger.log.Warn($"Update avaliable for {plugin.Plugin.Name}, but for a different Beat Saber version!");
  124. }
  125. }
  126. }
  127. }
  128. Logger.log.Info($"{toUpdate.Count} mods need updating");
  129. if (toUpdate.Count == 0) yield break;
  130. string tempDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName() + Path.GetRandomFileName());
  131. Directory.CreateDirectory(tempDirectory);
  132. Logger.log.Debug($"Created temp download dirtectory {tempDirectory}");
  133. foreach (var item in toUpdate)
  134. {
  135. StartCoroutine(UpdateModCoroutine(tempDirectory, item));
  136. }
  137. }
  138. class StreamDownloadHandler : DownloadHandlerScript
  139. {
  140. public MemoryStream Stream { get; set; }
  141. public StreamDownloadHandler(MemoryStream stream) : base()
  142. {
  143. Stream = stream;
  144. }
  145. protected override void ReceiveContentLength(int contentLength)
  146. {
  147. Stream.Capacity = contentLength;
  148. Logger.log.Debug($"Got content length: {contentLength}");
  149. }
  150. protected override void CompleteContent()
  151. {
  152. Logger.log.Debug("Download complete");
  153. }
  154. protected override bool ReceiveData(byte[] data, int dataLength)
  155. {
  156. Logger.log.Debug("ReceiveData");
  157. if (data == null || data.Length < 1)
  158. {
  159. Logger.log.Debug("CustomWebRequest :: ReceiveData - received a null/empty buffer");
  160. return false;
  161. }
  162. Stream.Write(data, 0, dataLength);
  163. return true;
  164. }
  165. protected override byte[] GetData() { return null; }
  166. protected override float GetProgress()
  167. {
  168. return 0f;
  169. }
  170. public override string ToString()
  171. {
  172. return $"{base.ToString()} ({Stream?.ToString()})";
  173. }
  174. }
  175. private void ExtractPluginAsync(MemoryStream stream, UpdateStruct item, ApiEndpoint.Mod.PlatformFile fileInfo)
  176. {
  177. Logger.log.Debug($"Getting ZIP file for {item.plugin.Plugin.Name}");
  178. //var stream = await httpClient.GetStreamAsync(url);
  179. var data = stream.GetBuffer();
  180. SHA1 sha = new SHA1CryptoServiceProvider();
  181. var hash = sha.ComputeHash(data);
  182. if (!LoneFunctions.UnsafeCompare(hash, fileInfo.Hash))
  183. throw new Exception("The hash for the file doesn't match what is defined");
  184. using (var zipFile = ZipFile.Read(stream))
  185. {
  186. Logger.log.Debug("Streams opened");
  187. foreach (var entry in zipFile)
  188. {
  189. Logger.log.Debug(entry?.FileName ?? "NULL");
  190. if (entry.IsDirectory)
  191. {
  192. Logger.log.Debug($"Creating directory {entry.FileName}");
  193. Directory.CreateDirectory(Path.Combine(Environment.CurrentDirectory, entry.FileName));
  194. }
  195. else
  196. {
  197. using (var ostream = new MemoryStream((int)entry.UncompressedSize))
  198. {
  199. entry.Extract(ostream);
  200. ostream.Seek(0, SeekOrigin.Begin);
  201. sha = new SHA1CryptoServiceProvider();
  202. var fileHash = sha.ComputeHash(ostream);
  203. if (!LoneFunctions.UnsafeCompare(fileHash, fileInfo.FileHashes[entry.FileName]))
  204. throw new Exception("The hash for the file doesn't match what is defined");
  205. ostream.Seek(0, SeekOrigin.Begin);
  206. FileInfo targetFile = new FileInfo(Path.Combine(Environment.CurrentDirectory, entry.FileName));
  207. if (targetFile.Exists)
  208. {
  209. Logger.log.Debug($"Target file {targetFile.FullName} exists");
  210. }
  211. var fstream = targetFile.Create();
  212. ostream.CopyTo(fstream);
  213. Logger.log.Debug($"Wrote file {targetFile.FullName}");
  214. }
  215. }
  216. }
  217. }
  218. Logger.log.Debug("Downloader exited");
  219. }
  220. IEnumerator UpdateModCoroutine(string tempdir, UpdateStruct item)
  221. {
  222. ApiEndpoint.Mod.PlatformFile platformFile;
  223. if (SteamCheck.IsAvailable || item.externInfo.OculusFile == null)
  224. platformFile = item.externInfo.SteamFile;
  225. else
  226. platformFile = item.externInfo.OculusFile;
  227. string url = platformFile.DownloadPath;
  228. Logger.log.Debug($"URL = {url}");
  229. const int MaxTries = 3;
  230. int maxTries = MaxTries;
  231. while (maxTries > 0)
  232. {
  233. if (maxTries-- != MaxTries)
  234. Logger.log.Info($"Re-trying download...");
  235. using (var stream = new MemoryStream())
  236. using (var request = UnityWebRequest.Get(url))
  237. using (var taskTokenSource = new CancellationTokenSource())
  238. {
  239. var dlh = new StreamDownloadHandler(stream);
  240. request.downloadHandler = dlh;
  241. Logger.log.Debug("Sending request");
  242. //Logger.log.Debug(request?.downloadHandler?.ToString() ?? "DLH==NULL");
  243. yield return request.SendWebRequest();
  244. Logger.log.Debug("Download finished");
  245. if (request.isNetworkError)
  246. {
  247. Logger.log.Error("Network error while trying to update mod");
  248. Logger.log.Error(request.error);
  249. taskTokenSource.Cancel();
  250. continue;
  251. }
  252. if (request.isHttpError)
  253. {
  254. Logger.log.Error($"Server returned an error code while trying to update mod");
  255. Logger.log.Error(request.error);
  256. taskTokenSource.Cancel();
  257. continue;
  258. }
  259. stream.Seek(0, SeekOrigin.Begin); // reset to beginning
  260. var downloadTask = Task.Run(() =>
  261. { // use slightly more multithreaded approach than coroutines
  262. ExtractPluginAsync(stream, item, platformFile);
  263. }, taskTokenSource.Token);
  264. while (!(downloadTask.IsCompleted || downloadTask.IsCanceled || downloadTask.IsFaulted))
  265. yield return null; // pause coroutine until task is done
  266. if (downloadTask.IsFaulted)
  267. {
  268. Logger.log.Error($"Error downloading mod {item.plugin.Plugin.Name}");
  269. Logger.log.Error(downloadTask.Exception);
  270. continue;
  271. }
  272. break;
  273. }
  274. }
  275. if (maxTries == 0)
  276. Logger.log.Warn($"Plugin download failed {MaxTries} times, not re-trying");
  277. else
  278. Logger.log.Debug("Download complete");
  279. }
  280. }
  281. }