Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make preferred model form consistent with the other forms #309

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { screen, waitFor } from "@testing-library/react";
import { WorkspacePreferredModel } from "../workspace-preferred-model";
import userEvent from "@testing-library/user-event";

test("render model overrides", () => {
test("render model overrides", async () => {
render(
<WorkspacePreferredModel
isArchived={false}
Expand All @@ -19,7 +19,10 @@ test("render model overrides", () => {
expect(
screen.getByRole("button", { name: /select the model/i }),
).toBeVisible();
expect(screen.getByRole("button", { name: /save/i })).toBeVisible();

await waitFor(() => {
expect(screen.getByRole("button", { name: /save/i })).toBeVisible();
});
});

test("submit preferred model", async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ function useCustomInstructionsValue({
options: V1GetWorkspaceCustomInstructionsData;
queryClient: QueryClient;
}) {
const formState = useFormState({ prompt: initialValue });
const initialFormValues = useMemo(
() => ({ prompt: initialValue }),
[initialValue],
);
const formState = useFormState(initialFormValues);
const { values, updateFormValues } = formState;

// Subscribe to changes in the workspace system prompt value in the query cache
Expand Down
3 changes: 2 additions & 1 deletion src/features/workspace/components/workspace-name.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { useNavigate } from "react-router-dom";
import { twMerge } from "tailwind-merge";
import { useFormState } from "@/hooks/useFormState";
import { FormButtons } from "@/components/FormButtons";
import { FormEvent } from "react";

export function WorkspaceName({
className,
Expand All @@ -32,7 +33,7 @@ export function WorkspaceName({
const isDefault = workspaceName === "default";
const isUneditable = isArchived || isPending || isDefault;

const handleSubmit = (event: { preventDefault: () => void }) => {
const handleSubmit = (event: FormEvent) => {
event.preventDefault();

mutateAsync(
Expand Down
61 changes: 38 additions & 23 deletions src/features/workspace/components/workspace-preferred-model.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {
Alert,
Button,
Card,
CardBody,
CardFooter,
Expand All @@ -16,6 +15,10 @@ import { FormEvent } from "react";
import { usePreferredModelWorkspace } from "../hooks/use-preferred-preferred-model";
import { Select, SelectButton } from "@stacklok/ui-kit";
import { useQueryListAllModelsForAllProviders } from "@/hooks/use-query-list-all-models-for-all-providers";
import { FormButtons } from "@/components/FormButtons";
import { invalidateQueries } from "@/lib/react-query-utils";
import { v1GetWorkspaceMuxesQueryKey } from "@/api/generated/@tanstack/react-query.gen";
import { useQueryClient } from "@tanstack/react-query";

function MissingProviderBanner() {
return (
Expand All @@ -39,30 +42,38 @@ export function WorkspacePreferredModel({
workspaceName: string;
isArchived: boolean | undefined;
}) {
const { preferredModel, setPreferredModel, isPending } =
usePreferredModelWorkspace(workspaceName);
const queryClient = useQueryClient();
const { formState, isPending } = usePreferredModelWorkspace(workspaceName);
const { mutateAsync } = useMutationPreferredModelWorkspace();
const { data: providerModels = [] } = useQueryListAllModelsForAllProviders();
const { model, provider_id } = preferredModel;
const isModelsEmpty = !isPending && providerModels.length === 0;

const handleSubmit = (event: FormEvent) => {
event.preventDefault();
mutateAsync({
path: { workspace_name: workspaceName },
body: [
{
matcher: "",
provider_id,
model,
matcher_type: MuxMatcherType.CATCH_ALL,
},
],
});
mutateAsync(
{
path: { workspace_name: workspaceName },
body: [
{
matcher: "",
matcher_type: MuxMatcherType.CATCH_ALL,
...formState.values.preferredModel,
},
],
},
{
onSuccess: () =>
invalidateQueries(queryClient, [v1GetWorkspaceMuxesQueryKey]),
},
);
};

return (
<Form onSubmit={handleSubmit} validationBehavior="aria">
<Form
onSubmit={handleSubmit}
validationBehavior="aria"
data-testid="preferred-model"
>
<Card className={twMerge(className, "shrink-0")}>
<CardBody className="flex flex-col gap-6">
<div className="flex flex-col justify-start">
Expand All @@ -84,16 +95,18 @@ export function WorkspacePreferredModel({
isRequired
isDisabled={isModelsEmpty}
className="w-full"
selectedKey={preferredModel?.model}
selectedKey={formState.values.preferredModel?.model}
placeholder="Select the model"
onSelectionChange={(model) => {
const preferredModelProvider = providerModels.find(
(item) => item.name === model,
);
if (preferredModelProvider) {
setPreferredModel({
model: preferredModelProvider.name,
provider_id: preferredModelProvider.provider_id,
formState.updateFormValues({
preferredModel: {
model: preferredModelProvider.name,
provider_id: preferredModelProvider.provider_id,
},
});
}
}}
Expand All @@ -109,9 +122,11 @@ export function WorkspacePreferredModel({
</div>
</CardBody>
<CardFooter className="justify-end">
<Button isDisabled={isArchived || isModelsEmpty} type="submit">
Save
</Button>
<FormButtons
isPending={isPending}
formState={formState}
canSubmit={!isArchived}
/>
</CardFooter>
</Card>
</Form>
Expand Down
17 changes: 7 additions & 10 deletions src/features/workspace/hooks/use-preferred-preferred-model.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { MuxRule, V1GetWorkspaceMuxesData } from "@/api/generated";
import { v1GetWorkspaceMuxesOptions } from "@/api/generated/@tanstack/react-query.gen";
import { useFormState } from "@/hooks/useFormState";
import { useQuery } from "@tanstack/react-query";
import { useEffect, useMemo, useState } from "react";
import { useMemo } from "react";

type ModelRule = Omit<MuxRule, "matcher_type" | "matcher"> & {};

Expand All @@ -21,8 +22,6 @@ const usePreferredModel = (options: {
};

export const usePreferredModelWorkspace = (workspaceName: string) => {
const [preferredModel, setPreferredModel] =
useState<ModelRule>(DEFAULT_STATE);
const options: V1GetWorkspaceMuxesData &
Omit<V1GetWorkspaceMuxesData, "body"> = useMemo(
() => ({
Expand All @@ -31,12 +30,10 @@ export const usePreferredModelWorkspace = (workspaceName: string) => {
[workspaceName],
);
const { data, isPending } = usePreferredModel(options);
const providerModel = data?.[0];
const formState = useFormState<{ preferredModel: ModelRule }>({
preferredModel: providerModel ?? DEFAULT_STATE,
});

useEffect(() => {
const providerModel = data?.[0];

setPreferredModel(providerModel ?? DEFAULT_STATE);
}, [data, setPreferredModel]);

return { preferredModel, setPreferredModel, isPending };
return { isPending, formState };
};
45 changes: 36 additions & 9 deletions src/hooks/useFormState.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { isEqual } from "lodash";
import { useState } from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";

export type FormState<T> = {
values: T;
Expand All @@ -8,23 +8,50 @@ export type FormState<T> = {
isDirty: boolean;
};

function useDeepMemo<T>(value: T): T {
const ref = useRef<T>(value);
if (!isEqual(ref.current, value)) {
ref.current = value;
}
return ref.current;
}

export function useFormState<Values extends Record<string, unknown>>(
initialValues: Values,
): FormState<Values> {
const memoizedInitialValues = useDeepMemo(initialValues);

// this could be replaced with some form library later
const [values, setValues] = useState<Values>(initialValues);
const updateFormValues = (newState: Partial<Values>) => {
const [values, setValues] = useState<Values>(memoizedInitialValues);
const [originalValues, setOriginalValues] = useState<Values>(values);

useEffect(() => {
// this logic supports the use case when the initialValues change
// due to an async request for instance
setOriginalValues(memoizedInitialValues);
setValues(memoizedInitialValues);
}, [memoizedInitialValues]);

const updateFormValues = useCallback((newState: Partial<Values>) => {
setValues((prevState: Values) => ({
...prevState,
...newState,
}));
};
}, []);

const resetForm = useCallback(() => {
setValues(originalValues);
}, [originalValues]);

const resetForm = () => {
setValues(initialValues);
};
const isDirty = useMemo(
() => !isEqual(values, originalValues),
[values, originalValues],
);

const isDirty = !isEqual(values, initialValues);
const formState = useMemo(
() => ({ values, updateFormValues, resetForm, isDirty }),
[values, updateFormValues, resetForm, isDirty],
);

return { values, updateFormValues, resetForm, isDirty };
return formState;
}
Loading