Skip to content

ModelLoader

Manages machine learning models with capabilities for fetching, downloading, and loading.

This class provides a unified interface for working with ML models regardless of their storage location (local, remote) or format. It uses STAC metadata to describe model properties and capabilities.

Attributes:

Name Type Description
source str

Location of the model (URL or local path)

scheme str

Access scheme ('snippet', 'local', 'http', etc.)

item Item

STAC metadata for the model

module ModuleType

Python module containing model loading functions

Examples:

>>> # Load a model from a snippet reference
>>> model = ModelLoader("resnet50")
>>>
>>> # Download the model locally
>>> model.download("./models")
>>>
>>> # Load the model for inference
>>> inference_model = model.load_compiled_model()
Source code in mlstac/main.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
class ModelLoader:
    """
    Manages machine learning models with capabilities for fetching, downloading, and loading.

    This class provides a unified interface for working with ML models regardless of
    their storage location (local, remote) or format. It uses STAC metadata to describe
    model properties and capabilities.

    Attributes:
        source (str): Location of the model (URL or local path)
        scheme (str): Access scheme ('snippet', 'local', 'http', etc.)
        item (pystac.Item): STAC metadata for the model
        module (types.ModuleType, optional): Python module containing model loading functions

    Examples:
        >>> # Load a model from a snippet reference
        >>> model = ModelLoader("resnet50")
        >>>
        >>> # Download the model locally
        >>> model.download("./models")
        >>>
        >>> # Load the model for inference
        >>> inference_model = model.load_compiled_model()
    """

    def __init__(self, file: str | list):
        """
        Initialize the model manager.

        Args:
            file: The JSON file that contains the model metadata, a directory path,
                or a list of .pt2 model files for ad-hoc ensemble creation

        Raises:
            ValueError: If the source cannot be resolved or the model cannot be loaded.
        """
        self.file = file
        self.scheme = get_scheme(file)

        if self.scheme == "pt2_list":
            self.item = self._create_minimal_stac_from_pt2s(file)
            self.source = str(Path(file[0]).parent)

        elif self.scheme == "local":
            if Path(file).is_dir():
                self.file = Path(file) / "mlm.json"
            self.item = self._load()
            self.source = Path(self.file).parent.as_posix()
        else:
            self.item = load_stac_item(file)
            self.source = None

        self.module = None
        self.status = "downloaded"
        self.device = None

    def download(self, output_dir: Path | str) -> ModelLoader:
        """
        Download this model's files into a local directory.

        Thin wrapper around the module-level download(): it resolves the
        assets from this loader's metadata source and returns a new loader
        pointing at the local copy.

        Args:
            output_dir: Target directory for the downloaded files

        Returns:
            A ModelLoader for the downloaded model
        """
        # `download` here is the module-level function, not this method.
        return download(self.file, output_dir)

    @property
    def thr(self) -> float | None:
        """
        Get the recommended threshold value for the model output.

        Returns:
            Recommended threshold value, or None if not available
        """
        mlm_output = self.item.properties.get("mlm:output", [{}])
        if mlm_output and mlm_output[0]:
            return mlm_output[0].get("recommended_threshold")
        return None

    @property
    def is_ensemble(self) -> bool:
        """
        Check if this is an ensemble model requiring runtime aggregation.

        An ensemble model is one that requires loading multiple .pt2 files
        and aggregating them at runtime (mean/max/min).

        A pre-fused ensemble (single .pt2 with embedded aggregation) is NOT
        considered an ensemble for runtime purposes.
        """
        if self.item.properties.get("custom:ensemble_fused", False):
            return False

        pt2_count = sum(
            1 for asset in self.item.assets.values() if asset.href.endswith(".pt2")
        )

        return pt2_count > 1

    @property
    def model(self):
        """
        Convenience property to get the compiled model (single models only).
        For ensembles, use compiled_model(mode=...) instead.

        Example:
            >>> # Single model - quick access
            >>> model = loader.model
        """
        if self.is_ensemble:
            raise AttributeError(
                "Cannot use .model property for ensemble models. "
                "Use .compiled_model(mode='mean'|'max'|'min') instead."
            )
        return self.compiled_model()

    def _verify_local_access(self) -> None:
        """
        Verify model is available locally before attempting to load.

        Raises:
            ValueError: If model hasn't been downloaded locally
        """
        if self.scheme not in {"local", "pt2_list"}:
            raise ValueError(
                "The model must be downloaded locally first. "
                "Run .download(path) to download the model files."
            )

    def _create_minimal_stac_from_pt2s(self, model_paths: list) -> pystac.Item:
        """
        Create minimal STAC metadata from a list of .pt2 model files.
        """
        from datetime import datetime, timezone

        model_paths = [Path(p) for p in model_paths]

        for p in model_paths:
            if not p.exists():
                raise FileNotFoundError(f"Model file not found: {p}")
            if p.suffix != ".pt2":
                raise ValueError(f"Expected .pt2 file, got: {p}")

        assets = {}
        for i, model_path in enumerate(sorted(model_paths), start=1):
            asset_key = f"model_{i}_{model_path.stem}"
            assets[asset_key] = {
                "href": str(model_path.absolute()),
                "type": "application/octet-stream",
                "title": f"Model {i}: {model_path.name}",
                "roles": ["mlm:model", "mlm:weights"],
            }

        stac_dict = {
            "type": "Feature",
            "stac_version": "1.1.0",
            "id": f"ENSEMBLE_ADHOC_{len(model_paths)}_MODELS",
            "geometry": {
                "type": "Polygon",
                "coordinates": [
                    [[-180, -90], [-180, 90], [180, 90], [180, -90], [-180, -90]]
                ],
            },
            "bbox": [-180, -90, 180, 90],
            "properties": {
                "datetime": datetime.now(timezone.utc).isoformat(),
                "title": f"Ad-hoc Ensemble ({len(model_paths)} models)",
                "description": f"Ensemble from {len(model_paths)} .pt2 files - no metadata available",
                "mlm:name": "adhoc_ensemble",
                "mlm:architecture": f"Ensemble of {len(model_paths)} models",
                "mlm:framework": "pytorch",
            },
            "links": [],
            "assets": assets,
        }

        return pystac.Item.from_dict(stac_dict)

    def print_schema(self) -> None:
        """
        Prints a visually appealing schema of the model.

        Automatically detects if running in a Jupyter/Colab notebook or terminal
        and formats the output accordingly.
        """
        in_notebook = "ipykernel" in sys.modules

        model_id = self.item.id
        title = self.item.properties.get("title", "Untitled Model")
        description = self.item.properties.get(
            "description", "No description available"
        )

        framework = self.item.properties.get("mlm:framework", "Not specified")
        framework_version = self.item.properties.get("mlm:framework_version", "")
        architecture = self.item.properties.get("mlm:architecture", "Not specified")
        tasks = self.item.properties.get("mlm:tasks", [])

        total_params = self.item.properties.get("mlm:total_parameters", 0)
        params_m = f"{total_params / 1_000_000:.2f}M" if total_params else "Unknown"

        file_size = self.item.properties.get("file:size", 0)
        file_size_mb = f"{file_size / (1024 * 1024):.2f} MB" if file_size else "Unknown"

        sensors = self.item.properties.get("custom:sensors", [])
        spatial_res = self.item.properties.get("custom:spatial_resolution", "Unknown")
        project = self.item.properties.get("custom:project", "")
        project_url = self.item.properties.get("custom:project_url", "")

        hyperparams = self.item.properties.get("mlm:hyperparameters", {})

        if hyperparams is None:
            hyperparams = {}

        learning_rate = hyperparams.get("learning_rate", "N/A")
        batch_size = hyperparams.get("batch_size", "N/A")
        epochs = hyperparams.get("training_epochs", "N/A")
        val_loss = hyperparams.get("final_val_loss", "N/A")

        mlm_input = self.item.properties.get("mlm:input", [{}])
        if mlm_input is None:
            mlm_input = [{}]

        input_shape = mlm_input[0].get("input", {}).get("shape", [])
        input_bands = mlm_input[0].get("bands", [])

        mlm_output = self.item.properties.get("mlm:output", [{}])

        if mlm_output is None:
            mlm_output = [{}]

        output_shape = mlm_output[0].get("result", {}).get("shape", [])
        standard_threshold = mlm_output[0].get("standard_threshold", "N/A")
        recommended_threshold = mlm_output[0].get("recommended_threshold", "N/A")

        links = {link.rel: link.href for link in self.item.links}

        dependencies = self.item.properties.get("dependencies", [])
        deps_str = (
            ", ".join([d.split(">=")[0] for d in dependencies[:3]])
            if dependencies
            else "None"
        )

        if in_notebook:
            from IPython.display import HTML, display

            html_content = f"""
            <style>
                .mlstac-container {{
                    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
                    background: linear-gradient(135deg, #1e40af 0%, #3b82f6 100%);
                    padding: 15px;
                    border-radius: 10px;
                    color: white;
                    box-shadow: 0 4px 20px rgba(0,0,0,0.2);
                    margin: 15px 0;
                }}
                .mlstac-header {{
                    text-align: center;
                    margin-bottom: 15px;
                    padding-bottom: 10px;
                    border-bottom: 1px solid rgba(255,255,255,0.3);
                }}
                .mlstac-header h2 {{
                    margin: 0 0 5px 0;
                    font-size: 24px;
                    font-weight: 600;
                }}
                .mlstac-header p {{
                    margin: 0;
                    font-size: 13px;
                    opacity: 0.9;
                }}
                .mlstac-grid {{
                    display: grid;
                    grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
                    gap: 10px;
                    margin-bottom: 15px;
                }}
                .mlstac-card {{
                    background: rgba(255, 255, 255, 0.15);
                    backdrop-filter: blur(10px);
                    border-radius: 8px;
                    padding: 12px;
                    border: 1px solid rgba(255,255,255,0.2);
                    transition: transform 0.2s ease-out, box-shadow 0.2s ease-out;
                }}
                .mlstac-card:hover {{
                    transform: translateY(-5px);
                    box-shadow: 0 10px 20px rgba(0,0,0,0.3);
                }}
                .mlstac-card h3 {{
                    margin: 0 0 8px 0;
                    font-size: 14px;
                    font-weight: 600;
                    opacity: 0.95;
                    display: flex;
                    align-items: center;
                    gap: 6px;
                }}
                .mlstac-card-content {{
                    font-size: 13px;
                    line-height: 1.5;
                }}
                .mlstac-card-content p {{
                    margin: 4px 0;
                    display: flex;
                    justify-content: space-between;
                    align-items: center;
                }}
                .mlstac-card-content .label {{
                    opacity: 0.8;
                    font-weight: 500;
                }}
                .mlstac-card-content .value {{
                    background: rgba(255,255,255,0.2);
                    padding: 2px 6px;
                    border-radius: 5px;
                    font-weight: 600;
                    text-align: right;
                    font-size: 13px;
                }}
                .mlstac-badge {{
                    display: inline-block;
                    background: rgba(255,255,255,0.25);
                    padding: 3px 8px;
                    border-radius: 15px;
                    font-size: 11px;
                    margin: 2px;
                    font-weight: 500;
                }}
                .mlstac-description {{
                    background: rgba(255, 255, 255, 0.1);
                    padding: 10px;
                    border-radius: 8px;
                    margin-bottom: 10px;
                    font-size: 12px;
                    line-height: 1.5;
                    border-left: 3px solid rgba(255,255,255,0.4);
                }}
                .mlstac-footer {{
                    text-align: center;
                    margin-top: 15px;
                    padding-top: 10px;
                    border-top: 1px solid rgba(255,255,255,0.3);
                    font-size: 12px;
                    opacity: 0.9;
                }}
                .mlstac-footer a {{
                    color: white;
                    text-decoration: none;
                    font-weight: 600;
                    border-bottom: 1px solid rgba(255,255,255,0.5);
                    transition: border-color 0.2s;
                }}
                .mlstac-footer a:hover {{
                    border-bottom-color: white;
                }}
                .icon {{
                    font-size: 16px;
                }}
            </style>
            <div class="mlstac-container">
                <div class="mlstac-header">
                    <h2>🚀 {title}</h2>
                    <p>Model ID: <strong>{model_id}</strong></p>
                </div>
                <div class="mlstac-description">
                    {description}
                </div>
                <div class="mlstac-grid">
                    <div class="mlstac-card">
                        <h3><span class="icon">🛠️</span> Framework & arch.</h3>
                        <div class="mlstac-card-content">
                            <p><span class="label">Framework:</span> <span class="value">{framework} {framework_version[:6]}</span></p>
                            <p><span class="label">Architecture:</span> <span class="value">{architecture}</span></p>
                            <p><span class="label">Parameters:</span> <span class="value">{params_m}</span></p>
                            <p><span class="label">Model Size:</span> <span class="value">{file_size_mb}</span></p>
                        </div>
                    </div>
                    <div class="mlstac-card">
                        <h3><span class="icon">🛰️</span> Data specs</h3>
                        <div class="mlstac-card-content">
                            <p><span class="label">Spatial Res:</span> <span class="value">{spatial_res}</span></p>
                            <p><span class="label">Input Shape:</span> <span class="value">{input_shape}</span></p>
                            <p><span class="label">Bands:</span> <span class="value">{len(input_bands)}</span></p>
                            <p><span class="label">Sensors:</span></p>
                            <div style="margin-left: 5px; text-align: right;">
                                {''.join([f'<span class="mlstac-badge">{s}</span>' for s in sensors])}
                            </div>
                        </div>
                    </div>
                    <div class="mlstac-card">
                        <h3><span class="icon">📊</span> Training metrics</h3>
                        <div class="mlstac-card-content">
                            <p><span class="label">Learning Rate:</span> <span class="value">{learning_rate}</span></p>
                            <p><span class="label">Batch Size:</span> <span class="value">{batch_size}</span></p>
                            <p><span class="label">Epochs:</span> <span class="value">{epochs}</span></p>
                            <p><span class="label">Val Loss:</span> <span class="value">{val_loss}</span></p>
                        </div>
                    </div>
                    <div class="mlstac-card">
                        <h3><span class="icon">🎯</span> Tasks & output</h3>
                        <div class="mlstac-card-content">
                            <p><span class="label">Output Shape:</span> <span class="value">{output_shape}</span></p>
                            <p><span class="label">Std Threshold:</span> <span class="value">{standard_threshold}</span></p>
                            <p><span class="label">Rec. Threshold:</span> <span class="value">{recommended_threshold}</span></p>
                            <p><span class="label">Dependencies:</span> <span class="value">{deps_str}</span></p>
                            <p><span class="label">Tasks:</span></p>
                            <div style="margin-left: 5px; text-align: right;">
                                {''.join([f'<span class="mlstac-badge">{t}</span>' for t in tasks])}
                            </div>
                        </div>
                    </div>
                </div>
                <div class="mlstac-footer">
                    {'<strong>' + project + '</strong> | ' if project else ''}<a href="{project_url}" target="_blank">Project Info</a>{' | <a href="' + links.get('license', '#') + '" target="_blank">License</a>' if 'license' in links else ''} | Source: <strong>{self.scheme.upper()}</strong> | Status: <strong>{self.status.capitalize()}</strong>
                </div>
            </div>
            """
            display(HTML(html_content))

        else:
            CYAN = "\033[96m"
            GREEN = "\033[92m"
            YELLOW = "\033[93m"
            BLUE = "\033[94m"
            BOLD = "\033[1m"
            RESET = "\033[0m"
            DIM = "\033[2m"

            print(f"\n{CYAN}{BOLD}🚀 {title}{RESET}")
            print(f"{DIM}   ID: {model_id}{RESET}")

            print(f"{BLUE}   {description}{RESET}\n")

            print(f"{GREEN}{BOLD}🛠️  Framework & architecture{RESET}")
            print(
                f"   Framework:    {YELLOW}{framework} {framework_version[:10]}{RESET}"
            )
            print(f"   Architecture: {YELLOW}{architecture}{RESET}")
            print(f"   Parameters:   {YELLOW}{params_m}{RESET}")
            print(f"   Model Size:   {YELLOW}{file_size_mb}{RESET}")

            print(f"{GREEN}{BOLD}🛰️  Data specifications{RESET}")
            print(f"   Sensors:      {YELLOW}{', '.join(sensors)}{RESET}")
            print(f"   Spatial Res:  {YELLOW}{spatial_res}{RESET}")
            print(f"   Input Shape:  {YELLOW}{input_shape}{RESET}")
            print(f"   Bands:        {YELLOW}{len(input_bands)} bands{RESET}")

            print(f"{GREEN}{BOLD}📊 Training metrics{RESET}")
            print(f"   Learning Rate: {YELLOW}{learning_rate}{RESET}")
            print(f"   Batch Size:    {YELLOW}{batch_size}{RESET}")
            print(f"   Epochs:        {YELLOW}{epochs}{RESET}")
            print(f"   Val Loss:     {YELLOW}{val_loss}{RESET}")

            print(f"{GREEN}{BOLD}🎯 Tasks & output{RESET}")
            print(f"   Tasks:            {YELLOW}{', '.join(tasks)}{RESET}")
            print(f"   Output Shape:     {YELLOW}{output_shape}{RESET}")
            print(f"   Std Threshold:    {YELLOW}{standard_threshold}{RESET}")
            print(f"   Rec. Threshold:   {YELLOW}{recommended_threshold}{RESET}")
            print(f"   Dependencies:     {YELLOW}{deps_str}{RESET}")

            status_str = f"| Status: {self.status.capitalize()} "
            print(
                f"\n{DIM}   {project} | Source: {self.scheme.upper()} {status_str}{RESET}"
            )
            if project_url:
                print(f"{DIM}   🔗 {project_url}{RESET}\n")

    def _load(self) -> pystac.Item:
        """
        Load and update model metadata from local storage.

        Returns:
            Updated STAC item with local file references

        Raises:
            FileNotFoundError: If the metadata file doesn't exist
            ValueError: If the metadata file is invalid or corrupted
        """
        try:
            with fsspec.open(self.file, "r", encoding="utf-8") as f:
                mlm_data = json.load(f)

                for key, value in mlm_data["assets"].items():
                    filename = Path(value["href"]).name
                    mlm_data["assets"][key]["href"] = (
                        (Path(self.file).parent / filename).absolute().as_posix()
                    )

            return pystac.item.Item.from_dict(mlm_data)
        except FileNotFoundError as e:
            raise FileNotFoundError(
                f"Model metadata file not found at {self.file}"
            ) from e
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid model metadata format: {e!s}") from e

    def example_data(self, *args, **kwargs) -> Any:
        """
        Load example data for model testing.

        Returns:
            Processed example data in the format expected by the model

        Raises:
            FileNotFoundError: If example data file doesn't exist
            ValueError: If model hasn't been downloaded locally
        """
        self._verify_local_access()

        try:
            if self.module is None:
                self.module = load_python_module(self.source)
            return self.module.example_data(Path(self.source), *args, **kwargs)
        except FileNotFoundError as e:
            raise FileNotFoundError(
                f"Example data file not found at {self.source}/example_data.safetensor"
            ) from e
        except AttributeError as e:
            raise AttributeError(
                "Model loader module doesn't implement 'example_data' function"
            ) from e

    def trainable_model(self, *args, **kwargs) -> Any:
        """
        Load the trainable version of the model for fine-tuning.

        Returns:
            Trainable model instance

        Raises:
            ValueError: If model hasn't been downloaded locally
            FileNotFoundError: If trainable model file doesn't exist
            AttributeError: If model loader doesn't implement required functions
        """
        self._verify_local_access()
        self.item = self._load()

        if self.module is None:
            self.module = load_python_module(self.source)

        try:
            return self.module.trainable_model(Path(self.source), *args, **kwargs)
        except KeyError as e:
            raise KeyError("Trainable model asset not found in metadata") from e
        except AttributeError as e:
            raise AttributeError(
                "Model loader module doesn't implement 'trainable_model' function"
            ) from e

    def compiled_model(self, **kwargs):
        """
        Load the compiled model for inference.

        For single models: No parameters needed
        For ensembles: Accepts mode='mean'|'median'|'max'|'min' (default: 'max')

        Returns:
            Compiled model instance ready for inference

        Example:
            >>> # Single model
            >>> model = loader.compiled_model()
            >>>
            >>> # Ensemble model
            >>> model = loader.compiled_model(mode="mean")
            >>> model = loader.compiled_model(mode="median")  # More robust to outliers
            >>> model = loader.compiled_model(mode="max")     # Conservative (more clouds)
        """
        self._verify_local_access()

        if self.scheme == "local":
            self.item = self._load()

        if self.module is None:
            self.module = load_python_module(self.source)

        if self.is_ensemble and "mode" not in kwargs:
            kwargs["mode"] = "max"

        return self.module.compiled_model(
            Path(self.source), stac_item=self.item, **kwargs
        )

    def predict_large(
        self,
        image: np.ndarray,
        model: torch.nn.Module | None = None,
        **kwargs,
    ):
        """
        Predict on large arrays using overlapping tiles.

        Args:
            image: Input array with shape (C, H, W)
            model: Pre-loaded model (optional, will load if not provided)
            chunk_size: Size of inference tiles (default: 512)
            overlap: Overlap between tiles (default: 64)
            device: 'cpu' or 'cuda' (default: 'cpu')
            nodata: No-data value (default: 0.0)

        Returns:
            - For ensembles: Tuple of (probabilities, uncertainty), both (1, H, W)
            - For single models: probabilities array (1, H, W)

        Example:
            >>> model = loader.compiled_model()
            >>> result = loader.predict_large(image, model=model, device="cuda")
        """
        self._verify_local_access()

        if self.module is None:
            self.module = load_python_module(self.source)

        if model is None:
            model = self.compiled_model()

        return self.module.predict_large(image=image, model=model, **kwargs)

    def display_results(self, *args, **kwargs) -> Any:
        """
        Load the function to display the results of the model.

        Returns:
            Compiled model instance for inference

        Raises:
            ValueError: If model hasn't been downloaded locally
            FileNotFoundError: If compiled model file doesn't exist
            AttributeError: If model loader doesn't implement required functions
        """
        self._verify_local_access()

        if self.scheme == "local":
            self.item = self._load()

        if self.module is None:
            self.module = load_python_module(self.source)

        try:
            return self.module.display_results(
                Path(self.source), *args, stac_item=self.item, **kwargs
            )
        except KeyError as e:
            raise KeyError("Compiled model asset not found in metadata") from e
        except AttributeError as e:
            raise AttributeError(
                "Model loader module doesn't implement 'compiled_model' function"
            ) from e

    def get_model_summary(self) -> dict[str, Any]:
        """
        Returns a dictionary with key information about the model.

        Returns:
            Dictionary containing model metadata
        """
        return {
            "id": self.item.id,
            "source": self.file,
            "scheme": self.scheme,
            "framework": self.item.properties.get("mlm:framework"),
            "architecture": self.item.properties.get("mlm:architecture"),
            "tasks": self.item.properties.get("mlm:tasks", []),
            "dependencies": self.item.properties.get("dependencies"),
            "size_bytes": self.item.properties.get("file:size", 0),
        }

    def __repr__(self) -> str:
        """Return string representation of the ModelLoader instance."""
        self.print_schema()
        return ""

    def __str__(self) -> str:
        """Return user-friendly string representation."""
        self.print_schema()
        return ""

is_ensemble property

Check if this is an ensemble model requiring runtime aggregation.

An ensemble model is one that requires loading multiple .pt2 files and aggregating them at runtime (mean/max/min).

A pre-fused ensemble (single .pt2 with embedded aggregation) is NOT considered an ensemble for runtime purposes.

model property

Convenience property to get the compiled model (single models only). For ensembles, use compiled_model(mode=...) instead.

Example

Single model - quick access

model = loader.model

thr property

Get the recommended threshold value for the model output.

Returns:

Type Description
float | None

Recommended threshold value, or None if not available

__init__(file)

Initialize the model manager.

Parameters:

Name Type Description Default
file str | list

The JSON file that contains the model metadata, a directory path, or a list of .pt2 model files for ad-hoc ensemble creation

required

Raises:

Type Description
ValueError

If the source cannot be resolved or the model cannot be loaded.

Source code in mlstac/main.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def __init__(self, file: str | list):
    """
    Initialize the model manager.

    Args:
        file: The JSON file that contains the model metadata, a directory path,
            or a list of .pt2 model files for ad-hoc ensemble creation

    Raises:
        ValueError: If the source cannot be resolved or the model cannot be loaded.
    """
    self.file = file
    self.scheme = get_scheme(file)

    if self.scheme == "pt2_list":
        self.item = self._create_minimal_stac_from_pt2s(file)
        self.source = str(Path(file[0]).parent)

    elif self.scheme == "local":
        if Path(file).is_dir():
            self.file = Path(file) / "mlm.json"
        self.item = self._load()
        self.source = Path(self.file).parent.as_posix()
    else:
        self.item = load_stac_item(file)
        self.source = None

    self.module = None
    self.status = "downloaded"
    self.device = None

__repr__()

Return string representation of the ModelLoader instance.

Source code in mlstac/main.py
744
745
746
747
def __repr__(self) -> str:
    """Return string representation of the ModelLoader instance."""
    self.print_schema()
    return ""

__str__()

Return user-friendly string representation.

Source code in mlstac/main.py
749
750
751
752
def __str__(self) -> str:
    """Return user-friendly string representation."""
    self.print_schema()
    return ""

compiled_model(**kwargs)

Load the compiled model for inference.

For single models: No parameters needed For ensembles: Accepts mode='mean'|'median'|'max'|'min' (default: 'max')

Returns:

Type Description

Compiled model instance ready for inference

Example

Single model

model = loader.compiled_model()

Ensemble model

model = loader.compiled_model(mode="mean") model = loader.compiled_model(mode="median") # More robust to outliers model = loader.compiled_model(mode="max") # Conservative (more clouds)

Source code in mlstac/main.py
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
def compiled_model(self, **kwargs):
    """
    Load the compiled model for inference.

    For single models: No parameters needed
    For ensembles: Accepts mode='mean'|'median'|'max'|'min' (default: 'max')

    Returns:
        Compiled model instance ready for inference

    Example:
        >>> # Single model
        >>> model = loader.compiled_model()
        >>>
        >>> # Ensemble model
        >>> model = loader.compiled_model(mode="mean")
        >>> model = loader.compiled_model(mode="median")  # More robust to outliers
        >>> model = loader.compiled_model(mode="max")     # Conservative (more clouds)
    """
    self._verify_local_access()

    if self.scheme == "local":
        self.item = self._load()

    if self.module is None:
        self.module = load_python_module(self.source)

    if self.is_ensemble and "mode" not in kwargs:
        kwargs["mode"] = "max"

    return self.module.compiled_model(
        Path(self.source), stac_item=self.item, **kwargs
    )

display_results(*args, **kwargs)

Load the function to display the results of the model.

Returns:

Type Description
Any

Compiled model instance for inference

Raises:

Type Description
ValueError

If model hasn't been downloaded locally

FileNotFoundError

If compiled model file doesn't exist

AttributeError

If model loader doesn't implement required functions

Source code in mlstac/main.py
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
def display_results(self, *args, **kwargs) -> Any:
    """
    Load the function to display the results of the model.

    Returns:
        Compiled model instance for inference

    Raises:
        ValueError: If model hasn't been downloaded locally
        FileNotFoundError: If compiled model file doesn't exist
        AttributeError: If model loader doesn't implement required functions
    """
    self._verify_local_access()

    if self.scheme == "local":
        self.item = self._load()

    if self.module is None:
        self.module = load_python_module(self.source)

    try:
        return self.module.display_results(
            Path(self.source), *args, stac_item=self.item, **kwargs
        )
    except KeyError as e:
        raise KeyError("Compiled model asset not found in metadata") from e
    except AttributeError as e:
        raise AttributeError(
            "Model loader module doesn't implement 'compiled_model' function"
        ) from e

download(output_dir)

Download this model's files into a local directory.

Thin wrapper around the module-level download(): it resolves the assets from this loader's metadata source and returns a new loader pointing at the local copy.

Parameters:

Name Type Description Default
output_dir Path | str

Target directory for the downloaded files

required

Returns:

Type Description
ModelLoader

A ModelLoader for the downloaded model

Source code in mlstac/main.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def download(self, output_dir: Path | str) -> ModelLoader:
    """
    Download this model's files into a local directory.

    Thin wrapper around the module-level download(): it resolves the
    assets from this loader's metadata source and returns a new loader
    pointing at the local copy.

    Args:
        output_dir: Target directory for the downloaded files

    Returns:
        A ModelLoader for the downloaded model
    """
    # `download` here is the module-level function, not this method.
    return download(self.file, output_dir)

example_data(*args, **kwargs)

Load example data for model testing.

Returns:

Type Description
Any

Processed example data in the format expected by the model

Raises:

Type Description
FileNotFoundError

If example data file doesn't exist

ValueError

If model hasn't been downloaded locally

Source code in mlstac/main.py
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
def example_data(self, *args, **kwargs) -> Any:
    """
    Load example data for model testing.

    Returns:
        Processed example data in the format expected by the model

    Raises:
        FileNotFoundError: If example data file doesn't exist
        ValueError: If model hasn't been downloaded locally
    """
    self._verify_local_access()

    try:
        if self.module is None:
            self.module = load_python_module(self.source)
        return self.module.example_data(Path(self.source), *args, **kwargs)
    except FileNotFoundError as e:
        raise FileNotFoundError(
            f"Example data file not found at {self.source}/example_data.safetensor"
        ) from e
    except AttributeError as e:
        raise AttributeError(
            "Model loader module doesn't implement 'example_data' function"
        ) from e

get_model_summary()

Returns a dictionary with key information about the model.

Returns:

Type Description
dict[str, Any]

Dictionary containing model metadata

Source code in mlstac/main.py
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
def get_model_summary(self) -> dict[str, Any]:
    """
    Returns a dictionary with key information about the model.

    Returns:
        Dictionary containing model metadata
    """
    return {
        "id": self.item.id,
        "source": self.file,
        "scheme": self.scheme,
        "framework": self.item.properties.get("mlm:framework"),
        "architecture": self.item.properties.get("mlm:architecture"),
        "tasks": self.item.properties.get("mlm:tasks", []),
        "dependencies": self.item.properties.get("dependencies"),
        "size_bytes": self.item.properties.get("file:size", 0),
    }

predict_large(image, model=None, **kwargs)

Predict on large arrays using overlapping tiles.

Parameters:

Name Type Description Default
image ndarray

Input array with shape (C, H, W)

required
model Module | None

Pre-loaded model (optional, will load if not provided)

None
chunk_size

Size of inference tiles (default: 512)

required
overlap

Overlap between tiles (default: 64)

required
device

'cpu' or 'cuda' (default: 'cpu')

required
nodata

No-data value (default: 0.0)

required

Returns:

Type Description
  • For ensembles: Tuple of (probabilities, uncertainty), both (1, H, W)
  • For single models: probabilities array (1, H, W)
Example

model = loader.compiled_model() result = loader.predict_large(image, model=model, device="cuda")

Source code in mlstac/main.py
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
def predict_large(
    self,
    image: np.ndarray,
    model: torch.nn.Module | None = None,
    **kwargs,
):
    """
    Predict on large arrays using overlapping tiles.

    Args:
        image: Input array with shape (C, H, W)
        model: Pre-loaded model (optional, will load if not provided)
        chunk_size: Size of inference tiles (default: 512)
        overlap: Overlap between tiles (default: 64)
        device: 'cpu' or 'cuda' (default: 'cpu')
        nodata: No-data value (default: 0.0)

    Returns:
        - For ensembles: Tuple of (probabilities, uncertainty), both (1, H, W)
        - For single models: probabilities array (1, H, W)

    Example:
        >>> model = loader.compiled_model()
        >>> result = loader.predict_large(image, model=model, device="cuda")
    """
    self._verify_local_access()

    if self.module is None:
        self.module = load_python_module(self.source)

    if model is None:
        model = self.compiled_model()

    return self.module.predict_large(image=image, model=model, **kwargs)

print_schema()

Prints a visually appealing schema of the model.

Automatically detects if running in a Jupyter/Colab notebook or terminal and formats the output accordingly.

Source code in mlstac/main.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
def print_schema(self) -> None:
    """
    Prints a visually appealing schema of the model.

    Automatically detects if running in a Jupyter/Colab notebook or terminal
    and formats the output accordingly.
    """
    in_notebook = "ipykernel" in sys.modules

    model_id = self.item.id
    title = self.item.properties.get("title", "Untitled Model")
    description = self.item.properties.get(
        "description", "No description available"
    )

    framework = self.item.properties.get("mlm:framework", "Not specified")
    framework_version = self.item.properties.get("mlm:framework_version", "")
    architecture = self.item.properties.get("mlm:architecture", "Not specified")
    tasks = self.item.properties.get("mlm:tasks", [])

    total_params = self.item.properties.get("mlm:total_parameters", 0)
    params_m = f"{total_params / 1_000_000:.2f}M" if total_params else "Unknown"

    file_size = self.item.properties.get("file:size", 0)
    file_size_mb = f"{file_size / (1024 * 1024):.2f} MB" if file_size else "Unknown"

    sensors = self.item.properties.get("custom:sensors", [])
    spatial_res = self.item.properties.get("custom:spatial_resolution", "Unknown")
    project = self.item.properties.get("custom:project", "")
    project_url = self.item.properties.get("custom:project_url", "")

    hyperparams = self.item.properties.get("mlm:hyperparameters", {})

    if hyperparams is None:
        hyperparams = {}

    learning_rate = hyperparams.get("learning_rate", "N/A")
    batch_size = hyperparams.get("batch_size", "N/A")
    epochs = hyperparams.get("training_epochs", "N/A")
    val_loss = hyperparams.get("final_val_loss", "N/A")

    mlm_input = self.item.properties.get("mlm:input", [{}])
    if mlm_input is None:
        mlm_input = [{}]

    input_shape = mlm_input[0].get("input", {}).get("shape", [])
    input_bands = mlm_input[0].get("bands", [])

    mlm_output = self.item.properties.get("mlm:output", [{}])

    if mlm_output is None:
        mlm_output = [{}]

    output_shape = mlm_output[0].get("result", {}).get("shape", [])
    standard_threshold = mlm_output[0].get("standard_threshold", "N/A")
    recommended_threshold = mlm_output[0].get("recommended_threshold", "N/A")

    links = {link.rel: link.href for link in self.item.links}

    dependencies = self.item.properties.get("dependencies", [])
    deps_str = (
        ", ".join([d.split(">=")[0] for d in dependencies[:3]])
        if dependencies
        else "None"
    )

    if in_notebook:
        from IPython.display import HTML, display

        html_content = f"""
        <style>
            .mlstac-container {{
                font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
                background: linear-gradient(135deg, #1e40af 0%, #3b82f6 100%);
                padding: 15px;
                border-radius: 10px;
                color: white;
                box-shadow: 0 4px 20px rgba(0,0,0,0.2);
                margin: 15px 0;
            }}
            .mlstac-header {{
                text-align: center;
                margin-bottom: 15px;
                padding-bottom: 10px;
                border-bottom: 1px solid rgba(255,255,255,0.3);
            }}
            .mlstac-header h2 {{
                margin: 0 0 5px 0;
                font-size: 24px;
                font-weight: 600;
            }}
            .mlstac-header p {{
                margin: 0;
                font-size: 13px;
                opacity: 0.9;
            }}
            .mlstac-grid {{
                display: grid;
                grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
                gap: 10px;
                margin-bottom: 15px;
            }}
            .mlstac-card {{
                background: rgba(255, 255, 255, 0.15);
                backdrop-filter: blur(10px);
                border-radius: 8px;
                padding: 12px;
                border: 1px solid rgba(255,255,255,0.2);
                transition: transform 0.2s ease-out, box-shadow 0.2s ease-out;
            }}
            .mlstac-card:hover {{
                transform: translateY(-5px);
                box-shadow: 0 10px 20px rgba(0,0,0,0.3);
            }}
            .mlstac-card h3 {{
                margin: 0 0 8px 0;
                font-size: 14px;
                font-weight: 600;
                opacity: 0.95;
                display: flex;
                align-items: center;
                gap: 6px;
            }}
            .mlstac-card-content {{
                font-size: 13px;
                line-height: 1.5;
            }}
            .mlstac-card-content p {{
                margin: 4px 0;
                display: flex;
                justify-content: space-between;
                align-items: center;
            }}
            .mlstac-card-content .label {{
                opacity: 0.8;
                font-weight: 500;
            }}
            .mlstac-card-content .value {{
                background: rgba(255,255,255,0.2);
                padding: 2px 6px;
                border-radius: 5px;
                font-weight: 600;
                text-align: right;
                font-size: 13px;
            }}
            .mlstac-badge {{
                display: inline-block;
                background: rgba(255,255,255,0.25);
                padding: 3px 8px;
                border-radius: 15px;
                font-size: 11px;
                margin: 2px;
                font-weight: 500;
            }}
            .mlstac-description {{
                background: rgba(255, 255, 255, 0.1);
                padding: 10px;
                border-radius: 8px;
                margin-bottom: 10px;
                font-size: 12px;
                line-height: 1.5;
                border-left: 3px solid rgba(255,255,255,0.4);
            }}
            .mlstac-footer {{
                text-align: center;
                margin-top: 15px;
                padding-top: 10px;
                border-top: 1px solid rgba(255,255,255,0.3);
                font-size: 12px;
                opacity: 0.9;
            }}
            .mlstac-footer a {{
                color: white;
                text-decoration: none;
                font-weight: 600;
                border-bottom: 1px solid rgba(255,255,255,0.5);
                transition: border-color 0.2s;
            }}
            .mlstac-footer a:hover {{
                border-bottom-color: white;
            }}
            .icon {{
                font-size: 16px;
            }}
        </style>
        <div class="mlstac-container">
            <div class="mlstac-header">
                <h2>🚀 {title}</h2>
                <p>Model ID: <strong>{model_id}</strong></p>
            </div>
            <div class="mlstac-description">
                {description}
            </div>
            <div class="mlstac-grid">
                <div class="mlstac-card">
                    <h3><span class="icon">🛠️</span> Framework & arch.</h3>
                    <div class="mlstac-card-content">
                        <p><span class="label">Framework:</span> <span class="value">{framework} {framework_version[:6]}</span></p>
                        <p><span class="label">Architecture:</span> <span class="value">{architecture}</span></p>
                        <p><span class="label">Parameters:</span> <span class="value">{params_m}</span></p>
                        <p><span class="label">Model Size:</span> <span class="value">{file_size_mb}</span></p>
                    </div>
                </div>
                <div class="mlstac-card">
                    <h3><span class="icon">🛰️</span> Data specs</h3>
                    <div class="mlstac-card-content">
                        <p><span class="label">Spatial Res:</span> <span class="value">{spatial_res}</span></p>
                        <p><span class="label">Input Shape:</span> <span class="value">{input_shape}</span></p>
                        <p><span class="label">Bands:</span> <span class="value">{len(input_bands)}</span></p>
                        <p><span class="label">Sensors:</span></p>
                        <div style="margin-left: 5px; text-align: right;">
                            {''.join([f'<span class="mlstac-badge">{s}</span>' for s in sensors])}
                        </div>
                    </div>
                </div>
                <div class="mlstac-card">
                    <h3><span class="icon">📊</span> Training metrics</h3>
                    <div class="mlstac-card-content">
                        <p><span class="label">Learning Rate:</span> <span class="value">{learning_rate}</span></p>
                        <p><span class="label">Batch Size:</span> <span class="value">{batch_size}</span></p>
                        <p><span class="label">Epochs:</span> <span class="value">{epochs}</span></p>
                        <p><span class="label">Val Loss:</span> <span class="value">{val_loss}</span></p>
                    </div>
                </div>
                <div class="mlstac-card">
                    <h3><span class="icon">🎯</span> Tasks & output</h3>
                    <div class="mlstac-card-content">
                        <p><span class="label">Output Shape:</span> <span class="value">{output_shape}</span></p>
                        <p><span class="label">Std Threshold:</span> <span class="value">{standard_threshold}</span></p>
                        <p><span class="label">Rec. Threshold:</span> <span class="value">{recommended_threshold}</span></p>
                        <p><span class="label">Dependencies:</span> <span class="value">{deps_str}</span></p>
                        <p><span class="label">Tasks:</span></p>
                        <div style="margin-left: 5px; text-align: right;">
                            {''.join([f'<span class="mlstac-badge">{t}</span>' for t in tasks])}
                        </div>
                    </div>
                </div>
            </div>
            <div class="mlstac-footer">
                {'<strong>' + project + '</strong> | ' if project else ''}<a href="{project_url}" target="_blank">Project Info</a>{' | <a href="' + links.get('license', '#') + '" target="_blank">License</a>' if 'license' in links else ''} | Source: <strong>{self.scheme.upper()}</strong> | Status: <strong>{self.status.capitalize()}</strong>
            </div>
        </div>
        """
        display(HTML(html_content))

    else:
        CYAN = "\033[96m"
        GREEN = "\033[92m"
        YELLOW = "\033[93m"
        BLUE = "\033[94m"
        BOLD = "\033[1m"
        RESET = "\033[0m"
        DIM = "\033[2m"

        print(f"\n{CYAN}{BOLD}🚀 {title}{RESET}")
        print(f"{DIM}   ID: {model_id}{RESET}")

        print(f"{BLUE}   {description}{RESET}\n")

        print(f"{GREEN}{BOLD}🛠️  Framework & architecture{RESET}")
        print(
            f"   Framework:    {YELLOW}{framework} {framework_version[:10]}{RESET}"
        )
        print(f"   Architecture: {YELLOW}{architecture}{RESET}")
        print(f"   Parameters:   {YELLOW}{params_m}{RESET}")
        print(f"   Model Size:   {YELLOW}{file_size_mb}{RESET}")

        print(f"{GREEN}{BOLD}🛰️  Data specifications{RESET}")
        print(f"   Sensors:      {YELLOW}{', '.join(sensors)}{RESET}")
        print(f"   Spatial Res:  {YELLOW}{spatial_res}{RESET}")
        print(f"   Input Shape:  {YELLOW}{input_shape}{RESET}")
        print(f"   Bands:        {YELLOW}{len(input_bands)} bands{RESET}")

        print(f"{GREEN}{BOLD}📊 Training metrics{RESET}")
        print(f"   Learning Rate: {YELLOW}{learning_rate}{RESET}")
        print(f"   Batch Size:    {YELLOW}{batch_size}{RESET}")
        print(f"   Epochs:        {YELLOW}{epochs}{RESET}")
        print(f"   Val Loss:     {YELLOW}{val_loss}{RESET}")

        print(f"{GREEN}{BOLD}🎯 Tasks & output{RESET}")
        print(f"   Tasks:            {YELLOW}{', '.join(tasks)}{RESET}")
        print(f"   Output Shape:     {YELLOW}{output_shape}{RESET}")
        print(f"   Std Threshold:    {YELLOW}{standard_threshold}{RESET}")
        print(f"   Rec. Threshold:   {YELLOW}{recommended_threshold}{RESET}")
        print(f"   Dependencies:     {YELLOW}{deps_str}{RESET}")

        status_str = f"| Status: {self.status.capitalize()} "
        print(
            f"\n{DIM}   {project} | Source: {self.scheme.upper()} {status_str}{RESET}"
        )
        if project_url:
            print(f"{DIM}   🔗 {project_url}{RESET}\n")

trainable_model(*args, **kwargs)

Load the trainable version of the model for fine-tuning.

Returns:

Type Description
Any

Trainable model instance

Raises:

Type Description
ValueError

If model hasn't been downloaded locally

FileNotFoundError

If trainable model file doesn't exist

AttributeError

If model loader doesn't implement required functions

Source code in mlstac/main.py
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def trainable_model(self, *args, **kwargs) -> Any:
    """
    Load the trainable version of the model for fine-tuning.

    Returns:
        Trainable model instance

    Raises:
        ValueError: If model hasn't been downloaded locally
        FileNotFoundError: If trainable model file doesn't exist
        AttributeError: If model loader doesn't implement required functions
    """
    self._verify_local_access()
    self.item = self._load()

    if self.module is None:
        self.module = load_python_module(self.source)

    try:
        return self.module.trainable_model(Path(self.source), *args, **kwargs)
    except KeyError as e:
        raise KeyError("Trainable model asset not found in metadata") from e
    except AttributeError as e:
        raise AttributeError(
            "Model loader module doesn't implement 'trainable_model' function"
        ) from e