diff --git a/app/android/build.gradle.kts b/app/android/build.gradle.kts index 218e84834..37cb86c22 100755 --- a/app/android/build.gradle.kts +++ b/app/android/build.gradle.kts @@ -32,6 +32,7 @@ android { buildConfigField("String", "OPEN_AI_INFO_URL", "\"https://platform.openai.com/api-keys\"") buildConfigField("String", "STABILITY_AI_INFO_URL", "\"https://platform.stability.ai/\"") buildConfigField("String", "FAL_AI_INFO_URL", "\"https://fal.ai/dashboard/keys\"") + buildConfigField("String", "ARLI_AI_INFO_URL", "\"https://www.arliai.com/quick-start\"") buildConfigField("String", "UPDATE_API_URL", "\"https://sdai.moroz.cc\"") buildConfigField("String", "REPORT_API_URL", "\"https://sdai-report.moroz.cc\"") buildConfigField("String", "DEMO_MODE_API_URL", "\"https://sdai.moroz.cc\"") diff --git a/app/android/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt b/app/android/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt index cc2a3109c..0d87ef808 100755 --- a/app/android/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt +++ b/app/android/src/main/java/com/shifthackz/aisdv1/app/di/ProvidersModule.kt @@ -31,6 +31,7 @@ val providersModule = module { override val openAiInfoUrl: String = BuildConfig.OPEN_AI_INFO_URL override val stabilityAiInfoUrl: String = BuildConfig.STABILITY_AI_INFO_URL override val falAiInfoUrl: String = BuildConfig.FAL_AI_INFO_URL + override val arliAiInfoUrl: String = BuildConfig.ARLI_AI_INFO_URL override val privacyPolicyUrl: String = BuildConfig.POLICY_URL override val donateUrl: String = BuildConfig.DONATE_URL override val projectWebsiteUrl: String = BuildConfig.PROJECT_WEBSITE_URL diff --git a/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/links/DefaultLinksProvider.kt b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/links/DefaultLinksProvider.kt index 11d3a731f..7f563a4f7 100644 --- a/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/links/DefaultLinksProvider.kt +++ b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/links/DefaultLinksProvider.kt @@ -42,6 +42,12 @@ object DefaultLinksProvider : LinksProvider { * @author Dmitriy Moroz */ override val falAiInfoUrl: String = "https://fal.ai/dashboard/keys" + /** + * Exposes the `arliAiInfoUrl` value used by the SDAI core common layer. + * + * @author Dmitriy Moroz + */ + override val arliAiInfoUrl: String = "https://www.arliai.com/quick-start" /** * Exposes the `privacyPolicyUrl` value used by the SDAI core common layer. * diff --git a/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt index e17d9fe03..bb42a3ade 100644 --- a/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt +++ b/core/common/src/commonMain/kotlin/com/shifthackz/aisdv1/core/common/links/LinksProvider.kt @@ -42,6 +42,12 @@ interface LinksProvider { * @author Dmitriy Moroz */ val falAiInfoUrl: String + /** + * Exposes the `arliAiInfoUrl` value used by the SDAI core common layer. + * + * @author Dmitriy Moroz + */ + val arliAiInfoUrl: String /** * Exposes the `privacyPolicyUrl` value used by the SDAI core common layer. * diff --git a/core/localization/src/androidMain/res/values-ru/strings.xml b/core/localization/src/androidMain/res/values-ru/strings.xml index af35ac695..de1229af5 100644 --- a/core/localization/src/androidMain/res/values-ru/strings.xml +++ b/core/localization/src/androidMain/res/values-ru/strings.xml @@ -217,6 +217,11 @@ Fal.ai предоставляет hosted API моделей генерации изображений с очередью выполнения. API ключи Fal.ai + Подключиться к ArliAI + ArliAI предоставляет облачный API генерации изображений, совместимый с SDNext. + Документация API ArliAI + Checkpoint модели ArliAI + Подключиться к Stability AI StabilityAI — это сервис генерации изображений от DreamStudio. О Stability AI diff --git a/core/localization/src/androidMain/res/values-tr/strings.xml b/core/localization/src/androidMain/res/values-tr/strings.xml index b7dffc754..0c533077a 100644 --- a/core/localization/src/androidMain/res/values-tr/strings.xml +++ b/core/localization/src/androidMain/res/values-tr/strings.xml @@ -217,6 +217,11 @@ Fal.ai, kuyruk tabanlı çıkarım kullanan barındırılan görüntü oluşturma modeli API\'leri sağlar. Fal.ai API anahtarları + ArliAI\'ye bağlanın + ArliAI, SDNext uyumlu bulut görüntü oluşturma API\'leri sağlar. + ArliAI API belgeleri + ArliAI model checkpoint + Stability AI\'ya bağlanın StabilityAI - DreamStudio tarafından sağlanan görüntü oluşturma hizmetidir. Hakkında Stability AI diff --git a/core/localization/src/androidMain/res/values-uk/strings.xml b/core/localization/src/androidMain/res/values-uk/strings.xml index ffe28907a..f51078a1e 100644 --- a/core/localization/src/androidMain/res/values-uk/strings.xml +++ b/core/localization/src/androidMain/res/values-uk/strings.xml @@ -217,6 +217,11 @@ Fal.ai надає hosted API моделей генерації зображень із чергою виконання. API ключі Fal.ai + Підключитися до ArliAI + ArliAI надає хмарний API генерації зображень, сумісний із SDNext. + Документація API ArliAI + Checkpoint моделі ArliAI + Підключитися до Stability AI StabilityAI — це сервіс генерації зображень від DreamStudio. Про Stability AI diff --git a/core/localization/src/androidMain/res/values-zh/strings.xml b/core/localization/src/androidMain/res/values-zh/strings.xml index c2b7b2e18..04c1bf04d 100644 --- a/core/localization/src/androidMain/res/values-zh/strings.xml +++ b/core/localization/src/androidMain/res/values-zh/strings.xml @@ -239,6 +239,11 @@ Fal.ai 提供托管的图像生成模型 API,并使用基于队列的推理。 Fal.ai API 密钥 + 连接到 ArliAI + ArliAI 提供兼容 SDNext 的云端图像生成 API。 + ArliAI API 文档 + ArliAI 模型 checkpoint + 连接到 Stability AI StabilityAI 是由 DreamStudio 提供的图像生成服务。 diff --git a/core/localization/src/androidMain/res/values/strings.xml b/core/localization/src/androidMain/res/values/strings.xml index 3027471a5..74124c672 100755 --- a/core/localization/src/androidMain/res/values/strings.xml +++ b/core/localization/src/androidMain/res/values/strings.xml @@ -122,6 +122,7 @@ OpenAI Stability AI Fal.ai + ArliAI Swarm UI Prompt @@ -239,6 +240,11 @@ Fal.ai provides hosted image generation model APIs with queue-based inference. Fal.ai API keys + Connect to ArliAI + ArliAI provides SDNext-compatible cloud image generation APIs. + ArliAI API docs + ArliAI model checkpoint + Connect to Stability AI StabilityAI is the image generation service provided by DreamStudio. About Stability AI diff --git a/data/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/data/mappers/KtorArliAiGenerationMappersTest.kt b/data/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/data/mappers/KtorArliAiGenerationMappersTest.kt new file mode 100644 index 000000000..838b81465 --- /dev/null +++ b/data/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/data/mappers/KtorArliAiGenerationMappersTest.kt @@ -0,0 +1,67 @@ +package com.shifthackz.aisdv1.data.mappers + +import com.shifthackz.aisdv1.data.mocks.mockImageToImagePayload +import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload +import com.shifthackz.aisdv1.domain.entity.ArliAiSampler +import org.junit.Assert +import org.junit.Test + +class KtorArliAiGenerationMappersTest { + + @Test + fun `given text payload, expected arli ai request contract`() { + val payload = mockTextToImagePayload.copy( + samplingSteps = 80, + seed = " 5598 ", + sampler = "", + batchCount = 3, + ) + + val request = payload.mapToArliAiRequest(MODEL) + + Assert.assertEquals(MODEL, request.sdModelCheckpoint) + Assert.assertEquals(payload.prompt, request.prompt) + Assert.assertEquals(payload.negativePrompt, request.negativePrompt) + Assert.assertEquals(40, request.steps) + Assert.assertEquals(ArliAiSampler.default.key, request.samplerName) + Assert.assertEquals(5598L, request.seed) + Assert.assertEquals(3, request.batchSize) + Assert.assertEquals(payload.restoreFaces, request.restoreFaces) + } + + @Test + fun `given image payload, expected arli ai image request contract`() { + val payload = mockImageToImagePayload.copy( + base64Image = "AQID", + base64MaskImage = "BAUG", + samplingSteps = 0, + seed = "-1", + sampler = "Euler a", + batchCount = 2, + maskBlur = 8, + inPaintingFill = 1, + inPaintFullRes = true, + inPaintFullResPadding = 32, + inPaintingMaskInvert = 1, + ) + + val request = payload.mapToArliAiRequest(MODEL) + + Assert.assertEquals(MODEL, request.sdModelCheckpoint) + Assert.assertEquals(listOf("AQID"), request.initImages) + Assert.assertEquals("BAUG", request.mask) + Assert.assertEquals(1, request.steps) + Assert.assertEquals("Euler a", request.samplerName) + Assert.assertEquals(-1L, request.seed) + Assert.assertEquals(2, request.batchSize) + Assert.assertEquals(8, request.maskBlur) + Assert.assertEquals(1, request.inPaintingFill) + Assert.assertTrue(request.inPaintFullRes == true) + Assert.assertEquals(32, request.inPaintFullResPadding) + Assert.assertEquals(1, request.inPaintingMaskInvert) + } + + private companion object { + const val MODEL = "Illustrious-XL-v2.0" + } +} diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/di/DataCoreModule.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/di/DataCoreModule.kt index 1a9c890d2..d668a3bfc 100644 --- a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/di/DataCoreModule.kt +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/di/DataCoreModule.kt @@ -2,6 +2,7 @@ package com.shifthackz.aisdv1.data.di import com.shifthackz.aisdv1.core.common.extensions.fixUrlSlashes import com.shifthackz.aisdv1.data.gateway.ServerConnectivityGatewayImpl +import com.shifthackz.aisdv1.data.local.ArliAiModelsLocalDataSource import com.shifthackz.aisdv1.data.local.DownloadableModelFileStore import com.shifthackz.aisdv1.data.local.DownloadableModelLocalDataSource import com.shifthackz.aisdv1.data.local.EmbeddingsLocalDataSource @@ -19,6 +20,8 @@ import com.shifthackz.aisdv1.data.provider.ServerUrlProvider import com.shifthackz.aisdv1.data.remote.DownloadableModelFileDownloader import com.shifthackz.aisdv1.data.remote.DownloadableModelRemoteDataSource import com.shifthackz.aisdv1.data.remote.HordeStatusSource +import com.shifthackz.aisdv1.data.remote.KtorArliAiGenerationRemoteDataSource +import com.shifthackz.aisdv1.data.remote.KtorArliAiModelsRemoteDataSource import com.shifthackz.aisdv1.data.remote.KtorForgeModulesRemoteDataSource import com.shifthackz.aisdv1.data.remote.KtorFalAiGenerationRemoteDataSource import com.shifthackz.aisdv1.data.remote.KtorHordeGenerationRemoteDataSource @@ -44,6 +47,8 @@ import com.shifthackz.aisdv1.data.remote.KtorSwarmUiModelsRemoteDataSource import com.shifthackz.aisdv1.data.remote.NoOpDownloadableModelFileDownloader import com.shifthackz.aisdv1.data.remote.RandomImageRemoteDataSource import com.shifthackz.aisdv1.data.remote.ReportRemoteDataSource +import com.shifthackz.aisdv1.data.repository.ArliAiGenerationRepositoryImpl +import com.shifthackz.aisdv1.data.repository.ArliAiModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.DownloadableModelRepositoryImpl import com.shifthackz.aisdv1.data.repository.EmbeddingsRepositoryImpl import com.shifthackz.aisdv1.data.repository.FalAiGenerationRepositoryImpl @@ -71,6 +76,8 @@ import com.shifthackz.aisdv1.data.repository.StableDiffusionScriptsRepositoryImp import com.shifthackz.aisdv1.data.repository.SupportersRepositoryImpl import com.shifthackz.aisdv1.data.repository.SwarmUiGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.TemporaryGenerationResultRepositoryImpl +import com.shifthackz.aisdv1.domain.datasource.ArliAiGenerationDataSource +import com.shifthackz.aisdv1.domain.datasource.ArliAiModelsDataSource import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.datasource.EmbeddingsDataSource import com.shifthackz.aisdv1.domain.datasource.FalAiGenerationDataSource @@ -107,6 +114,8 @@ import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway import com.shifthackz.aisdv1.domain.gateway.NoOpMediaStoreGateway import com.shifthackz.aisdv1.domain.gateway.ServerConnectivityGateway import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.ArliAiGenerationRepository +import com.shifthackz.aisdv1.domain.repository.ArliAiModelsRepository import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository import com.shifthackz.aisdv1.domain.repository.EmbeddingsRepository import com.shifthackz.aisdv1.domain.repository.FalAiGenerationRepository @@ -184,6 +193,9 @@ val coreDataModule = module { single { StableDiffusionModelsLocalDataSource(dao = get()) } + single { + ArliAiModelsLocalDataSource(dao = get()) + } single { StableDiffusionSamplersLocalDataSource(dao = get()) } @@ -337,6 +349,28 @@ val coreDataModule = module { remoteDataSource = get(), ) } + single { + KtorArliAiGenerationRemoteDataSource(api = get()) + } + single { + KtorArliAiModelsRemoteDataSource(api = get()) + } + single { + ArliAiModelsRepositoryImpl( + remoteDataSource = get(), + localDataSource = get(), + preferenceManager = get(), + ) + } + single { + ArliAiGenerationRepositoryImpl( + mediaStoreGateway = get(), + localDataSource = get(), + backgroundWorkObserver = get(), + preferenceManager = get(), + remoteDataSource = get(), + ) + } single { KtorStabilityAiGenerationRemoteDataSource(api = get()) } diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/local/ArliAiModelsLocalDataSource.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/local/ArliAiModelsLocalDataSource.kt new file mode 100644 index 000000000..b10c80c8b --- /dev/null +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/local/ArliAiModelsLocalDataSource.kt @@ -0,0 +1,43 @@ +package com.shifthackz.aisdv1.data.local + +import com.shifthackz.aisdv1.data.mappers.mapArliAiEntityToDomain +import com.shifthackz.aisdv1.data.mappers.mapDomainToArliAiEntity +import com.shifthackz.aisdv1.domain.datasource.ArliAiModelsDataSource +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel +import com.shifthackz.aisdv1.storage.db.cache.dao.ArliAiModelDao +import com.shifthackz.aisdv1.storage.db.cache.entity.ArliAiModelEntity + +/** + * Reads and writes cached ArliAI checkpoint metadata. + * + * @param dao Room DAO for the ArliAI model cache table. + * + * @author Dmitriy Moroz + */ +internal class ArliAiModelsLocalDataSource( + private val dao: ArliAiModelDao, +) : ArliAiModelsDataSource.Local { + + /** + * Reads all cached ArliAI checkpoints. + * + * @return locally stored ArliAI checkpoints mapped into domain models. + * + * @author Dmitriy Moroz + */ + override suspend fun getModels(): List = dao + .queryAll() + .let(List::mapArliAiEntityToDomain) + + /** + * Replaces the cached ArliAI checkpoint list. + * + * @param models latest checkpoint metadata returned by the provider. + * + * @author Dmitriy Moroz + */ + override suspend fun insertModels(models: List) { + dao.deleteAll() + dao.insertList(models.mapDomainToArliAiEntity()) + } +} diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/mappers/ArliAiModelsMappers.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/mappers/ArliAiModelsMappers.kt new file mode 100644 index 000000000..d6556b510 --- /dev/null +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/mappers/ArliAiModelsMappers.kt @@ -0,0 +1,66 @@ +package com.shifthackz.aisdv1.data.mappers + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel +import com.shifthackz.aisdv1.storage.db.cache.entity.ArliAiModelEntity + +/** + * Converts domain checkpoint metadata into distinct ArliAI cache rows. + * + * @return Room entities keyed by the best available ArliAI checkpoint name. + * + * @author Dmitriy Moroz + */ +fun List.mapDomainToArliAiEntity(): List = + distinctBy(StableDiffusionModel::arliAiCheckpointName) + .map(StableDiffusionModel::mapDomainToArliAiEntity) + +/** + * Converts one domain checkpoint into an ArliAI cache row. + * + * @author Dmitriy Moroz + */ +fun StableDiffusionModel.mapDomainToArliAiEntity(): ArliAiModelEntity = with(this) { + ArliAiModelEntity( + id = arliAiCheckpointName, + title = title, + name = modelName, + hash = hash, + sha256 = sha256, + filename = filename, + config = config, + ) +} + +/** + * Converts cached ArliAI model rows into domain checkpoint metadata. + * + * @return domain models shown by setup and generation screens. + * + * @author Dmitriy Moroz + */ +fun List.mapArliAiEntityToDomain(): List = + map(ArliAiModelEntity::mapArliAiEntityToDomain) + +/** + * Converts one cached ArliAI model row into domain checkpoint metadata. + * + * @author Dmitriy Moroz + */ +fun ArliAiModelEntity.mapArliAiEntityToDomain(): StableDiffusionModel = with(this) { + StableDiffusionModel( + title = title, + modelName = name, + hash = hash, + sha256 = sha256, + filename = filename, + config = config, + ) +} + +/** + * Returns the stable key used for ArliAI model cache de-duplication. + * + * @author Dmitriy Moroz + */ +private val StableDiffusionModel.arliAiCheckpointName: String + get() = title.ifBlank { modelName }.ifBlank { filename } diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/mappers/KtorArliAiGenerationMappers.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/mappers/KtorArliAiGenerationMappers.kt new file mode 100644 index 000000000..f3263c24a --- /dev/null +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/mappers/KtorArliAiGenerationMappers.kt @@ -0,0 +1,115 @@ +package com.shifthackz.aisdv1.data.mappers + +import com.shifthackz.aisdv1.domain.entity.ADetailerConfig +import com.shifthackz.aisdv1.domain.entity.ArliAiSampler +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.network.request.ArliAiImageToImageRequest +import com.shifthackz.aisdv1.network.request.ArliAiTextToImageRequest + +/** + * Converts text-to-image settings into an ArliAI SDNext payload. + * + * Sampling steps are clamped to the ArliAI-supported range and ADetailer fields are sent only + * when the feature is enabled. + * + * @param model checkpoint name sent as `sd_model_checkpoint`. + * @return provider request payload ready for JSON serialization. + * + * @author Dmitriy Moroz + */ +fun TextToImagePayload.mapToArliAiRequest(model: String): ArliAiTextToImageRequest = with(this) { + ArliAiTextToImageRequest( + sdModelCheckpoint = model, + prompt = prompt, + negativePrompt = negativePrompt, + steps = samplingSteps.coerceIn(MIN_STEPS, MAX_STEPS), + samplerName = sampler.ifBlank { ArliAiSampler.default.key }, + width = width, + height = height, + seed = seed.mapToArliAiSeed(), + cfgScale = cfgScale, + batchSize = batchCount.coerceAtLeast(1), + restoreFaces = restoreFaces, + detailerEnabled = aDetailer.enabled.takeIf { it }, + detailerPrompt = aDetailer.promptIfEnabled(), + detailerNegative = aDetailer.negativePromptIfEnabled(), + detailerSteps = samplingSteps.coerceIn(MIN_STEPS, MAX_STEPS).takeIf { aDetailer.enabled }, + detailerStrength = aDetailer.denoisingStrengthIfEnabled(), + detailerModel = aDetailer.modelIfEnabled(), + detailerConfidence = aDetailer.confidenceIfEnabled(), + detailerPadding = aDetailer.paddingIfEnabled(), + detailerBlur = aDetailer.blurIfEnabled(), + ) +} + +/** + * Converts image-to-image settings into an ArliAI SDNext payload. + * + * Sampling steps are clamped to the ArliAI-supported range, the optional mask is omitted when + * blank, and ADetailer fields are sent only when the feature is enabled. + * + * @param model checkpoint name sent as `sd_model_checkpoint`. + * @return provider request payload ready for JSON serialization. + * + * @author Dmitriy Moroz + */ +fun ImageToImagePayload.mapToArliAiRequest(model: String): ArliAiImageToImageRequest = with(this) { + ArliAiImageToImageRequest( + sdModelCheckpoint = model, + prompt = prompt, + negativePrompt = negativePrompt, + initImages = listOf(base64Image), + mask = base64MaskImage.takeIf(String::isNotBlank), + denoisingStrength = denoisingStrength, + steps = samplingSteps.coerceIn(MIN_STEPS, MAX_STEPS), + samplerName = sampler.ifBlank { ArliAiSampler.default.key }, + width = width, + height = height, + seed = seed.mapToArliAiSeed(), + cfgScale = cfgScale, + batchSize = batchCount.coerceAtLeast(1), + restoreFaces = restoreFaces, + maskBlur = maskBlur, + inPaintingFill = inPaintingFill, + inPaintFullRes = inPaintFullRes, + inPaintFullResPadding = inPaintFullResPadding, + inPaintingMaskInvert = inPaintingMaskInvert, + detailerEnabled = aDetailer.enabled.takeIf { it }, + detailerPrompt = aDetailer.promptIfEnabled(), + detailerNegative = aDetailer.negativePromptIfEnabled(), + detailerSteps = samplingSteps.coerceIn(MIN_STEPS, MAX_STEPS).takeIf { aDetailer.enabled }, + detailerStrength = aDetailer.denoisingStrengthIfEnabled(), + detailerModel = aDetailer.modelIfEnabled(), + detailerConfidence = aDetailer.confidenceIfEnabled(), + detailerPadding = aDetailer.paddingIfEnabled(), + detailerBlur = aDetailer.blurIfEnabled(), + ) +} + +private fun String.mapToArliAiSeed(): Long? = + trim().takeIf(String::isNotEmpty)?.toLongOrNull() + +private fun ADetailerConfig.promptIfEnabled() = + prompt.takeIf { enabled } + +private fun ADetailerConfig.negativePromptIfEnabled() = + negativePrompt.takeIf { enabled } + +private fun ADetailerConfig.denoisingStrengthIfEnabled() = + denoisingStrength.takeIf { enabled } + +private fun ADetailerConfig.modelIfEnabled() = + model.takeIf { enabled } + +private fun ADetailerConfig.confidenceIfEnabled() = + confidence.takeIf { enabled } + +private fun ADetailerConfig.paddingIfEnabled() = + inpaintPadding.takeIf { enabled } + +private fun ADetailerConfig.blurIfEnabled() = + maskBlur.takeIf { enabled } + +private const val MIN_STEPS = 1 +private const val MAX_STEPS = 40 diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/preference/KeyValueConfigurationStore.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/preference/KeyValueConfigurationStore.kt index f26d2f1d4..943444ddf 100644 --- a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/preference/KeyValueConfigurationStore.kt +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/preference/KeyValueConfigurationStore.kt @@ -117,6 +117,15 @@ internal class KeyValueConfigurationStore( get() = keyValueStore.getString(KEY_FAL_AI_API_KEY) set(value) = keyValueStore.putString(KEY_FAL_AI_API_KEY, value) + /** + * Exposes the `arliAiApiKey` value used by the SDAI data layer. + * + * @author Dmitriy Moroz + */ + override var arliAiApiKey: String + get() = keyValueStore.getString(KEY_ARLI_AI_API_KEY) + set(value) = keyValueStore.putString(KEY_ARLI_AI_API_KEY, value) + /** * Exposes the `stabilityAiEngineId` value used by the SDAI data layer. * @@ -308,6 +317,12 @@ internal class KeyValueConfigurationStore( * @author Dmitriy Moroz */ const val KEY_FAL_AI_API_KEY = "key_fal_ai_api_key" + /** + * Exposes the `KEY_ARLI_AI_API_KEY` value used by the SDAI data layer. + * + * @author Dmitriy Moroz + */ + const val KEY_ARLI_AI_API_KEY = "key_arli_ai_api_key" /** * Exposes the `KEY_STABILITY_AI_ENGINE_ID_KEY` value used by the SDAI data layer. * diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index d0591f10b..6a999f14e 100644 --- a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -258,6 +258,24 @@ internal class PreferenceManagerImpl( get() = keyValueStore.getString(KEY_FAL_AI_API_KEY) set(value) = putString(KEY_FAL_AI_API_KEY, value) + /** + * Exposes the `arliAiApiKey` value used by the SDAI data layer. + * + * @author Dmitriy Moroz + */ + override var arliAiApiKey: String + get() = keyValueStore.getString(KEY_ARLI_AI_API_KEY) + set(value) = putString(KEY_ARLI_AI_API_KEY, value) + + /** + * Exposes the `arliAiModel` value used by the SDAI data layer. + * + * @author Dmitriy Moroz + */ + override var arliAiModel: String + get() = keyValueStore.getString(KEY_ARLI_AI_MODEL_KEY) + set(value) = putString(KEY_ARLI_AI_MODEL_KEY, value) + /** * Exposes the `stabilityAiEngineId` value used by the SDAI data layer. * @@ -455,6 +473,7 @@ internal class PreferenceManagerImpl( formPromptTaggedInput = formPromptTaggedInput, source = source, hordeApiKey = hordeApiKey, + arliAiModel = arliAiModel, localUseNNAPI = localOnnxUseNNAPI, designUseSystemColorPalette = designUseSystemColorPalette, designUseSystemDarkTheme = designUseSystemDarkTheme, @@ -681,6 +700,18 @@ internal class PreferenceManagerImpl( * @author Dmitriy Moroz */ const val KEY_FAL_AI_API_KEY = "key_fal_ai_api_key" + /** + * Exposes the `KEY_ARLI_AI_API_KEY` value used by the SDAI data layer. + * + * @author Dmitriy Moroz + */ + const val KEY_ARLI_AI_API_KEY = "key_arli_ai_api_key" + /** + * Exposes the `KEY_ARLI_AI_MODEL_KEY` value used by the SDAI data layer. + * + * @author Dmitriy Moroz + */ + const val KEY_ARLI_AI_MODEL_KEY = "key_arli_ai_model_key" /** * Exposes the `KEY_STABILITY_AI_ENGINE_ID_KEY` value used by the SDAI data layer. * diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/remote/KtorArliAiGenerationRemoteDataSource.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/remote/KtorArliAiGenerationRemoteDataSource.kt new file mode 100644 index 000000000..699ff6bb4 --- /dev/null +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/remote/KtorArliAiGenerationRemoteDataSource.kt @@ -0,0 +1,76 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mappers.mapStableDiffusionImageToImageResult +import com.shifthackz.aisdv1.data.mappers.mapStableDiffusionTextToImageResult +import com.shifthackz.aisdv1.data.mappers.mapToArliAiRequest +import com.shifthackz.aisdv1.domain.datasource.ArliAiGenerationDataSource +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.network.api.arliai.ArliAiGenerationApi +import kotlin.time.Clock +import kotlin.time.ExperimentalTime + +/** + * Maps ArliAI network responses into domain generation results. + * + * The provider response is Automatic1111-compatible, so this data source reuses the Stable + * Diffusion response mappers after converting domain payloads into ArliAI requests. + * + * @param api ArliAI network API used for validation and generation calls. + * + * @author Dmitriy Moroz + */ +@OptIn(ExperimentalTime::class) +class KtorArliAiGenerationRemoteDataSource( + private val api: ArliAiGenerationApi, +) : ArliAiGenerationDataSource.Remote { + + /** + * Treats any successful model-list request as a valid API key check. + * + * @param apiKey ArliAI API key entered by the user. + * @return `true` when the provider request succeeds. + * + * @author Dmitriy Moroz + */ + override suspend fun validateApiKey(apiKey: String): Boolean = try { + api.validateApiKey(apiKey) + true + } catch (_: Throwable) { + false + } + + /** + * Sends text-to-image generation and maps returned images with the current timestamp. + * + * @param apiKey ArliAI API key entered by the user. + * @param model checkpoint name sent to ArliAI. + * @param payload domain generation settings. + * @return mapped generation records returned by the provider. + * + * @author Dmitriy Moroz + */ + override suspend fun textToImage( + apiKey: String, + model: String, + payload: TextToImagePayload, + ) = (payload to api.textToImage(apiKey, payload.mapToArliAiRequest(model))) + .mapStableDiffusionTextToImageResult(Clock.System.now().toEpochMilliseconds()) + + /** + * Sends image-to-image generation and maps returned images with the current timestamp. + * + * @param apiKey ArliAI API key entered by the user. + * @param model checkpoint name sent to ArliAI. + * @param payload domain generation settings and source image data. + * @return mapped generation records returned by the provider. + * + * @author Dmitriy Moroz + */ + override suspend fun imageToImage( + apiKey: String, + model: String, + payload: ImageToImagePayload, + ) = (payload to api.imageToImage(apiKey, payload.mapToArliAiRequest(model))) + .mapStableDiffusionImageToImageResult(Clock.System.now().toEpochMilliseconds()) +} diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/remote/KtorArliAiModelsRemoteDataSource.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/remote/KtorArliAiModelsRemoteDataSource.kt new file mode 100644 index 000000000..4eb3462c5 --- /dev/null +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/remote/KtorArliAiModelsRemoteDataSource.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.data.remote + +import com.shifthackz.aisdv1.data.mappers.mapKtorRawToCheckpointDomain +import com.shifthackz.aisdv1.domain.datasource.ArliAiModelsDataSource +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel +import com.shifthackz.aisdv1.network.api.arliai.ArliAiGenerationApi + +/** + * Loads ArliAI checkpoint metadata from the network API. + * + * @param api ArliAI network API used for model discovery. + * + * @author Dmitriy Moroz + */ +class KtorArliAiModelsRemoteDataSource( + private val api: ArliAiGenerationApi, +) : ArliAiModelsDataSource.Remote { + + /** + * Fetches ArliAI checkpoints and maps them into shared Stable Diffusion model metadata. + * + * @param apiKey ArliAI API key entered by the user. + * @return checkpoint metadata available to the supplied key. + * + * @author Dmitriy Moroz + */ + override suspend fun fetchModels(apiKey: String): List = + api.fetchModels(apiKey).mapKtorRawToCheckpointDomain() +} diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/repository/ArliAiGenerationRepositoryImpl.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/repository/ArliAiGenerationRepositoryImpl.kt new file mode 100644 index 000000000..f62b46ecb --- /dev/null +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/repository/ArliAiGenerationRepositoryImpl.kt @@ -0,0 +1,79 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.data.core.CoreGenerationRepository +import com.shifthackz.aisdv1.data.mappers.withModelName +import com.shifthackz.aisdv1.domain.datasource.ArliAiGenerationDataSource +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.ArliAiGenerationRepository + +/** + * Generates images through ArliAI and persists returned gallery records. + * + * Each generation uses a payload model override when present, otherwise it falls back to the + * selected ArliAI model saved in preferences. + * + * @author Dmitriy Moroz + */ +internal class ArliAiGenerationRepositoryImpl( + mediaStoreGateway: MediaStoreGateway, + localDataSource: GenerationResultDataSource.Local, + backgroundWorkObserver: BackgroundWorkObserver, + private val preferenceManager: PreferenceManager, + private val remoteDataSource: ArliAiGenerationDataSource.Remote, +) : CoreGenerationRepository( + mediaStoreGateway = mediaStoreGateway, + localDataSource = localDataSource, + preferenceManager = preferenceManager, + backgroundWorkObserver = backgroundWorkObserver, +), ArliAiGenerationRepository { + + /** + * Validates the currently saved ArliAI API key. + * + * @return `true` when ArliAI accepts the key. + * + * @author Dmitriy Moroz + */ + override suspend fun validateApiKey(): Boolean = + remoteDataSource.validateApiKey(preferenceManager.arliAiApiKey) + + /** + * Generates text-to-image results and saves each returned image to local history. + * + * @param payload generation settings and optional ArliAI model override. + * @return persisted generation records. + * + * @author Dmitriy Moroz + */ + override suspend fun generateFromText(payload: TextToImagePayload): List { + val model = payload.arliAiModel.ifBlank { preferenceManager.arliAiModel } + return remoteDataSource + .textToImage(preferenceManager.arliAiApiKey, model, payload) + .map { result -> + insertGenerationResult(result.withModelName(model)) + } + } + + /** + * Generates image-to-image results and saves each returned image to local history. + * + * @param payload generation settings, source image data, and optional ArliAI model override. + * @return persisted generation records. + * + * @author Dmitriy Moroz + */ + override suspend fun generateFromImage(payload: ImageToImagePayload): List { + val model = payload.arliAiModel.ifBlank { preferenceManager.arliAiModel } + return remoteDataSource + .imageToImage(preferenceManager.arliAiApiKey, model, payload) + .map { result -> + insertGenerationResult(result.withModelName(model)) + } + } +} diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/repository/ArliAiModelsRepositoryImpl.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/repository/ArliAiModelsRepositoryImpl.kt new file mode 100644 index 000000000..45e0aa4ee --- /dev/null +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/repository/ArliAiModelsRepositoryImpl.kt @@ -0,0 +1,56 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.domain.datasource.ArliAiModelsDataSource +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.ArliAiModelsRepository + +/** + * Synchronizes ArliAI checkpoint metadata between the provider and local cache. + * + * @param remoteDataSource ArliAI model-list source backed by the network API. + * @param localDataSource ArliAI model-list source backed by the cache database. + * @param preferenceManager supplies the saved ArliAI API key. + * + * @author Dmitriy Moroz + */ +internal class ArliAiModelsRepositoryImpl( + private val remoteDataSource: ArliAiModelsDataSource.Remote, + private val localDataSource: ArliAiModelsDataSource.Local, + private val preferenceManager: PreferenceManager, +) : ArliAiModelsRepository { + + /** + * Refreshes provider checkpoints and replaces the local cache. + * + * @author Dmitriy Moroz + */ + override suspend fun fetchModels() { + val models = remoteDataSource.fetchModels(preferenceManager.arliAiApiKey) + localDataSource.insertModels(models) + } + + /** + * Attempts a refresh and always returns the cached checkpoint list. + * + * Network errors are ignored here so UI can continue using stale cached models. + * + * @return cached ArliAI checkpoints after the refresh attempt. + * + * @author Dmitriy Moroz + */ + override suspend fun fetchAndGetModels(): List { + runCatching { fetchModels() } + return getModels() + } + + /** + * Reads cached ArliAI checkpoint metadata. + * + * @return locally stored ArliAI checkpoints. + * + * @author Dmitriy Moroz + */ + override suspend fun getModels(): List = + localDataSource.getModels() +} diff --git a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/repository/ReportRepositoryImpl.kt b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/repository/ReportRepositoryImpl.kt index 78fe17fd2..f209780d4 100644 --- a/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/repository/ReportRepositoryImpl.kt +++ b/data/src/commonMain/kotlin/com/shifthackz/aisdv1/data/repository/ReportRepositoryImpl.kt @@ -39,6 +39,7 @@ internal class ReportRepositoryImpl( val model = when (source) { ServerSource.HUGGING_FACE -> preferenceManager.huggingFaceModel ServerSource.STABILITY_AI -> preferenceManager.stabilityAiEngineId + ServerSource.ARLI_AI -> preferenceManager.arliAiModel ServerSource.LOCAL_MICROSOFT_ONNX -> preferenceManager.localOnnxModelId ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> preferenceManager.localMediaPipeModelId ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> preferenceManager.localSdxlModelId diff --git a/domain/src/androidMain/kotlin/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt b/domain/src/androidMain/kotlin/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt index bcda17650..80082158f 100644 --- a/domain/src/androidMain/kotlin/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt +++ b/domain/src/androidMain/kotlin/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.domain.interactor.settings import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToArliAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToFalAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase @@ -64,6 +65,12 @@ interface SetupConnectionInterActor { * @author Dmitriy Moroz */ val connectToFalAi: ConnectToFalAiUseCase + /** + * Exposes the `connectToArliAi` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + val connectToArliAi: ConnectToArliAiUseCase /** * Exposes the `connectToSwarmUi` value used by the SDAI domain layer. * diff --git a/domain/src/androidMain/kotlin/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt b/domain/src/androidMain/kotlin/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt index 609ac8ce5..2305b313e 100644 --- a/domain/src/androidMain/kotlin/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt +++ b/domain/src/androidMain/kotlin/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.domain.interactor.settings import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToArliAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToFalAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase @@ -64,6 +65,12 @@ internal data class SetupConnectionInterActorImpl( * @author Dmitriy Moroz */ override val connectToFalAi: ConnectToFalAiUseCase, + /** + * Exposes the `connectToArliAi` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + override val connectToArliAi: ConnectToArliAiUseCase, /** * Exposes the `connectToSwarmUi` value used by the SDAI domain layer. * diff --git a/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt b/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt index fc5a05fef..fae624379 100644 --- a/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt +++ b/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt @@ -13,6 +13,7 @@ val mockConfiguration = Configuration( stabilityAiApiKey = "5598", stabilityAiEngineId = "5598", falAiApiKey = "5598", + arliAiApiKey = "5598", localOnnxModelId = "5598", localOnnxModelPath = "/storage/emulated/0/5598", ) diff --git a/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/mocks/ConfigurationStoreStub.kt b/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/mocks/ConfigurationStoreStub.kt index 504c2cb63..b518f20ef 100644 --- a/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/mocks/ConfigurationStoreStub.kt +++ b/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/mocks/ConfigurationStoreStub.kt @@ -19,6 +19,7 @@ class ConfigurationStoreStub( override var stabilityAiApiKey: String = configuration.stabilityAiApiKey override var stabilityAiEngineId: String = configuration.stabilityAiEngineId override var falAiApiKey: String = configuration.falAiApiKey + override var arliAiApiKey: String = configuration.arliAiApiKey override var localOnnxModelId: String = configuration.localOnnxModelId override var localOnnxModelPath: String = configuration.localOnnxModelPath override var localMediaPipeModelId: String = configuration.localMediaPipeModelId diff --git a/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt b/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt index efab58e9e..363bcd643 100644 --- a/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt +++ b/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.mocks.mockAiGenerationResult import com.shifthackz.aisdv1.domain.mocks.mockImageToImagePayload import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.ArliAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.CoreMlGenerationRepository import com.shifthackz.aisdv1.domain.repository.FalAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository @@ -31,6 +32,7 @@ class ImageToImageUseCaseImplTest { private val stubStabilityAiGenerationRepository = mockk() private val stubCoreMlGenerationRepository = mockk() private val stubFalAiGenerationRepository = mockk() + private val stubArliAiGenerationRepository = mockk() private val stubPreferenceManager = mockk() private val useCase = ImageToImageUseCaseImpl( @@ -41,6 +43,7 @@ class ImageToImageUseCaseImplTest { stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, coreMlGenerationRepository = stubCoreMlGenerationRepository, falAiGenerationRepository = stubFalAiGenerationRepository, + arliAiGenerationRepository = stubArliAiGenerationRepository, preferenceManager = stubPreferenceManager, ) @@ -121,6 +124,17 @@ class ImageToImageUseCaseImplTest { coVerify(exactly = 1) { stubFalAiGenerationRepository.generateFromImage(any()) } } + @Test + fun `given source is ARLI_AI, expected arli ai generation`() = runTest { + every { stubPreferenceManager.source } returns ServerSource.ARLI_AI + coEvery { stubArliAiGenerationRepository.generateFromImage(any()) } returns listOf(mockAiGenerationResult) + + val actual = useCase(mockImageToImagePayload.copy(batchCount = 1)) + + assertEquals(listOf(mockAiGenerationResult), actual) + coVerify(exactly = 1) { stubArliAiGenerationRepository.generateFromImage(any()) } + } + @Test fun `given automatic1111 batch count is 10, expected batch generated by repository`() = runTest { every { stubPreferenceManager.source } returns ServerSource.AUTOMATIC1111 @@ -154,6 +168,17 @@ class ImageToImageUseCaseImplTest { coVerify(exactly = 1) { stubFalAiGenerationRepository.generateFromImage(any()) } } + @Test + fun `given arli ai batch count is 4, expected batch generated by repository`() = runTest { + every { stubPreferenceManager.source } returns ServerSource.ARLI_AI + coEvery { stubArliAiGenerationRepository.generateFromImage(any()) } returns List(4) { mockAiGenerationResult } + + val actual = useCase(mockImageToImagePayload.copy(batchCount = 4)) + + assertEquals(List(4) { mockAiGenerationResult }, actual) + coVerify(exactly = 1) { stubArliAiGenerationRepository.generateFromImage(any()) } + } + @Test fun `given generation fails, expected error propagated`() = runTest { every { stubPreferenceManager.source } returns ServerSource.AUTOMATIC1111 diff --git a/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt b/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt index 2f3caa967..32500442d 100644 --- a/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt +++ b/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.mocks.mockAiGenerationResult import com.shifthackz.aisdv1.domain.mocks.mockTextToImagePayload import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.ArliAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.CoreMlGenerationRepository import com.shifthackz.aisdv1.domain.repository.FalAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository @@ -33,6 +34,7 @@ class TextToImageUseCaseImplTest { private val stubOpenAiGenerationRepository = mockk() private val stubStabilityAiGenerationRepository = mockk() private val stubFalAiGenerationRepository = mockk() + private val stubArliAiGenerationRepository = mockk() private val stubSwarmUiGenerationRepository = mockk() private val stubLocalDiffusionGenerationRepository = mockk() private val stubMediaPipeGenerationRepository = mockk() @@ -47,6 +49,7 @@ class TextToImageUseCaseImplTest { openAiGenerationRepository = stubOpenAiGenerationRepository, stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, falAiGenerationRepository = stubFalAiGenerationRepository, + arliAiGenerationRepository = stubArliAiGenerationRepository, localDiffusionGenerationRepository = stubLocalDiffusionGenerationRepository, swarmUiGenerationRepository = stubSwarmUiGenerationRepository, mediaPipeGenerationRepository = stubMediaPipeGenerationRepository, @@ -121,6 +124,17 @@ class TextToImageUseCaseImplTest { coVerify(exactly = 1) { stubFalAiGenerationRepository.generateFromText(any()) } } + @Test + fun `given source is ARLI_AI, expected arli ai generation`() = runTest { + every { stubPreferenceManager.source } returns ServerSource.ARLI_AI + coEvery { stubArliAiGenerationRepository.generateFromText(any()) } returns listOf(mockAiGenerationResult) + + val actual = useCase(mockTextToImagePayload.copy(batchCount = 1)) + + assertEquals(listOf(mockAiGenerationResult), actual) + coVerify(exactly = 1) { stubArliAiGenerationRepository.generateFromText(any()) } + } + @Test fun `given source is SWARM_UI, expected swarm ui generation`() = runTest { every { stubPreferenceManager.source } returns ServerSource.SWARM_UI @@ -209,6 +223,17 @@ class TextToImageUseCaseImplTest { coVerify(exactly = 1) { stubFalAiGenerationRepository.generateFromText(any()) } } + @Test + fun `given arli ai batch count is 4, expected batch generated by repository`() = runTest { + every { stubPreferenceManager.source } returns ServerSource.ARLI_AI + coEvery { stubArliAiGenerationRepository.generateFromText(any()) } returns List(4) { mockAiGenerationResult } + + val actual = useCase(mockTextToImagePayload.copy(batchCount = 4)) + + assertEquals(List(4) { mockAiGenerationResult }, actual) + coVerify(exactly = 1) { stubArliAiGenerationRepository.generateFromText(any()) } + } + @Test fun `given generation fails, expected error propagated`() = runTest { every { stubPreferenceManager.source } returns ServerSource.AUTOMATIC1111 diff --git a/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToArliAiUseCaseImplTest.kt b/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToArliAiUseCaseImplTest.kt new file mode 100644 index 000000000..c9ea23427 --- /dev/null +++ b/domain/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToArliAiUseCaseImplTest.kt @@ -0,0 +1,95 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials +import com.shifthackz.aisdv1.domain.mocks.mockConfiguration +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestArliAiApiKeyUseCase +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.mockk +import kotlinx.coroutines.test.runTest +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Test + +class DefaultConnectToArliAiUseCaseImplTest { + + private val stubThrowable = Throwable("Something went wrong.") + private val stubGetConfigurationUseCase = mockk() + private val stubSetServerConfigurationUseCase = mockk() + private val stubTestArliAiApiKeyUseCase = mockk() + + private val useCase = DefaultConnectToArliAiUseCaseImpl( + getConfigurationUseCase = stubGetConfigurationUseCase, + setServerConfigurationUseCase = stubSetServerConfigurationUseCase, + testArliAiApiKeyUseCase = stubTestArliAiApiKeyUseCase, + ) + + @Test + fun `given connection process successful, API key is valid, expected success result value`() = runTest { + coEvery { + stubGetConfigurationUseCase() + } returns mockConfiguration + + coEvery { + stubSetServerConfigurationUseCase(any()) + } returns Unit + + coEvery { + stubTestArliAiApiKeyUseCase() + } returns true + + assertEquals(Result.success(Unit), useCase("arli-key")) + coVerify { + stubSetServerConfigurationUseCase( + match { + it.source == ServerSource.ARLI_AI && + it.arliAiApiKey == "arli-key" && + it.authCredentials == AuthorizationCredentials.None + }, + ) + } + } + + @Test + fun `given connection process successful, API key is NOT valid, expected failure result value`() = runTest { + coEvery { + stubGetConfigurationUseCase() + } returns mockConfiguration + + coEvery { + stubSetServerConfigurationUseCase(any()) + } returns Unit + + coEvery { + stubTestArliAiApiKeyUseCase() + } returns false + + val actual = useCase("arli-key") + + assertTrue(actual.isFailure) + assertTrue(actual.exceptionOrNull() is IllegalStateException) + assertEquals("Bad key", actual.exceptionOrNull()?.message) + } + + @Test + fun `given connection process failed, expected error result value`() = runTest { + coEvery { + stubGetConfigurationUseCase() + } throws stubThrowable + + coEvery { + stubSetServerConfigurationUseCase(any()) + } throws stubThrowable + + coEvery { + stubTestArliAiApiKeyUseCase() + } throws stubThrowable + + val actual = useCase("arli-key") + + assertTrue(actual.isFailure) + assertEquals(stubThrowable.message, actual.exceptionOrNull()?.message) + } + +} diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/datasource/ArliAiGenerationDataSource.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/datasource/ArliAiGenerationDataSource.kt new file mode 100644 index 000000000..f74890568 --- /dev/null +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/datasource/ArliAiGenerationDataSource.kt @@ -0,0 +1,61 @@ +package com.shifthackz.aisdv1.domain.datasource + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload + +/** + * Groups ArliAI generation data sources. + * + * @author Dmitriy Moroz + */ +sealed interface ArliAiGenerationDataSource { + /** + * Sends ArliAI validation and generation requests to the network layer. + * + * @author Dmitriy Moroz + */ + interface Remote : ArliAiGenerationDataSource { + /** + * Checks whether ArliAI accepts the supplied API key. + * + * @param apiKey ArliAI API key entered by the user. + * @return `true` when the provider accepts the key. + * + * @author Dmitriy Moroz + */ + suspend fun validateApiKey(apiKey: String): Boolean + + /** + * Generates text-to-image results with a resolved ArliAI checkpoint. + * + * @param apiKey ArliAI API key entered by the user. + * @param model checkpoint name sent to ArliAI. + * @param payload domain generation settings. + * @return mapped generation records returned by the provider. + * + * @author Dmitriy Moroz + */ + suspend fun textToImage( + apiKey: String, + model: String, + payload: TextToImagePayload, + ): List + + /** + * Generates image-to-image results with a resolved ArliAI checkpoint. + * + * @param apiKey ArliAI API key entered by the user. + * @param model checkpoint name sent to ArliAI. + * @param payload domain generation settings and source image data. + * @return mapped generation records returned by the provider. + * + * @author Dmitriy Moroz + */ + suspend fun imageToImage( + apiKey: String, + model: String, + payload: ImageToImagePayload, + ): List + } +} diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/datasource/ArliAiModelsDataSource.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/datasource/ArliAiModelsDataSource.kt new file mode 100644 index 000000000..01494df18 --- /dev/null +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/datasource/ArliAiModelsDataSource.kt @@ -0,0 +1,53 @@ +package com.shifthackz.aisdv1.domain.datasource + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel + +/** + * Groups ArliAI model-list data sources. + * + * @author Dmitriy Moroz + */ +sealed interface ArliAiModelsDataSource { + + /** + * Loads ArliAI checkpoint metadata from the provider. + * + * @author Dmitriy Moroz + */ + interface Remote : ArliAiModelsDataSource { + /** + * Fetches the checkpoint list available to the supplied API key. + * + * @param apiKey ArliAI API key entered by the user. + * @return checkpoint metadata mapped into the domain model. + * + * @author Dmitriy Moroz + */ + suspend fun fetchModels(apiKey: String): List + } + + /** + * Stores ArliAI checkpoint metadata in the local cache. + * + * @author Dmitriy Moroz + */ + interface Local : ArliAiModelsDataSource { + /** + * Reads cached ArliAI checkpoint metadata. + * + * @return locally stored checkpoint metadata. + * + * @author Dmitriy Moroz + */ + suspend fun getModels(): List + + /** + * Replaces cached ArliAI checkpoint metadata. + * + * @param models checkpoint metadata returned by the provider. + * + * @author Dmitriy Moroz + */ + suspend fun insertModels(models: List) + } +} diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/di/DomainCoreModule.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/di/DomainCoreModule.kt index 0d07f5f54..c4a9bb75b 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/di/DomainCoreModule.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/di/DomainCoreModule.kt @@ -16,6 +16,8 @@ import com.shifthackz.aisdv1.domain.repository.NoOpLocalDiffusionGenerationRepos import com.shifthackz.aisdv1.domain.repository.NoOpMediaPipeGenerationRepository import com.shifthackz.aisdv1.domain.repository.NoOpNetworkUsageRepository import com.shifthackz.aisdv1.domain.repository.NoOpStableDiffusionCppGenerationRepository +import com.shifthackz.aisdv1.domain.usecase.arliai.FetchAndGetArliAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.arliai.FetchAndGetArliAiModelsUseCaseImpl import com.shifthackz.aisdv1.domain.repository.NoOpWakeLockRepository import com.shifthackz.aisdv1.domain.repository.NetworkUsageRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionCppGenerationRepository @@ -30,6 +32,7 @@ import com.shifthackz.aisdv1.domain.usecase.caching.GetLastResultFromCacheUseCas import com.shifthackz.aisdv1.domain.usecase.caching.NoOpAppCacheCleaner import com.shifthackz.aisdv1.domain.usecase.caching.SaveLastResultToCacheUseCase import com.shifthackz.aisdv1.domain.usecase.caching.SaveLastResultToCacheUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.connectivity.DefaultTestArliAiApiKeyUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.connectivity.DefaultTestConnectivityUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.connectivity.DefaultTestFalAiApiKeyUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.connectivity.DefaultTestHordeApiKeyUseCaseImpl @@ -45,6 +48,7 @@ import com.shifthackz.aisdv1.domain.usecase.connectivity.ObserveSeverConnectivit import com.shifthackz.aisdv1.domain.usecase.connectivity.ObserveSeverConnectivityUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.connectivity.PingStableDiffusionServiceUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.PingStableDiffusionServiceUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestArliAiApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestConnectivityUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestFalAiApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHordeApiKeyUseCase @@ -139,6 +143,7 @@ import com.shifthackz.aisdv1.domain.usecase.sdscript.IsADetailerAvailableUseCase import com.shifthackz.aisdv1.domain.usecase.sdsampler.GetStableDiffusionSamplersUseCase import com.shifthackz.aisdv1.domain.usecase.sdsampler.GetStableDiffusionSamplersUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToArliAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToCoreMlUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToCoreMlUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToFalAiUseCase @@ -154,6 +159,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSdxlUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.DefaultConnectToA1111UseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.settings.DefaultConnectToArliAiUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.DefaultConnectToFalAiUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.DefaultConnectToHordeUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.DefaultConnectToHuggingFaceUseCaseImpl @@ -271,6 +277,12 @@ val coreDomainModule = module { remoteDataSource = get(), ) } + factory { + DefaultTestArliAiApiKeyUseCaseImpl( + configurationStore = get(), + remoteDataSource = get(), + ) + } factory { DataPreLoaderUseCaseImpl( serverConfigurationRepository = get(), @@ -341,6 +353,7 @@ val coreDomainModule = module { openAiGenerationRepository = get(), stabilityAiGenerationRepository = get(), falAiGenerationRepository = get(), + arliAiGenerationRepository = get(), swarmUiGenerationRepository = get(), localDiffusionGenerationRepository = get(), mediaPipeGenerationRepository = get(), @@ -358,6 +371,7 @@ val coreDomainModule = module { stabilityAiGenerationRepository = get(), coreMlGenerationRepository = get(), falAiGenerationRepository = get(), + arliAiGenerationRepository = get(), preferenceManager = get(), ) } @@ -520,6 +534,13 @@ val coreDomainModule = module { testFalAiApiKeyUseCase = get(), ) } + factory { + DefaultConnectToArliAiUseCaseImpl( + getConfigurationUseCase = get(), + setServerConfigurationUseCase = get(), + testArliAiApiKeyUseCase = get(), + ) + } factory { FetchSupportersUseCaseImpl(repository = get()) } @@ -553,6 +574,12 @@ val coreDomainModule = module { repository = get(), ) } + factory { + FetchAndGetArliAiModelsUseCaseImpl( + preferenceManager = get(), + repository = get(), + ) + } factory { FetchAndGetLorasUseCaseImpl(lorasRepository = get()) } diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ArliAiSampler.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ArliAiSampler.kt new file mode 100644 index 000000000..74cc1f276 --- /dev/null +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ArliAiSampler.kt @@ -0,0 +1,51 @@ +package com.shifthackz.aisdv1.domain.entity + +/** + * Lists ArliAI sampler names supported by the setup and generation forms. + * + * @property key sampler name sent to ArliAI. + * + * @author Dmitriy Moroz + */ +enum class ArliAiSampler( + val key: String, +) { + DPM_PLUS_PLUS_2M_KARRAS("DPM++ 2M Karras"), + EULER_A("Euler a"), + UNIPC("UniPC"), + DDIM("DDIM"), + EULER("Euler"), + EULER_SGM("Euler SGM"), + EULER_EDM("Euler EDM"), + DPM_PLUS_PLUS_2M("DPM++ 2M"), + DPM_PLUS_PLUS_3M("DPM++ 3M"), + DPM_PLUS_PLUS_1S("DPM++ 1S"), + DPM_PLUS_PLUS_SDE("DPM++ SDE"), + DPM_PLUS_PLUS_2M_SDE("DPM++ 2M SDE"), + DPM_PLUS_PLUS_2M_EDM("DPM++ 2M EDM"), + DPM_PLUS_PLUS_COSINE("DPM++ Cosine"), + DPM_SDE("DPM SDE"), + DPM_PLUS_PLUS_INVERSE("DPM++ Inverse"), + DPM_PLUS_PLUS_2M_INVERSE("DPM++ 2M Inverse"), + DPM_PLUS_PLUS_3M_INVERSE("DPM++ 3M Inverse"), + HEUN("Heun"), + DEIS("DEIS"), + LCM("LCM"), + ; + + companion object { + /** + * Default sampler used when the payload does not specify one. + * + * @author Dmitriy Moroz + */ + val default: ArliAiSampler = DPM_PLUS_PLUS_2M_KARRAS + + /** + * Provider-facing sampler names shown in UI selectors. + * + * @author Dmitriy Moroz + */ + val supported: List = entries.map(ArliAiSampler::key) + } +} diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/Configuration.kt index e232df347..613b557e3 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/Configuration.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/Configuration.kt @@ -80,6 +80,12 @@ data class Configuration( * @author Dmitriy Moroz */ val falAiApiKey: String = "", + /** + * Exposes the `arliAiApiKey` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + val arliAiApiKey: String = "", /** * Exposes the `authCredentials` value used by the SDAI domain layer. * diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ImageToImagePayload.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ImageToImagePayload.kt index 8554676b8..34dafa506 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ImageToImagePayload.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ImageToImagePayload.kt @@ -180,4 +180,10 @@ data class ImageToImagePayload( * @author Dmitriy Moroz */ val falAiSyncMode: Boolean = false, + /** + * Exposes the `arliAiModel` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + val arliAiModel: String = "", ) diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ServerSource.kt index f69c7d011..16e45d0ba 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ServerSource.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/ServerSource.kt @@ -155,6 +155,18 @@ enum class ServerSource( FeatureTag.MultipleModels, FeatureTag.Batch, ), + ), + ARLI_AI( + key = "arli_ai", + type = ServerSourceType.CLOUD, + readiness = ServerSourceReadiness.ALPHA, + version = "2026.6.13", + featureTags = setOf( + FeatureTag.Txt2Img, + FeatureTag.Img2Img, + FeatureTag.MultipleModels, + FeatureTag.Batch, + ), ); companion object { diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/Settings.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/Settings.kt index 62ec3f7c5..5aff570ef 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/Settings.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/Settings.kt @@ -92,6 +92,12 @@ data class Settings( * @author Dmitriy Moroz */ val hordeApiKey: String = "", + /** + * Exposes the `arliAiModel` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + val arliAiModel: String = "", /** * Exposes the `localUseNNAPI` value used by the SDAI domain layer. * diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/TextToImagePayload.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/TextToImagePayload.kt index e5e043731..a555b16c7 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/TextToImagePayload.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/entity/TextToImagePayload.kt @@ -168,4 +168,10 @@ data class TextToImagePayload( * @author Dmitriy Moroz */ val falAiSyncMode: Boolean = false, + /** + * Exposes the `arliAiModel` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + val arliAiModel: String = "", ) diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/preference/ConfigurationStore.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/preference/ConfigurationStore.kt index 1ce45038b..55150c752 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/preference/ConfigurationStore.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/preference/ConfigurationStore.kt @@ -82,6 +82,12 @@ interface ConfigurationStore { * @author Dmitriy Moroz */ var falAiApiKey: String + /** + * Exposes the `arliAiApiKey` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + var arliAiApiKey: String /** * Exposes the `localOnnxModelId` value used by the SDAI domain layer. * @@ -152,6 +158,7 @@ interface ConfigurationStore { stabilityAiApiKey = stabilityAiApiKey, stabilityAiEngineId = stabilityAiEngineId, falAiApiKey = falAiApiKey, + arliAiApiKey = arliAiApiKey, authCredentials = authCredentials, localOnnxModelId = localOnnxModelId, localOnnxModelPath = localOnnxModelPath, @@ -182,6 +189,7 @@ interface ConfigurationStore { stabilityAiApiKey = configuration.stabilityAiApiKey stabilityAiEngineId = configuration.stabilityAiEngineId falAiApiKey = configuration.falAiApiKey + arliAiApiKey = configuration.arliAiApiKey localOnnxModelId = configuration.localOnnxModelId localOnnxModelPath = configuration.localOnnxModelPath localMediaPipeModelId = configuration.localMediaPipeModelId diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index 41639494e..cb690c6e5 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -156,6 +156,18 @@ interface PreferenceManager { * @author Dmitriy Moroz */ var falAiApiKey: String + /** + * Exposes the `arliAiApiKey` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + var arliAiApiKey: String + /** + * Exposes the `arliAiModel` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + var arliAiModel: String /** * Exposes the `stabilityAiEngineId` value used by the SDAI domain layer. * diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/repository/ArliAiGenerationRepository.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/repository/ArliAiGenerationRepository.kt new file mode 100644 index 000000000..53bcf573f --- /dev/null +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/repository/ArliAiGenerationRepository.kt @@ -0,0 +1,44 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload + +/** + * Generates images through the configured ArliAI account. + * + * Implementations resolve the selected checkpoint from the payload first and then from + * persisted ArliAI settings, so callers can override the model per request. + * + * @author Dmitriy Moroz + */ +interface ArliAiGenerationRepository { + /** + * Checks whether the persisted ArliAI API key is accepted by the provider. + * + * @return `true` when ArliAI accepts the current key. + * + * @author Dmitriy Moroz + */ + suspend fun validateApiKey(): Boolean + + /** + * Generates one or more text-to-image results. + * + * @param payload generation settings and optional ArliAI model override. + * @return persisted generation records produced from the provider response. + * + * @author Dmitriy Moroz + */ + suspend fun generateFromText(payload: TextToImagePayload): List + + /** + * Generates one or more image-to-image results. + * + * @param payload generation settings, source image data, and optional ArliAI model override. + * @return persisted generation records produced from the provider response. + * + * @author Dmitriy Moroz + */ + suspend fun generateFromImage(payload: ImageToImagePayload): List +} diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/repository/ArliAiModelsRepository.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/repository/ArliAiModelsRepository.kt new file mode 100644 index 000000000..f17f6c426 --- /dev/null +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/repository/ArliAiModelsRepository.kt @@ -0,0 +1,38 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel + +/** + * Synchronizes and reads ArliAI checkpoint metadata. + * + * Remote models are cached locally so setup and generation screens can keep showing the last + * known ArliAI list when a refresh fails. + * + * @author Dmitriy Moroz + */ +interface ArliAiModelsRepository { + /** + * Refreshes ArliAI checkpoint metadata from the provider and stores it locally. + * + * @author Dmitriy Moroz + */ + suspend fun fetchModels() + + /** + * Attempts a provider refresh and then returns the locally cached checkpoint list. + * + * @return cached ArliAI checkpoints after the refresh attempt. + * + * @author Dmitriy Moroz + */ + suspend fun fetchAndGetModels(): List + + /** + * Reads cached ArliAI checkpoint metadata. + * + * @return locally stored ArliAI checkpoints. + * + * @author Dmitriy Moroz + */ + suspend fun getModels(): List +} diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/arliai/FetchAndGetArliAiModelsUseCase.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/arliai/FetchAndGetArliAiModelsUseCase.kt new file mode 100644 index 000000000..18e9075d8 --- /dev/null +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/arliai/FetchAndGetArliAiModelsUseCase.kt @@ -0,0 +1,72 @@ +package com.shifthackz.aisdv1.domain.usecase.arliai + +import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.ArliAiModelsRepository + +/** + * Refreshes and returns ArliAI checkpoints for active ArliAI sessions. + * + * @author Dmitriy Moroz + */ +interface FetchAndGetArliAiModelsUseCase { + + /** + * Loads ArliAI models when ArliAI is the selected source. + * + * @return cached ArliAI checkpoints, or an empty list for other server sources. + * + * @author Dmitriy Moroz + */ + suspend operator fun invoke(): List +} + +/** + * Keeps the selected ArliAI checkpoint valid after model-list refreshes. + * + * If the saved model no longer exists, the first available checkpoint becomes the selected model. + * + * @author Dmitriy Moroz + */ +internal class FetchAndGetArliAiModelsUseCaseImpl( + /** + * Exposes the `preferenceManager` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + private val preferenceManager: PreferenceManager, + /** + * Exposes the `repository` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + private val repository: ArliAiModelsRepository, +) : FetchAndGetArliAiModelsUseCase { + + /** + * Loads ArliAI models and updates the persisted selected checkpoint if needed. + * + * @return cached ArliAI checkpoints, or an empty list for other server sources. + * + * @author Dmitriy Moroz + */ + override suspend fun invoke(): List { + if (preferenceManager.source != ServerSource.ARLI_AI) return emptyList() + + val models = repository.fetchAndGetModels() + val modelNames = models.map(StableDiffusionModel::checkpointName) + if (!modelNames.contains(preferenceManager.arliAiModel)) { + preferenceManager.arliAiModel = modelNames.firstOrNull().orEmpty() + } + return models + } +} + +/** + * Returns the best checkpoint identifier available in ArliAI model metadata. + * + * @author Dmitriy Moroz + */ +internal val StableDiffusionModel.checkpointName: String + get() = title.ifBlank { modelName }.ifBlank { filename } diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/connectivity/ConnectivityUseCaseImpls.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/connectivity/ConnectivityUseCaseImpls.kt index 714cf4348..12457ca62 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/connectivity/ConnectivityUseCaseImpls.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/connectivity/ConnectivityUseCaseImpls.kt @@ -1,5 +1,6 @@ package com.shifthackz.aisdv1.domain.usecase.connectivity +import com.shifthackz.aisdv1.domain.datasource.ArliAiGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.HordeGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.HuggingFaceGenerationDataSource import com.shifthackz.aisdv1.domain.datasource.FalAiGenerationDataSource @@ -227,3 +228,33 @@ internal class DefaultTestFalAiApiKeyUseCaseImpl( override suspend fun invoke(): Boolean = remoteDataSource.validateApiKey(configurationStore.falAiApiKey) } + +/** + * Implements `DefaultTestArliAiApiKeyUseCase` behavior in the SDAI domain layer. + * + * @author Dmitriy Moroz + */ +internal class DefaultTestArliAiApiKeyUseCaseImpl( + /** + * Exposes the `configurationStore` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + private val configurationStore: ConfigurationStore, + /** + * Exposes the `remoteDataSource` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + private val remoteDataSource: ArliAiGenerationDataSource.Remote, +) : TestArliAiApiKeyUseCase { + + /** + * Executes the `invoke` step in the SDAI domain layer. + * + * @return Result produced by `invoke`. + * @author Dmitriy Moroz + */ + override suspend fun invoke(): Boolean = + remoteDataSource.validateApiKey(configurationStore.arliAiApiKey) +} diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/connectivity/TestArliAiApiKeyUseCase.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/connectivity/TestArliAiApiKeyUseCase.kt new file mode 100644 index 000000000..aff13b0c9 --- /dev/null +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/connectivity/TestArliAiApiKeyUseCase.kt @@ -0,0 +1,17 @@ +package com.shifthackz.aisdv1.domain.usecase.connectivity + +/** + * Validates the ArliAI API key currently stored in setup configuration. + * + * @author Dmitriy Moroz + */ +fun interface TestArliAiApiKeyUseCase { + /** + * Checks whether the configured ArliAI key can reach provider endpoints. + * + * @return `true` when the stored key is accepted by ArliAI. + * + * @author Dmitriy Moroz + */ + suspend operator fun invoke(): Boolean +} diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt index 0bec547db..ba6591d89 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImpl.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.domain.usecase.generation import com.shifthackz.aisdv1.domain.entity.ImageToImagePayload import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.ArliAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.CoreMlGenerationRepository import com.shifthackz.aisdv1.domain.repository.FalAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository @@ -67,6 +68,13 @@ internal class ImageToImageUseCaseImpl( * @author Dmitriy Moroz */ private val falAiGenerationRepository: FalAiGenerationRepository, + /** + * Exposes the `arliAiGenerationRepository` value used by the SDAI domain layer. + * + * @throws IllegalStateException when the delegated operation cannot complete. + * @author Dmitriy Moroz + */ + private val arliAiGenerationRepository: ArliAiGenerationRepository, /** * Exposes the `preferenceManager` value used by the SDAI domain layer. * @@ -85,6 +93,7 @@ internal class ImageToImageUseCaseImpl( override suspend fun invoke(payload: ImageToImagePayload) = when (preferenceManager.source) { ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.generateFromImage(payload) ServerSource.FAL_AI -> falAiGenerationRepository.generateFromImage(payload) + ServerSource.ARLI_AI -> arliAiGenerationRepository.generateFromImage(payload) else -> List(payload.batchCount.coerceAtLeast(1)) { generateSingle(payload) } @@ -104,6 +113,7 @@ internal class ImageToImageUseCaseImpl( ServerSource.LOCAL_APPLE_CORE_ML -> coreMlGenerationRepository.generateFromImage(payload) ServerSource.AUTOMATIC1111 -> error("Automatic1111 batch must be generated through generateFromImage(payload).") ServerSource.FAL_AI -> error("Fal.ai batch must be generated through generateFromImage(payload).") + ServerSource.ARLI_AI -> error("ArliAI batch must be generated through generateFromImage(payload).") else -> throw IllegalStateException("Img2Img not yet supported on ${preferenceManager.source}!") } } diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt index 1cde6eec5..d581f4e9c 100755 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.ArliAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.CoreMlGenerationRepository import com.shifthackz.aisdv1.domain.repository.FalAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository @@ -65,6 +66,13 @@ internal class TextToImageUseCaseImpl( * @author Dmitriy Moroz */ private val falAiGenerationRepository: FalAiGenerationRepository, + /** + * Exposes the `arliAiGenerationRepository` value used by the SDAI domain layer. + * + * @throws IllegalStateException when the current state is invalid. + * @author Dmitriy Moroz + */ + private val arliAiGenerationRepository: ArliAiGenerationRepository, /** * Exposes the `swarmUiGenerationRepository` value used by the SDAI domain layer. * @@ -120,6 +128,7 @@ internal class TextToImageUseCaseImpl( ): List = when (preferenceManager.source) { ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.generateFromText(payload) ServerSource.FAL_AI -> falAiGenerationRepository.generateFromText(payload) + ServerSource.ARLI_AI -> arliAiGenerationRepository.generateFromText(payload) else -> List(payload.batchCount.coerceAtLeast(1)) { generateSingle(payload) } @@ -143,5 +152,6 @@ internal class TextToImageUseCaseImpl( ServerSource.LOCAL_APPLE_CORE_ML -> coreMlGenerationRepository.generateFromText(payload) ServerSource.AUTOMATIC1111 -> error("Automatic1111 batch must be generated through generateFromText(payload).") ServerSource.FAL_AI -> error("Fal.ai batch must be generated through generateFromText(payload).") + ServerSource.ARLI_AI -> error("ArliAI batch must be generated through generateFromText(payload).") } } diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToArliAiUseCase.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToArliAiUseCase.kt new file mode 100644 index 000000000..9dc99414d --- /dev/null +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToArliAiUseCase.kt @@ -0,0 +1,18 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +/** + * Saves ArliAI setup data and verifies provider connectivity. + * + * @author Dmitriy Moroz + */ +fun interface ConnectToArliAiUseCase { + /** + * Attempts to connect to ArliAI with the supplied API key. + * + * @param apiKey ArliAI API key entered by the user. + * @return success when the key is saved and accepted by ArliAI. + * + * @author Dmitriy Moroz + */ + suspend operator fun invoke(apiKey: String): Result +} diff --git a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToCloudUseCaseImpls.kt b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToCloudUseCaseImpls.kt index d773f1ca8..96a98b163 100644 --- a/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToCloudUseCaseImpls.kt +++ b/domain/src/commonMain/kotlin/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToCloudUseCaseImpls.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.domain.usecase.settings import com.shifthackz.aisdv1.domain.entity.Configuration import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials +import com.shifthackz.aisdv1.domain.usecase.connectivity.TestArliAiApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHordeApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestHuggingFaceApiKeyUseCase import com.shifthackz.aisdv1.domain.usecase.connectivity.TestFalAiApiKeyUseCase @@ -11,6 +12,7 @@ import com.shifthackz.aisdv1.domain.usecase.connectivity.TestStabilityAiApiKeyUs import com.shifthackz.aisdv1.domain.usecase.connectivity.TestSwarmUiConnectivityUseCase import kotlinx.coroutines.delay import kotlinx.coroutines.withTimeout +import kotlin.time.Duration.Companion.milliseconds /** * Implements `DefaultConnectToSwarmUiUseCase` behavior in the SDAI domain layer. @@ -52,7 +54,7 @@ internal class DefaultConnectToSwarmUiUseCaseImpl( ): Result { var configuration: Configuration? = null return try { - withTimeout(CONNECTION_TIMEOUT_MILLIS) { + withTimeout(CONNECTION_TIMEOUT_MILLIS.milliseconds) { val originalConfiguration = getConfigurationUseCase() configuration = originalConfiguration val newConfiguration = originalConfiguration.copy( @@ -61,7 +63,7 @@ internal class DefaultConnectToSwarmUiUseCaseImpl( authCredentials = credentials, ) setServerConfigurationUseCase(newConfiguration) - delay(CONNECTION_DELAY_MILLIS) + delay(CONNECTION_DELAY_MILLIS.milliseconds) withLocalNetworkPermissionRetry(url) { testSwarmUiConnectivityUseCase(url) } @@ -310,6 +312,53 @@ internal class DefaultConnectToFalAiUseCaseImpl( } } +/** + * Implements `DefaultConnectToArliAiUseCase` behavior in the SDAI domain layer. + * + * @author Dmitriy Moroz + */ +internal class DefaultConnectToArliAiUseCaseImpl( + /** + * Exposes the `getConfigurationUseCase` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + private val getConfigurationUseCase: GetConfigurationUseCase, + /** + * Exposes the `setServerConfigurationUseCase` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + private val setServerConfigurationUseCase: SetServerConfigurationUseCase, + /** + * Exposes the `testArliAiApiKeyUseCase` value used by the SDAI domain layer. + * + * @author Dmitriy Moroz + */ + private val testArliAiApiKeyUseCase: TestArliAiApiKeyUseCase, +) : ConnectToArliAiUseCase { + + /** + * Executes the `invoke` step in the SDAI domain layer. + * + * @param apiKey api key value consumed by the API. + * @return Result produced by `invoke`. + * @author Dmitriy Moroz + */ + override suspend fun invoke(apiKey: String): Result = + connectWithApiKey( + getConfigurationUseCase = getConfigurationUseCase, + setServerConfigurationUseCase = setServerConfigurationUseCase, + testApiKey = testArliAiApiKeyUseCase::invoke, + ) { configuration -> + configuration.copy( + source = ServerSource.ARLI_AI, + arliAiApiKey = apiKey, + authCredentials = AuthorizationCredentials.None, + ) + } +} + /** * Executes the `connectWithApiKey` step in the SDAI domain layer. * @@ -331,7 +380,7 @@ private suspend fun connectWithApiKey( val originalConfiguration = getConfigurationUseCase() configuration = originalConfiguration setServerConfigurationUseCase(updateConfiguration(originalConfiguration)) - delay(CONNECTION_DELAY_MILLIS) + delay(CONNECTION_DELAY_MILLIS.milliseconds) requireRemoteValidApiKey(testApiKey()) Result.success(Unit) } catch (t: Throwable) { diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 345ec168e..0fa492520 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -77,6 +77,7 @@ okhttp-core = { group = "com.squareup.okhttp3", name = "okhttp", version.ref = " ktor-client-core = { group = "io.ktor", name = "ktor-client-core", version.ref = "ktor" } ktor-client-content-negotiation = { group = "io.ktor", name = "ktor-client-content-negotiation", version.ref = "ktor" } ktor-client-logging = { group = "io.ktor", name = "ktor-client-logging", version.ref = "ktor" } +ktor-client-mock = { group = "io.ktor", name = "ktor-client-mock", version.ref = "ktor" } ktor-client-okhttp = { group = "io.ktor", name = "ktor-client-okhttp", version.ref = "ktor" } ktor-client-darwin = { group = "io.ktor", name = "ktor-client-darwin", version.ref = "ktor" } ktor-serialization-kotlinx-json = { group = "io.ktor", name = "ktor-serialization-kotlinx-json", version.ref = "ktor" } diff --git a/network/build.gradle.kts b/network/build.gradle.kts index e697a3b12..060914e89 100755 --- a/network/build.gradle.kts +++ b/network/build.gradle.kts @@ -29,6 +29,7 @@ kotlin { androidUnitTest.dependencies { implementation(libs.test.junit) implementation(libs.test.mockk) + implementation(libs.ktor.client.mock) } } } diff --git a/network/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/network/api/arliai/KtorArliAiGenerationApiTest.kt b/network/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/network/api/arliai/KtorArliAiGenerationApiTest.kt new file mode 100644 index 000000000..fe907035e --- /dev/null +++ b/network/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/network/api/arliai/KtorArliAiGenerationApiTest.kt @@ -0,0 +1,145 @@ +package com.shifthackz.aisdv1.network.api.arliai + +import com.shifthackz.aisdv1.network.client.NetworkUsageCategory +import com.shifthackz.aisdv1.network.client.NetworkUsageCounter +import com.shifthackz.aisdv1.network.client.defaultNetworkJson +import com.shifthackz.aisdv1.network.request.ArliAiImageToImageRequest +import com.shifthackz.aisdv1.network.request.ArliAiTextToImageRequest +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.headersOf +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.encodeToString +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Test + +class KtorArliAiGenerationApiTest { + + private val records = mutableListOf>() + + @After + fun tearDown() { + NetworkUsageCounter.recorder = null + } + + @Test + fun `given models request, expected config traffic recorded`() = runNetworkUsageTest { + val response = """[{"title":"Dream","model_name":"dream"}]""" + val api = createApi(response) + + val models = api.fetchModels(API_KEY) + + assertEquals("Dream", models.first().title) + assertEquals( + listOf(NetworkUsageCategory.CONFIGS to response.byteSize), + records, + ) + } + + @Test + fun `given text to image request, expected inference traffic recorded`() = runNetworkUsageTest { + val request = textToImageRequest() + val response = """{"images":["AQID"],"info":"{}"}""" + val api = createApi(response) + + val result = api.textToImage(API_KEY, request) + + assertEquals(listOf("AQID"), result.images) + assertEquals( + listOf( + NetworkUsageCategory.INFERENCE to defaultNetworkJson.encodeToString(request).byteSize, + NetworkUsageCategory.INFERENCE to response.byteSize, + ), + records, + ) + } + + @Test + fun `given image to image request, expected inference traffic recorded`() = runNetworkUsageTest { + val request = imageToImageRequest() + val response = """{"images":["BAUG"],"info":"{}"}""" + val api = createApi(response) + + val result = api.imageToImage(API_KEY, request) + + assertEquals(listOf("BAUG"), result.images) + assertEquals( + listOf( + NetworkUsageCategory.INFERENCE to defaultNetworkJson.encodeToString(request).byteSize, + NetworkUsageCategory.INFERENCE to response.byteSize, + ), + records, + ) + } + + private fun runNetworkUsageTest(block: suspend () -> Unit) = runBlocking { + NetworkUsageCounter.recorder = { category, bytes -> + records += category to bytes + } + block() + } + + private fun createApi(response: String): KtorArliAiGenerationApi { + val engine = MockEngine { + respond( + content = response, + status = HttpStatusCode.OK, + headers = headersOf(HttpHeaders.ContentType, ContentType.Application.Json.toString()), + ) + } + return KtorArliAiGenerationApi( + httpClient = HttpClient(engine), + baseUrl = BASE_URL, + ) + } + + private val String.byteSize: Long + get() = encodeToByteArray().size.toLong() + + private fun textToImageRequest() = ArliAiTextToImageRequest( + sdModelCheckpoint = MODEL, + prompt = "a cozy treehouse", + negativePrompt = "low quality", + steps = 20, + samplerName = "Euler", + width = 512, + height = 512, + seed = 5598L, + cfgScale = 7.5f, + batchSize = 1, + restoreFaces = false, + ) + + private fun imageToImageRequest() = ArliAiImageToImageRequest( + sdModelCheckpoint = MODEL, + prompt = "a cozy treehouse", + negativePrompt = "low quality", + initImages = listOf("AQID"), + mask = null, + denoisingStrength = 0.65f, + steps = 20, + samplerName = "Euler", + width = 512, + height = 512, + seed = 5598L, + cfgScale = 7.5f, + batchSize = 1, + restoreFaces = false, + maskBlur = null, + inPaintingFill = null, + inPaintFullRes = null, + inPaintFullResPadding = null, + inPaintingMaskInvert = null, + ) + + private companion object { + const val API_KEY = "key" + const val BASE_URL = "https://api.arliai.example" + const val MODEL = "Illustrious-XL-v2.0" + } +} diff --git a/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/api/arliai/ArliAiGenerationApi.kt b/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/api/arliai/ArliAiGenerationApi.kt new file mode 100644 index 000000000..7e74eb295 --- /dev/null +++ b/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/api/arliai/ArliAiGenerationApi.kt @@ -0,0 +1,63 @@ +package com.shifthackz.aisdv1.network.api.arliai + +import com.shifthackz.aisdv1.network.model.KtorStableDiffusionModelRaw +import com.shifthackz.aisdv1.network.request.ArliAiImageToImageRequest +import com.shifthackz.aisdv1.network.request.ArliAiTextToImageRequest +import com.shifthackz.aisdv1.network.response.SdGenerationResponse + +/** + * Describes the ArliAI SDNext-compatible generation endpoints. + * + * Implementations use bearer authorization and return Automatic1111-style + * generation responses so the data layer can reuse Stable Diffusion mappers. + * + * @author Dmitriy Moroz + */ +interface ArliAiGenerationApi { + /** + * Verifies that the supplied ArliAI key can access the provider API. + * + * @param apiKey ArliAI API key sent as bearer authorization. + * + * @author Dmitriy Moroz + */ + suspend fun validateApiKey(apiKey: String) + + /** + * Loads the checkpoint list exposed by ArliAI. + * + * @param apiKey ArliAI API key sent as bearer authorization. + * @return raw checkpoint metadata returned by the provider. + * + * @author Dmitriy Moroz + */ + suspend fun fetchModels(apiKey: String): List + + /** + * Sends a text-to-image request to ArliAI. + * + * @param apiKey ArliAI API key sent as bearer authorization. + * @param request SDNext-compatible text-to-image payload. + * @return generated image payload returned by ArliAI. + * + * @author Dmitriy Moroz + */ + suspend fun textToImage( + apiKey: String, + request: ArliAiTextToImageRequest, + ): SdGenerationResponse + + /** + * Sends an image-to-image request to ArliAI. + * + * @param apiKey ArliAI API key sent as bearer authorization. + * @param request SDNext-compatible image-to-image payload. + * @return generated image payload returned by ArliAI. + * + * @author Dmitriy Moroz + */ + suspend fun imageToImage( + apiKey: String, + request: ArliAiImageToImageRequest, + ): SdGenerationResponse +} diff --git a/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/api/arliai/KtorArliAiGenerationApi.kt b/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/api/arliai/KtorArliAiGenerationApi.kt new file mode 100644 index 000000000..d564c41cb --- /dev/null +++ b/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/api/arliai/KtorArliAiGenerationApi.kt @@ -0,0 +1,126 @@ +package com.shifthackz.aisdv1.network.api.arliai + +import com.shifthackz.aisdv1.network.client.createConfiguredHttpClient +import com.shifthackz.aisdv1.network.client.NetworkUsageCategory +import com.shifthackz.aisdv1.network.client.setTrackedJsonBody +import com.shifthackz.aisdv1.network.client.trackUsage +import com.shifthackz.aisdv1.network.client.trackedJsonBody +import com.shifthackz.aisdv1.network.model.KtorStableDiffusionModelRaw +import com.shifthackz.aisdv1.network.request.ArliAiImageToImageRequest +import com.shifthackz.aisdv1.network.request.ArliAiTextToImageRequest +import com.shifthackz.aisdv1.network.response.SdGenerationResponse +import io.ktor.client.HttpClient +import io.ktor.client.request.get +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.http.HttpHeaders +import io.ktor.http.appendPathSegments +import io.ktor.http.takeFrom + +/** + * Ktor implementation of ArliAI validation, model discovery, and image generation. + * + * Model-list traffic is counted as [NetworkUsageCategory.CONFIGS]. Text-to-image and + * image-to-image request and response bodies are counted as [NetworkUsageCategory.INFERENCE]. + * + * @param httpClient configured Ktor client used to send provider requests. + * @param baseUrl ArliAI SDNext-compatible API base URL. + * + * @author Dmitriy Moroz + */ +class KtorArliAiGenerationApi( + private val httpClient: HttpClient, + private val baseUrl: String, +) : ArliAiGenerationApi { + + /** + * Creates an ArliAI API client with the shared application HTTP configuration. + * + * @param baseUrl ArliAI SDNext-compatible API base URL. + * + * @author Dmitriy Moroz + */ + constructor(baseUrl: String) : this( + httpClient = createConfiguredHttpClient(), + baseUrl = baseUrl, + ) + + /** + * Validates the key by loading the ArliAI model list. + * + * @param apiKey ArliAI API key sent as bearer authorization. + * + * @author Dmitriy Moroz + */ + override suspend fun validateApiKey(apiKey: String) { + fetchModels(apiKey) + } + + /** + * Loads available ArliAI checkpoints and records the response as configuration traffic. + * + * @param apiKey ArliAI API key sent as bearer authorization. + * @return raw checkpoint metadata returned by the provider. + * + * @author Dmitriy Moroz + */ + override suspend fun fetchModels(apiKey: String): List = httpClient.get { + url.takeFrom(baseUrl) + url.appendPathSegments(PATH_SD_API, PATH_V1, PATH_SD_MODELS) + header(HttpHeaders.Authorization, apiKey.headerValue) + trackUsage(NetworkUsageCategory.CONFIGS) + }.trackedJsonBody(NetworkUsageCategory.CONFIGS) + + /** + * Sends text-to-image generation and records request plus response bytes as inference traffic. + * + * @param apiKey ArliAI API key sent as bearer authorization. + * @param request SDNext-compatible text-to-image payload. + * @return generated image payload returned by ArliAI. + * + * @author Dmitriy Moroz + */ + override suspend fun textToImage( + apiKey: String, + request: ArliAiTextToImageRequest, + ): SdGenerationResponse = httpClient + .post { + url.takeFrom(baseUrl) + url.appendPathSegments(PATH_SD_API, PATH_V1, PATH_TXT_TO_IMG) + header(HttpHeaders.Authorization, apiKey.headerValue) + setTrackedJsonBody(NetworkUsageCategory.INFERENCE, request) + } + .trackedJsonBody(NetworkUsageCategory.INFERENCE) + + /** + * Sends image-to-image generation and records request plus response bytes as inference traffic. + * + * @param apiKey ArliAI API key sent as bearer authorization. + * @param request SDNext-compatible image-to-image payload. + * @return generated image payload returned by ArliAI. + * + * @author Dmitriy Moroz + */ + override suspend fun imageToImage( + apiKey: String, + request: ArliAiImageToImageRequest, + ): SdGenerationResponse = httpClient + .post { + url.takeFrom(baseUrl) + url.appendPathSegments(PATH_SD_API, PATH_V1, PATH_IMG_TO_IMG) + header(HttpHeaders.Authorization, apiKey.headerValue) + setTrackedJsonBody(NetworkUsageCategory.INFERENCE, request) + } + .trackedJsonBody(NetworkUsageCategory.INFERENCE) + + private val String.headerValue: String + get() = "Bearer $this" + + private companion object { + const val PATH_SD_API = "sdapi" + const val PATH_V1 = "v1" + const val PATH_SD_MODELS = "sd-models" + const val PATH_TXT_TO_IMG = "txt2img" + const val PATH_IMG_TO_IMG = "img2img" + } +} diff --git a/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/di/NetworkCoreModule.kt b/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/di/NetworkCoreModule.kt index be1ab692d..ff0eaaebd 100644 --- a/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/di/NetworkCoreModule.kt +++ b/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/di/NetworkCoreModule.kt @@ -4,6 +4,8 @@ import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111GenerationAp import com.shifthackz.aisdv1.network.api.automatic1111.Automatic1111MetadataApi import com.shifthackz.aisdv1.network.api.automatic1111.KtorAutomatic1111GenerationApi import com.shifthackz.aisdv1.network.api.automatic1111.KtorAutomatic1111MetadataApi +import com.shifthackz.aisdv1.network.api.arliai.ArliAiGenerationApi +import com.shifthackz.aisdv1.network.api.arliai.KtorArliAiGenerationApi import com.shifthackz.aisdv1.network.api.falai.FalAiGenerationApi import com.shifthackz.aisdv1.network.api.falai.KtorFalAiGenerationApi import com.shifthackz.aisdv1.network.api.horde.HordeGenerationApi @@ -64,6 +66,9 @@ val coreNetworkModule = module { single { KtorFalAiGenerationApi(FAL_AI_API_URL, FAL_AI_QUEUE_API_URL) } + single { + KtorArliAiGenerationApi(ARLI_AI_API_URL) + } single { KtorSwarmUiModelsApi() } @@ -138,3 +143,9 @@ private const val FAL_AI_API_URL = "https://api.fal.ai" * @author Dmitriy Moroz */ private const val FAL_AI_QUEUE_API_URL = "https://queue.fal.run" +/** + * Exposes the `ARLI_AI_API_URL` value used by the SDAI network layer. + * + * @author Dmitriy Moroz + */ +private const val ARLI_AI_API_URL = "https://api.arliai.com" diff --git a/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/request/ArliAiGenerationRequest.kt b/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/request/ArliAiGenerationRequest.kt new file mode 100644 index 000000000..8bc66734f --- /dev/null +++ b/network/src/commonMain/kotlin/com/shifthackz/aisdv1/network/request/ArliAiGenerationRequest.kt @@ -0,0 +1,124 @@ +package com.shifthackz.aisdv1.network.request + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +/** + * JSON payload for ArliAI text-to-image generation. + * + * The field names follow the SDNext-compatible ArliAI API. Optional detailer fields are omitted + * unless ADetailer is enabled in the domain payload mapper. + * + * @author Dmitriy Moroz + */ +@Serializable +data class ArliAiTextToImageRequest( + @SerialName("sd_model_checkpoint") + val sdModelCheckpoint: String, + @SerialName("prompt") + val prompt: String, + @SerialName("negative_prompt") + val negativePrompt: String, + @SerialName("steps") + val steps: Int, + @SerialName("sampler_name") + val samplerName: String, + @SerialName("width") + val width: Int, + @SerialName("height") + val height: Int, + @SerialName("seed") + val seed: Long?, + @SerialName("cfg_scale") + val cfgScale: Float, + @SerialName("batch_size") + val batchSize: Int, + @SerialName("restore_faces") + val restoreFaces: Boolean, + @SerialName("detailer_enabled") + val detailerEnabled: Boolean? = null, + @SerialName("detailer_prompt") + val detailerPrompt: String? = null, + @SerialName("detailer_negative") + val detailerNegative: String? = null, + @SerialName("detailer_steps") + val detailerSteps: Int? = null, + @SerialName("detailer_strength") + val detailerStrength: Float? = null, + @SerialName("detailer_model") + val detailerModel: String? = null, + @SerialName("detailer_conf") + val detailerConfidence: Float? = null, + @SerialName("detailer_padding") + val detailerPadding: Int? = null, + @SerialName("detailer_blur") + val detailerBlur: Int? = null, +) + +/** + * JSON payload for ArliAI image-to-image generation. + * + * The field names follow the SDNext-compatible ArliAI API. Optional mask and detailer fields are + * omitted when the source payload does not provide them. + * + * @author Dmitriy Moroz + */ +@Serializable +data class ArliAiImageToImageRequest( + @SerialName("sd_model_checkpoint") + val sdModelCheckpoint: String, + @SerialName("prompt") + val prompt: String, + @SerialName("negative_prompt") + val negativePrompt: String, + @SerialName("init_images") + val initImages: List, + @SerialName("mask") + val mask: String?, + @SerialName("denoising_strength") + val denoisingStrength: Float, + @SerialName("steps") + val steps: Int, + @SerialName("sampler_name") + val samplerName: String, + @SerialName("width") + val width: Int, + @SerialName("height") + val height: Int, + @SerialName("seed") + val seed: Long?, + @SerialName("cfg_scale") + val cfgScale: Float, + @SerialName("batch_size") + val batchSize: Int, + @SerialName("restore_faces") + val restoreFaces: Boolean, + @SerialName("mask_blur") + val maskBlur: Int?, + @SerialName("inpainting_fill") + val inPaintingFill: Int?, + @SerialName("inpaint_full_res") + val inPaintFullRes: Boolean?, + @SerialName("inpaint_full_res_padding") + val inPaintFullResPadding: Int?, + @SerialName("inpainting_mask_invert") + val inPaintingMaskInvert: Int?, + @SerialName("detailer_enabled") + val detailerEnabled: Boolean? = null, + @SerialName("detailer_prompt") + val detailerPrompt: String? = null, + @SerialName("detailer_negative") + val detailerNegative: String? = null, + @SerialName("detailer_steps") + val detailerSteps: Int? = null, + @SerialName("detailer_strength") + val detailerStrength: Float? = null, + @SerialName("detailer_model") + val detailerModel: String? = null, + @SerialName("detailer_conf") + val detailerConfidence: Float? = null, + @SerialName("detailer_padding") + val detailerPadding: Int? = null, + @SerialName("detailer_blur") + val detailerBlur: Int? = null, +) diff --git a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt index c04c77864..7c7b68c0d 100644 --- a/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt +++ b/presentation/src/androidUnitTest/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt @@ -346,6 +346,7 @@ private class TestLinksProvider : LinksProvider { override val openAiInfoUrl = "" override val stabilityAiInfoUrl = "" override val falAiInfoUrl = "" + override val arliAiInfoUrl = "" override val privacyPolicyUrl = "https://policy.test" override val donateUrl = "" override val projectWebsiteUrl = "https://project.test" diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/di/PresentationViewModelBindings.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/di/PresentationViewModelBindings.kt index c54f1df07..d0f92c230 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/di/PresentationViewModelBindings.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/di/PresentationViewModelBindings.kt @@ -238,6 +238,7 @@ internal fun Module.registerPresentationViewModelBindings() { getConfigurationUseCase = get(), getStableDiffusionSamplersUseCase = get(), getForgeModulesUseCase = get(), + fetchAndGetArliAiModelsUseCase = get(), isADetailerAvailableUseCase = get(), textToImageUseCase = get(), saveGenerationResultUseCase = get(), @@ -267,6 +268,7 @@ internal fun Module.registerPresentationViewModelBindings() { getConfigurationUseCase = get(), getStableDiffusionSamplersUseCase = get(), getForgeModulesUseCase = get(), + fetchAndGetArliAiModelsUseCase = get(), isADetailerAvailableUseCase = get(), getRandomImageUseCase = get(), imageToImageUseCase = get(), @@ -353,6 +355,7 @@ internal fun Module.registerPresentationViewModelBindings() { connectToOpenAiUseCase = get(), connectToStabilityAiUseCase = get(), connectToFalAiUseCase = get(), + connectToArliAiUseCase = get(), downloadModelUseCase = get(), deleteModelUseCase = get(), downloadGuard = get(), diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkScreen.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkScreen.kt index 5b5dc3fd7..da10debcb 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkScreen.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkScreen.kt @@ -793,6 +793,7 @@ private fun ServerSource.displayName(): String = when (this) { ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") + ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") } private fun SdxlBackend.displayName(): String = displayName diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkViewModel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkViewModel.kt index 70c9dbcc3..1b7b382f5 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkViewModel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/benchmark/BenchmarkViewModel.kt @@ -220,6 +220,7 @@ private fun ServerSource.displayName(): String = when (this) { ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") + ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") } private fun SdxlBackend.displayName(): String = displayName diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageContentHelpers.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageContentHelpers.kt index 0fa1417c7..627390881 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageContentHelpers.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageContentHelpers.kt @@ -4,7 +4,6 @@ package com.shifthackz.aisdv1.presentation.screen.img2img import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.runtime.Composable -import androidx.compose.ui.graphics.ImageBitmap import androidx.compose.ui.text.input.TextFieldValue import com.shifthackz.aisdv1.core.localization.Localization import com.shifthackz.aisdv1.domain.entity.AiGenerationResult @@ -54,6 +53,7 @@ internal fun GenerationInputFormEvent.toImageToImageIntent(): ImageToImageIntent is GenerationInputFormEvent.UpdateFalAiAcceleration -> ImageToImageIntent.UpdateFalAiAcceleration(value) is GenerationInputFormEvent.UpdateSdxlBackend -> null is GenerationInputFormEvent.UpdateFalAiSyncMode -> ImageToImageIntent.UpdateFalAiSyncMode(value) + is GenerationInputFormEvent.UpdateArliAiModel -> ImageToImageIntent.UpdateArliAiModel(value) is GenerationInputFormEvent.UpdateStabilityAiStyle -> ImageToImageIntent.UpdateStabilityAiStyle(value) is GenerationInputFormEvent.UpdateStabilityAiClipGuidance -> ImageToImageIntent.UpdateStabilityAiClipGuidance(value) @@ -190,14 +190,6 @@ internal sealed interface ImageToImagePanel { internal val AiGenerationResult.aspectRatio: Float get() = if (width > 0 && height > 0) width.toFloat() / height.toFloat() else 1f -/** - * Exposes the `ImageBitmap` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ -internal val ImageBitmap.safeAspectRatio: Float - get() = if (width > 0 && height > 0) width.toFloat() / height.toFloat() else 1f - /** * Exposes the `ServerSource` value used by the SDAI presentation layer. * @@ -212,6 +204,7 @@ internal val ServerSource.displayName: String ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") + ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") ServerSource.LOCAL_MICROSOFT_ONNX -> Localization.string("srv_type_local_short") ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Localization.string("srv_type_media_pipe_short") ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> Localization.string("srv_type_sdxl_short") diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageIntent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageIntent.kt index c6db1bf94..bf37cd442 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageIntent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageIntent.kt @@ -418,6 +418,13 @@ sealed interface ImageToImageIntent : MviIntent { * @author Dmitriy Moroz */ data class UpdateFalAiSyncMode(val value: Boolean) : ImageToImageIntent + /** + * Carries `UpdateArliAiModel` data through the SDAI presentation layer. + * + * @param value value consumed by the API. + * @author Dmitriy Moroz + */ + data class UpdateArliAiModel(val value: String) : ImageToImageIntent /** * Carries `UpdateStabilityAiStyle` data through the SDAI presentation layer. * diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageIntentProcessor.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageIntentProcessor.kt index 840555883..6f69e8cc4 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageIntentProcessor.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageIntentProcessor.kt @@ -224,7 +224,7 @@ internal class ImageToImageIntentProcessor( } ?: MIN_STEPS val maxSteps = it.falAiModel.maxInferenceSteps.takeIf { _ -> it.mode == ServerSource.FAL_AI - } ?: MAX_STEPS + } ?: if (it.mode == ServerSource.ARLI_AI) MAX_ARLI_AI_STEPS else MAX_STEPS it.copy(samplingSteps = intent.value.coerceIn(minSteps, maxSteps), message = null) } is ImageToImageIntent.UpdateCfgScale -> updateState { @@ -315,6 +315,9 @@ internal class ImageToImageIntentProcessor( is ImageToImageIntent.UpdateFalAiSyncMode -> updateState { it.copy(falAiSyncMode = intent.value, message = null) } + is ImageToImageIntent.UpdateArliAiModel -> updateState { + it.copy(arliAiModel = intent.value, message = null, error = null) + } is ImageToImageIntent.UpdateStabilityAiStyle -> updateState { it.copy(selectedStylePreset = intent.value, message = null) } @@ -401,3 +404,4 @@ private const val MIN_SIZE = 64 private const val MAX_SIZE = 2048 private const val SIZE_STEP = 64 private const val MAX_FAL_AI_BATCH_COUNT = 4 +private const val MAX_ARLI_AI_STEPS = 40 diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt index 8f6bed591..12304fddc 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt @@ -288,12 +288,9 @@ data class ImageToImageState( * @author Dmitriy Moroz */ override val falAiSyncMode: Boolean = false, - /** - * Exposes the `sdxlBackend` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val sdxlBackend: SdxlBackend = SdxlBackend.AUTO, + override val arliAiModels: List = emptyList(), + override val arliAiModel: String = "", /** * Exposes the `widthValidationError` value used by the SDAI presentation layer. * @@ -354,7 +351,8 @@ data class ImageToImageState( mode == ServerSource.HUGGING_FACE || mode == ServerSource.STABILITY_AI || mode == ServerSource.LOCAL_APPLE_CORE_ML || - mode == ServerSource.FAL_AI + mode == ServerSource.FAL_AI || + mode == ServerSource.ARLI_AI val sourceSupportsInPaint: Boolean get() = mode != ServerSource.LOCAL_APPLE_CORE_ML @@ -427,6 +425,7 @@ internal fun ImageToImageState.mapToPayload( falAiImageSize = falAiImageSize, falAiAcceleration = falAiAcceleration, falAiSyncMode = falAiSyncMode, + arliAiModel = arliAiModel.takeIf { mode == ServerSource.ARLI_AI }.orEmpty(), ) /** diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt index 01ceb7579..93dd7dbce 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModel.kt @@ -6,14 +6,17 @@ import com.shifthackz.aisdv1.core.localization.Localization import com.shifthackz.aisdv1.core.mvi.BaseMviViewModel import com.shifthackz.aisdv1.core.mvi.EmptyEffect import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.ForgeModule import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.Settings +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel import com.shifthackz.aisdv1.domain.entity.StableDiffusionSampler import com.shifthackz.aisdv1.domain.feature.work.BackgroundTaskManager import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.usecase.arliai.FetchAndGetArliAiModelsUseCase import com.shifthackz.aisdv1.domain.usecase.caching.SaveLastResultToCacheUseCase import com.shifthackz.aisdv1.domain.usecase.forgemodule.GetForgeModulesUseCase import com.shifthackz.aisdv1.domain.usecase.generation.GetRandomImageUseCase @@ -65,6 +68,12 @@ class ImageToImageViewModel( * @author Dmitriy Moroz */ private val getForgeModulesUseCase: GetForgeModulesUseCase, + /** + * Exposes the `fetchAndGetArliAiModelsUseCase` value used by the SDAI presentation layer. + * + * @author Dmitriy Moroz + */ + private val fetchAndGetArliAiModelsUseCase: FetchAndGetArliAiModelsUseCase, /** * Exposes the `isADetailerAvailableUseCase` value used by the SDAI presentation layer. * @@ -216,6 +225,8 @@ class ImageToImageViewModel( private var forgeModulesKey: StableDiffusionSamplersKey? = null private var aDetailerAvailable: Boolean? = null private var aDetailerAvailabilityKey: StableDiffusionSamplersKey? = null + private var arliAiModels: List? = null + private var arliAiModelsKey: String? = null private val actionHandler = ImageToImageActionHandler( dispatchersProvider = dispatchersProvider, @@ -269,16 +280,16 @@ class ImageToImageViewModel( ImageToImageIntent.SuppressBenchmarkWarningAndContinue -> actionHandler.continueAfterBenchmarkWarning(suppressFutureWarnings = true) ImageToImageIntent.DismissModal -> actionHandler.dismissBenchmarkDialog() + is ImageToImageIntent.UpdateArliAiModel -> { + preferenceManager.arliAiModel = intent.value + intentProcessor.process(intent) + } else -> intentProcessor.process(intent) } } - private fun applyGenerationResult(ai: com.shifthackz.aisdv1.domain.entity.AiGenerationResult) { - applyGenerationResult(ai, imageBase64 = null) - } - private fun applyGenerationResult( - ai: com.shifthackz.aisdv1.domain.entity.AiGenerationResult, + ai: AiGenerationResult, inputImage: Boolean, ) { applyGenerationResult( @@ -292,7 +303,7 @@ class ImageToImageViewModel( } private fun applyGenerationResult( - ai: com.shifthackz.aisdv1.domain.entity.AiGenerationResult, + ai: AiGenerationResult, imageBase64: String?, ) { updateState { state -> @@ -386,6 +397,7 @@ class ImageToImageViewModel( stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } } @@ -414,12 +426,14 @@ class ImageToImageViewModel( } .collect { settings -> refreshStableDiffusionMetadataIfNeeded(settings) + refreshArliAiModelsIfNeeded(settings) updateState { state -> state.withSettings( settings = settings, stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } } @@ -457,6 +471,7 @@ class ImageToImageViewModel( stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } loadSamplers() @@ -478,6 +493,7 @@ class ImageToImageViewModel( stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } } @@ -503,6 +519,7 @@ class ImageToImageViewModel( stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } } @@ -516,6 +533,7 @@ class ImageToImageViewModel( stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } } @@ -555,4 +573,61 @@ class ImageToImageViewModel( } } } + + private fun refreshArliAiModelsIfNeeded(settings: Settings) { + if (settings.source != ServerSource.ARLI_AI) { + arliAiModelsKey = null + return + } + val key = preferenceManager.arliAiApiKey + if (arliAiModelsKey == key) return + + arliAiModelsKey = key + arliAiModels = emptyList() + updateState { + it.withSource( + source = ServerSource.ARLI_AI, + stableDiffusionSamplers = stableDiffusionSamplers, + forgeModules = forgeModules, + aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, + ) + } + loadArliAiModels() + } + + private fun loadArliAiModels() { + launch(dispatchersProvider.io) { + runCatching { + fetchAndGetArliAiModelsUseCase() + .map(StableDiffusionModel::arliAiCheckpointName) + .filter(String::isNotBlank) + .distinct() + } + .onSuccess { models -> + arliAiModels = models + withContext(dispatchersProvider.immediate) { + updateState { + it.copy(arliAiModel = preferenceManager.arliAiModel) + .withSource( + source = it.mode, + stableDiffusionSamplers = stableDiffusionSamplers, + forgeModules = forgeModules, + aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, + ) + } + } + } + .onFailure { t -> + withContext(dispatchersProvider.immediate) { + updateState { it.copy(error = t.localizedMessageString()) } + } + onError(t) + } + } + } } + +private val StableDiffusionModel.arliAiCheckpointName: String + get() = title.ifBlank { modelName }.ifBlank { filename } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModelHelpers.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModelHelpers.kt index 42e523472..ce19985c1 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModelHelpers.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageViewModelHelpers.kt @@ -6,6 +6,7 @@ import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.core.validation.ValidationResult import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ArliAiSampler import com.shifthackz.aisdv1.domain.entity.ForgeModule import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.Settings @@ -157,6 +158,8 @@ internal fun ImageToImageState.validated( Localization.string("error_img2img_openai_unsupported") !sourceSupportsImageToImage -> Localization.string("error_img2img_local_android_only") + mode == ServerSource.ARLI_AI && arliAiModel.isBlank() -> + Localization.string("error_invalid") imageBase64.isBlank() -> Localization.string("error_img2img_select_input") else -> null @@ -202,8 +205,11 @@ internal fun ImageToImageState.withSettings( stableDiffusionSamplers: List?, forgeModules: List?, aDetailerAvailable: Boolean?, + arliAiModels: List?, ): ImageToImageState = - withSource(settings.source, stableDiffusionSamplers, forgeModules, aDetailerAvailable).copy( + copy(arliAiModel = settings.arliAiModel) + .withSource(settings.source, stableDiffusionSamplers, forgeModules, aDetailerAvailable, arliAiModels) + .copy( advancedToggleButtonVisible = !settings.formAdvancedOptionsAlwaysShow, advancedOptionsVisible = if (settings.formAdvancedOptionsAlwaysShow) { true @@ -227,14 +233,17 @@ internal fun ImageToImageState.withSource( stableDiffusionSamplers: List?, forgeModules: List?, aDetailerAvailable: Boolean?, + arliAiModels: List?, ): ImageToImageState { val samplers = when (source) { ServerSource.STABILITY_AI -> StabilityAiSampler.entries.map { "$it" } + ServerSource.ARLI_AI -> ArliAiSampler.supported else -> stableDiffusionSamplers.orEmpty() } val modules = forgeModules.orEmpty().takeIf { source == ServerSource.AUTOMATIC1111 }.orEmpty() + val arliModels = arliAiModels.orEmpty() return copy( mode = source, availableSamplers = samplers, @@ -243,6 +252,14 @@ internal fun ImageToImageState.withSource( ?: selectedSampler, availableForgeModules = modules, selectedForgeModules = selectedForgeModules.filter(modules::contains), + arliAiModels = arliModels.takeIf { source == ServerSource.ARLI_AI }.orEmpty(), + arliAiModel = if (source == ServerSource.ARLI_AI) { + arliAiModel.takeIf(arliModels::contains) + ?: arliModels.firstOrNull() + ?: arliAiModel + } else { + arliAiModel + }, aDetailerAvailable = if (source == ServerSource.AUTOMATIC1111) { aDetailerAvailable ?: this.aDetailerAvailable } else { diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt index beb51b705..015cce654 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt @@ -103,6 +103,7 @@ private fun Configuration.requiresRemotePreload(): Boolean = when (source) { ServerSource.OPEN_AI, ServerSource.STABILITY_AI, ServerSource.FAL_AI, + ServerSource.ARLI_AI, ServerSource.LOCAL_MICROSOFT_ONNX, ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, ServerSource.LOCAL_STABLE_DIFFUSION_CPP, diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/onboarding/page/ProvidersPageContent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/onboarding/page/ProvidersPageContent.kt index d131f2355..c5226198b 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/onboarding/page/ProvidersPageContent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/onboarding/page/ProvidersPageContent.kt @@ -58,6 +58,7 @@ import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi import com.shifthackz.aisdv1.presentation.widget.frame.PhoneFrame import kotlinx.coroutines.delay import kotlinx.coroutines.launch +import kotlin.time.Duration.Companion.milliseconds @Composable fun ProviderPageContent( @@ -74,6 +75,7 @@ fun ProviderPageContent( ServerSource.HUGGING_FACE, ServerSource.OPEN_AI, ServerSource.STABILITY_AI, + ServerSource.ARLI_AI, ) } var selectedSource by remember { mutableStateOf(ServerSource.STABILITY_AI) } @@ -104,7 +106,7 @@ fun ProviderPageContent( val job = scope.launch { var index = previewSources.indexOf(selectedSource).coerceAtLeast(0) while (isPageVisible) { - delay(1200) + delay(1200.milliseconds) index = (index + 1) % previewSources.size selectedSource = previewSources[index] } @@ -326,6 +328,7 @@ private fun ServerSource.previewTitle(strings: ServerSetupStrings): String = whe ServerSource.HUGGING_FACE -> strings.huggingFaceTitle ServerSource.OPEN_AI -> strings.openAiTitle ServerSource.STABILITY_AI -> strings.stabilityTitle + ServerSource.ARLI_AI -> strings.arliAiTitle else -> strings.automaticTitle } @@ -333,5 +336,6 @@ private fun ServerSource.previewSubtitle(strings: ServerSetupStrings): String = ServerSource.HUGGING_FACE -> strings.huggingFaceSubtitle ServerSource.OPEN_AI -> strings.openAiSubtitle ServerSource.STABILITY_AI -> strings.stabilitySubtitle + ServerSource.ARLI_AI -> strings.arliAiSubtitle else -> strings.automaticSubtitle } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/model/SettingsUiText.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/model/SettingsUiText.kt index 661652a44..54d6bb1f3 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/model/SettingsUiText.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/settings/model/SettingsUiText.kt @@ -28,6 +28,7 @@ internal fun ServerSource.shortTitle(): String = when (this) { ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") + ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") ServerSource.LOCAL_MICROSOFT_ONNX -> Localization.string("srv_type_local_short") ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Localization.string("srv_type_media_pipe_short") ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> Localization.string("srv_type_sdxl_short") diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 621eb76b5..0b4f56144 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -22,6 +22,7 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCa import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalSdxlModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToArliAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToCoreMlUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToFalAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase @@ -56,6 +57,7 @@ import kotlinx.coroutines.flow.distinctUntilChanged import kotlinx.coroutines.withContext import kotlinx.coroutines.withTimeout import kotlin.coroutines.cancellation.CancellationException +import kotlin.time.Duration.Companion.milliseconds /** * Owns provider setup state, validation, local model downloads, and final connection side effects. @@ -118,6 +120,7 @@ class ServerSetupViewModel( private val connectToOpenAiUseCase: ConnectToOpenAiUseCase, private val connectToStabilityAiUseCase: ConnectToStabilityAiUseCase, private val connectToFalAiUseCase: ConnectToFalAiUseCase, + private val connectToArliAiUseCase: ConnectToArliAiUseCase, private val downloadModelUseCase: DownloadModelUseCase, private val deleteModelUseCase: DeleteModelUseCase, private val downloadGuard: ServerSetupDownloadGuard, @@ -163,7 +166,7 @@ class ServerSetupViewModel( emptyList() } val models = runCatching { - withTimeout(HUGGING_FACE_MODELS_TIMEOUT_MILLIS) { + withTimeout(HUGGING_FACE_MODELS_TIMEOUT_MILLIS.milliseconds) { fetchHuggingFaceModelsUseCase() } } @@ -247,6 +250,7 @@ class ServerSetupViewModel( ServerSource.OPEN_AI -> connectToOpenAi() ServerSource.STABILITY_AI -> connectToStabilityAi() ServerSource.FAL_AI -> connectToFalAi() + ServerSource.ARLI_AI -> connectToArliAi() ServerSource.LOCAL_MICROSOFT_ONNX -> connectToLocalDiffusion() ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> connectToMediaPipe() ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> connectToSdxl() @@ -318,6 +322,10 @@ class ServerSetupViewModel( apiKey = currentState.falAiApiKey, ) + private suspend fun connectToArliAi(): Result = connectToArliAiUseCase( + apiKey = currentState.arliAiApiKey, + ) + private suspend fun connectToLocalDiffusion(): Result = connectToLocalDiffusionUseCase( modelId = currentState.localOnnxModels.find { it.selected }?.id.orEmpty(), modelPath = currentState.localOnnxCustomModelPath, diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/component/ServerSetupWidgets.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/component/ServerSetupWidgets.kt index dd7e9979a..8b09e311d 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/component/ServerSetupWidgets.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/component/ServerSetupWidgets.kt @@ -182,6 +182,7 @@ internal val ServerSource.icon: ImageVector ServerSource.OPEN_AI, ServerSource.STABILITY_AI, ServerSource.FAL_AI, + ServerSource.ARLI_AI, -> Icons.Default.Cloud ServerSource.LOCAL_MICROSOFT_ONNX, @@ -200,6 +201,7 @@ internal fun ServerSource.title(strings: ServerSetupStrings): String = when (thi ServerSource.OPEN_AI -> strings.openAiTitle ServerSource.STABILITY_AI -> strings.stabilityTitle ServerSource.FAL_AI -> strings.falAiTitle + ServerSource.ARLI_AI -> strings.arliAiTitle ServerSource.LOCAL_MICROSOFT_ONNX -> strings.localDiffusionTitle ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> strings.mediaPipeTitle ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> strings.sdxlTitle @@ -214,6 +216,7 @@ internal fun ServerSource.subtitle(strings: ServerSetupStrings): String = when ( ServerSource.OPEN_AI -> strings.openAiSubtitle ServerSource.STABILITY_AI -> strings.stabilitySubtitle ServerSource.FAL_AI -> strings.falAiSubtitle + ServerSource.ARLI_AI -> strings.arliAiSubtitle ServerSource.LOCAL_MICROSOFT_ONNX -> strings.localDiffusionSubtitle ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> strings.mediaPipeSubtitle ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> strings.sdxlSubtitle diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/content/ServerSetupStrings.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/content/ServerSetupStrings.kt index 634058602..7ca06af51 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/content/ServerSetupStrings.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/content/ServerSetupStrings.kt @@ -64,6 +64,7 @@ data class ServerSetupStrings( val huggingFaceAbout: String = Localization.string("hint_hugging_face_about"), val openAiAbout: String = Localization.string("hint_open_ai_about"), val falAiAbout: String = Localization.string("hint_fal_ai_about"), + val arliAiAbout: String = Localization.string("hint_arli_ai_about"), val stabilityAbout: String = Localization.string("hint_stability_ai_about"), val automaticTitle: String = Localization.string("srv_type_own"), val automaticFormTitle: String = Localization.string("hint_server_setup_title"), @@ -85,6 +86,8 @@ data class ServerSetupStrings( val openAiSubtitle: String = Localization.string("hint_open_ai_sub_title"), val falAiTitle: String = Localization.string("hint_fal_ai_title"), val falAiSubtitle: String = Localization.string("hint_fal_ai_sub_title"), + val arliAiTitle: String = Localization.string("hint_arli_ai_title"), + val arliAiSubtitle: String = Localization.string("hint_arli_ai_sub_title"), val stabilityTitle: String = Localization.string("hint_stability_ai_title"), val stabilitySubtitle: String = Localization.string("hint_stability_ai_sub_title"), val localDiffusionTitle: String = Localization.string("hint_local_diffusion_title"), diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/local/ServerSetupLocalForms.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/local/ServerSetupLocalForms.kt index 072fa4b16..5bfef754b 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/local/ServerSetupLocalForms.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/local/ServerSetupLocalForms.kt @@ -28,6 +28,7 @@ import com.shifthackz.aisdv1.presentation.screen.setup.component.SwitchRow import com.shifthackz.aisdv1.presentation.screen.setup.component.isCustom import com.shifthackz.aisdv1.presentation.screen.setup.component.message import com.shifthackz.aisdv1.presentation.screen.setup.content.ServerSetupStrings +import com.shifthackz.aisdv1.presentation.screen.setup.form.remote.ArliAiForm import com.shifthackz.aisdv1.presentation.screen.setup.form.remote.Automatic1111Form import com.shifthackz.aisdv1.presentation.screen.setup.form.remote.FalAiForm import com.shifthackz.aisdv1.presentation.screen.setup.form.remote.FormTitle @@ -96,6 +97,12 @@ internal fun ConfigurationStep( processIntent = processIntent, ) + ServerSource.ARLI_AI -> ArliAiForm( + state = state, + strings = strings, + processIntent = processIntent, + ) + ServerSource.STABILITY_AI -> StabilityAiForm( state = state, strings = strings, diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/remote/ServerSetupRemoteForms.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/remote/ServerSetupRemoteForms.kt index 13b64068e..79eb75f07 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/remote/ServerSetupRemoteForms.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/form/remote/ServerSetupRemoteForms.kt @@ -233,6 +233,34 @@ internal fun FalAiForm( } } +@Composable +internal fun ArliAiForm( + state: ServerSetupState, + strings: ServerSetupStrings, + processIntent: (ServerSetupIntent) -> Unit, +) { + RemoteFormScaffold( + title = strings.arliAiTitle, + subtitle = strings.arliAiSubtitle, + ) { + SetupTextField( + value = state.arliAiApiKey, + onValueChange = { processIntent(ServerSetupIntent.UpdateArliAiApiKey(it)) }, + label = strings.apiKey, + keyboardType = KeyboardType.Password, + error = state.arliAiApiKeyValidationError?.message(strings), + ) + SettingsItem( + modifier = Modifier + .fillMaxWidth() + .padding(top = 8.dp), + startIcon = Icons.Default.Cloud, + text = strings.arliAiAbout.asUiText(), + onClick = { processIntent(ServerSetupIntent.LaunchUrl(ServerSetupLink.ArliAiInfo)) }, + ) + } +} + @Composable internal fun StabilityAiForm( state: ServerSetupState, diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupIntent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupIntent.kt index 18a811598..a57c0a0da 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupIntent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupIntent.kt @@ -30,6 +30,7 @@ sealed interface ServerSetupIntent : MviIntent { data class UpdateOpenAiApiKey(val key: String) : ServerSetupIntent data class UpdateStabilityAiApiKey(val key: String) : ServerSetupIntent data class UpdateFalAiApiKey(val key: String) : ServerSetupIntent + data class UpdateArliAiApiKey(val key: String) : ServerSetupIntent data class UpdateHuggingFaceApiKey(val key: String) : ServerSetupIntent data class UpdateHuggingFaceModel(val model: String) : ServerSetupIntent data class UpdateDemoMode(val value: Boolean) : ServerSetupIntent @@ -65,4 +66,5 @@ enum class ServerSetupLink { OpenAiInfo, StabilityAiInfo, FalAiInfo, + ArliAiInfo, } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupState.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupState.kt index 07cfbe62c..3c3e1423b 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupState.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/model/ServerSetupState.kt @@ -40,6 +40,7 @@ data class ServerSetupState( val openAiApiKey: String = "", val stabilityAiApiKey: String = "", val falAiApiKey: String = "", + val arliAiApiKey: String = "", val hordeDefaultApiKey: Boolean = false, val demoMode: Boolean = false, val demoModeUrl: String = "", @@ -69,6 +70,7 @@ data class ServerSetupState( val openAiApiKeyValidationError: ValidationError? = null, val stabilityAiApiKeyValidationError: ValidationError? = null, val falAiApiKeyValidationError: ValidationError? = null, + val arliAiApiKeyValidationError: ValidationError? = null, val localCustomOnnxPathValidationError: ValidationError? = null, val localCustomMediaPipePathValidationError: ValidationError? = null, val localCustomSdxlPathValidationError: ValidationError? = null, @@ -312,6 +314,7 @@ data class ServerSetupState( ServerSource.OPEN_AI, ServerSource.STABILITY_AI, ServerSource.FAL_AI, + ServerSource.ARLI_AI, ) } } @@ -366,6 +369,7 @@ fun Configuration.toServerSetupState( openAiApiKey = openAiApiKey, stabilityAiApiKey = stabilityAiApiKey, falAiApiKey = falAiApiKey, + arliAiApiKey = arliAiApiKey, demoMode = demoMode, demoModeUrl = demoModeUrl, ).withCredentials( diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/reducer/ServerSetupIntentProcessor.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/reducer/ServerSetupIntentProcessor.kt index b8dcb354d..7e71436d7 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/reducer/ServerSetupIntentProcessor.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/reducer/ServerSetupIntentProcessor.kt @@ -181,6 +181,10 @@ internal class ServerSetupIntentProcessor( it.copy(falAiApiKey = intent.key, falAiApiKeyValidationError = null) } + is ServerSetupIntent.UpdateArliAiApiKey -> updateState { + it.copy(arliAiApiKey = intent.key, arliAiApiKeyValidationError = null) + } + is ServerSetupIntent.UpdateHuggingFaceApiKey -> updateState { it.copy(huggingFaceApiKey = intent.key, huggingFaceApiKeyValidationError = null) } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/validation/ServerSetupValidation.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/validation/ServerSetupValidation.kt index c54fb5515..b6a749c74 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/validation/ServerSetupValidation.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/validation/ServerSetupValidation.kt @@ -68,6 +68,12 @@ internal fun ServerSetupState.validateServerSetup( update = { error -> copy(falAiApiKeyValidationError = error) }, ) + ServerSource.ARLI_AI -> validateApiKey( + key = arliAiApiKey, + stringValidator = stringValidator, + update = { error -> copy(arliAiApiKeyValidationError = error) }, + ) + ServerSource.LOCAL_MICROSOFT_ONNX -> validateLocalModel( customModel = localOnnxCustomModel, customModelPath = localOnnxCustomModelPath, diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/viewmodel/ServerSetupViewModelHelpers.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/viewmodel/ServerSetupViewModelHelpers.kt index cd5a44558..1c3276dd2 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/viewmodel/ServerSetupViewModelHelpers.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/setup/viewmodel/ServerSetupViewModelHelpers.kt @@ -26,6 +26,7 @@ internal fun ServerSetupLink.url(linksProvider: LinksProvider): String = when (t ServerSetupLink.OpenAiInfo -> linksProvider.openAiInfoUrl ServerSetupLink.StabilityAiInfo -> linksProvider.stabilityAiInfoUrl ServerSetupLink.FalAiInfo -> linksProvider.falAiInfoUrl + ServerSetupLink.ArliAiInfo -> linksProvider.arliAiInfoUrl } internal fun ValidationResult.mapStringToValidationError(): diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageContent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageContent.kt index b23bfe4a0..7c43c0ab8 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageContent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageContent.kt @@ -402,6 +402,7 @@ internal fun GenerationInputFormEvent.toTextToImageIntent(): TextToImageIntent? is GenerationInputFormEvent.UpdateFalAiAcceleration -> TextToImageIntent.UpdateFalAiAcceleration(value) is GenerationInputFormEvent.UpdateSdxlBackend -> TextToImageIntent.UpdateSdxlBackend(value) is GenerationInputFormEvent.UpdateFalAiSyncMode -> TextToImageIntent.UpdateFalAiSyncMode(value) + is GenerationInputFormEvent.UpdateArliAiModel -> TextToImageIntent.UpdateArliAiModel(value) is GenerationInputFormEvent.UpdatePrompt -> TextToImageIntent.UpdatePrompt(value) is GenerationInputFormEvent.UpdateRestoreFaces -> TextToImageIntent.UpdateRestoreFaces(value) is GenerationInputFormEvent.UpdateSampler -> TextToImageIntent.UpdateSampler(value) diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageControls.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageControls.kt index bad3976e5..cc9fa39e0 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageControls.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageControls.kt @@ -308,6 +308,7 @@ internal val ServerSource.displayName: String ServerSource.OPEN_AI -> Localization.string("srv_type_open_ai") ServerSource.STABILITY_AI -> Localization.string("srv_type_stability_ai") ServerSource.FAL_AI -> Localization.string("srv_type_fal_ai") + ServerSource.ARLI_AI -> Localization.string("srv_type_arli_ai") ServerSource.LOCAL_MICROSOFT_ONNX -> Localization.string("srv_type_local_short") ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Localization.string("srv_type_media_pipe_short") ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> Localization.string("srv_type_sdxl_short") diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntent.kt index f98fe60c3..d3499bd75 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntent.kt @@ -380,6 +380,13 @@ sealed interface TextToImageIntent : MviIntent { * @author Dmitriy Moroz */ data class UpdateFalAiSyncMode(val value: Boolean) : TextToImageIntent + /** + * Carries `UpdateArliAiModel` data through the SDAI presentation layer. + * + * @param value value consumed by the API. + * @author Dmitriy Moroz + */ + data class UpdateArliAiModel(val value: String) : TextToImageIntent /** * Carries `UpdateStabilityAiStyle` data through the SDAI presentation layer. * diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntentProcessor.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntentProcessor.kt index 72903883c..225bc84d8 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntentProcessor.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageIntentProcessor.kt @@ -174,7 +174,7 @@ internal class TextToImageIntentProcessor( } ?: MIN_STEPS val maxSteps = it.falAiModel.maxInferenceSteps.takeIf { _ -> it.mode == ServerSource.FAL_AI - } ?: MAX_STEPS + } ?: if (it.mode == ServerSource.ARLI_AI) MAX_ARLI_AI_STEPS else MAX_STEPS it.copy(samplingSteps = intent.value.coerceIn(minSteps, maxSteps), message = null) } is TextToImageIntent.UpdateCfgScale -> updateState { @@ -273,6 +273,9 @@ internal class TextToImageIntentProcessor( is TextToImageIntent.UpdateFalAiSyncMode -> updateState { it.copy(falAiSyncMode = intent.value, message = null) } + is TextToImageIntent.UpdateArliAiModel -> updateState { + it.copy(arliAiModel = intent.value, message = null, error = null) + } is TextToImageIntent.UpdateStabilityAiStyle -> updateState { it.copy(selectedStylePreset = intent.value, message = null) } @@ -314,3 +317,4 @@ private const val MIN_SIZE = 64 private const val MAX_SIZE = 2048 private const val SIZE_STEP = 64 private const val MAX_FAL_AI_BATCH_COUNT = 4 +private const val MAX_ARLI_AI_STEPS = 40 diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt index efddb8079..9c6f76f4a 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt @@ -264,12 +264,9 @@ data class TextToImageState( * @author Dmitriy Moroz */ override val falAiSyncMode: Boolean = false, - /** - * Exposes the `sdxlBackend` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ override val sdxlBackend: SdxlBackend = SdxlBackend.AUTO, + override val arliAiModels: List = emptyList(), + override val arliAiModel: String = "", /** * Exposes the `widthValidationError` value used by the SDAI presentation layer. * @@ -399,6 +396,7 @@ internal fun TextToImageState.mapToPayload(): TextToImagePayload = TextToImagePa mode == ServerSource.LOCAL_STABLE_DIFFUSION_CPP } ?: SdxlBackend.AUTO, falAiSyncMode = falAiSyncMode, + arliAiModel = arliAiModel.takeIf { mode == ServerSource.ARLI_AI }.orEmpty(), ) /** diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt index e34b12417..be295d59f 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt @@ -9,11 +9,13 @@ import com.shifthackz.aisdv1.domain.entity.ForgeModule import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.Settings +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel import com.shifthackz.aisdv1.domain.entity.StableDiffusionSampler import com.shifthackz.aisdv1.domain.feature.work.BackgroundTaskManager import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.usecase.arliai.FetchAndGetArliAiModelsUseCase import com.shifthackz.aisdv1.domain.usecase.caching.SaveLastResultToCacheUseCase import com.shifthackz.aisdv1.domain.usecase.forgemodule.GetForgeModulesUseCase import com.shifthackz.aisdv1.domain.usecase.generation.InterruptGenerationUseCase @@ -64,6 +66,12 @@ class TextToImageViewModel( * @author Dmitriy Moroz */ private val getForgeModulesUseCase: GetForgeModulesUseCase, + /** + * Exposes the `fetchAndGetArliAiModelsUseCase` value used by the SDAI presentation layer. + * + * @author Dmitriy Moroz + */ + private val fetchAndGetArliAiModelsUseCase: FetchAndGetArliAiModelsUseCase, /** * Exposes the `isADetailerAvailableUseCase` value used by the SDAI presentation layer. * @@ -215,6 +223,8 @@ class TextToImageViewModel( private var forgeModulesKey: StableDiffusionSamplersKey? = null private var aDetailerAvailable: Boolean? = null private var aDetailerAvailabilityKey: StableDiffusionSamplersKey? = null + private var arliAiModels: List? = null + private var arliAiModelsKey: String? = null private val actionHandler = TextToImageActionHandler( dispatchersProvider = dispatchersProvider, @@ -264,6 +274,10 @@ class TextToImageViewModel( TextToImageIntent.SuppressBenchmarkWarningAndContinue -> actionHandler.continueAfterBenchmarkWarning(suppressFutureWarnings = true) TextToImageIntent.DismissModal -> actionHandler.dismissBenchmarkDialog() + is TextToImageIntent.UpdateArliAiModel -> { + preferenceManager.arliAiModel = intent.value + intentProcessor.process(intent) + } else -> intentProcessor.process(intent) } } @@ -361,6 +375,7 @@ class TextToImageViewModel( stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } } @@ -389,12 +404,14 @@ class TextToImageViewModel( } .collect { settings -> refreshStableDiffusionMetadataIfNeeded(settings) + refreshArliAiModelsIfNeeded(settings) updateState { state -> state.withSettings( settings = settings, stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } } @@ -432,6 +449,7 @@ class TextToImageViewModel( stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } loadSamplers() @@ -453,6 +471,7 @@ class TextToImageViewModel( stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } } @@ -478,6 +497,7 @@ class TextToImageViewModel( stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } } @@ -491,6 +511,7 @@ class TextToImageViewModel( stableDiffusionSamplers = stableDiffusionSamplers, forgeModules = forgeModules, aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, ) } } @@ -530,4 +551,61 @@ class TextToImageViewModel( } } } + + private fun refreshArliAiModelsIfNeeded(settings: Settings) { + if (settings.source != ServerSource.ARLI_AI) { + arliAiModelsKey = null + return + } + val key = preferenceManager.arliAiApiKey + if (arliAiModelsKey == key) return + + arliAiModelsKey = key + arliAiModels = emptyList() + updateState { + it.withSource( + source = ServerSource.ARLI_AI, + stableDiffusionSamplers = stableDiffusionSamplers, + forgeModules = forgeModules, + aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, + ) + } + loadArliAiModels() + } + + private fun loadArliAiModels() { + launch(dispatchersProvider.io) { + runCatching { + fetchAndGetArliAiModelsUseCase() + .map(StableDiffusionModel::arliAiCheckpointName) + .filter(String::isNotBlank) + .distinct() + } + .onSuccess { models -> + arliAiModels = models + withContext(dispatchersProvider.immediate) { + updateState { + it.copy(arliAiModel = preferenceManager.arliAiModel) + .withSource( + source = it.mode, + stableDiffusionSamplers = stableDiffusionSamplers, + forgeModules = forgeModules, + aDetailerAvailable = aDetailerAvailable, + arliAiModels = arliAiModels, + ) + } + } + } + .onFailure { t -> + withContext(dispatchersProvider.immediate) { + updateState { it.copy(error = t.localizedMessageText()) } + } + onError(t) + } + } + } } + +private val StableDiffusionModel.arliAiCheckpointName: String + get() = title.ifBlank { modelName }.ifBlank { filename } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModelHelpers.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModelHelpers.kt index 48df557fc..599dca394 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModelHelpers.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModelHelpers.kt @@ -6,6 +6,7 @@ import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.core.validation.ValidationResult import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.ArliAiSampler import com.shifthackz.aisdv1.domain.entity.ForgeModule import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.Settings @@ -133,11 +134,15 @@ internal fun TextToImageState.validated( val validateDimensions = mode != ServerSource.OPEN_AI val widthResult = dimensionValidator(width).takeIf { validateDimensions } val heightResult = dimensionValidator(height).takeIf { validateDimensions } + val sourceError = when { + mode == ServerSource.ARLI_AI && arliAiModel.isBlank() -> Localization.string("error_invalid").asUiText() + else -> null + } return copy( promptValidationError = null, widthValidationError = widthResult?.errorMessage(), heightValidationError = heightResult?.errorMessage(), - error = null, + error = sourceError, ) } @@ -186,8 +191,11 @@ internal fun TextToImageState.withSettings( stableDiffusionSamplers: List?, forgeModules: List?, aDetailerAvailable: Boolean?, + arliAiModels: List?, ): TextToImageState = - withSource(settings.source, stableDiffusionSamplers, forgeModules, aDetailerAvailable).copy( + copy(arliAiModel = settings.arliAiModel) + .withSource(settings.source, stableDiffusionSamplers, forgeModules, aDetailerAvailable, arliAiModels) + .copy( advancedToggleButtonVisible = !settings.formAdvancedOptionsAlwaysShow, advancedOptionsVisible = if (settings.formAdvancedOptionsAlwaysShow) { true @@ -211,14 +219,17 @@ internal fun TextToImageState.withSource( stableDiffusionSamplers: List?, forgeModules: List?, aDetailerAvailable: Boolean?, + arliAiModels: List?, ): TextToImageState { val samplers = when (source) { ServerSource.STABILITY_AI -> StabilityAiSampler.entries.map { "$it" } + ServerSource.ARLI_AI -> ArliAiSampler.supported else -> stableDiffusionSamplers.orEmpty() } val modules = forgeModules.orEmpty().takeIf { source == ServerSource.AUTOMATIC1111 }.orEmpty() + val arliModels = arliAiModels.orEmpty() return copy( mode = source, availableSamplers = samplers, @@ -227,6 +238,14 @@ internal fun TextToImageState.withSource( ?: selectedSampler, availableForgeModules = modules, selectedForgeModules = selectedForgeModules.filter(modules::contains), + arliAiModels = arliModels.takeIf { source == ServerSource.ARLI_AI }.orEmpty(), + arliAiModel = if (source == ServerSource.ARLI_AI) { + arliAiModel.takeIf(arliModels::contains) + ?: arliModels.firstOrNull() + ?: arliAiModel + } else { + arliAiModel + }, aDetailerAvailable = if (source == ServerSource.AUTOMATIC1111) { aDetailerAvailable ?: this.aDetailerAvailable } else { diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt index 3748fb4fe..98e2f9c5b 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt @@ -139,6 +139,7 @@ fun EngineSelectionContent( ServerSource.HORDE, ServerSource.OPEN_AI, ServerSource.FAL_AI, + ServerSource.ARLI_AI, -> Unit } } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt index ab204bbed..2f33afd94 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt @@ -7,6 +7,9 @@ import com.shifthackz.aisdv1.domain.entity.Configuration import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource +import com.shifthackz.aisdv1.domain.entity.StabilityAiEngine +import com.shifthackz.aisdv1.domain.entity.StableDiffusionModel +import com.shifthackz.aisdv1.domain.entity.SwarmUiModel import com.shifthackz.aisdv1.domain.feature.coreml.CoreMlModelSupport import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalCoreMlModelsUseCase @@ -24,6 +27,7 @@ import kotlinx.coroutines.flow.collectLatest import kotlinx.coroutines.flow.combine import kotlinx.coroutines.flow.map import kotlinx.coroutines.withTimeoutOrNull +import kotlin.time.Duration.Companion.milliseconds /** * Coordinates `EngineSelectionViewModel` behavior in the SDAI presentation layer. @@ -238,6 +242,7 @@ class EngineSelectionViewModel( ServerSource.HORDE, ServerSource.OPEN_AI, ServerSource.FAL_AI, + ServerSource.ARLI_AI, -> remoteOptions } } @@ -290,7 +295,7 @@ class EngineSelectionViewModel( } private suspend fun loadList(block: suspend () -> List): List = runCatching { - withTimeoutOrNull(REMOTE_OPTIONS_TIMEOUT_MILLIS) { + withTimeoutOrNull(REMOTE_OPTIONS_TIMEOUT_MILLIS.milliseconds) { block() } ?: emptyList() } @@ -375,23 +380,23 @@ private data class RemoteOptions( * * @author Dmitriy Moroz */ - val sdModels: List> = emptyList(), + val sdModels: List> = emptyList(), /** * Exposes the `swarmModels` value used by the SDAI presentation layer. * * @author Dmitriy Moroz */ - val swarmModels: List = emptyList(), + val swarmModels: List = emptyList(), /** * Exposes the `hfModels` value used by the SDAI presentation layer. * * @author Dmitriy Moroz */ - val hfModels: List = emptyList(), + val hfModels: List = emptyList(), /** * Exposes the `stEngines` value used by the SDAI presentation layer. * * @author Dmitriy Moroz */ - val stEngines: List = emptyList(), + val stEngines: List = emptyList(), ) diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputAdvancedOptions.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputAdvancedOptions.kt index 3ba8755cd..9ecc84be2 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputAdvancedOptions.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputAdvancedOptions.kt @@ -31,6 +31,7 @@ import com.shifthackz.aisdv1.presentation.theme.textFieldColors import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormConstants.CFG_SCALE_RANGE_MAX import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormConstants.CFG_SCALE_RANGE_MIN import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormConstants.SAMPLING_STEPS_LOCAL_DIFFUSION_MAX +import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormConstants.SAMPLING_STEPS_RANGE_ARLI_AI_MAX import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormConstants.SAMPLING_STEPS_RANGE_MAX import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormConstants.SAMPLING_STEPS_RANGE_MIN import com.shifthackz.aisdv1.presentation.widget.input.GenerationInputFormConstants.SAMPLING_STEPS_RANGE_STABILITY_AI_MAX @@ -64,6 +65,7 @@ internal fun GenerationInputAdvancedOptions( // Sampler selection only supported for A1111, STABILITY AI when (state.mode) { ServerSource.STABILITY_AI, + ServerSource.ARLI_AI, ServerSource.AUTOMATIC1111 -> DropdownTextField( modifier = Modifier .fillMaxWidth() @@ -297,6 +299,7 @@ internal fun GenerationInputAdvancedOptions( ServerSource.LOCAL_STABLE_DIFFUSION_CPP -> SAMPLING_STEPS_LOCAL_DIFFUSION_MAX ServerSource.STABILITY_AI -> SAMPLING_STEPS_RANGE_STABILITY_AI_MAX ServerSource.FAL_AI -> state.falAiModel.maxInferenceSteps + ServerSource.ARLI_AI -> SAMPLING_STEPS_RANGE_ARLI_AI_MAX else -> SAMPLING_STEPS_RANGE_MAX } val steps = state.samplingSteps.coerceIn(stepsMin, stepsMax) @@ -359,6 +362,7 @@ internal fun GenerationInputAdvancedOptions( ServerSource.AUTOMATIC1111, ServerSource.SWARM_UI, ServerSource.STABILITY_AI, + ServerSource.ARLI_AI, ServerSource.HORDE, ServerSource.LOCAL_APPLE_CORE_ML -> afterSlidersSection() diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt index c64404111..089259f23 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt @@ -91,6 +91,16 @@ fun GenerationInputForm( displayDelegate = { it.displayName.asUiText() }, ) + ServerSource.ARLI_AI -> DropdownTextField( + modifier = Modifier.padding(top = 8.dp), + label = Localization.string("hint_arli_ai_model").asUiText(), + value = state.arliAiModel.takeIf(String::isNotBlank), + items = state.arliAiModels.ifEmpty { + listOfNotNull(state.arliAiModel.takeIf(String::isNotBlank)) + }, + onItemSelected = { onEvent(GenerationInputFormEvent.UpdateArliAiModel(it)) }, + ) + else -> Unit } @@ -147,6 +157,7 @@ fun GenerationInputForm( ServerSource.SWARM_UI, ServerSource.HUGGING_FACE, ServerSource.STABILITY_AI, + ServerSource.ARLI_AI, ServerSource.LOCAL_MICROSOFT_ONNX, ServerSource.LOCAL_STABLE_DIFFUSION_CPP, ServerSource.LOCAL_APPLE_CORE_ML -> { @@ -225,7 +236,9 @@ fun GenerationInputForm( ServerSource.AUTOMATIC1111, ServerSource.SWARM_UI, - ServerSource.HUGGING_FACE -> { + ServerSource.HUGGING_FACE, + ServerSource.ARLI_AI, + -> { GenerationSizeTextFieldsComponent(modifier = localModifier, state = state, onEvent = onEvent) } diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormConstants.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormConstants.kt index f2774d31e..31c1f20da 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormConstants.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormConstants.kt @@ -37,6 +37,12 @@ internal object GenerationInputFormConstants { * @author Dmitriy Moroz */ const val SAMPLING_STEPS_RANGE_STABILITY_AI_MAX = 50 + /** + * Exposes the `SAMPLING_STEPS_RANGE_ARLI_AI_MAX` value used by the SDAI presentation layer. + * + * @author Dmitriy Moroz + */ + const val SAMPLING_STEPS_RANGE_ARLI_AI_MAX = 40 /** * Exposes the `SAMPLING_STEPS_RANGE_FAL_AI_MAX` value used by the SDAI presentation layer. * diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormEvent.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormEvent.kt index eede007c4..d4017f527 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormEvent.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormEvent.kt @@ -233,6 +233,13 @@ sealed interface GenerationInputFormEvent { * @author Dmitriy Moroz */ data class UpdateFalAiSyncMode(val value: Boolean) : GenerationInputFormEvent + /** + * Carries `UpdateArliAiModel` data through the SDAI presentation layer. + * + * @param value value consumed by the API. + * @author Dmitriy Moroz + */ + data class UpdateArliAiModel(val value: String) : GenerationInputFormEvent /** * Carries `UpdateStabilityAiStyle` data through the SDAI presentation layer. * diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormState.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormState.kt index f87ef99b2..06b2af7a3 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormState.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputFormState.kt @@ -196,12 +196,9 @@ interface GenerationInputFormState { * @author Dmitriy Moroz */ val falAiSyncMode: Boolean - /** - * Exposes the `sdxlBackend` value used by the SDAI presentation layer. - * - * @author Dmitriy Moroz - */ val sdxlBackend: SdxlBackend + val arliAiModels: List + val arliAiModel: String /** * Exposes the `widthValidationError` value used by the SDAI presentation layer. * diff --git a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt index 0704cb6e0..527ee5c56 100644 --- a/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt +++ b/presentation/src/commonMain/kotlin/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt @@ -32,6 +32,7 @@ fun ServerSource.getNameUiText(): UiText = Localization.string( ServerSource.OPEN_AI -> "srv_type_open_ai" ServerSource.STABILITY_AI -> "srv_type_stability_ai" ServerSource.FAL_AI -> "srv_type_fal_ai" + ServerSource.ARLI_AI -> "srv_type_arli_ai" ServerSource.SWARM_UI -> "srv_type_swarm_ui" }, ).asUiText() diff --git a/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json b/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json index 2f2735251..1b71c4c5b 100644 --- a/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json +++ b/storage/schemas/com.shifthackz.aisdv1.storage.db.cache.CacheDatabase/1.json @@ -2,7 +2,7 @@ "formatVersion": 1, "database": { "version": 1, - "identityHash": "5dc36f6910330fd3f65e65c4e15e56c2", + "identityHash": "5ee09647ab10b611e1d4e4cb867a8e19", "entities": [ { "tableName": "server_config", @@ -26,9 +26,7 @@ "columnNames": [ "server_id" ] - }, - "indices": [], - "foreignKeys": [] + } }, { "tableName": "sd_models", @@ -82,9 +80,7 @@ "columnNames": [ "id" ] - }, - "indices": [], - "foreignKeys": [] + } }, { "tableName": "sd_samplers", @@ -120,9 +116,7 @@ "columnNames": [ "id" ] - }, - "indices": [], - "foreignKeys": [] + } }, { "tableName": "loras", @@ -158,9 +152,7 @@ "columnNames": [ "id" ] - }, - "indices": [], - "foreignKeys": [] + } }, { "tableName": "hyper_networks", @@ -190,9 +182,7 @@ "columnNames": [ "id" ] - }, - "indices": [], - "foreignKeys": [] + } }, { "tableName": "embeddings", @@ -216,9 +206,7 @@ "columnNames": [ "id" ] - }, - "indices": [], - "foreignKeys": [] + } }, { "tableName": "swarm_models", @@ -254,15 +242,66 @@ "columnNames": [ "id" ] - }, - "indices": [], - "foreignKeys": [] + } + }, + { + "tableName": "arli_ai_models", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` TEXT NOT NULL, `title` TEXT NOT NULL, `name` TEXT NOT NULL, `hash` TEXT NOT NULL, `sha256` TEXT NOT NULL, `filename` TEXT NOT NULL, `config` TEXT NOT NULL, PRIMARY KEY(`id`))", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "title", + "columnName": "title", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "hash", + "columnName": "hash", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "sha256", + "columnName": "sha256", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "filename", + "columnName": "filename", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "config", + "columnName": "config", + "affinity": "TEXT", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": false, + "columnNames": [ + "id" + ] + } } ], - "views": [], "setupQueries": [ "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", - "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '5dc36f6910330fd3f65e65c4e15e56c2')" + "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '5ee09647ab10b611e1d4e4cb867a8e19')" ] } } \ No newline at end of file diff --git a/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt b/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt index 10e65190a..da82abdaf 100755 --- a/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt +++ b/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/CacheDatabase.kt @@ -8,6 +8,7 @@ import androidx.room.TypeConverters import com.shifthackz.aisdv1.storage.converters.ListConverters import com.shifthackz.aisdv1.storage.converters.MapConverters import com.shifthackz.aisdv1.storage.db.cache.CacheDatabase.Companion.DB_VERSION +import com.shifthackz.aisdv1.storage.db.cache.dao.ArliAiModelDao import com.shifthackz.aisdv1.storage.db.cache.dao.ServerConfigurationDao import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionEmbeddingDao import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionHyperNetworkDao @@ -15,6 +16,7 @@ import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionLoraDao import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionModelDao import com.shifthackz.aisdv1.storage.db.cache.dao.StableDiffusionSamplerDao import com.shifthackz.aisdv1.storage.db.cache.dao.SwarmUiModelDao +import com.shifthackz.aisdv1.storage.db.cache.entity.ArliAiModelEntity import com.shifthackz.aisdv1.storage.db.cache.entity.ServerConfigurationEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionEmbeddingEntity import com.shifthackz.aisdv1.storage.db.cache.entity.StableDiffusionHyperNetworkEntity @@ -39,6 +41,7 @@ import com.shifthackz.aisdv1.storage.db.cache.entity.SwarmUiModelEntity StableDiffusionHyperNetworkEntity::class, StableDiffusionEmbeddingEntity::class, SwarmUiModelEntity::class, + ArliAiModelEntity::class, ], ) @TypeConverters( @@ -96,6 +99,13 @@ internal abstract class CacheDatabase : RoomDatabase() { * @author Dmitriy Moroz */ abstract fun swarmUiModelDao(): SwarmUiModelDao + /** + * Executes the `arliAiModelDao` step in the SDAI storage layer. + * + * @return Result produced by `arliAiModelDao`. + * @author Dmitriy Moroz + */ + abstract fun arliAiModelDao(): ArliAiModelDao /** * Provides the `companion object` singleton used by the SDAI storage layer. @@ -117,7 +127,7 @@ internal abstract class CacheDatabase : RoomDatabase() { * * @author Dmitriy Moroz */ -@Suppress("KotlinNoActualForExpect") +@Suppress("EXPECT_ACTUAL_CLASSIFIERS_ARE_IN_BETA_WARNING") internal expect object CacheDatabaseConstructor : RoomDatabaseConstructor { /** * Executes the `initialize` step in the SDAI storage layer. diff --git a/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/contract/ArliAiModelContract.kt b/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/contract/ArliAiModelContract.kt new file mode 100644 index 000000000..86d07bf01 --- /dev/null +++ b/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/contract/ArliAiModelContract.kt @@ -0,0 +1,58 @@ +package com.shifthackz.aisdv1.storage.db.cache.contract + +/** + * Defines the Room table and columns for cached ArliAI checkpoints. + * + * @author Dmitriy Moroz + */ +internal object ArliAiModelContract { + /** + * Room table containing ArliAI checkpoint metadata. + * + * @author Dmitriy Moroz + */ + const val TABLE = "arli_ai_models" + + /** + * Stable checkpoint identifier used as the primary key. + * + * @author Dmitriy Moroz + */ + const val ID = "id" + /** + * Provider-facing checkpoint title. + * + * @author Dmitriy Moroz + */ + const val TITLE = "title" + /** + * Provider-facing checkpoint model name. + * + * @author Dmitriy Moroz + */ + const val NAME = "name" + /** + * Short provider hash when ArliAI returns one. + * + * @author Dmitriy Moroz + */ + const val HASH = "hash" + /** + * Full SHA-256 model hash when ArliAI returns one. + * + * @author Dmitriy Moroz + */ + const val SHA256 = "sha256" + /** + * Provider checkpoint filename. + * + * @author Dmitriy Moroz + */ + const val FILENAME = "filename" + /** + * Provider checkpoint config filename. + * + * @author Dmitriy Moroz + */ + const val CONFIG = "config" +} diff --git a/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/dao/ArliAiModelDao.kt b/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/dao/ArliAiModelDao.kt new file mode 100644 index 000000000..26e27fee8 --- /dev/null +++ b/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/dao/ArliAiModelDao.kt @@ -0,0 +1,45 @@ +package com.shifthackz.aisdv1.storage.db.cache.dao + +import androidx.room.Dao +import androidx.room.Insert +import androidx.room.OnConflictStrategy +import androidx.room.Query +import com.shifthackz.aisdv1.storage.db.cache.contract.ArliAiModelContract +import com.shifthackz.aisdv1.storage.db.cache.entity.ArliAiModelEntity + +/** + * Provides Room access to cached ArliAI checkpoint metadata. + * + * @author Dmitriy Moroz + */ +@Dao +interface ArliAiModelDao { + + /** + * Reads all cached ArliAI checkpoints. + * + * @return rows currently stored in the ArliAI model cache table. + * + * @author Dmitriy Moroz + */ + @Query("SELECT * FROM ${ArliAiModelContract.TABLE}") + suspend fun queryAll(): List + + /** + * Inserts or replaces cached ArliAI checkpoints. + * + * @param items rows produced from the latest provider model list. + * + * @author Dmitriy Moroz + */ + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertList(items: List) + + /** + * Clears all cached ArliAI checkpoints before a refresh writes the new list. + * + * @author Dmitriy Moroz + */ + @Query("DELETE FROM ${ArliAiModelContract.TABLE}") + suspend fun deleteAll() +} diff --git a/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/entity/ArliAiModelEntity.kt b/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/entity/ArliAiModelEntity.kt new file mode 100644 index 000000000..e8252a7ec --- /dev/null +++ b/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/db/cache/entity/ArliAiModelEntity.kt @@ -0,0 +1,65 @@ +package com.shifthackz.aisdv1.storage.db.cache.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.PrimaryKey +import com.shifthackz.aisdv1.storage.db.cache.contract.ArliAiModelContract + +/** + * Stores one ArliAI checkpoint in the local cache database. + * + * @author Dmitriy Moroz + */ +@Entity(tableName = ArliAiModelContract.TABLE) +data class ArliAiModelEntity( + /** + * Exposes the `id` value used by the SDAI storage layer. + * + * @author Dmitriy Moroz + */ + @PrimaryKey(autoGenerate = false) + @ColumnInfo(name = ArliAiModelContract.ID) + val id: String, + /** + * Exposes the `title` value used by the SDAI storage layer. + * + * @author Dmitriy Moroz + */ + @ColumnInfo(name = ArliAiModelContract.TITLE) + val title: String, + /** + * Exposes the `name` value used by the SDAI storage layer. + * + * @author Dmitriy Moroz + */ + @ColumnInfo(name = ArliAiModelContract.NAME) + val name: String, + /** + * Exposes the `hash` value used by the SDAI storage layer. + * + * @author Dmitriy Moroz + */ + @ColumnInfo(name = ArliAiModelContract.HASH) + val hash: String, + /** + * Exposes the `sha256` value used by the SDAI storage layer. + * + * @author Dmitriy Moroz + */ + @ColumnInfo(name = ArliAiModelContract.SHA256) + val sha256: String, + /** + * Exposes the `filename` value used by the SDAI storage layer. + * + * @author Dmitriy Moroz + */ + @ColumnInfo(name = ArliAiModelContract.FILENAME) + val filename: String, + /** + * Exposes the `config` value used by the SDAI storage layer. + * + * @author Dmitriy Moroz + */ + @ColumnInfo(name = ArliAiModelContract.CONFIG) + val config: String, +) diff --git a/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/di/CacheDatabaseModule.kt b/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/di/CacheDatabaseModule.kt index fec6028e6..36d648fc2 100644 --- a/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/di/CacheDatabaseModule.kt +++ b/storage/src/commonMain/kotlin/com/shifthackz/aisdv1/storage/di/CacheDatabaseModule.kt @@ -41,6 +41,7 @@ val cacheDatabaseModule = module { single { get().sdEmbeddingDao() } single { get().serverConfigurationDao() } single { get().swarmUiModelDao() } + single { get().arliAiModelDao() } } /**