Skip to content

Commit e3d20af

Browse files
Lora support in Model Ext (#8199)
* adding more enhancements * releasing v04 preview * updating snapshot * adding model to registry * Update CLI snapshots for registry.json changes Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add LoRA config support to models create and show commands - Add LoRAConfig struct (rank, alpha, targetModules, dropout) to CustomModel - Add --lora-rank, --lora-alpha, --lora-target-modules, --lora-dropout flags to 'models create' with validation when --weight-type is LoRA - Display LoRA Configuration section in 'models show' table output - Add LoRAConfig field to RegisterModelRequest for API serialization - Add WeightType column to 'models list' table view - Add tests for flag validation, JSON round-trip, and omitempty behavior Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Bump azure.ai.models to 0.0.7-preview with LoRA changelog Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Move LoRA flag validation before upload steps Validate --lora-* flags and --weight-type consistency before StartPendingUpload and AzCopy upload to fail fast on invalid input without creating unnecessary upload artifacts. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * chore: trigger check-enforcer re-evaluation --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 444ec22 commit e3d20af

8 files changed

Lines changed: 288 additions & 20 deletions

File tree

cli/azd/extensions/azure.ai.models/CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
# Release History
22

33

4+
## 0.0.7-preview (Unreleased)
5+
6+
### Features
7+
8+
- Added LoRA adapter support to `create` command with `--lora-rank`, `--lora-alpha`, `--lora-target-modules`, and `--lora-dropout` flags for registering LoRA adapters (`--weight-type LoRA`)
9+
- `show` command now displays LoRA Configuration section (rank, alpha, target modules, dropout) for LoRA adapters
10+
- `list` command now shows Weight Type column to distinguish FullWeight and LoRA models
11+
412
## 0.0.6-preview (Unreleased)
513

614
### Features

cli/azd/extensions/azure.ai.models/extension.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ namespace: ai.models
33
displayName: Foundry Custom Models (Preview)
44
description: Extension for managing custom models in Azure AI Foundry. (Preview)
55
usage: azd ai models <command> [options]
6-
version: 0.0.6-preview
6+
version: 0.0.7-preview
77
language: go
88
capabilities:
99
- custom-commands

cli/azd/extensions/azure.ai.models/internal/cmd/custom_create.go

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,20 @@ import (
2323
)
2424

2525
type customCreateFlags struct {
26-
Name string
27-
Version string
28-
Source string
29-
SourceFile string
30-
Description string
31-
BaseModel string
32-
WeightType string
33-
Publisher string
34-
AzcopyPath string
35-
NoWait bool
26+
Name string
27+
Version string
28+
Source string
29+
SourceFile string
30+
Description string
31+
BaseModel string
32+
WeightType string
33+
Publisher string
34+
AzcopyPath string
35+
NoWait bool
36+
LoRARank int
37+
LoRAAlpha int
38+
LoRATargetModules string
39+
LoRADropout float64
3640
}
3741

3842
func newCustomCreateCommand(parentFlags *customFlags) *cobra.Command {
@@ -77,7 +81,7 @@ provide a file containing the URL instead.`,
7781
return fmt.Errorf("either --source or --source-file is required")
7882
}
7983

80-
return runCustomCreate(ctx, parentFlags, flags)
84+
return runCustomCreate(ctx, cmd, parentFlags, flags)
8185
},
8286
}
8387

@@ -91,14 +95,19 @@ provide a file containing the URL instead.`,
9195
cmd.Flags().StringVar(&flags.Publisher, "publisher", "", "Model publisher ID for catalog info (e.g., Fireworks)")
9296
cmd.Flags().StringVar(&flags.AzcopyPath, "azcopy-path", "", "Path to azcopy binary (auto-detected if not provided)")
9397
cmd.Flags().BoolVar(&flags.NoWait, "no-wait", false, "Start async registration and return immediately with the operation URL")
98+
cmd.Flags().IntVar(&flags.LoRARank, "lora-rank", 0, "LoRA rank (r) — required when --weight-type is LoRA")
99+
cmd.Flags().IntVar(&flags.LoRAAlpha, "lora-alpha", 0, "LoRA scaling factor (alpha) — required when --weight-type is LoRA")
100+
cmd.Flags().StringVar(&flags.LoRATargetModules, "lora-target-modules", "",
101+
"Comma-separated list of target modules (e.g., \"q_proj,v_proj,k_proj,o_proj\")")
102+
cmd.Flags().Float64Var(&flags.LoRADropout, "lora-dropout", 0, "LoRA dropout rate used during training (informational)")
94103

95104
_ = cmd.MarkFlagRequired("name")
96105
_ = cmd.MarkFlagRequired("base-model")
97106

98107
return cmd
99108
}
100109

101-
func runCustomCreate(ctx context.Context, parentFlags *customFlags, flags *customCreateFlags) error {
110+
func runCustomCreate(ctx context.Context, cmd *cobra.Command, parentFlags *customFlags, flags *customCreateFlags) error {
102111
azdClient, err := azdext.NewAzdClient()
103112
if err != nil {
104113
return fmt.Errorf("failed to create azd client: %w", err)
@@ -132,6 +141,19 @@ func runCustomCreate(ctx context.Context, parentFlags *customFlags, flags *custo
132141
return err
133142
}
134143

144+
// Validate LoRA flags early, before any uploads
145+
var loraConfig *models.LoRAConfig
146+
if strings.EqualFold(flags.WeightType, "LoRA") {
147+
var err error
148+
loraConfig, err = buildLoRAConfig(cmd, flags)
149+
if err != nil {
150+
return err
151+
}
152+
} else if cmd.Flags().Changed("lora-rank") || cmd.Flags().Changed("lora-alpha") ||
153+
cmd.Flags().Changed("lora-target-modules") || cmd.Flags().Changed("lora-dropout") {
154+
return fmt.Errorf("--lora-* flags are only valid when --weight-type is LoRA")
155+
}
156+
135157
// ── Step 1: Start pending upload ──
136158
fmt.Printf("Creating custom model: %s (version %s)\n\n", flags.Name, flags.Version)
137159

@@ -240,6 +262,11 @@ func runCustomCreate(ctx context.Context, parentFlags *customFlags, flags *custo
240262
}
241263
}
242264

265+
// Attach pre-validated LoRA config
266+
if loraConfig != nil {
267+
regReq.LoRAConfig = loraConfig
268+
}
269+
243270
operationURL, err := foundryClient.RegisterModelAsync(ctx, flags.Name, flags.Version, regReq)
244271
_ = regSpinner.Stop(ctx)
245272
fmt.Println()
@@ -335,3 +362,35 @@ func extractVersionFromURI(uri string) string {
335362
}
336363
return ""
337364
}
365+
366+
// buildLoRAConfig validates LoRA flags and builds a LoRAConfig for the registration request.
367+
func buildLoRAConfig(cmd *cobra.Command, flags *customCreateFlags) (*models.LoRAConfig, error) {
368+
if flags.LoRARank <= 0 {
369+
return nil, fmt.Errorf("--lora-rank is required and must be a positive integer when --weight-type is LoRA")
370+
}
371+
if flags.LoRAAlpha <= 0 {
372+
return nil, fmt.Errorf("--lora-alpha is required and must be a positive integer when --weight-type is LoRA")
373+
}
374+
375+
config := &models.LoRAConfig{
376+
Rank: new(flags.LoRARank),
377+
Alpha: new(flags.LoRAAlpha),
378+
}
379+
380+
if flags.LoRATargetModules != "" {
381+
modules := strings.Split(flags.LoRATargetModules, ",")
382+
for i := range modules {
383+
modules[i] = strings.TrimSpace(modules[i])
384+
if modules[i] == "" {
385+
return nil, fmt.Errorf("--lora-target-modules contains an empty entry")
386+
}
387+
}
388+
config.TargetModules = modules
389+
}
390+
391+
if cmd.Flags().Changed("lora-dropout") {
392+
config.Dropout = new(flags.LoRADropout)
393+
}
394+
395+
return config, nil
396+
}

cli/azd/extensions/azure.ai.models/internal/cmd/custom_create_test.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
package cmd
55

66
import (
7+
"encoding/json"
78
"testing"
9+
10+
"azure.ai.models/pkg/models"
11+
12+
"github.com/spf13/cobra"
813
)
914

1015
func TestBuildDerivedModelURI(t *testing.T) {
@@ -100,3 +105,171 @@ func TestExtractVersionFromURI(t *testing.T) {
100105
})
101106
}
102107
}
108+
109+
// newTestCmd creates a cobra command with lora flags registered for testing buildLoRAConfig.
110+
func newTestCmd(flags *customCreateFlags) *cobra.Command {
111+
cmd := &cobra.Command{Use: "test"}
112+
cmd.Flags().IntVar(&flags.LoRARank, "lora-rank", 0, "")
113+
cmd.Flags().IntVar(&flags.LoRAAlpha, "lora-alpha", 0, "")
114+
cmd.Flags().StringVar(&flags.LoRATargetModules, "lora-target-modules", "", "")
115+
cmd.Flags().Float64Var(&flags.LoRADropout, "lora-dropout", 0, "")
116+
return cmd
117+
}
118+
119+
func TestBuildLoRAConfig(t *testing.T) {
120+
tests := []struct {
121+
name string
122+
args []string
123+
wantErr string
124+
wantRank int
125+
wantAlpha int
126+
wantModules []string
127+
wantDropout *float64
128+
}{
129+
{
130+
name: "missing rank",
131+
args: []string{"--lora-alpha", "32"},
132+
wantErr: "--lora-rank is required",
133+
},
134+
{
135+
name: "missing alpha",
136+
args: []string{"--lora-rank", "16"},
137+
wantErr: "--lora-alpha is required",
138+
},
139+
{
140+
name: "rank and alpha only",
141+
args: []string{"--lora-rank", "16", "--lora-alpha", "32"},
142+
wantRank: 16,
143+
wantAlpha: 32,
144+
},
145+
{
146+
name: "all fields",
147+
args: []string{"--lora-rank", "8", "--lora-alpha", "16", "--lora-target-modules", "q_proj,v_proj", "--lora-dropout", "0.05"},
148+
wantRank: 8,
149+
wantAlpha: 16,
150+
wantModules: []string{"q_proj", "v_proj"},
151+
wantDropout: new(0.05),
152+
},
153+
{
154+
name: "empty entry in target modules",
155+
args: []string{"--lora-rank", "16", "--lora-alpha", "32", "--lora-target-modules", "q_proj,,v_proj"},
156+
wantErr: "empty entry",
157+
},
158+
}
159+
160+
for _, tt := range tests {
161+
t.Run(tt.name, func(t *testing.T) {
162+
flags := &customCreateFlags{}
163+
cmd := newTestCmd(flags)
164+
if err := cmd.ParseFlags(tt.args); err != nil {
165+
t.Fatalf("ParseFlags: %v", err)
166+
}
167+
168+
got, err := buildLoRAConfig(cmd, flags)
169+
if tt.wantErr != "" {
170+
if err == nil {
171+
t.Fatalf("expected error containing %q, got nil", tt.wantErr)
172+
}
173+
if !contains(err.Error(), tt.wantErr) {
174+
t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error())
175+
}
176+
return
177+
}
178+
if err != nil {
179+
t.Fatalf("unexpected error: %v", err)
180+
}
181+
if got.Rank == nil || *got.Rank != tt.wantRank {
182+
t.Errorf("Rank = %v, want %d", got.Rank, tt.wantRank)
183+
}
184+
if got.Alpha == nil || *got.Alpha != tt.wantAlpha {
185+
t.Errorf("Alpha = %v, want %d", got.Alpha, tt.wantAlpha)
186+
}
187+
if len(tt.wantModules) > 0 {
188+
if len(got.TargetModules) != len(tt.wantModules) {
189+
t.Errorf("TargetModules = %v, want %v", got.TargetModules, tt.wantModules)
190+
}
191+
for i := range tt.wantModules {
192+
if i < len(got.TargetModules) && got.TargetModules[i] != tt.wantModules[i] {
193+
t.Errorf("TargetModules[%d] = %q, want %q", i, got.TargetModules[i], tt.wantModules[i])
194+
}
195+
}
196+
}
197+
if tt.wantDropout != nil {
198+
if got.Dropout == nil || *got.Dropout != *tt.wantDropout {
199+
t.Errorf("Dropout = %v, want %v", got.Dropout, *tt.wantDropout)
200+
}
201+
}
202+
})
203+
}
204+
}
205+
206+
func contains(s, substr string) bool {
207+
return len(s) >= len(substr) && searchString(s, substr)
208+
}
209+
210+
func searchString(s, substr string) bool {
211+
for i := 0; i <= len(s)-len(substr); i++ {
212+
if s[i:i+len(substr)] == substr {
213+
return true
214+
}
215+
}
216+
return false
217+
}
218+
219+
func TestLoRAConfigJSONRoundTrip(t *testing.T) {
220+
original := &models.CustomModel{
221+
Name: "test-lora-adapter",
222+
Version: "1",
223+
WeightType: "LoRA",
224+
LoRAConfig: &models.LoRAConfig{
225+
Rank: new(16),
226+
Alpha: new(32),
227+
TargetModules: []string{"q_proj", "v_proj", "k_proj", "o_proj"},
228+
Dropout: new(0.05),
229+
},
230+
}
231+
232+
data, err := json.Marshal(original)
233+
if err != nil {
234+
t.Fatalf("Marshal: %v", err)
235+
}
236+
237+
var decoded models.CustomModel
238+
if err := json.Unmarshal(data, &decoded); err != nil {
239+
t.Fatalf("Unmarshal: %v", err)
240+
}
241+
242+
if decoded.LoRAConfig == nil {
243+
t.Fatal("LoRAConfig is nil after round-trip")
244+
}
245+
if decoded.LoRAConfig.Rank == nil || *decoded.LoRAConfig.Rank != 16 {
246+
t.Errorf("Rank = %v, want 16", decoded.LoRAConfig.Rank)
247+
}
248+
if decoded.LoRAConfig.Alpha == nil || *decoded.LoRAConfig.Alpha != 32 {
249+
t.Errorf("Alpha = %v, want 32", decoded.LoRAConfig.Alpha)
250+
}
251+
if len(decoded.LoRAConfig.TargetModules) != 4 {
252+
t.Errorf("TargetModules len = %d, want 4", len(decoded.LoRAConfig.TargetModules))
253+
}
254+
if decoded.LoRAConfig.Dropout == nil || *decoded.LoRAConfig.Dropout != 0.05 {
255+
t.Errorf("Dropout = %v, want 0.05", decoded.LoRAConfig.Dropout)
256+
}
257+
}
258+
259+
func TestLoRAConfigJSONOmittedWhenNil(t *testing.T) {
260+
model := &models.CustomModel{
261+
Name: "full-weight-model",
262+
Version: "1",
263+
WeightType: "FullWeight",
264+
}
265+
266+
data, err := json.Marshal(model)
267+
if err != nil {
268+
t.Fatalf("Marshal: %v", err)
269+
}
270+
271+
jsonStr := string(data)
272+
if contains(jsonStr, "loraConfig") {
273+
t.Errorf("loraConfig should be omitted for FullWeight models, got: %s", jsonStr)
274+
}
275+
}

cli/azd/extensions/azure.ai.models/internal/cmd/custom_show.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,22 @@ func runCustomShow(ctx context.Context, parentFlags *customFlags, flags *customS
165165
fmt.Printf(" URI: %s\n", *model.DerivedModelInformation.BaseModel)
166166
}
167167

168+
if model.LoRAConfig != nil {
169+
fmt.Println("\nLoRA Configuration:")
170+
if model.LoRAConfig.Rank != nil {
171+
fmt.Printf(" Rank: %d\n", *model.LoRAConfig.Rank)
172+
}
173+
if model.LoRAConfig.Alpha != nil {
174+
fmt.Printf(" Alpha: %d\n", *model.LoRAConfig.Alpha)
175+
}
176+
if len(model.LoRAConfig.TargetModules) > 0 {
177+
fmt.Printf(" Target Modules: %s\n", strings.Join(model.LoRAConfig.TargetModules, ", "))
178+
}
179+
if model.LoRAConfig.Dropout != nil {
180+
fmt.Printf(" Dropout: %g\n", *model.LoRAConfig.Dropout)
181+
}
182+
}
183+
168184
if model.Source != nil {
169185
fmt.Println("\nSource:")
170186
fmt.Printf(" Type: %s\n", model.Source.SourceType)

0 commit comments

Comments
 (0)