Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce integrations with AIPI for automated content scanning #644

Merged
merged 4 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion Refresh.GameServer/Configuration/IntegrationConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace Refresh.GameServer.Configuration;
/// </summary>
public class IntegrationConfig : Config
{
public override int CurrentConfigVersion => 4;
public override int CurrentConfigVersion => 5;
public override int Version { get; set; }
protected override void Migrate(int oldVer, dynamic oldConfig)
{
Expand All @@ -31,11 +31,32 @@ protected override void Migrate(int oldVer, dynamic oldConfig)

public bool DiscordWebhookEnabled { get; set; }
public string DiscordWebhookUrl { get; set; } = "https://discord.com/api/webhooks/id/key";

public bool DiscordStaffWebhookEnabled { get; set; }
public string DiscordStaffWebhookUrl { get; set; } = "https://discord.com/api/webhooks/id/key";
public int DiscordWorkerFrequencySeconds { get; set; } = 60;
public string DiscordNickname { get; set; } = "Refresh";
public string DiscordAvatarUrl { get; set; } = "https://raw.githubusercontent.com/LittleBigRefresh/Branding/main/icons/refresh_512x.png";

#endregion

#region AIPI

public bool AipiEnabled { get; set; } = false;
public string AipiBaseUrl { get; set; } = "http://localhost:5000";

/// <summary>
/// The threshold at which tags are discarded during EVA2 prediction.
/// </summary>
public float AipiThreshold { get; set; } = 0.85f;

// in DO we store this statically, but this exposing this as a config option allows us to obscure which tags
// are being blocked, because refresh is FOSS and DT could probably just look at it.
public string[] AipiBannedTags { get; set; } = [];

public bool AipiRestrictAccountOnDetection { get; set; } = false;

#endregion

public string? GrafanaDashboardUrl { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace Refresh.GameServer.Endpoints.ApiV3.ApiTypes.Errors;

public class ApiModerationError : ApiError
{
public static readonly ApiModerationError Instance = new();

public ApiModerationError() : base("This content was flagged as potentially unsafe, and administrators have been alerted. If you believe this is an error, please contact an administrator.", UnprocessableContent)
{
}
}
13 changes: 12 additions & 1 deletion Refresh.GameServer/Endpoints/ApiV3/ResourceApiEndpoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response;
using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.Data;
using Refresh.GameServer.Importing;
using Refresh.GameServer.Services;
using Refresh.GameServer.Types.Assets;
using Refresh.GameServer.Types.Data;
using Refresh.GameServer.Types.Roles;
Expand Down Expand Up @@ -158,7 +159,11 @@ public ApiResponse<ApiGameAssetResponse> UploadImageAsset(RequestContext context
IDataStore dataStore, AssetImporter importer, GameServerConfig config,
[DocSummary("The SHA1 hash of the asset")]
string hash,
byte[] body, GameUser user, DataContext dataContext)
byte[] body, GameUser user, DataContext dataContext,
AipiService? aipi,
DiscordStaffService? discord,
IntegrationConfig integration
)
{
// If we're blocking asset uploads, throw unless the user is an admin.
// We also have the ability to block asset uploads for trusted users (when they would normally bypass this)
Expand Down Expand Up @@ -197,6 +202,12 @@ public ApiResponse<ApiGameAssetResponse> UploadImageAsset(RequestContext context
return ApiInternalError.CouldNotWriteAssetError;

gameAsset.OriginalUploader = user;

if (aipi != null && aipi.ScanAndHandleAsset(dataContext, gameAsset))
{
return ApiModerationError.Instance;
}

database.AddAssetToDatabase(gameAsset);

return new ApiResponse<ApiGameAssetResponse>(ApiGameAssetResponse.FromOld(gameAsset, dataContext)!, Created);
Expand Down
14 changes: 13 additions & 1 deletion Refresh.GameServer/Endpoints/Game/PhotoEndpoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using Bunkum.Protocols.Http;
using Refresh.GameServer.Database;
using Refresh.GameServer.Extensions;
using Refresh.GameServer.Services;
using Refresh.GameServer.Types.Assets;
using Refresh.GameServer.Types.Data;
using Refresh.GameServer.Types.Levels;
using Refresh.GameServer.Types.Lists;
Expand All @@ -19,7 +21,8 @@ public class PhotoEndpoints : EndpointGroup
{
[GameEndpoint("uploadPhoto", HttpMethods.Post, ContentType.Xml)]
[RequireEmailVerified]
public Response UploadPhoto(RequestContext context, SerializedPhoto body, GameDatabaseContext database, GameUser user, IDataStore dataStore)
public Response UploadPhoto(RequestContext context, SerializedPhoto body, GameDatabaseContext database, GameUser user, IDataStore dataStore,
DataContext dataContext, AipiService aipi)
{
if (!dataStore.ExistsInStore(body.SmallHash) ||
!dataStore.ExistsInStore(body.MediumHash) ||
Expand All @@ -36,6 +39,15 @@ public Response UploadPhoto(RequestContext context, SerializedPhoto body, GameDa
return BadRequest;
}

List<string> hashes = [body.LargeHash, body.MediumHash, body.SmallHash];
foreach (string hash in hashes.Distinct())
{
GameAsset? gameAsset = database.GetAssetFromHash(hash);
if(gameAsset == null) continue;
if (aipi != null && aipi.ScanAndHandleAsset(dataContext, gameAsset))
return Unauthorized;
}

database.UploadPhoto(body, user);

return OK;
Expand Down
1 change: 1 addition & 0 deletions Refresh.GameServer/RefreshContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ public enum RefreshContext
LevelListOverride,
CoolLevels,
Publishing,
Aipi,
}
4 changes: 4 additions & 0 deletions Refresh.GameServer/RefreshGameServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ protected override void SetupServices()
this.Server.AddService<RequestStatisticTrackingService>();
this.Server.AddService<LevelListOverrideService>();
this.Server.AddService<CommandService>();
this.Server.AddService<DiscordStaffService>();

if(this._integrationConfig!.AipiEnabled)
this.Server.AddService<AipiService>();

#if DEBUG
this.Server.AddService<DebugService>();
Expand Down
160 changes: 160 additions & 0 deletions Refresh.GameServer/Services/AipiService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
using System.Diagnostics;
using System.Net.Http.Json;
using Bunkum.Core.Services;
using JetBrains.Annotations;
using NotEnoughLogs;
using Refresh.GameServer.Configuration;
using Refresh.GameServer.Importing;
using Refresh.GameServer.Types.Assets;
using Refresh.GameServer.Types.Data;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.Formats;
using SixLabors.ImageSharp.Processing;

namespace Refresh.GameServer.Services;

// Referenced from DO.
public class AipiService : EndpointService
{
private readonly HttpClient _client;
private readonly IntegrationConfig _config;
private readonly DiscordStaffService? _discord;

private readonly ImageImporter _importer;

[UsedImplicitly]
public AipiService(Logger logger, IntegrationConfig config, ImportService import, DiscordStaffService discord) : base(logger)
{
this._discord = discord;
this._config = config;

this._client = new HttpClient
{
BaseAddress = new Uri(config.AipiBaseUrl),
};

this._importer = import.ImageImporter;
}

public override void Initialize()
{
if (!this._config.DiscordStaffWebhookEnabled)
{
this.Logger.LogWarning(RefreshContext.Aipi,
"The Discord staff webhook is not enabled, but AIPI is. This is probably behavior you don't want.");
}
this.TestConnectivityAsync().Wait();
}

private async Task TestConnectivityAsync()
{
try
{
HttpResponseMessage response = await this._client.GetAsync("/");
string content = await response.Content.ReadAsStringAsync();

if (response.IsSuccessStatusCode && content == "AIPI scanning service")
this.Logger.LogInfo(RefreshContext.Aipi, "AIPI appears to be working correctly");
else
this.Logger.LogError(RefreshContext.Aipi,
$"AIPI seems to be down. Status code: {response.StatusCode}, content: {content}");
}
catch (Exception e)
{
this.Logger.LogError(RefreshContext.Aipi, "AIPI connection failed: {0}", e.ToString());
}
}

private async Task<TData> PostAsync<TData>(string endpoint, Stream data)
{
HttpResponseMessage response = await this._client.PostAsync(endpoint, new StreamContent(data));
AipiResponse<TData>? aipiResponse = await response.Content.ReadFromJsonAsync<AipiResponse<TData>>();

if (aipiResponse == null) throw new Exception("No response was received from the server.");
if (!aipiResponse.Success) throw new Exception($"{response.StatusCode}: {aipiResponse.Reason}");
jvyden marked this conversation as resolved.
Show resolved Hide resolved

return aipiResponse.Data!;
}

private async Task<Dictionary<string, float>> PredictEvaAsync(Stream data)
{
Stopwatch stopwatch = new();
this.Logger.LogTrace(RefreshContext.Aipi, "Pre-processing image data...");

DecoderOptions options = new()
{
MaxFrames = 1,
Configuration = SixLabors.ImageSharp.Configuration.Default,
};

Image image = await Image.LoadAsync(options, data);
// Technically, we don't read videos in Refresh like in DO, but a couple of users are currently using APNGs as their avatar.
// I don't want to break APNGs as they're harmless, so let's handle this by just reading the first frame for now.
if (image.Frames.Count > 0)
image = image.Frames.CloneFrame(0);

image.Mutate(x => x.Resize(new ResizeOptions
{
Size = new Size(512),
Mode = ResizeMode.Max,
}));

using MemoryStream processedData = new();
await image.SaveAsPngAsync(processedData);
// await image.SaveAsPngAsync($"/tmp/{DateTimeOffset.Now.ToUnixTimeMilliseconds()}.png");
processedData.Seek(0, SeekOrigin.Begin);

float threshold = this._config.AipiThreshold;

this.Logger.LogDebug(RefreshContext.Aipi, $"Running prediction for image @ threshold={threshold}...");

stopwatch.Start();
Dictionary<string, float> prediction = await this.PostAsync<Dictionary<string, float>>($"/eva/predict?threshold={threshold}", processedData);
stopwatch.Stop();

this.Logger.LogInfo(RefreshContext.Aipi, $"Got prediction result in {stopwatch.ElapsedMilliseconds}ms.");
this.Logger.LogDebug(RefreshContext.Aipi, JsonConvert.SerializeObject(prediction));
return prediction;
}

public bool ScanAndHandleAsset(DataContext context, GameAsset asset)
{
// guard the fact that assets have an owner
Debug.Assert(asset.OriginalUploader != null, $"Asset {asset.AssetHash} had no original uploader when trying to scan");
if (asset.OriginalUploader == null)
return false;

// import the asset as png
bool isPspAsset = asset.AssetHash.StartsWith("psp/");

if (!context.DataStore.ExistsInStore("png/" + asset.AssetHash))
jvyden marked this conversation as resolved.
Show resolved Hide resolved
{
this._importer.ImportAsset(asset.AssetHash, isPspAsset, asset.AssetType, context.DataStore);
}

// do actual prediction
using Stream stream = context.DataStore.GetStreamFromStore("png/" + asset.AssetHash);
Dictionary<string, float> results = this.PredictEvaAsync(stream).Result;

if (!results.Any(r => this._config.AipiBannedTags.Contains(r.Key)))
return false;

this._discord?.PostPredictionResult(results, asset);

if (this._config.AipiRestrictAccountOnDetection)
{
const string reason = "Automatic restriction for posting disallowed content. This will usually be undone within 24 hours if this is a mistake.";
context.Database.RestrictUser(asset.OriginalUploader, reason, DateTimeOffset.MaxValue);
}

return true;
}

private class AipiResponse<TData>
{
public bool Success { get; set; }

public TData? Data { get; set; }
public string? Reason { get; set; }
}
}
75 changes: 75 additions & 0 deletions Refresh.GameServer/Services/DiscordStaffService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using System.Diagnostics;
using Bunkum.Core.Services;
using Discord;
using Discord.Webhook;
using NotEnoughLogs;
using Refresh.GameServer.Configuration;
using Refresh.GameServer.Types.UserData;
using GameAsset = Refresh.GameServer.Types.Assets.GameAsset;

namespace Refresh.GameServer.Services;

public class DiscordStaffService : EndpointService
{
private readonly DiscordWebhookClient? _client;
private readonly IntegrationConfig _config;

private readonly string _externalUrl;

private const string NameSuffix = " (Staff)";

private const string DefaultResultsDescription = "These are the results of the AI's best guess at deciphering the contents of the image. " +
"Take them with a grain of salt as the AI isn't perfect.";

internal DiscordStaffService(Logger logger, GameServerConfig gameConfig, IntegrationConfig config) : base(logger)
{
this._config = config;
this._externalUrl = gameConfig.WebExternalUrl;

if(config.DiscordStaffWebhookEnabled)
this._client = new DiscordWebhookClient(config.DiscordStaffWebhookUrl);
}

private string GetAssetUrl(string hash)
{
return $"{this._externalUrl}/api/v3/assets/{hash}/image";
}

jvyden marked this conversation as resolved.
Show resolved Hide resolved
private string GetAssetInfoUrl(string hash)
{
return $"{this._externalUrl}/api/v3/assets/{hash}";
}

private void PostMessage(string? message = null, IEnumerable<Embed>? embeds = null!)
{
if (this._client == null)
return;

embeds ??= [];

ulong id = this._client.SendMessageAsync(embeds: embeds,
username: this._config.DiscordNickname + NameSuffix, avatarUrl: this._config.DiscordAvatarUrl).Result;

this.Logger.LogInfo(RefreshContext.Discord, $"Posted webhook {id}: '{message}'");
}

public void PostPredictionResult(Dictionary<string, float> results, GameAsset asset)
{
GameUser author = asset.OriginalUploader!;

EmbedBuilder builder = new EmbedBuilder()
.WithAuthor($"Image posted by @{author.Username} (id: {author.UserId})", this.GetAssetUrl(author.IconHash))
.WithDescription(DefaultResultsDescription)
.WithUrl(this.GetAssetInfoUrl(asset.AssetHash))
.WithTitle($"AI Analysis of `{asset.AssetHash}`");

foreach ((string tag, float confidence) in results.OrderByDescending(r => r.Value).Take(25))
{
string tagFormatted = this._config.AipiBannedTags.Contains(tag) ? $"{tag} (flagged!)" : tag;
string confidenceFormatted = confidence.ToString("0.00%");
builder.AddField(tagFormatted, confidenceFormatted, true);
}

this.PostMessage($"Prediction result for {asset.AssetHash} ({author.Username}):", [builder.Build()]);
}
}
Loading
Loading